Mercurial > repos > bgruening > sklearn_train_test_eval
comparison search_model_validation.py @ 15:2eb5c017958d draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:15:27 +0000 |
parents | caf7d2b71a48 |
children |
comparison
equal
deleted
inserted
replaced
14:4d1637cac794 | 15:2eb5c017958d |
---|---|
1 import argparse | 1 import argparse |
2 import collections | |
3 import json | 2 import json |
4 import os | 3 import os |
5 import pickle | |
6 import sys | 4 import sys |
7 import warnings | 5 import warnings |
6 from distutils.version import LooseVersion as Version | |
8 | 7 |
9 import imblearn | 8 import imblearn |
10 import joblib | 9 import joblib |
11 import numpy as np | 10 import numpy as np |
12 import pandas as pd | 11 import pandas as pd |
13 import skrebate | 12 import skrebate |
14 from galaxy_ml.utils import (clean_params, get_cv, | 13 from galaxy_ml import __version__ as galaxy_ml_version |
15 get_main_estimator, get_module, get_scoring, | 14 from galaxy_ml.binarize_target import IRAPSClassifier |
16 load_model, read_columns, SafeEval, try_get_attr) | 15 from galaxy_ml.model_persist import dump_model_to_h5, load_model_from_h5 |
16 from galaxy_ml.utils import ( | |
17 clean_params, | |
18 get_cv, | |
19 get_main_estimator, | |
20 get_module, | |
21 get_scoring, | |
22 read_columns, | |
23 SafeEval, | |
24 try_get_attr | |
25 ) | |
17 from scipy.io import mmread | 26 from scipy.io import mmread |
18 from sklearn import (cluster, decomposition, feature_selection, | 27 from sklearn import ( |
19 kernel_approximation, model_selection, preprocessing) | 28 cluster, |
29 decomposition, | |
30 feature_selection, | |
31 kernel_approximation, | |
32 model_selection, | |
33 preprocessing, | |
34 ) | |
20 from sklearn.exceptions import FitFailedWarning | 35 from sklearn.exceptions import FitFailedWarning |
21 from sklearn.model_selection import _search, _validation | 36 from sklearn.model_selection import _search, _validation |
22 from sklearn.model_selection._validation import _score, cross_validate | 37 from sklearn.model_selection._validation import _score, cross_validate |
23 | 38 from sklearn.preprocessing import LabelEncoder |
24 _fit_and_score = try_get_attr("galaxy_ml.model_validations", "_fit_and_score") | 39 from skopt import BayesSearchCV |
25 setattr(_search, "_fit_and_score", _fit_and_score) | |
26 setattr(_validation, "_fit_and_score", _fit_and_score) | |
27 | 40 |
28 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1)) | 41 N_JOBS = int(os.environ.get("GALAXY_SLOTS", 1)) |
29 # handle disk cache | 42 # handle disk cache |
30 CACHE_DIR = os.path.join(os.getcwd(), "cached") | 43 CACHE_DIR = os.path.join(os.getcwd(), "cached") |
31 del os | 44 NON_SEARCHABLE = ( |
32 NON_SEARCHABLE = ("n_jobs", "pre_dispatch", "memory", "_path", "nthread", "callbacks") | 45 "n_jobs", |
46 "pre_dispatch", | |
47 "memory", | |
48 "_path", | |
49 "_dir", | |
50 "nthread", | |
51 "callbacks", | |
52 ) | |
33 | 53 |
34 | 54 |
35 def _eval_search_params(params_builder): | 55 def _eval_search_params(params_builder): |
36 search_params = {} | 56 search_params = {} |
37 | 57 |
98 skrebate.MultiSURFstar(n_jobs=N_JOBS), | 118 skrebate.MultiSURFstar(n_jobs=N_JOBS), |
99 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS), | 119 imblearn.under_sampling.ClusterCentroids(random_state=0, n_jobs=N_JOBS), |
100 imblearn.under_sampling.CondensedNearestNeighbour( | 120 imblearn.under_sampling.CondensedNearestNeighbour( |
101 random_state=0, n_jobs=N_JOBS | 121 random_state=0, n_jobs=N_JOBS |
102 ), | 122 ), |
103 imblearn.under_sampling.EditedNearestNeighbours( | 123 imblearn.under_sampling.EditedNearestNeighbours(n_jobs=N_JOBS), |
104 random_state=0, n_jobs=N_JOBS | 124 imblearn.under_sampling.RepeatedEditedNearestNeighbours(n_jobs=N_JOBS), |
105 ), | 125 imblearn.under_sampling.AllKNN(n_jobs=N_JOBS), |
106 imblearn.under_sampling.RepeatedEditedNearestNeighbours( | |
107 random_state=0, n_jobs=N_JOBS | |
108 ), | |
109 imblearn.under_sampling.AllKNN(random_state=0, n_jobs=N_JOBS), | |
110 imblearn.under_sampling.InstanceHardnessThreshold( | 126 imblearn.under_sampling.InstanceHardnessThreshold( |
111 random_state=0, n_jobs=N_JOBS | 127 random_state=0, n_jobs=N_JOBS |
112 ), | 128 ), |
113 imblearn.under_sampling.NearMiss(random_state=0, n_jobs=N_JOBS), | 129 imblearn.under_sampling.NearMiss(n_jobs=N_JOBS), |
114 imblearn.under_sampling.NeighbourhoodCleaningRule( | 130 imblearn.under_sampling.NeighbourhoodCleaningRule(n_jobs=N_JOBS), |
115 random_state=0, n_jobs=N_JOBS | |
116 ), | |
117 imblearn.under_sampling.OneSidedSelection( | 131 imblearn.under_sampling.OneSidedSelection( |
118 random_state=0, n_jobs=N_JOBS | 132 random_state=0, n_jobs=N_JOBS |
119 ), | 133 ), |
120 imblearn.under_sampling.RandomUnderSampler(random_state=0), | 134 imblearn.under_sampling.RandomUnderSampler(random_state=0), |
121 imblearn.under_sampling.TomekLinks(random_state=0, n_jobs=N_JOBS), | 135 imblearn.under_sampling.TomekLinks(n_jobs=N_JOBS), |
122 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), | 136 imblearn.over_sampling.ADASYN(random_state=0, n_jobs=N_JOBS), |
137 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS), | |
138 imblearn.over_sampling.KMeansSMOTE(random_state=0, n_jobs=N_JOBS), | |
123 imblearn.over_sampling.RandomOverSampler(random_state=0), | 139 imblearn.over_sampling.RandomOverSampler(random_state=0), |
124 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), | 140 imblearn.over_sampling.SMOTE(random_state=0, n_jobs=N_JOBS), |
125 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), | 141 imblearn.over_sampling.SMOTEN(random_state=0, n_jobs=N_JOBS), |
126 imblearn.over_sampling.BorderlineSMOTE(random_state=0, n_jobs=N_JOBS), | |
127 imblearn.over_sampling.SMOTENC( | 142 imblearn.over_sampling.SMOTENC( |
128 categorical_features=[], random_state=0, n_jobs=N_JOBS | 143 categorical_features=[], random_state=0, n_jobs=N_JOBS |
129 ), | 144 ), |
145 imblearn.over_sampling.SVMSMOTE(random_state=0, n_jobs=N_JOBS), | |
130 imblearn.combine.SMOTEENN(random_state=0), | 146 imblearn.combine.SMOTEENN(random_state=0), |
131 imblearn.combine.SMOTETomek(random_state=0), | 147 imblearn.combine.SMOTETomek(random_state=0), |
132 ) | 148 ) |
133 newlist = [] | 149 newlist = [] |
134 for obj in ev: | 150 for obj in ev: |
286 else: | 302 else: |
287 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) | 303 infile2 = pd.read_csv(infile2, sep="\t", header=header, parse_dates=True) |
288 loaded_df[df_key] = infile2 | 304 loaded_df[df_key] = infile2 |
289 | 305 |
290 y = read_columns( | 306 y = read_columns( |
291 infile2, c=c, c_option=column_option, sep="\t", header=header, parse_dates=True | 307 infile2, |
308 c=c, | |
309 c_option=column_option, | |
310 sep="\t", | |
311 header=header, | |
312 parse_dates=True, | |
292 ) | 313 ) |
293 if len(y.shape) == 2 and y.shape[1] == 1: | 314 if len(y.shape) == 2 and y.shape[1] == 1: |
294 y = y.ravel() | 315 y = y.ravel() |
295 if input_type == "refseq_and_interval": | 316 if input_type == "refseq_and_interval": |
296 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) | 317 estimator.set_params(data_batch_generator__features=y.ravel().tolist()) |
414 pass | 435 pass |
415 for warning in w: | 436 for warning in w: |
416 print(repr(warning.message)) | 437 print(repr(warning.message)) |
417 | 438 |
418 scorer_ = searcher.scorer_ | 439 scorer_ = searcher.scorer_ |
419 if isinstance(scorer_, collections.Mapping): | |
420 is_multimetric = True | |
421 else: | |
422 is_multimetric = False | |
423 | 440 |
424 best_estimator_ = getattr(searcher, "best_estimator_") | 441 best_estimator_ = getattr(searcher, "best_estimator_") |
425 | 442 |
426 # TODO Solve deep learning models in pipeline | 443 # TODO Solve deep learning models in pipeline |
427 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier": | 444 if best_estimator_.__class__.__name__ == "KerasGBatchClassifier": |
428 test_score = best_estimator_.evaluate( | 445 test_score = best_estimator_.evaluate( |
429 X_test, scorer=scorer_, is_multimetric=is_multimetric | 446 X_test, |
447 scorer=scorer_, | |
430 ) | 448 ) |
431 else: | 449 else: |
432 test_score = _score( | 450 test_score = _score(best_estimator_, X_test, y_test, scorer_) |
433 best_estimator_, X_test, y_test, scorer_, is_multimetric=is_multimetric | 451 |
434 ) | 452 if not isinstance(scorer_, dict): |
435 | |
436 if not is_multimetric: | |
437 test_score = {primary_scoring: test_score} | 453 test_score = {primary_scoring: test_score} |
438 for key, value in test_score.items(): | 454 for key, value in test_score.items(): |
439 test_score[key] = [value] | 455 test_score[key] = [value] |
440 result_df = pd.DataFrame(test_score) | 456 result_df = pd.DataFrame(test_score) |
441 result_df.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False) | 457 result_df.to_csv(path_or_buf=outfile, sep="\t", header=True, index=False) |
442 | 458 |
443 return searcher | 459 return searcher |
460 | |
461 | |
462 def _set_memory(estimator, memory): | |
463 """set memeory cache | |
464 | |
465 Parameters | |
466 ---------- | |
467 estimator : python object | |
468 memory : joblib.Memory object | |
469 | |
470 Returns | |
471 ------- | |
472 estimator : estimator object after setting new attributes | |
473 """ | |
474 if isinstance(estimator, IRAPSClassifier): | |
475 estimator.set_params(memory=memory) | |
476 return estimator | |
477 | |
478 estimator_params = estimator.get_params() | |
479 | |
480 new_params = {} | |
481 for k in estimator_params.keys(): | |
482 if k.endswith("irapsclassifier__memory"): | |
483 new_params[k] = memory | |
484 | |
485 estimator.set_params(**new_params) | |
486 | |
487 return estimator | |
444 | 488 |
445 | 489 |
446 def main( | 490 def main( |
447 inputs, | 491 inputs, |
448 infile_estimator, | 492 infile_estimator, |
449 infile1, | 493 infile1, |
450 infile2, | 494 infile2, |
451 outfile_result, | 495 outfile_result, |
452 outfile_object=None, | 496 outfile_object=None, |
453 outfile_weights=None, | |
454 groups=None, | 497 groups=None, |
455 ref_seq=None, | 498 ref_seq=None, |
456 intervals=None, | 499 intervals=None, |
457 targets=None, | 500 targets=None, |
458 fasta_path=None, | 501 fasta_path=None, |
459 ): | 502 ): |
460 """ | 503 """ |
461 Parameter | 504 Parameter |
462 --------- | 505 --------- |
463 inputs : str | 506 inputs : str |
464 File path to galaxy tool parameter | 507 File path to galaxy tool parameter. |
465 | 508 |
466 infile_estimator : str | 509 infile_estimator : str |
467 File path to estimator | 510 File path to estimator. |
468 | 511 |
469 infile1 : str | 512 infile1 : str |
470 File path to dataset containing features | 513 File path to dataset containing features |
471 | 514 |
472 infile2 : str | 515 infile2 : str |
475 outfile_result : str | 518 outfile_result : str |
476 File path to save the results, either cv_results or test result | 519 File path to save the results, either cv_results or test result |
477 | 520 |
478 outfile_object : str, optional | 521 outfile_object : str, optional |
479 File path to save searchCV object | 522 File path to save searchCV object |
480 | |
481 outfile_weights : str, optional | |
482 File path to save model weights | |
483 | 523 |
484 groups : str | 524 groups : str |
485 File path to dataset containing groups labels | 525 File path to dataset containing groups labels |
486 | 526 |
487 ref_seq : str | 527 ref_seq : str |
503 | 543 |
504 with open(inputs, "r") as param_handler: | 544 with open(inputs, "r") as param_handler: |
505 params = json.load(param_handler) | 545 params = json.load(param_handler) |
506 | 546 |
507 # Override the refit parameter | 547 # Override the refit parameter |
508 params["search_schemes"]["options"]["refit"] = ( | 548 params["options"]["refit"] = ( |
509 True if params["save"] != "nope" else False | 549 True |
550 if ( | |
551 params["save"] != "nope" | |
552 or params["outer_split"]["split_mode"] == "nested_cv" | |
553 ) | |
554 else False | |
510 ) | 555 ) |
511 | 556 |
512 with open(infile_estimator, "rb") as estimator_handler: | 557 estimator = load_model_from_h5(infile_estimator) |
513 estimator = load_model(estimator_handler) | 558 |
514 | 559 estimator = clean_params(estimator) |
515 optimizer = params["search_schemes"]["selected_search_scheme"] | 560 |
516 optimizer = getattr(model_selection, optimizer) | 561 if estimator.__class__.__name__ == "KerasGBatchClassifier": |
562 _fit_and_score = try_get_attr( | |
563 "galaxy_ml.model_validations", | |
564 "_fit_and_score", | |
565 ) | |
566 | |
567 setattr(_search, "_fit_and_score", _fit_and_score) | |
568 setattr(_validation, "_fit_and_score", _fit_and_score) | |
569 | |
570 search_algos_and_options = params["search_algos"] | |
571 optimizer = search_algos_and_options.pop("selected_search_algo") | |
572 if optimizer == "skopt.BayesSearchCV": | |
573 optimizer = BayesSearchCV | |
574 else: | |
575 optimizer = getattr(model_selection, optimizer) | |
517 | 576 |
518 # handle gridsearchcv options | 577 # handle gridsearchcv options |
519 options = params["search_schemes"]["options"] | 578 options = params["options"] |
579 options.update(search_algos_and_options) | |
520 | 580 |
521 if groups: | 581 if groups: |
522 header = ( | 582 header = ( |
523 "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None | 583 "infer" if (options["cv_selector"]["groups_selector"]["header_g"]) else None |
524 ) | 584 ) |
551 parse_dates=True, | 611 parse_dates=True, |
552 ) | 612 ) |
553 groups = groups.ravel() | 613 groups = groups.ravel() |
554 options["cv_selector"]["groups_selector"] = groups | 614 options["cv_selector"]["groups_selector"] = groups |
555 | 615 |
556 splitter, groups = get_cv(options.pop("cv_selector")) | 616 cv_selector = options.pop("cv_selector") |
617 if Version(galaxy_ml_version) < Version("0.8.3"): | |
618 cv_selector.pop("n_stratification_bins", None) | |
619 splitter, groups = get_cv(cv_selector) | |
557 options["cv"] = splitter | 620 options["cv"] = splitter |
558 primary_scoring = options["scoring"]["primary_scoring"] | 621 primary_scoring = options["scoring"]["primary_scoring"] |
559 # get_scoring() expects secondary_scoring to be a comma separated string (not a list) | 622 options["scoring"] = get_scoring(options["scoring"]) |
560 # Check if secondary_scoring is specified | 623 # TODO make BayesSearchCV support multiple scoring |
561 secondary_scoring = options["scoring"].get("secondary_scoring", None) | 624 if optimizer == "skopt.BayesSearchCV" and isinstance(options["scoring"], dict): |
562 if secondary_scoring is not None: | 625 options["scoring"] = options["scoring"][primary_scoring] |
563 # If secondary_scoring is specified, convert the list into comman separated string | 626 warnings.warn( |
564 options["scoring"]["secondary_scoring"] = ",".join( | 627 "BayesSearchCV doesn't support multiple " |
565 options["scoring"]["secondary_scoring"] | 628 "scorings! Primary scoring is used." |
566 ) | 629 ) |
567 options["scoring"] = get_scoring(options["scoring"]) | |
568 if options["error_score"]: | 630 if options["error_score"]: |
569 options["error_score"] = "raise" | 631 options["error_score"] = "raise" |
570 else: | 632 else: |
571 options["error_score"] = np.nan | 633 options["error_score"] = np.NaN |
572 if options["refit"] and isinstance(options["scoring"], dict): | 634 if options["refit"] and isinstance(options["scoring"], dict): |
573 options["refit"] = primary_scoring | 635 options["refit"] = primary_scoring |
574 if "pre_dispatch" in options and options["pre_dispatch"] == "": | 636 if "pre_dispatch" in options and options["pre_dispatch"] == "": |
575 options["pre_dispatch"] = None | 637 options["pre_dispatch"] = None |
576 | 638 |
577 params_builder = params["search_schemes"]["search_params_builder"] | 639 params_builder = params["search_params_builder"] |
578 param_grid = _eval_search_params(params_builder) | 640 param_grid = _eval_search_params(params_builder) |
579 | |
580 estimator = clean_params(estimator) | |
581 | 641 |
582 # save the SearchCV object without fit | 642 # save the SearchCV object without fit |
583 if params["save"] == "save_no_fit": | 643 if params["save"] == "save_no_fit": |
584 searcher = optimizer(estimator, param_grid, **options) | 644 searcher = optimizer(estimator, param_grid, **options) |
585 print(searcher) | 645 dump_model_to_h5(searcher, outfile_object) |
586 with open(outfile_object, "wb") as output_handler: | |
587 pickle.dump(searcher, output_handler, pickle.HIGHEST_PROTOCOL) | |
588 return 0 | 646 return 0 |
589 | 647 |
590 # read inputs and loads new attributes, like paths | 648 # read inputs and loads new attributes, like paths |
591 estimator, X, y = _handle_X_y( | 649 estimator, X, y = _handle_X_y( |
592 estimator, | 650 estimator, |
598 intervals=intervals, | 656 intervals=intervals, |
599 targets=targets, | 657 targets=targets, |
600 fasta_path=fasta_path, | 658 fasta_path=fasta_path, |
601 ) | 659 ) |
602 | 660 |
661 label_encoder = LabelEncoder() | |
662 if get_main_estimator(estimator).__class__.__name__ == "XGBClassifier": | |
663 y = label_encoder.fit_transform(y) | |
664 | |
603 # cache iraps_core fits could increase search speed significantly | 665 # cache iraps_core fits could increase search speed significantly |
604 memory = joblib.Memory(location=CACHE_DIR, verbose=0) | 666 memory = joblib.Memory(location=CACHE_DIR, verbose=0) |
605 main_est = get_main_estimator(estimator) | 667 estimator = _set_memory(estimator, memory) |
606 if main_est.__class__.__name__ == "IRAPSClassifier": | |
607 main_est.set_params(memory=memory) | |
608 | 668 |
609 searcher = optimizer(estimator, param_grid, **options) | 669 searcher = optimizer(estimator, param_grid, **options) |
610 | 670 |
611 split_mode = params["outer_split"].pop("split_mode") | 671 split_mode = params["outer_split"].pop("split_mode") |
612 | 672 |
673 # Nested CV | |
613 if split_mode == "nested_cv": | 674 if split_mode == "nested_cv": |
614 # make sure refit is choosen | 675 cv_selector = params["outer_split"]["cv_selector"] |
615 # this could be True for sklearn models, but not the case for | 676 if Version(galaxy_ml_version) < Version("0.8.3"): |
616 # deep learning models | 677 cv_selector.pop("n_stratification_bins", None) |
617 if not options["refit"] and not all( | 678 outer_cv, _ = get_cv(cv_selector) |
618 hasattr(estimator, attr) for attr in ("config", "model_type") | |
619 ): | |
620 warnings.warn("Refit is change to `True` for nested validation!") | |
621 setattr(searcher, "refit", True) | |
622 | |
623 outer_cv, _ = get_cv(params["outer_split"]["cv_selector"]) | |
624 # nested CV, outer cv using cross_validate | 679 # nested CV, outer cv using cross_validate |
625 if options["error_score"] == "raise": | 680 if options["error_score"] == "raise": |
626 rval = cross_validate( | 681 rval = cross_validate( |
627 searcher, | 682 searcher, |
628 X, | 683 X, |
629 y, | 684 y, |
685 groups=groups, | |
630 scoring=options["scoring"], | 686 scoring=options["scoring"], |
631 cv=outer_cv, | 687 cv=outer_cv, |
632 n_jobs=N_JOBS, | 688 n_jobs=N_JOBS, |
633 verbose=options["verbose"], | 689 verbose=options["verbose"], |
690 fit_params={"groups": groups}, | |
634 return_estimator=(params["save"] == "save_estimator"), | 691 return_estimator=(params["save"] == "save_estimator"), |
635 error_score=options["error_score"], | 692 error_score=options["error_score"], |
636 return_train_score=True, | 693 return_train_score=True, |
637 ) | 694 ) |
638 else: | 695 else: |
641 try: | 698 try: |
642 rval = cross_validate( | 699 rval = cross_validate( |
643 searcher, | 700 searcher, |
644 X, | 701 X, |
645 y, | 702 y, |
703 groups=groups, | |
646 scoring=options["scoring"], | 704 scoring=options["scoring"], |
647 cv=outer_cv, | 705 cv=outer_cv, |
648 n_jobs=N_JOBS, | 706 n_jobs=N_JOBS, |
649 verbose=options["verbose"], | 707 verbose=options["verbose"], |
708 fit_params={"groups": groups}, | |
650 return_estimator=(params["save"] == "save_estimator"), | 709 return_estimator=(params["save"] == "save_estimator"), |
651 error_score=options["error_score"], | 710 error_score=options["error_score"], |
652 return_train_score=True, | 711 return_train_score=True, |
653 ) | 712 ) |
654 except ValueError: | 713 except ValueError: |
674 cv_results_ = pd.DataFrame(cv_results_) | 733 cv_results_ = pd.DataFrame(cv_results_) |
675 cv_results_ = cv_results_[sorted(cv_results_.columns)] | 734 cv_results_ = cv_results_[sorted(cv_results_.columns)] |
676 cv_results_.to_csv(target_path, sep="\t", header=True, index=False) | 735 cv_results_.to_csv(target_path, sep="\t", header=True, index=False) |
677 except Exception as e: | 736 except Exception as e: |
678 print(e) | 737 print(e) |
679 finally: | |
680 del os | |
681 | 738 |
682 keys = list(rval.keys()) | 739 keys = list(rval.keys()) |
683 for k in keys: | 740 for k in keys: |
684 if k.startswith("test"): | 741 if k.startswith("test"): |
685 rval["mean_" + k] = np.mean(rval[k]) | 742 rval["mean_" + k] = np.mean(rval[k]) |
687 if k.endswith("time"): | 744 if k.endswith("time"): |
688 rval.pop(k) | 745 rval.pop(k) |
689 rval = pd.DataFrame(rval) | 746 rval = pd.DataFrame(rval) |
690 rval = rval[sorted(rval.columns)] | 747 rval = rval[sorted(rval.columns)] |
691 rval.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) | 748 rval.to_csv(path_or_buf=outfile_result, sep="\t", header=True, index=False) |
749 | |
750 return 0 | |
751 | |
692 # deprecate train test split mode | 752 # deprecate train test split mode |
693 """searcher = _do_train_test_split_val( | 753 """searcher = _do_train_test_split_val( |
694 searcher, X, y, params, | 754 searcher, X, y, params, |
695 primary_scoring=primary_scoring, | 755 primary_scoring=primary_scoring, |
696 error_score=options['error_score'], | 756 error_score=options['error_score'], |
697 groups=groups, | 757 groups=groups, |
698 outfile=outfile_result)""" | 758 outfile=outfile_result)""" |
699 return 0 | |
700 | 759 |
701 # no outer split | 760 # no outer split |
702 else: | 761 else: |
703 searcher.set_params(n_jobs=N_JOBS) | 762 searcher.set_params(n_jobs=N_JOBS) |
704 if options["error_score"] == "raise": | 763 if options["error_score"] == "raise": |
730 "'best_estimator_', because either it's " | 789 "'best_estimator_', because either it's " |
731 "nested gridsearch or `refit` is False!" | 790 "nested gridsearch or `refit` is False!" |
732 ) | 791 ) |
733 return | 792 return |
734 | 793 |
735 # clean prams | 794 dump_model_to_h5(best_estimator_, outfile_object) |
736 best_estimator_ = clean_params(best_estimator_) | |
737 | |
738 main_est = get_main_estimator(best_estimator_) | |
739 | |
740 if hasattr(main_est, "model_") and hasattr(main_est, "save_weights"): | |
741 if outfile_weights: | |
742 main_est.save_weights(outfile_weights) | |
743 del main_est.model_ | |
744 del main_est.fit_params | |
745 del main_est.model_class_ | |
746 del main_est.validation_data | |
747 if getattr(main_est, "data_generator_", None): | |
748 del main_est.data_generator_ | |
749 | |
750 with open(outfile_object, "wb") as output_handler: | |
751 print("Best estimator is saved: %s " % repr(best_estimator_)) | |
752 pickle.dump(best_estimator_, output_handler, pickle.HIGHEST_PROTOCOL) | |
753 | 795 |
754 | 796 |
755 if __name__ == "__main__": | 797 if __name__ == "__main__": |
756 aparser = argparse.ArgumentParser() | 798 aparser = argparse.ArgumentParser() |
757 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 799 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
758 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | 800 aparser.add_argument("-e", "--estimator", dest="infile_estimator") |
759 aparser.add_argument("-X", "--infile1", dest="infile1") | 801 aparser.add_argument("-X", "--infile1", dest="infile1") |
760 aparser.add_argument("-y", "--infile2", dest="infile2") | 802 aparser.add_argument("-y", "--infile2", dest="infile2") |
761 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | 803 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") |
762 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | 804 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") |
763 aparser.add_argument("-w", "--outfile_weights", dest="outfile_weights") | |
764 aparser.add_argument("-g", "--groups", dest="groups") | 805 aparser.add_argument("-g", "--groups", dest="groups") |
765 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 806 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
766 aparser.add_argument("-b", "--intervals", dest="intervals") | 807 aparser.add_argument("-b", "--intervals", dest="intervals") |
767 aparser.add_argument("-t", "--targets", dest="targets") | 808 aparser.add_argument("-t", "--targets", dest="targets") |
768 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 809 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
769 args = aparser.parse_args() | 810 args = aparser.parse_args() |
770 | 811 |
771 main( | 812 main(**vars(args)) |
772 args.inputs, | |
773 args.infile_estimator, | |
774 args.infile1, | |
775 args.infile2, | |
776 args.outfile_result, | |
777 outfile_object=args.outfile_object, | |
778 outfile_weights=args.outfile_weights, | |
779 groups=args.groups, | |
780 ref_seq=args.ref_seq, | |
781 intervals=args.intervals, | |
782 targets=args.targets, | |
783 fasta_path=args.fasta_path, | |
784 ) |