Mercurial > repos > bgruening > sklearn_pca
comparison ml_visualization_ex.py @ 0:2d7016b3ae92 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 2afb24f3c81d625312186750a714d702363012b5"
author | bgruening |
---|---|
date | Fri, 02 Oct 2020 08:45:21 +0000 |
parents | |
children | 132805688fa3 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:2d7016b3ae92 |
---|---|
1 import argparse | |
2 import json | |
3 import matplotlib | |
4 import matplotlib.pyplot as plt | |
5 import numpy as np | |
6 import os | |
7 import pandas as pd | |
8 import plotly | |
9 import plotly.graph_objs as go | |
10 import warnings | |
11 | |
12 from keras.models import model_from_json | |
13 from keras.utils import plot_model | |
14 from sklearn.feature_selection.base import SelectorMixin | |
15 from sklearn.metrics import precision_recall_curve, average_precision_score | |
16 from sklearn.metrics import roc_curve, auc, confusion_matrix | |
17 from sklearn.pipeline import Pipeline | |
18 from galaxy_ml.utils import load_model, read_columns, SafeEval | |
19 | |
20 | |
21 safe_eval = SafeEval() | |
22 | |
23 # plotly default colors | |
24 default_colors = [ | |
25 '#1f77b4', # muted blue | |
26 '#ff7f0e', # safety orange | |
27 '#2ca02c', # cooked asparagus green | |
28 '#d62728', # brick red | |
29 '#9467bd', # muted purple | |
30 '#8c564b', # chestnut brown | |
31 '#e377c2', # raspberry yogurt pink | |
32 '#7f7f7f', # middle gray | |
33 '#bcbd22', # curry yellow-green | |
34 '#17becf' # blue-teal | |
35 ] | |
36 | |
37 | |
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): | |
39 """output pr-curve in html using plotly | |
40 | |
41 df1 : pandas.DataFrame | |
42 Containing y_true | |
43 df2 : pandas.DataFrame | |
44 Containing y_score | |
45 pos_label : None | |
46 The label of positive class | |
47 title : str | |
48 Plot title | |
49 """ | |
50 data = [] | |
51 for idx in range(df1.shape[1]): | |
52 y_true = df1.iloc[:, idx].values | |
53 y_score = df2.iloc[:, idx].values | |
54 | |
55 precision, recall, _ = precision_recall_curve( | |
56 y_true, y_score, pos_label=pos_label) | |
57 ap = average_precision_score( | |
58 y_true, y_score, pos_label=pos_label or 1) | |
59 | |
60 trace = go.Scatter( | |
61 x=recall, | |
62 y=precision, | |
63 mode='lines', | |
64 marker=dict( | |
65 color=default_colors[idx % len(default_colors)] | |
66 ), | |
67 name='%s (area = %.3f)' % (idx, ap) | |
68 ) | |
69 data.append(trace) | |
70 | |
71 layout = go.Layout( | |
72 xaxis=dict( | |
73 title='Recall', | |
74 linecolor='lightslategray', | |
75 linewidth=1 | |
76 ), | |
77 yaxis=dict( | |
78 title='Precision', | |
79 linecolor='lightslategray', | |
80 linewidth=1 | |
81 ), | |
82 title=dict( | |
83 text=title or 'Precision-Recall Curve', | |
84 x=0.5, | |
85 y=0.92, | |
86 xanchor='center', | |
87 yanchor='top' | |
88 ), | |
89 font=dict( | |
90 family="sans-serif", | |
91 size=11 | |
92 ), | |
93 # control backgroud colors | |
94 plot_bgcolor='rgba(255,255,255,0)' | |
95 ) | |
96 """ | |
97 legend=dict( | |
98 x=0.95, | |
99 y=0, | |
100 traceorder="normal", | |
101 font=dict( | |
102 family="sans-serif", | |
103 size=9, | |
104 color="black" | |
105 ), | |
106 bgcolor="LightSteelBlue", | |
107 bordercolor="Black", | |
108 borderwidth=2 | |
109 ),""" | |
110 | |
111 fig = go.Figure(data=data, layout=layout) | |
112 | |
113 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
114 # to be discovered by `from_work_dir` | |
115 os.rename('output.html', 'output') | |
116 | |
117 | |
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): | |
119 """visualize pr-curve using matplotlib and output svg image | |
120 """ | |
121 backend = matplotlib.get_backend() | |
122 if "inline" not in backend: | |
123 matplotlib.use("SVG") | |
124 plt.style.use('seaborn-colorblind') | |
125 plt.figure() | |
126 | |
127 for idx in range(df1.shape[1]): | |
128 y_true = df1.iloc[:, idx].values | |
129 y_score = df2.iloc[:, idx].values | |
130 | |
131 precision, recall, _ = precision_recall_curve( | |
132 y_true, y_score, pos_label=pos_label) | |
133 ap = average_precision_score( | |
134 y_true, y_score, pos_label=pos_label or 1) | |
135 | |
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3, | |
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) | |
138 | |
139 plt.xlim([0.0, 1.0]) | |
140 plt.ylim([0.0, 1.05]) | |
141 plt.xlabel('Recall') | |
142 plt.ylabel('Precision') | |
143 title = title or 'Precision-Recall Curve' | |
144 plt.title(title) | |
145 folder = os.getcwd() | |
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | |
147 os.rename(os.path.join(folder, "output.svg"), | |
148 os.path.join(folder, "output")) | |
149 | |
150 | |
151 def visualize_roc_curve_plotly(df1, df2, pos_label, | |
152 drop_intermediate=True, | |
153 title=None): | |
154 """output roc-curve in html using plotly | |
155 | |
156 df1 : pandas.DataFrame | |
157 Containing y_true | |
158 df2 : pandas.DataFrame | |
159 Containing y_score | |
160 pos_label : None | |
161 The label of positive class | |
162 drop_intermediate : bool | |
163 Whether to drop some suboptimal thresholds | |
164 title : str | |
165 Plot title | |
166 """ | |
167 data = [] | |
168 for idx in range(df1.shape[1]): | |
169 y_true = df1.iloc[:, idx].values | |
170 y_score = df2.iloc[:, idx].values | |
171 | |
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, | |
173 drop_intermediate=drop_intermediate) | |
174 roc_auc = auc(fpr, tpr) | |
175 | |
176 trace = go.Scatter( | |
177 x=fpr, | |
178 y=tpr, | |
179 mode='lines', | |
180 marker=dict( | |
181 color=default_colors[idx % len(default_colors)] | |
182 ), | |
183 name='%s (area = %.3f)' % (idx, roc_auc) | |
184 ) | |
185 data.append(trace) | |
186 | |
187 layout = go.Layout( | |
188 xaxis=dict( | |
189 title='False Positive Rate', | |
190 linecolor='lightslategray', | |
191 linewidth=1 | |
192 ), | |
193 yaxis=dict( | |
194 title='True Positive Rate', | |
195 linecolor='lightslategray', | |
196 linewidth=1 | |
197 ), | |
198 title=dict( | |
199 text=title or 'Receiver Operating Characteristic (ROC) Curve', | |
200 x=0.5, | |
201 y=0.92, | |
202 xanchor='center', | |
203 yanchor='top' | |
204 ), | |
205 font=dict( | |
206 family="sans-serif", | |
207 size=11 | |
208 ), | |
209 # control backgroud colors | |
210 plot_bgcolor='rgba(255,255,255,0)' | |
211 ) | |
212 """ | |
213 # legend=dict( | |
214 # x=0.95, | |
215 # y=0, | |
216 # traceorder="normal", | |
217 # font=dict( | |
218 # family="sans-serif", | |
219 # size=9, | |
220 # color="black" | |
221 # ), | |
222 # bgcolor="LightSteelBlue", | |
223 # bordercolor="Black", | |
224 # borderwidth=2 | |
225 # ), | |
226 """ | |
227 | |
228 fig = go.Figure(data=data, layout=layout) | |
229 | |
230 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
231 # to be discovered by `from_work_dir` | |
232 os.rename('output.html', 'output') | |
233 | |
234 | |
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label, | |
236 drop_intermediate=True, | |
237 title=None): | |
238 """visualize roc-curve using matplotlib and output svg image | |
239 """ | |
240 backend = matplotlib.get_backend() | |
241 if "inline" not in backend: | |
242 matplotlib.use("SVG") | |
243 plt.style.use('seaborn-colorblind') | |
244 plt.figure() | |
245 | |
246 for idx in range(df1.shape[1]): | |
247 y_true = df1.iloc[:, idx].values | |
248 y_score = df2.iloc[:, idx].values | |
249 | |
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, | |
251 drop_intermediate=drop_intermediate) | |
252 roc_auc = auc(fpr, tpr) | |
253 | |
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, | |
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc)) | |
256 | |
257 plt.xlim([0.0, 1.0]) | |
258 plt.ylim([0.0, 1.05]) | |
259 plt.xlabel('False Positive Rate') | |
260 plt.ylabel('True Positive Rate') | |
261 title = title or 'Receiver Operating Characteristic (ROC) Curve' | |
262 plt.title(title) | |
263 folder = os.getcwd() | |
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | |
265 os.rename(os.path.join(folder, "output.svg"), | |
266 os.path.join(folder, "output")) | |
267 | |
268 | |
269 def get_dataframe(file_path, plot_selection, header_name, column_name): | |
270 header = 'infer' if plot_selection[header_name] else None | |
271 column_option = plot_selection[column_name]["selected_column_selector_option"] | |
272 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: | |
273 col = plot_selection[column_name]["col1"] | |
274 else: | |
275 col = None | |
276 _, input_df = read_columns(file_path, c=col, | |
277 c_option=column_option, | |
278 return_df=True, | |
279 sep='\t', header=header, | |
280 parse_dates=True) | |
281 return input_df | |
282 | |
283 | |
284 def main(inputs, infile_estimator=None, infile1=None, | |
285 infile2=None, outfile_result=None, | |
286 outfile_object=None, groups=None, | |
287 ref_seq=None, intervals=None, | |
288 targets=None, fasta_path=None, | |
289 model_config=None, true_labels=None, | |
290 predicted_labels=None, plot_color=None, | |
291 title=None): | |
292 """ | |
293 Parameter | |
294 --------- | |
295 inputs : str | |
296 File path to galaxy tool parameter | |
297 | |
298 infile_estimator : str, default is None | |
299 File path to estimator | |
300 | |
301 infile1 : str, default is None | |
302 File path to dataset containing features or true labels. | |
303 | |
304 infile2 : str, default is None | |
305 File path to dataset containing target values or predicted | |
306 probabilities. | |
307 | |
308 outfile_result : str, default is None | |
309 File path to save the results, either cv_results or test result | |
310 | |
311 outfile_object : str, default is None | |
312 File path to save searchCV object | |
313 | |
314 groups : str, default is None | |
315 File path to dataset containing groups labels | |
316 | |
317 ref_seq : str, default is None | |
318 File path to dataset containing genome sequence file | |
319 | |
320 intervals : str, default is None | |
321 File path to dataset containing interval file | |
322 | |
323 targets : str, default is None | |
324 File path to dataset compressed target bed file | |
325 | |
326 fasta_path : str, default is None | |
327 File path to dataset containing fasta file | |
328 | |
329 model_config : str, default is None | |
330 File path to dataset containing JSON config for neural networks | |
331 | |
332 true_labels : str, default is None | |
333 File path to dataset containing true labels | |
334 | |
335 predicted_labels : str, default is None | |
336 File path to dataset containing true predicted labels | |
337 | |
338 plot_color : str, default is None | |
339 Color of the confusion matrix heatmap | |
340 | |
341 title : str, default is None | |
342 Title of the confusion matrix heatmap | |
343 """ | |
344 warnings.simplefilter('ignore') | |
345 | |
346 with open(inputs, 'r') as param_handler: | |
347 params = json.load(param_handler) | |
348 | |
349 title = params['plotting_selection']['title'].strip() | |
350 plot_type = params['plotting_selection']['plot_type'] | |
351 plot_format = params['plotting_selection']['plot_format'] | |
352 | |
353 if plot_type == 'feature_importances': | |
354 with open(infile_estimator, 'rb') as estimator_handler: | |
355 estimator = load_model(estimator_handler) | |
356 | |
357 column_option = (params['plotting_selection'] | |
358 ['column_selector_options'] | |
359 ['selected_column_selector_option']) | |
360 if column_option in ['by_index_number', 'all_but_by_index_number', | |
361 'by_header_name', 'all_but_by_header_name']: | |
362 c = (params['plotting_selection'] | |
363 ['column_selector_options']['col1']) | |
364 else: | |
365 c = None | |
366 | |
367 _, input_df = read_columns(infile1, c=c, | |
368 c_option=column_option, | |
369 return_df=True, | |
370 sep='\t', header='infer', | |
371 parse_dates=True) | |
372 | |
373 feature_names = input_df.columns.values | |
374 | |
375 if isinstance(estimator, Pipeline): | |
376 for st in estimator.steps[:-1]: | |
377 if isinstance(st[-1], SelectorMixin): | |
378 mask = st[-1].get_support() | |
379 feature_names = feature_names[mask] | |
380 estimator = estimator.steps[-1][-1] | |
381 | |
382 if hasattr(estimator, 'coef_'): | |
383 coefs = estimator.coef_ | |
384 else: | |
385 coefs = getattr(estimator, 'feature_importances_', None) | |
386 if coefs is None: | |
387 raise RuntimeError('The classifier does not expose ' | |
388 '"coef_" or "feature_importances_" ' | |
389 'attributes') | |
390 | |
391 threshold = params['plotting_selection']['threshold'] | |
392 if threshold is not None: | |
393 mask = (coefs > threshold) | (coefs < -threshold) | |
394 coefs = coefs[mask] | |
395 feature_names = feature_names[mask] | |
396 | |
397 # sort | |
398 indices = np.argsort(coefs)[::-1] | |
399 | |
400 trace = go.Bar(x=feature_names[indices], | |
401 y=coefs[indices]) | |
402 layout = go.Layout(title=title or "Feature Importances") | |
403 fig = go.Figure(data=[trace], layout=layout) | |
404 | |
405 plotly.offline.plot(fig, filename="output.html", | |
406 auto_open=False) | |
407 # to be discovered by `from_work_dir` | |
408 os.rename('output.html', 'output') | |
409 | |
410 return 0 | |
411 | |
412 elif plot_type in ('pr_curve', 'roc_curve'): | |
413 df1 = pd.read_csv(infile1, sep='\t', header='infer') | |
414 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32) | |
415 | |
416 minimum = params['plotting_selection']['report_minimum_n_positives'] | |
417 # filter out columns whose n_positives is beblow the threhold | |
418 if minimum: | |
419 mask = df1.sum(axis=0) >= minimum | |
420 df1 = df1.loc[:, mask] | |
421 df2 = df2.loc[:, mask] | |
422 | |
423 pos_label = params['plotting_selection']['pos_label'].strip() \ | |
424 or None | |
425 | |
426 if plot_type == 'pr_curve': | |
427 if plot_format == 'plotly_html': | |
428 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) | |
429 else: | |
430 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) | |
431 else: # 'roc_curve' | |
432 drop_intermediate = (params['plotting_selection'] | |
433 ['drop_intermediate']) | |
434 if plot_format == 'plotly_html': | |
435 visualize_roc_curve_plotly(df1, df2, pos_label, | |
436 drop_intermediate=drop_intermediate, | |
437 title=title) | |
438 else: | |
439 visualize_roc_curve_matplotlib( | |
440 df1, df2, pos_label, | |
441 drop_intermediate=drop_intermediate, | |
442 title=title) | |
443 | |
444 return 0 | |
445 | |
446 elif plot_type == 'rfecv_gridscores': | |
447 input_df = pd.read_csv(infile1, sep='\t', header='infer') | |
448 scores = input_df.iloc[:, 0] | |
449 steps = params['plotting_selection']['steps'].strip() | |
450 steps = safe_eval(steps) | |
451 | |
452 data = go.Scatter( | |
453 x=list(range(len(scores))), | |
454 y=scores, | |
455 text=[str(_) for _ in steps] if steps else None, | |
456 mode='lines' | |
457 ) | |
458 layout = go.Layout( | |
459 xaxis=dict(title="Number of features selected"), | |
460 yaxis=dict(title="Cross validation score"), | |
461 title=dict( | |
462 text=title or None, | |
463 x=0.5, | |
464 y=0.92, | |
465 xanchor='center', | |
466 yanchor='top' | |
467 ), | |
468 font=dict( | |
469 family="sans-serif", | |
470 size=11 | |
471 ), | |
472 # control backgroud colors | |
473 plot_bgcolor='rgba(255,255,255,0)' | |
474 ) | |
475 """ | |
476 # legend=dict( | |
477 # x=0.95, | |
478 # y=0, | |
479 # traceorder="normal", | |
480 # font=dict( | |
481 # family="sans-serif", | |
482 # size=9, | |
483 # color="black" | |
484 # ), | |
485 # bgcolor="LightSteelBlue", | |
486 # bordercolor="Black", | |
487 # borderwidth=2 | |
488 # ), | |
489 """ | |
490 | |
491 fig = go.Figure(data=[data], layout=layout) | |
492 plotly.offline.plot(fig, filename="output.html", | |
493 auto_open=False) | |
494 # to be discovered by `from_work_dir` | |
495 os.rename('output.html', 'output') | |
496 | |
497 return 0 | |
498 | |
499 elif plot_type == 'learning_curve': | |
500 input_df = pd.read_csv(infile1, sep='\t', header='infer') | |
501 plot_std_err = params['plotting_selection']['plot_std_err'] | |
502 data1 = go.Scatter( | |
503 x=input_df['train_sizes_abs'], | |
504 y=input_df['mean_train_scores'], | |
505 error_y=dict( | |
506 array=input_df['std_train_scores'] | |
507 ) if plot_std_err else None, | |
508 mode='lines', | |
509 name="Train Scores", | |
510 ) | |
511 data2 = go.Scatter( | |
512 x=input_df['train_sizes_abs'], | |
513 y=input_df['mean_test_scores'], | |
514 error_y=dict( | |
515 array=input_df['std_test_scores'] | |
516 ) if plot_std_err else None, | |
517 mode='lines', | |
518 name="Test Scores", | |
519 ) | |
520 layout = dict( | |
521 xaxis=dict( | |
522 title='No. of samples' | |
523 ), | |
524 yaxis=dict( | |
525 title='Performance Score' | |
526 ), | |
527 # modify these configurations to customize image | |
528 title=dict( | |
529 text=title or 'Learning Curve', | |
530 x=0.5, | |
531 y=0.92, | |
532 xanchor='center', | |
533 yanchor='top' | |
534 ), | |
535 font=dict( | |
536 family="sans-serif", | |
537 size=11 | |
538 ), | |
539 # control backgroud colors | |
540 plot_bgcolor='rgba(255,255,255,0)' | |
541 ) | |
542 """ | |
543 # legend=dict( | |
544 # x=0.95, | |
545 # y=0, | |
546 # traceorder="normal", | |
547 # font=dict( | |
548 # family="sans-serif", | |
549 # size=9, | |
550 # color="black" | |
551 # ), | |
552 # bgcolor="LightSteelBlue", | |
553 # bordercolor="Black", | |
554 # borderwidth=2 | |
555 # ), | |
556 """ | |
557 | |
558 fig = go.Figure(data=[data1, data2], layout=layout) | |
559 plotly.offline.plot(fig, filename="output.html", | |
560 auto_open=False) | |
561 # to be discovered by `from_work_dir` | |
562 os.rename('output.html', 'output') | |
563 | |
564 return 0 | |
565 | |
566 elif plot_type == 'keras_plot_model': | |
567 with open(model_config, 'r') as f: | |
568 model_str = f.read() | |
569 model = model_from_json(model_str) | |
570 plot_model(model, to_file="output.png") | |
571 os.rename('output.png', 'output') | |
572 | |
573 return 0 | |
574 | |
575 elif plot_type == 'classification_confusion_matrix': | |
576 plot_selection = params["plotting_selection"] | |
577 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") | |
578 header_predicted = 'infer' if plot_selection["header_predicted"] else None | |
579 input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted) | |
580 true_classes = input_true.iloc[:, -1].copy() | |
581 predicted_classes = input_predicted.iloc[:, -1].copy() | |
582 axis_labels = list(set(true_classes)) | |
583 c_matrix = confusion_matrix(true_classes, predicted_classes) | |
584 fig, ax = plt.subplots(figsize=(7, 7)) | |
585 im = plt.imshow(c_matrix, cmap=plot_color) | |
586 for i in range(len(c_matrix)): | |
587 for j in range(len(c_matrix)): | |
588 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") | |
589 ax.set_ylabel('True class labels') | |
590 ax.set_xlabel('Predicted class labels') | |
591 ax.set_title(title) | |
592 ax.set_xticks(axis_labels) | |
593 ax.set_yticks(axis_labels) | |
594 fig.colorbar(im, ax=ax) | |
595 fig.tight_layout() | |
596 plt.savefig("output.png", dpi=125) | |
597 os.rename('output.png', 'output') | |
598 | |
599 return 0 | |
600 | |
601 # save pdf file to disk | |
602 # fig.write_image("image.pdf", format='pdf') | |
603 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | |
604 | |
605 | |
606 if __name__ == '__main__': | |
607 aparser = argparse.ArgumentParser() | |
608 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
609 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | |
610 aparser.add_argument("-X", "--infile1", dest="infile1") | |
611 aparser.add_argument("-y", "--infile2", dest="infile2") | |
612 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | |
613 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | |
614 aparser.add_argument("-g", "--groups", dest="groups") | |
615 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
616 aparser.add_argument("-b", "--intervals", dest="intervals") | |
617 aparser.add_argument("-t", "--targets", dest="targets") | |
618 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
619 aparser.add_argument("-c", "--model_config", dest="model_config") | |
620 aparser.add_argument("-tl", "--true_labels", dest="true_labels") | |
621 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") | |
622 aparser.add_argument("-pc", "--plot_color", dest="plot_color") | |
623 aparser.add_argument("-pt", "--title", dest="title") | |
624 args = aparser.parse_args() | |
625 | |
626 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, | |
627 args.outfile_result, outfile_object=args.outfile_object, | |
628 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, | |
629 targets=args.targets, fasta_path=args.fasta_path, | |
630 model_config=args.model_config, true_labels=args.true_labels, | |
631 predicted_labels=args.predicted_labels, | |
632 plot_color=args.plot_color, | |
633 title=args.title) |