Mercurial > repos > bgruening > sklearn_mlxtend_association_rules
comparison ml_visualization_ex.py @ 0:af2624d5ab32 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author | bgruening |
---|---|
date | Sat, 01 May 2021 01:24:32 +0000 |
parents | |
children | 9349ed2749c6 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:af2624d5ab32 |
---|---|
1 import argparse | |
2 import json | |
3 import os | |
4 import warnings | |
5 | |
6 import matplotlib | |
7 import matplotlib.pyplot as plt | |
8 import numpy as np | |
9 import pandas as pd | |
10 import plotly | |
11 import plotly.graph_objs as go | |
12 from galaxy_ml.utils import load_model, read_columns, SafeEval | |
13 from keras.models import model_from_json | |
14 from keras.utils import plot_model | |
15 from sklearn.feature_selection.base import SelectorMixin | |
16 from sklearn.metrics import (auc, average_precision_score, confusion_matrix, | |
17 precision_recall_curve, roc_curve) | |
18 from sklearn.pipeline import Pipeline | |
19 | |
20 safe_eval = SafeEval() | |
21 | |
22 # plotly default colors | |
23 default_colors = [ | |
24 "#1f77b4", # muted blue | |
25 "#ff7f0e", # safety orange | |
26 "#2ca02c", # cooked asparagus green | |
27 "#d62728", # brick red | |
28 "#9467bd", # muted purple | |
29 "#8c564b", # chestnut brown | |
30 "#e377c2", # raspberry yogurt pink | |
31 "#7f7f7f", # middle gray | |
32 "#bcbd22", # curry yellow-green | |
33 "#17becf", # blue-teal | |
34 ] | |
35 | |
36 | |
37 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): | |
38 """output pr-curve in html using plotly | |
39 | |
40 df1 : pandas.DataFrame | |
41 Containing y_true | |
42 df2 : pandas.DataFrame | |
43 Containing y_score | |
44 pos_label : None | |
45 The label of positive class | |
46 title : str | |
47 Plot title | |
48 """ | |
49 data = [] | |
50 for idx in range(df1.shape[1]): | |
51 y_true = df1.iloc[:, idx].values | |
52 y_score = df2.iloc[:, idx].values | |
53 | |
54 precision, recall, _ = precision_recall_curve( | |
55 y_true, y_score, pos_label=pos_label | |
56 ) | |
57 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) | |
58 | |
59 trace = go.Scatter( | |
60 x=recall, | |
61 y=precision, | |
62 mode="lines", | |
63 marker=dict(color=default_colors[idx % len(default_colors)]), | |
64 name="%s (area = %.3f)" % (idx, ap), | |
65 ) | |
66 data.append(trace) | |
67 | |
68 layout = go.Layout( | |
69 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1), | |
70 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1), | |
71 title=dict( | |
72 text=title or "Precision-Recall Curve", | |
73 x=0.5, | |
74 y=0.92, | |
75 xanchor="center", | |
76 yanchor="top", | |
77 ), | |
78 font=dict(family="sans-serif", size=11), | |
79 # control backgroud colors | |
80 plot_bgcolor="rgba(255,255,255,0)", | |
81 ) | |
82 """ | |
83 legend=dict( | |
84 x=0.95, | |
85 y=0, | |
86 traceorder="normal", | |
87 font=dict( | |
88 family="sans-serif", | |
89 size=9, | |
90 color="black" | |
91 ), | |
92 bgcolor="LightSteelBlue", | |
93 bordercolor="Black", | |
94 borderwidth=2 | |
95 ),""" | |
96 | |
97 fig = go.Figure(data=data, layout=layout) | |
98 | |
99 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
100 # to be discovered by `from_work_dir` | |
101 os.rename("output.html", "output") | |
102 | |
103 | |
104 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): | |
105 """visualize pr-curve using matplotlib and output svg image""" | |
106 backend = matplotlib.get_backend() | |
107 if "inline" not in backend: | |
108 matplotlib.use("SVG") | |
109 plt.style.use("seaborn-colorblind") | |
110 plt.figure() | |
111 | |
112 for idx in range(df1.shape[1]): | |
113 y_true = df1.iloc[:, idx].values | |
114 y_score = df2.iloc[:, idx].values | |
115 | |
116 precision, recall, _ = precision_recall_curve( | |
117 y_true, y_score, pos_label=pos_label | |
118 ) | |
119 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) | |
120 | |
121 plt.step( | |
122 recall, | |
123 precision, | |
124 "r-", | |
125 color="black", | |
126 alpha=0.3, | |
127 lw=1, | |
128 where="post", | |
129 label="%s (area = %.3f)" % (idx, ap), | |
130 ) | |
131 | |
132 plt.xlim([0.0, 1.0]) | |
133 plt.ylim([0.0, 1.05]) | |
134 plt.xlabel("Recall") | |
135 plt.ylabel("Precision") | |
136 title = title or "Precision-Recall Curve" | |
137 plt.title(title) | |
138 folder = os.getcwd() | |
139 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | |
140 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) | |
141 | |
142 | |
143 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None): | |
144 """output roc-curve in html using plotly | |
145 | |
146 df1 : pandas.DataFrame | |
147 Containing y_true | |
148 df2 : pandas.DataFrame | |
149 Containing y_score | |
150 pos_label : None | |
151 The label of positive class | |
152 drop_intermediate : bool | |
153 Whether to drop some suboptimal thresholds | |
154 title : str | |
155 Plot title | |
156 """ | |
157 data = [] | |
158 for idx in range(df1.shape[1]): | |
159 y_true = df1.iloc[:, idx].values | |
160 y_score = df2.iloc[:, idx].values | |
161 | |
162 fpr, tpr, _ = roc_curve( | |
163 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate | |
164 ) | |
165 roc_auc = auc(fpr, tpr) | |
166 | |
167 trace = go.Scatter( | |
168 x=fpr, | |
169 y=tpr, | |
170 mode="lines", | |
171 marker=dict(color=default_colors[idx % len(default_colors)]), | |
172 name="%s (area = %.3f)" % (idx, roc_auc), | |
173 ) | |
174 data.append(trace) | |
175 | |
176 layout = go.Layout( | |
177 xaxis=dict( | |
178 title="False Positive Rate", linecolor="lightslategray", linewidth=1 | |
179 ), | |
180 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), | |
181 title=dict( | |
182 text=title or "Receiver Operating Characteristic (ROC) Curve", | |
183 x=0.5, | |
184 y=0.92, | |
185 xanchor="center", | |
186 yanchor="top", | |
187 ), | |
188 font=dict(family="sans-serif", size=11), | |
189 # control backgroud colors | |
190 plot_bgcolor="rgba(255,255,255,0)", | |
191 ) | |
192 """ | |
193 # legend=dict( | |
194 # x=0.95, | |
195 # y=0, | |
196 # traceorder="normal", | |
197 # font=dict( | |
198 # family="sans-serif", | |
199 # size=9, | |
200 # color="black" | |
201 # ), | |
202 # bgcolor="LightSteelBlue", | |
203 # bordercolor="Black", | |
204 # borderwidth=2 | |
205 # ), | |
206 """ | |
207 | |
208 fig = go.Figure(data=data, layout=layout) | |
209 | |
210 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
211 # to be discovered by `from_work_dir` | |
212 os.rename("output.html", "output") | |
213 | |
214 | |
215 def visualize_roc_curve_matplotlib( | |
216 df1, df2, pos_label, drop_intermediate=True, title=None | |
217 ): | |
218 """visualize roc-curve using matplotlib and output svg image""" | |
219 backend = matplotlib.get_backend() | |
220 if "inline" not in backend: | |
221 matplotlib.use("SVG") | |
222 plt.style.use("seaborn-colorblind") | |
223 plt.figure() | |
224 | |
225 for idx in range(df1.shape[1]): | |
226 y_true = df1.iloc[:, idx].values | |
227 y_score = df2.iloc[:, idx].values | |
228 | |
229 fpr, tpr, _ = roc_curve( | |
230 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate | |
231 ) | |
232 roc_auc = auc(fpr, tpr) | |
233 | |
234 plt.step( | |
235 fpr, | |
236 tpr, | |
237 "r-", | |
238 color="black", | |
239 alpha=0.3, | |
240 lw=1, | |
241 where="post", | |
242 label="%s (area = %.3f)" % (idx, roc_auc), | |
243 ) | |
244 | |
245 plt.xlim([0.0, 1.0]) | |
246 plt.ylim([0.0, 1.05]) | |
247 plt.xlabel("False Positive Rate") | |
248 plt.ylabel("True Positive Rate") | |
249 title = title or "Receiver Operating Characteristic (ROC) Curve" | |
250 plt.title(title) | |
251 folder = os.getcwd() | |
252 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | |
253 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) | |
254 | |
255 | |
256 def get_dataframe(file_path, plot_selection, header_name, column_name): | |
257 header = "infer" if plot_selection[header_name] else None | |
258 column_option = plot_selection[column_name]["selected_column_selector_option"] | |
259 if column_option in [ | |
260 "by_index_number", | |
261 "all_but_by_index_number", | |
262 "by_header_name", | |
263 "all_but_by_header_name", | |
264 ]: | |
265 col = plot_selection[column_name]["col1"] | |
266 else: | |
267 col = None | |
268 _, input_df = read_columns( | |
269 file_path, | |
270 c=col, | |
271 c_option=column_option, | |
272 return_df=True, | |
273 sep="\t", | |
274 header=header, | |
275 parse_dates=True, | |
276 ) | |
277 return input_df | |
278 | |
279 | |
280 def main( | |
281 inputs, | |
282 infile_estimator=None, | |
283 infile1=None, | |
284 infile2=None, | |
285 outfile_result=None, | |
286 outfile_object=None, | |
287 groups=None, | |
288 ref_seq=None, | |
289 intervals=None, | |
290 targets=None, | |
291 fasta_path=None, | |
292 model_config=None, | |
293 true_labels=None, | |
294 predicted_labels=None, | |
295 plot_color=None, | |
296 title=None, | |
297 ): | |
298 """ | |
299 Parameter | |
300 --------- | |
301 inputs : str | |
302 File path to galaxy tool parameter | |
303 | |
304 infile_estimator : str, default is None | |
305 File path to estimator | |
306 | |
307 infile1 : str, default is None | |
308 File path to dataset containing features or true labels. | |
309 | |
310 infile2 : str, default is None | |
311 File path to dataset containing target values or predicted | |
312 probabilities. | |
313 | |
314 outfile_result : str, default is None | |
315 File path to save the results, either cv_results or test result | |
316 | |
317 outfile_object : str, default is None | |
318 File path to save searchCV object | |
319 | |
320 groups : str, default is None | |
321 File path to dataset containing groups labels | |
322 | |
323 ref_seq : str, default is None | |
324 File path to dataset containing genome sequence file | |
325 | |
326 intervals : str, default is None | |
327 File path to dataset containing interval file | |
328 | |
329 targets : str, default is None | |
330 File path to dataset compressed target bed file | |
331 | |
332 fasta_path : str, default is None | |
333 File path to dataset containing fasta file | |
334 | |
335 model_config : str, default is None | |
336 File path to dataset containing JSON config for neural networks | |
337 | |
338 true_labels : str, default is None | |
339 File path to dataset containing true labels | |
340 | |
341 predicted_labels : str, default is None | |
342 File path to dataset containing true predicted labels | |
343 | |
344 plot_color : str, default is None | |
345 Color of the confusion matrix heatmap | |
346 | |
347 title : str, default is None | |
348 Title of the confusion matrix heatmap | |
349 """ | |
350 warnings.simplefilter("ignore") | |
351 | |
352 with open(inputs, "r") as param_handler: | |
353 params = json.load(param_handler) | |
354 | |
355 title = params["plotting_selection"]["title"].strip() | |
356 plot_type = params["plotting_selection"]["plot_type"] | |
357 plot_format = params["plotting_selection"]["plot_format"] | |
358 | |
359 if plot_type == "feature_importances": | |
360 with open(infile_estimator, "rb") as estimator_handler: | |
361 estimator = load_model(estimator_handler) | |
362 | |
363 column_option = params["plotting_selection"]["column_selector_options"][ | |
364 "selected_column_selector_option" | |
365 ] | |
366 if column_option in [ | |
367 "by_index_number", | |
368 "all_but_by_index_number", | |
369 "by_header_name", | |
370 "all_but_by_header_name", | |
371 ]: | |
372 c = params["plotting_selection"]["column_selector_options"]["col1"] | |
373 else: | |
374 c = None | |
375 | |
376 _, input_df = read_columns( | |
377 infile1, | |
378 c=c, | |
379 c_option=column_option, | |
380 return_df=True, | |
381 sep="\t", | |
382 header="infer", | |
383 parse_dates=True, | |
384 ) | |
385 | |
386 feature_names = input_df.columns.values | |
387 | |
388 if isinstance(estimator, Pipeline): | |
389 for st in estimator.steps[:-1]: | |
390 if isinstance(st[-1], SelectorMixin): | |
391 mask = st[-1].get_support() | |
392 feature_names = feature_names[mask] | |
393 estimator = estimator.steps[-1][-1] | |
394 | |
395 if hasattr(estimator, "coef_"): | |
396 coefs = estimator.coef_ | |
397 else: | |
398 coefs = getattr(estimator, "feature_importances_", None) | |
399 if coefs is None: | |
400 raise RuntimeError( | |
401 "The classifier does not expose " | |
402 '"coef_" or "feature_importances_" ' | |
403 "attributes" | |
404 ) | |
405 | |
406 threshold = params["plotting_selection"]["threshold"] | |
407 if threshold is not None: | |
408 mask = (coefs > threshold) | (coefs < -threshold) | |
409 coefs = coefs[mask] | |
410 feature_names = feature_names[mask] | |
411 | |
412 # sort | |
413 indices = np.argsort(coefs)[::-1] | |
414 | |
415 trace = go.Bar(x=feature_names[indices], y=coefs[indices]) | |
416 layout = go.Layout(title=title or "Feature Importances") | |
417 fig = go.Figure(data=[trace], layout=layout) | |
418 | |
419 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
420 # to be discovered by `from_work_dir` | |
421 os.rename("output.html", "output") | |
422 | |
423 return 0 | |
424 | |
425 elif plot_type in ("pr_curve", "roc_curve"): | |
426 df1 = pd.read_csv(infile1, sep="\t", header="infer") | |
427 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32) | |
428 | |
429 minimum = params["plotting_selection"]["report_minimum_n_positives"] | |
430 # filter out columns whose n_positives is beblow the threhold | |
431 if minimum: | |
432 mask = df1.sum(axis=0) >= minimum | |
433 df1 = df1.loc[:, mask] | |
434 df2 = df2.loc[:, mask] | |
435 | |
436 pos_label = params["plotting_selection"]["pos_label"].strip() or None | |
437 | |
438 if plot_type == "pr_curve": | |
439 if plot_format == "plotly_html": | |
440 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) | |
441 else: | |
442 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) | |
443 else: # 'roc_curve' | |
444 drop_intermediate = params["plotting_selection"]["drop_intermediate"] | |
445 if plot_format == "plotly_html": | |
446 visualize_roc_curve_plotly( | |
447 df1, | |
448 df2, | |
449 pos_label, | |
450 drop_intermediate=drop_intermediate, | |
451 title=title, | |
452 ) | |
453 else: | |
454 visualize_roc_curve_matplotlib( | |
455 df1, | |
456 df2, | |
457 pos_label, | |
458 drop_intermediate=drop_intermediate, | |
459 title=title, | |
460 ) | |
461 | |
462 return 0 | |
463 | |
464 elif plot_type == "rfecv_gridscores": | |
465 input_df = pd.read_csv(infile1, sep="\t", header="infer") | |
466 scores = input_df.iloc[:, 0] | |
467 steps = params["plotting_selection"]["steps"].strip() | |
468 steps = safe_eval(steps) | |
469 | |
470 data = go.Scatter( | |
471 x=list(range(len(scores))), | |
472 y=scores, | |
473 text=[str(_) for _ in steps] if steps else None, | |
474 mode="lines", | |
475 ) | |
476 layout = go.Layout( | |
477 xaxis=dict(title="Number of features selected"), | |
478 yaxis=dict(title="Cross validation score"), | |
479 title=dict( | |
480 text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top" | |
481 ), | |
482 font=dict(family="sans-serif", size=11), | |
483 # control backgroud colors | |
484 plot_bgcolor="rgba(255,255,255,0)", | |
485 ) | |
486 """ | |
487 # legend=dict( | |
488 # x=0.95, | |
489 # y=0, | |
490 # traceorder="normal", | |
491 # font=dict( | |
492 # family="sans-serif", | |
493 # size=9, | |
494 # color="black" | |
495 # ), | |
496 # bgcolor="LightSteelBlue", | |
497 # bordercolor="Black", | |
498 # borderwidth=2 | |
499 # ), | |
500 """ | |
501 | |
502 fig = go.Figure(data=[data], layout=layout) | |
503 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
504 # to be discovered by `from_work_dir` | |
505 os.rename("output.html", "output") | |
506 | |
507 return 0 | |
508 | |
509 elif plot_type == "learning_curve": | |
510 input_df = pd.read_csv(infile1, sep="\t", header="infer") | |
511 plot_std_err = params["plotting_selection"]["plot_std_err"] | |
512 data1 = go.Scatter( | |
513 x=input_df["train_sizes_abs"], | |
514 y=input_df["mean_train_scores"], | |
515 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None, | |
516 mode="lines", | |
517 name="Train Scores", | |
518 ) | |
519 data2 = go.Scatter( | |
520 x=input_df["train_sizes_abs"], | |
521 y=input_df["mean_test_scores"], | |
522 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None, | |
523 mode="lines", | |
524 name="Test Scores", | |
525 ) | |
526 layout = dict( | |
527 xaxis=dict(title="No. of samples"), | |
528 yaxis=dict(title="Performance Score"), | |
529 # modify these configurations to customize image | |
530 title=dict( | |
531 text=title or "Learning Curve", | |
532 x=0.5, | |
533 y=0.92, | |
534 xanchor="center", | |
535 yanchor="top", | |
536 ), | |
537 font=dict(family="sans-serif", size=11), | |
538 # control backgroud colors | |
539 plot_bgcolor="rgba(255,255,255,0)", | |
540 ) | |
541 """ | |
542 # legend=dict( | |
543 # x=0.95, | |
544 # y=0, | |
545 # traceorder="normal", | |
546 # font=dict( | |
547 # family="sans-serif", | |
548 # size=9, | |
549 # color="black" | |
550 # ), | |
551 # bgcolor="LightSteelBlue", | |
552 # bordercolor="Black", | |
553 # borderwidth=2 | |
554 # ), | |
555 """ | |
556 | |
557 fig = go.Figure(data=[data1, data2], layout=layout) | |
558 plotly.offline.plot(fig, filename="output.html", auto_open=False) | |
559 # to be discovered by `from_work_dir` | |
560 os.rename("output.html", "output") | |
561 | |
562 return 0 | |
563 | |
564 elif plot_type == "keras_plot_model": | |
565 with open(model_config, "r") as f: | |
566 model_str = f.read() | |
567 model = model_from_json(model_str) | |
568 plot_model(model, to_file="output.png") | |
569 os.rename("output.png", "output") | |
570 | |
571 return 0 | |
572 | |
573 elif plot_type == "classification_confusion_matrix": | |
574 plot_selection = params["plotting_selection"] | |
575 input_true = get_dataframe( | |
576 true_labels, plot_selection, "header_true", "column_selector_options_true" | |
577 ) | |
578 header_predicted = "infer" if plot_selection["header_predicted"] else None | |
579 input_predicted = pd.read_csv( | |
580 predicted_labels, sep="\t", parse_dates=True, header=header_predicted | |
581 ) | |
582 true_classes = input_true.iloc[:, -1].copy() | |
583 predicted_classes = input_predicted.iloc[:, -1].copy() | |
584 axis_labels = list(set(true_classes)) | |
585 c_matrix = confusion_matrix(true_classes, predicted_classes) | |
586 fig, ax = plt.subplots(figsize=(7, 7)) | |
587 im = plt.imshow(c_matrix, cmap=plot_color) | |
588 for i in range(len(c_matrix)): | |
589 for j in range(len(c_matrix)): | |
590 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") | |
591 ax.set_ylabel("True class labels") | |
592 ax.set_xlabel("Predicted class labels") | |
593 ax.set_title(title) | |
594 ax.set_xticks(axis_labels) | |
595 ax.set_yticks(axis_labels) | |
596 fig.colorbar(im, ax=ax) | |
597 fig.tight_layout() | |
598 plt.savefig("output.png", dpi=125) | |
599 os.rename("output.png", "output") | |
600 | |
601 return 0 | |
602 | |
603 # save pdf file to disk | |
604 # fig.write_image("image.pdf", format='pdf') | |
605 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | |
606 | |
607 | |
608 if __name__ == "__main__": | |
609 aparser = argparse.ArgumentParser() | |
610 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | |
611 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | |
612 aparser.add_argument("-X", "--infile1", dest="infile1") | |
613 aparser.add_argument("-y", "--infile2", dest="infile2") | |
614 aparser.add_argument("-O", "--outfile_result", dest="outfile_result") | |
615 aparser.add_argument("-o", "--outfile_object", dest="outfile_object") | |
616 aparser.add_argument("-g", "--groups", dest="groups") | |
617 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | |
618 aparser.add_argument("-b", "--intervals", dest="intervals") | |
619 aparser.add_argument("-t", "--targets", dest="targets") | |
620 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | |
621 aparser.add_argument("-c", "--model_config", dest="model_config") | |
622 aparser.add_argument("-tl", "--true_labels", dest="true_labels") | |
623 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") | |
624 aparser.add_argument("-pc", "--plot_color", dest="plot_color") | |
625 aparser.add_argument("-pt", "--title", dest="title") | |
626 args = aparser.parse_args() | |
627 | |
628 main( | |
629 args.inputs, | |
630 args.infile_estimator, | |
631 args.infile1, | |
632 args.infile2, | |
633 args.outfile_result, | |
634 outfile_object=args.outfile_object, | |
635 groups=args.groups, | |
636 ref_seq=args.ref_seq, | |
637 intervals=args.intervals, | |
638 targets=args.targets, | |
639 fasta_path=args.fasta_path, | |
640 model_config=args.model_config, | |
641 true_labels=args.true_labels, | |
642 predicted_labels=args.predicted_labels, | |
643 plot_color=args.plot_color, | |
644 title=args.title, | |
645 ) |