comparison model_prediction.py @ 37:913bf1c4c7bb draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author bgruening
date Sat, 01 May 2021 01:18:07 +0000
parents 616a241c5b37
children fe181d613429
comparison
equal deleted inserted replaced
36:616a241c5b37 37:913bf1c4c7bb
61 if isinstance(estimator, Pipeline): 61 if isinstance(estimator, Pipeline):
62 main_est = estimator.steps[-1][-1] 62 main_est = estimator.steps[-1][-1]
63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"): 63 if hasattr(main_est, "config") and hasattr(main_est, "load_weights"):
64 if not infile_weights or infile_weights == "None": 64 if not infile_weights or infile_weights == "None":
65 raise ValueError( 65 raise ValueError(
66 "The selected model skeleton asks for weights, " "but dataset for weights wan not selected!" 66 "The selected model skeleton asks for weights, "
67 "but dataset for weights wan not selected!"
67 ) 68 )
68 main_est.load_weights(infile_weights) 69 main_est.load_weights(infile_weights)
69 70
70 # handle data input 71 # handle data input
71 input_type = params["input_options"]["selected_input"] 72 input_type = params["input_options"]["selected_input"]
72 # tabular input 73 # tabular input
73 if input_type == "tabular": 74 if input_type == "tabular":
74 header = "infer" if params["input_options"]["header1"] else None 75 header = "infer" if params["input_options"]["header1"] else None
75 column_option = params["input_options"]["column_selector_options_1"]["selected_column_selector_option"] 76 column_option = params["input_options"]["column_selector_options_1"][
77 "selected_column_selector_option"
78 ]
76 if column_option in [ 79 if column_option in [
77 "by_index_number", 80 "by_index_number",
78 "all_but_by_index_number", 81 "all_but_by_index_number",
79 "by_header_name", 82 "by_header_name",
80 "all_but_by_header_name", 83 "all_but_by_header_name",
120 klass = try_get_attr("galaxy_ml.preprocessors", seq_type) 123 klass = try_get_attr("galaxy_ml.preprocessors", seq_type)
121 124
122 pred_data_generator = klass(fasta_path, seq_length=seq_length) 125 pred_data_generator = klass(fasta_path, seq_length=seq_length)
123 126
124 if params["method"] == "predict": 127 if params["method"] == "predict":
125 preds = estimator.predict(X, data_generator=pred_data_generator, steps=steps) 128 preds = estimator.predict(
126 else: 129 X, data_generator=pred_data_generator, steps=steps
127 preds = estimator.predict_proba(X, data_generator=pred_data_generator, steps=steps) 130 )
131 else:
132 preds = estimator.predict_proba(
133 X, data_generator=pred_data_generator, steps=steps
134 )
128 135
129 # vcf input 136 # vcf input
130 elif input_type == "variant_effect": 137 elif input_type == "variant_effect":
131 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator") 138 klass = try_get_attr("galaxy_ml.preprocessors", "GenomicVariantBatchGenerator")
132 139
133 options = params["input_options"] 140 options = params["input_options"]
134 options.pop("selected_input") 141 options.pop("selected_input")
135 if options["blacklist_regions"] == "none": 142 if options["blacklist_regions"] == "none":
136 options["blacklist_regions"] = None 143 options["blacklist_regions"] = None
137 144
138 pred_data_generator = klass(ref_genome_path=ref_seq, vcf_path=vcf_path, **options) 145 pred_data_generator = klass(
146 ref_genome_path=ref_seq, vcf_path=vcf_path, **options
147 )
139 148
140 pred_data_generator.set_processing_attrs() 149 pred_data_generator.set_processing_attrs()
141 150
142 variants = pred_data_generator.variants 151 variants = pred_data_generator.variants
143 152