Mercurial > repos > bgruening > sklearn_train_test_eval
comparison ml_visualization_ex.py @ 1:cc49634df38f draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author | bgruening |
---|---|
date | Fri, 13 Sep 2019 12:08:44 -0400 |
parents | |
children | e23cfe4be9d4 |
comparison
equal
deleted
inserted
replaced
0:68aaa903052a | 1:cc49634df38f |
---|---|
1 import argparse | |
2 import json | |
3 import numpy as np | |
4 import pandas as pd | |
5 import plotly | |
6 import plotly.graph_objs as go | |
7 import warnings | |
8 | |
9 from keras.models import model_from_json | |
10 from keras.utils import plot_model | |
11 from sklearn.feature_selection.base import SelectorMixin | |
12 from sklearn.metrics import precision_recall_curve, average_precision_score | |
13 from sklearn.metrics import roc_curve, auc | |
14 from sklearn.pipeline import Pipeline | |
15 from galaxy_ml.utils import load_model, read_columns, SafeEval | |
16 | |
17 | |
18 safe_eval = SafeEval() | |
19 | |
20 | |
21 def main(inputs, infile_estimator=None, infile1=None, | |
22 infile2=None, outfile_result=None, | |
23 outfile_object=None, groups=None, | |
24 ref_seq=None, intervals=None, | |
25 targets=None, fasta_path=None, | |
26 model_config=None): | |
27 """ | |
28 Parameter | |
29 --------- | |
30 inputs : str | |
31 File path to galaxy tool parameter | |
32 | |
33 infile_estimator : str, default is None | |
34 File path to estimator | |
35 | |
36 infile1 : str, default is None | |
37 File path to dataset containing features or true labels. | |
38 | |
39 infile2 : str, default is None | |
40 File path to dataset containing target values or predicted | |
41 probabilities. | |
42 | |
43 outfile_result : str, default is None | |
44 File path to save the results, either cv_results or test result | |
45 | |
46 outfile_object : str, default is None | |
47 File path to save searchCV object | |
48 | |
49 groups : str, default is None | |
50 File path to dataset containing groups labels | |
51 | |
52 ref_seq : str, default is None | |
53 File path to dataset containing genome sequence file | |
54 | |
55 intervals : str, default is None | |
56 File path to dataset containing interval file | |
57 | |
58 targets : str, default is None | |
59 File path to dataset compressed target bed file | |
60 | |
61 fasta_path : str, default is None | |
62 File path to dataset containing fasta file | |
63 | |
64 model_config : str, default is None | |
65 File path to dataset containing JSON config for neural networks | |
66 """ | |
67 warnings.simplefilter('ignore') | |
68 | |
69 with open(inputs, 'r') as param_handler: | |
70 params = json.load(param_handler) | |
71 | |
72 title = params['plotting_selection']['title'].strip() | |
73 plot_type = params['plotting_selection']['plot_type'] | |
74 if plot_type == 'feature_importances': | |
75 with open(infile_estimator, 'rb') as estimator_handler: | |
76 estimator = load_model(estimator_handler) | |
77 | |
78 column_option = (params['plotting_selection'] | |
79 ['column_selector_options'] | |
80 ['selected_column_selector_option']) | |
81 if column_option in ['by_index_number', 'all_but_by_index_number', | |
82 'by_header_name', 'all_but_by_header_name']: | |
83 c = (params['plotting_selection'] | |
84 ['column_selector_options']['col1']) | |
85 else: | |
86 c = None | |
87 | |
88 _, input_df = read_columns(infile1, c=c, | |
89 c_option=column_option, | |
90 return_df=True, | |
91 sep='\t', header='infer', | |
92 parse_dates=True) | |
93 | |
94 feature_names = input_df.columns.values | |
95 | |
96 if isinstance(estimator, Pipeline): | |
97 for st in estimator.steps[:-1]: | |
98 if isinstance(st[-1], SelectorMixin): | |
99 mask = st[-1].get_support() | |
100 feature_names = feature_names[mask] | |
101 estimator = estimator.steps[-1][-1] | |
102 | |
103 if hasattr(estimator, 'coef_'): | |
104 coefs = estimator.coef_ | |
105 else: | |
106 coefs = getattr(estimator, 'feature_importances_', None) | |
107 if coefs is None: | |
108 raise RuntimeError('The classifier does not expose ' | |
109 '"coef_" or "feature_importances_" ' | |
110 'attributes') | |
111 | |
112 threshold = params['plotting_selection']['threshold'] | |
113 if threshold is not None: | |
114 mask = (coefs > threshold) | (coefs < -threshold) | |
115 coefs = coefs[mask] | |
116 feature_names = feature_names[mask] | |
117 | |
118 # sort | |
119 indices = np.argsort(coefs)[::-1] | |
120 | |
121 trace = go.Bar(x=feature_names[indices], | |
122 y=coefs[indices]) | |
123 layout = go.Layout(title=title or "Feature Importances") | |
124 fig = go.Figure(data=[trace], layout=layout) | |
125 | |
126 elif plot_type == 'pr_curve': | |
127 df1 = pd.read_csv(infile1, sep='\t', header=None) | |
128 df2 = pd.read_csv(infile2, sep='\t', header=None) | |
129 | |
130 precision = {} | |
131 recall = {} | |
132 ap = {} | |
133 | |
134 pos_label = params['plotting_selection']['pos_label'].strip() \ | |
135 or None | |
136 for col in df1.columns: | |
137 y_true = df1[col].values | |
138 y_score = df2[col].values | |
139 | |
140 precision[col], recall[col], _ = precision_recall_curve( | |
141 y_true, y_score, pos_label=pos_label) | |
142 ap[col] = average_precision_score( | |
143 y_true, y_score, pos_label=pos_label or 1) | |
144 | |
145 if len(df1.columns) > 1: | |
146 precision["micro"], recall["micro"], _ = precision_recall_curve( | |
147 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) | |
148 ap['micro'] = average_precision_score( | |
149 df1.values, df2.values, average='micro', pos_label=pos_label or 1) | |
150 | |
151 data = [] | |
152 for key in precision.keys(): | |
153 trace = go.Scatter( | |
154 x=recall[key], | |
155 y=precision[key], | |
156 mode='lines', | |
157 name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro' | |
158 else 'column %s (area = %.2f)' % (key, ap[key]) | |
159 ) | |
160 data.append(trace) | |
161 | |
162 layout = go.Layout( | |
163 title=title or "Precision-Recall curve", | |
164 xaxis=dict(title='Recall'), | |
165 yaxis=dict(title='Precision') | |
166 ) | |
167 | |
168 fig = go.Figure(data=data, layout=layout) | |
169 | |
170 elif plot_type == 'roc_curve': | |
171 df1 = pd.read_csv(infile1, sep='\t', header=None) | |
172 df2 = pd.read_csv(infile2, sep='\t', header=None) | |
173 | |
174 fpr = {} | |
175 tpr = {} | |
176 roc_auc = {} | |
177 | |
178 pos_label = params['plotting_selection']['pos_label'].strip() \ | |
179 or None | |
180 for col in df1.columns: | |
181 y_true = df1[col].values | |
182 y_score = df2[col].values | |
183 | |
184 fpr[col], tpr[col], _ = roc_curve( | |
185 y_true, y_score, pos_label=pos_label) | |
186 roc_auc[col] = auc(fpr[col], tpr[col]) | |
187 | |
188 if len(df1.columns) > 1: | |
189 fpr["micro"], tpr["micro"], _ = roc_curve( | |
190 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) | |
191 roc_auc['micro'] = auc(fpr["micro"], tpr["micro"]) | |
192 | |
193 data = [] | |
194 for key in fpr.keys(): | |
195 trace = go.Scatter( | |
196 x=fpr[key], | |
197 y=tpr[key], | |
198 mode='lines', | |
199 name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro' | |
200 else 'column %s (area = %.2f)' % (key, roc_auc[key]) | |
201 ) | |
202 data.append(trace) | |
203 | |
204 trace = go.Scatter(x=[0, 1], y=[0, 1], | |
205 mode='lines', | |
206 line=dict(color='black', dash='dash'), | |
207 showlegend=False) | |
208 data.append(trace) | |
209 | |
210 layout = go.Layout( | |
211 title=title or "Receiver operating characteristic curve", | |
212 xaxis=dict(title='False Positive Rate'), | |
213 yaxis=dict(title='True Positive Rate') | |
214 ) | |
215 | |
216 fig = go.Figure(data=data, layout=layout) | |
217 | |
218 elif plot_type == 'rfecv_gridscores': | |
219 input_df = pd.read_csv(infile1, sep='\t', header='infer') | |
220 scores = input_df.iloc[:, 0] | |
221 steps = params['plotting_selection']['steps'].strip() | |
222 steps = safe_eval(steps) | |
223 | |
224 data = go.Scatter( | |
225 x=list(range(len(scores))), | |
226 y=scores, | |
227 text=[str(_) for _ in steps] if steps else None, | |
228 mode='lines' | |
229 ) | |
230 layout = go.Layout( | |
231 xaxis=dict(title="Number of features selected"), | |
232 yaxis=dict(title="Cross validation score"), | |
233 title=title or None | |
234 ) | |
235 | |
236 fig = go.Figure(data=[data], layout=layout) | |
237 | |
238 elif plot_type == 'learning_curve': | |
239 input_df = pd.read_csv(infile1, sep='\t', header='infer') | |
240 plot_std_err = params['plotting_selection']['plot_std_err'] | |
241 data1 = go.Scatter( | |
242 x=input_df['train_sizes_abs'], | |
243 y=input_df['mean_train_scores'], | |
244 error_y=dict( | |
245 array=input_df['std_train_scores'] | |
246 ) if plot_std_err else None, | |
247 mode='lines', | |
248 name="Train Scores", | |
249 ) | |
250 data2 = go.Scatter( | |
251 x=input_df['train_sizes_abs'], | |
252 y=input_df['mean_test_scores'], | |
253 error_y=dict( | |
254 array=input_df['std_test_scores'] | |
255 ) if plot_std_err else None, | |
256 mode='lines', | |
257 name="Test Scores", | |
258 ) | |
259 layout = dict( | |
260 xaxis=dict( | |
261 title='No. of samples' | |
262 ), | |
263 yaxis=dict( | |
264 title='Performance Score' | |
265 ), | |
266 title=title or 'Learning Curve' | |
267 ) | |
268 fig = go.Figure(data=[data1, data2], layout=layout) | |
269 | |
270 elif plot_type == 'keras_plot_model': | |
271 with open(model_config, 'r') as f: | |
272 model_str = f.read() | |
273 model = model_from_json(model_str) | |
274 plot_model(model, to_file="output.png") | |
275 __import__('os').rename('output.png', 'output') | |
276 | |
277 return 0 | |
278 | |
279 plotly.offline.plot(fig, filename="output.html", | |
280 auto_open=False) | |
281 # to be discovered by `from_work_dir` | |
282 __import__('os').rename('output.html', 'output') | |
283 | |
284 | |
285 if __name__ == '__main__': | |
286 aparser = argparse.ArgumentParser() | |
287 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
288 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | |
289 aparser.add_argument("-X", "--infile1", dest="infile1") | |
290 aparser.add_argument("-y", "--infile2", dest="infile2") | |
291 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | |
292 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | |
293 aparser.add_argument("-g", "--groups", dest="groups") | |
294 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
295 aparser.add_argument("-b", "--intervals", dest="intervals") | |
296 aparser.add_argument("-t", "--targets", dest="targets") | |
297 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
298 aparser.add_argument("-c", "--model_config", dest="model_config") | |
299 args = aparser.parse_args() | |
300 | |
301 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, | |
302 args.outfile_result, outfile_object=args.outfile_object, | |
303 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, | |
304 targets=args.targets, fasta_path=args.fasta_path, | |
305 model_config=args.model_config) |