Mercurial > repos > bgruening > sklearn_model_validation
comparison model_prediction.py @ 30:4b359039f09f draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author | bgruening |
---|---|
date | Sat, 01 May 2021 01:03:56 +0000 |
parents | de360b57a5ab |
children | 1fe00785190d |
comparison
equal
deleted
inserted
replaced
29:de360b57a5ab | 30:4b359039f09f |
---|---|
61 if isinstance(estimator, Pipeline): | 61 if isinstance(estimator, Pipeline): |
62 main_est = estimator.steps[-1][-1] | 62 main_est = estimator.steps[-1][-1] |
63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): | 63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): |
64 if not infile_weights or infile_weights == "None": | 64 if not infile_weights or infile_weights == "None": |
65 raise ValueError( | 65 raise ValueError( |
66 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" | 66 "The selected model skeleton asks for weights, " |
67 "but dataset for weights wan not selected!" | |
67 ) | 68 ) |
68 main_est.load_weights(infile_weights) | 69 main_est.load_weights(infile_weights) |
69 | 70 |
70 # handle data input | 71 # handle data input |
71 input_type = params["input_options"]["selected_input"] | 72 input_type = params["input_options"]["selected_input"] |
72 # tabular input | 73 # tabular input |
73 if input_type == "tabular": | 74 if input_type == "tabular": |
74 header = "infer" if params["input_options"]["header1"] else None | 75 header = "infer" if params["input_options"]["header1"] else None |
75 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] | 76 column_option = params["input_options"]["column_selector_options_1"][ |
77 "selected_column_selector_option" | |
78 ] | |
76 if column_option in [ | 79 if column_option in [ |
77 "by_index_number", | 80 "by_index_number", |
78 "all_but_by_index_number", | 81 "all_but_by_index_number", |
79 "by_header_name", | 82 "by_header_name", |
80 "all_but_by_header_name", | 83 "all_but_by_header_name", |
120 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) | 123 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) |
121 | 124 |
122 pred_data_generator = klass(fasta_path, seq_length=seq_length) | 125 pred_data_generator = klass(fasta_path, seq_length=seq_length) |
123 | 126 |
124 if params["method"] == "predict": | 127 if params["method"] == "predict": |
125 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps) | 128 preds = estimator.predict( |
126 else: | 129 X, data_generator=pred_data_generator, steps=steps |
127 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps) | 130 ) |
131 else: | |
132 preds = estimator.predict_proba( | |
133 X, data_generator=pred_data_generator, steps=steps | |
134 ) | |
128 | 135 |
129 # vcf input | 136 # vcf input |
130 elif input_type == "variant_effect": | 137 elif input_type == "variant_effect": |
131 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") | 138 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") |
132 | 139 |
133 options = params["input_options"] | 140 options = params["input_options"] |
134 options.pop("selected_input") | 141 options.pop("selected_input") |
135 if options["blacklist_regions"] == "none": | 142 if options["blacklist_regions"] == "none": |
136 options["blacklist_regions"] = None | 143 options["blacklist_regions"] = None |
137 | 144 |
138 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options) | 145 pred_data_generator = klass( |
146 ref_genome_path=ref_seq, vcf_path=vcf_path, **options | |
147 ) | |
139 | 148 |
140 pred_data_generator.set_processing_attrs() | 149 pred_data_generator.set_processing_attrs() |
141 | 150 |
142 variants = pred_data_generator.variants | 151 variants = pred_data_generator.variants |
143 | 152 |