Mercurial > repos > bgruening > sklearn_train_test_split
comparison ml_visualization_ex.py @ 6:13b9ac5d277c draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 208a8d348e7c7a182cfbe1b6f17868146428a7e2"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 22:24:07 +0000 |
parents | 145208b3579d |
children | 3312fb686ffb |
comparison
equal
deleted
inserted
replaced
5:ce2fd1edbc6e | 6:13b9ac5d277c |
---|---|
1 import argparse | 1 import argparse |
2 import json | 2 import json |
3 import os | |
4 import warnings | |
5 | |
3 import matplotlib | 6 import matplotlib |
4 import matplotlib.pyplot as plt | 7 import matplotlib.pyplot as plt |
5 import numpy as np | 8 import numpy as np |
6 import os | |
7 import pandas as pd | 9 import pandas as pd |
8 import plotly | 10 import plotly |
9 import plotly.graph_objs as go | 11 import plotly.graph_objs as go |
10 import warnings | 12 from galaxy_ml.utils import load_model, read_columns, SafeEval |
11 | |
12 from keras.models import model_from_json | 13 from keras.models import model_from_json |
13 from keras.utils import plot_model | 14 from keras.utils import plot_model |
14 from sklearn.feature_selection.base import SelectorMixin | 15 from sklearn.feature_selection.base import SelectorMixin |
15 from sklearn.metrics import precision_recall_curve, average_precision_score | 16 from sklearn.metrics import auc, average_precision_score, confusion_matrix, precision_recall_curve, roc_curve |
16 from sklearn.metrics import roc_curve, auc, confusion_matrix | |
17 from sklearn.pipeline import Pipeline | 17 from sklearn.pipeline import Pipeline |
18 from galaxy_ml.utils import load_model, read_columns, SafeEval | |
19 | 18 |
20 | 19 |
21 safe_eval = SafeEval() | 20 safe_eval = SafeEval() |
22 | 21 |
23 # plotly default colors | 22 # plotly default colors |
24 default_colors = [ | 23 default_colors = [ |
25 '#1f77b4', # muted blue | 24 "#1f77b4", # muted blue |
26 '#ff7f0e', # safety orange | 25 "#ff7f0e", # safety orange |
27 '#2ca02c', # cooked asparagus green | 26 "#2ca02c", # cooked asparagus green |
28 '#d62728', # brick red | 27 "#d62728", # brick red |
29 '#9467bd', # muted purple | 28 "#9467bd", # muted purple |
30 '#8c564b', # chestnut brown | 29 "#8c564b", # chestnut brown |
31 '#e377c2', # raspberry yogurt pink | 30 "#e377c2", # raspberry yogurt pink |
32 '#7f7f7f', # middle gray | 31 "#7f7f7f", # middle gray |
33 '#bcbd22', # curry yellow-green | 32 "#bcbd22", # curry yellow-green |
34 '#17becf' # blue-teal | 33 "#17becf", # blue-teal |
35 ] | 34 ] |
36 | 35 |
37 | 36 |
38 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): | 37 def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): |
39 """output pr-curve in html using plotly | 38 """output pr-curve in html using plotly |
50 data = [] | 49 data = [] |
51 for idx in range(df1.shape[1]): | 50 for idx in range(df1.shape[1]): |
52 y_true = df1.iloc[:, idx].values | 51 y_true = df1.iloc[:, idx].values |
53 y_score = df2.iloc[:, idx].values | 52 y_score = df2.iloc[:, idx].values |
54 | 53 |
55 precision, recall, _ = precision_recall_curve( | 54 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) |
56 y_true, y_score, pos_label=pos_label) | 55 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) |
57 ap = average_precision_score( | |
58 y_true, y_score, pos_label=pos_label or 1) | |
59 | 56 |
60 trace = go.Scatter( | 57 trace = go.Scatter( |
61 x=recall, | 58 x=recall, |
62 y=precision, | 59 y=precision, |
63 mode='lines', | 60 mode="lines", |
64 marker=dict( | 61 marker=dict(color=default_colors[idx % len(default_colors)]), |
65 color=default_colors[idx % len(default_colors)] | 62 name="%s (area = %.3f)" % (idx, ap), |
66 ), | |
67 name='%s (area = %.3f)' % (idx, ap) | |
68 ) | 63 ) |
69 data.append(trace) | 64 data.append(trace) |
70 | 65 |
71 layout = go.Layout( | 66 layout = go.Layout( |
72 xaxis=dict( | 67 xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1), |
73 title='Recall', | 68 yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1), |
74 linecolor='lightslategray', | |
75 linewidth=1 | |
76 ), | |
77 yaxis=dict( | |
78 title='Precision', | |
79 linecolor='lightslategray', | |
80 linewidth=1 | |
81 ), | |
82 title=dict( | 69 title=dict( |
83 text=title or 'Precision-Recall Curve', | 70 text=title or "Precision-Recall Curve", |
84 x=0.5, | 71 x=0.5, |
85 y=0.92, | 72 y=0.92, |
86 xanchor='center', | 73 xanchor="center", |
87 yanchor='top' | 74 yanchor="top", |
88 ), | 75 ), |
89 font=dict( | 76 font=dict(family="sans-serif", size=11), |
90 family="sans-serif", | |
91 size=11 | |
92 ), | |
93 # control backgroud colors | 77 # control backgroud colors |
94 plot_bgcolor='rgba(255,255,255,0)' | 78 plot_bgcolor="rgba(255,255,255,0)", |
95 ) | 79 ) |
96 """ | 80 """ |
97 legend=dict( | 81 legend=dict( |
98 x=0.95, | 82 x=0.95, |
99 y=0, | 83 y=0, |
110 | 94 |
111 fig = go.Figure(data=data, layout=layout) | 95 fig = go.Figure(data=data, layout=layout) |
112 | 96 |
113 plotly.offline.plot(fig, filename="output.html", auto_open=False) | 97 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
114 # to be discovered by `from_work_dir` | 98 # to be discovered by `from_work_dir` |
115 os.rename('output.html', 'output') | 99 os.rename("output.html", "output") |
116 | 100 |
117 | 101 |
118 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): | 102 def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): |
119 """visualize pr-curve using matplotlib and output svg image | 103 """visualize pr-curve using matplotlib and output svg image""" |
120 """ | |
121 backend = matplotlib.get_backend() | 104 backend = matplotlib.get_backend() |
122 if "inline" not in backend: | 105 if "inline" not in backend: |
123 matplotlib.use("SVG") | 106 matplotlib.use("SVG") |
124 plt.style.use('seaborn-colorblind') | 107 plt.style.use("seaborn-colorblind") |
125 plt.figure() | 108 plt.figure() |
126 | 109 |
127 for idx in range(df1.shape[1]): | 110 for idx in range(df1.shape[1]): |
128 y_true = df1.iloc[:, idx].values | 111 y_true = df1.iloc[:, idx].values |
129 y_score = df2.iloc[:, idx].values | 112 y_score = df2.iloc[:, idx].values |
130 | 113 |
131 precision, recall, _ = precision_recall_curve( | 114 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) |
132 y_true, y_score, pos_label=pos_label) | 115 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) |
133 ap = average_precision_score( | 116 |
134 y_true, y_score, pos_label=pos_label or 1) | 117 plt.step( |
135 | 118 recall, |
136 plt.step(recall, precision, 'r-', color="black", alpha=0.3, | 119 precision, |
137 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) | 120 "r-", |
121 color="black", | |
122 alpha=0.3, | |
123 lw=1, | |
124 where="post", | |
125 label="%s (area = %.3f)" % (idx, ap), | |
126 ) | |
138 | 127 |
139 plt.xlim([0.0, 1.0]) | 128 plt.xlim([0.0, 1.0]) |
140 plt.ylim([0.0, 1.05]) | 129 plt.ylim([0.0, 1.05]) |
141 plt.xlabel('Recall') | 130 plt.xlabel("Recall") |
142 plt.ylabel('Precision') | 131 plt.ylabel("Precision") |
143 title = title or 'Precision-Recall Curve' | 132 title = title or "Precision-Recall Curve" |
144 plt.title(title) | 133 plt.title(title) |
145 folder = os.getcwd() | 134 folder = os.getcwd() |
146 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 135 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
147 os.rename(os.path.join(folder, "output.svg"), | 136 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) |
148 os.path.join(folder, "output")) | 137 |
149 | 138 |
150 | 139 def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None): |
151 def visualize_roc_curve_plotly(df1, df2, pos_label, | |
152 drop_intermediate=True, | |
153 title=None): | |
154 """output roc-curve in html using plotly | 140 """output roc-curve in html using plotly |
155 | 141 |
156 df1 : pandas.DataFrame | 142 df1 : pandas.DataFrame |
157 Containing y_true | 143 Containing y_true |
158 df2 : pandas.DataFrame | 144 df2 : pandas.DataFrame |
167 data = [] | 153 data = [] |
168 for idx in range(df1.shape[1]): | 154 for idx in range(df1.shape[1]): |
169 y_true = df1.iloc[:, idx].values | 155 y_true = df1.iloc[:, idx].values |
170 y_score = df2.iloc[:, idx].values | 156 y_score = df2.iloc[:, idx].values |
171 | 157 |
172 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, | 158 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) |
173 drop_intermediate=drop_intermediate) | |
174 roc_auc = auc(fpr, tpr) | 159 roc_auc = auc(fpr, tpr) |
175 | 160 |
176 trace = go.Scatter( | 161 trace = go.Scatter( |
177 x=fpr, | 162 x=fpr, |
178 y=tpr, | 163 y=tpr, |
179 mode='lines', | 164 mode="lines", |
180 marker=dict( | 165 marker=dict(color=default_colors[idx % len(default_colors)]), |
181 color=default_colors[idx % len(default_colors)] | 166 name="%s (area = %.3f)" % (idx, roc_auc), |
182 ), | |
183 name='%s (area = %.3f)' % (idx, roc_auc) | |
184 ) | 167 ) |
185 data.append(trace) | 168 data.append(trace) |
186 | 169 |
187 layout = go.Layout( | 170 layout = go.Layout( |
188 xaxis=dict( | 171 xaxis=dict(title="False Positive Rate", linecolor="lightslategray", linewidth=1), |
189 title='False Positive Rate', | 172 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), |
190 linecolor='lightslategray', | |
191 linewidth=1 | |
192 ), | |
193 yaxis=dict( | |
194 title='True Positive Rate', | |
195 linecolor='lightslategray', | |
196 linewidth=1 | |
197 ), | |
198 title=dict( | 173 title=dict( |
199 text=title or 'Receiver Operating Characteristic (ROC) Curve', | 174 text=title or "Receiver Operating Characteristic (ROC) Curve", |
200 x=0.5, | 175 x=0.5, |
201 y=0.92, | 176 y=0.92, |
202 xanchor='center', | 177 xanchor="center", |
203 yanchor='top' | 178 yanchor="top", |
204 ), | 179 ), |
205 font=dict( | 180 font=dict(family="sans-serif", size=11), |
206 family="sans-serif", | |
207 size=11 | |
208 ), | |
209 # control backgroud colors | 181 # control backgroud colors |
210 plot_bgcolor='rgba(255,255,255,0)' | 182 plot_bgcolor="rgba(255,255,255,0)", |
211 ) | 183 ) |
212 """ | 184 """ |
213 # legend=dict( | 185 # legend=dict( |
214 # x=0.95, | 186 # x=0.95, |
215 # y=0, | 187 # y=0, |
227 | 199 |
228 fig = go.Figure(data=data, layout=layout) | 200 fig = go.Figure(data=data, layout=layout) |
229 | 201 |
230 plotly.offline.plot(fig, filename="output.html", auto_open=False) | 202 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
231 # to be discovered by `from_work_dir` | 203 # to be discovered by `from_work_dir` |
232 os.rename('output.html', 'output') | 204 os.rename("output.html", "output") |
233 | 205 |
234 | 206 |
235 def visualize_roc_curve_matplotlib(df1, df2, pos_label, | 207 def visualize_roc_curve_matplotlib(df1, df2, pos_label, drop_intermediate=True, title=None): |
236 drop_intermediate=True, | 208 """visualize roc-curve using matplotlib and output svg image""" |
237 title=None): | |
238 """visualize roc-curve using matplotlib and output svg image | |
239 """ | |
240 backend = matplotlib.get_backend() | 209 backend = matplotlib.get_backend() |
241 if "inline" not in backend: | 210 if "inline" not in backend: |
242 matplotlib.use("SVG") | 211 matplotlib.use("SVG") |
243 plt.style.use('seaborn-colorblind') | 212 plt.style.use("seaborn-colorblind") |
244 plt.figure() | 213 plt.figure() |
245 | 214 |
246 for idx in range(df1.shape[1]): | 215 for idx in range(df1.shape[1]): |
247 y_true = df1.iloc[:, idx].values | 216 y_true = df1.iloc[:, idx].values |
248 y_score = df2.iloc[:, idx].values | 217 y_score = df2.iloc[:, idx].values |
249 | 218 |
250 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, | 219 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) |
251 drop_intermediate=drop_intermediate) | |
252 roc_auc = auc(fpr, tpr) | 220 roc_auc = auc(fpr, tpr) |
253 | 221 |
254 plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, | 222 plt.step( |
255 where="post", label='%s (area = %.3f)' % (idx, roc_auc)) | 223 fpr, |
224 tpr, | |
225 "r-", | |
226 color="black", | |
227 alpha=0.3, | |
228 lw=1, | |
229 where="post", | |
230 label="%s (area = %.3f)" % (idx, roc_auc), | |
231 ) | |
256 | 232 |
257 plt.xlim([0.0, 1.0]) | 233 plt.xlim([0.0, 1.0]) |
258 plt.ylim([0.0, 1.05]) | 234 plt.ylim([0.0, 1.05]) |
259 plt.xlabel('False Positive Rate') | 235 plt.xlabel("False Positive Rate") |
260 plt.ylabel('True Positive Rate') | 236 plt.ylabel("True Positive Rate") |
261 title = title or 'Receiver Operating Characteristic (ROC) Curve' | 237 title = title or "Receiver Operating Characteristic (ROC) Curve" |
262 plt.title(title) | 238 plt.title(title) |
263 folder = os.getcwd() | 239 folder = os.getcwd() |
264 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 240 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
265 os.rename(os.path.join(folder, "output.svg"), | 241 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) |
266 os.path.join(folder, "output")) | |
267 | 242 |
268 | 243 |
269 def get_dataframe(file_path, plot_selection, header_name, column_name): | 244 def get_dataframe(file_path, plot_selection, header_name, column_name): |
270 header = 'infer' if plot_selection[header_name] else None | 245 header = "infer" if plot_selection[header_name] else None |
271 column_option = plot_selection[column_name]["selected_column_selector_option"] | 246 column_option = plot_selection[column_name]["selected_column_selector_option"] |
272 if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: | 247 if column_option in [ |
248 "by_index_number", | |
249 "all_but_by_index_number", | |
250 "by_header_name", | |
251 "all_but_by_header_name", | |
252 ]: | |
273 col = plot_selection[column_name]["col1"] | 253 col = plot_selection[column_name]["col1"] |
274 else: | 254 else: |
275 col = None | 255 col = None |
276 _, input_df = read_columns(file_path, c=col, | 256 _, input_df = read_columns(file_path, c=col, |
277 c_option=column_option, | 257 c_option=column_option, |
278 return_df=True, | 258 return_df=True, |
279 sep='\t', header=header, | 259 sep='\t', header=header, |
280 parse_dates=True) | 260 parse_dates=True) |
281 return input_df | 261 return input_df |
282 | 262 |
283 | 263 |
284 def main(inputs, infile_estimator=None, infile1=None, | 264 def main( |
285 infile2=None, outfile_result=None, | 265 inputs, |
286 outfile_object=None, groups=None, | 266 infile_estimator=None, |
287 ref_seq=None, intervals=None, | 267 infile1=None, |
288 targets=None, fasta_path=None, | 268 infile2=None, |
289 model_config=None, true_labels=None, | 269 outfile_result=None, |
290 predicted_labels=None, plot_color=None, | 270 outfile_object=None, |
291 title=None): | 271 groups=None, |
272 ref_seq=None, | |
273 intervals=None, | |
274 targets=None, | |
275 fasta_path=None, | |
276 model_config=None, | |
277 true_labels=None, | |
278 predicted_labels=None, | |
279 plot_color=None, | |
280 title=None, | |
281 ): | |
292 """ | 282 """ |
293 Parameter | 283 Parameter |
294 --------- | 284 --------- |
295 inputs : str | 285 inputs : str |
296 File path to galaxy tool parameter | 286 File path to galaxy tool parameter |
339 Color of the confusion matrix heatmap | 329 Color of the confusion matrix heatmap |
340 | 330 |
341 title : str, default is None | 331 title : str, default is None |
342 Title of the confusion matrix heatmap | 332 Title of the confusion matrix heatmap |
343 """ | 333 """ |
344 warnings.simplefilter('ignore') | 334 warnings.simplefilter("ignore") |
345 | 335 |
346 with open(inputs, 'r') as param_handler: | 336 with open(inputs, "r") as param_handler: |
347 params = json.load(param_handler) | 337 params = json.load(param_handler) |
348 | 338 |
349 title = params['plotting_selection']['title'].strip() | 339 title = params["plotting_selection"]["title"].strip() |
350 plot_type = params['plotting_selection']['plot_type'] | 340 plot_type = params["plotting_selection"]["plot_type"] |
351 plot_format = params['plotting_selection']['plot_format'] | 341 plot_format = params["plotting_selection"]["plot_format"] |
352 | 342 |
353 if plot_type == 'feature_importances': | 343 if plot_type == "feature_importances": |
354 with open(infile_estimator, 'rb') as estimator_handler: | 344 with open(infile_estimator, "rb") as estimator_handler: |
355 estimator = load_model(estimator_handler) | 345 estimator = load_model(estimator_handler) |
356 | 346 |
357 column_option = (params['plotting_selection'] | 347 column_option = params["plotting_selection"]["column_selector_options"]["selected_column_selector_option"] |
358 ['column_selector_options'] | 348 if column_option in [ |
359 ['selected_column_selector_option']) | 349 "by_index_number", |
360 if column_option in ['by_index_number', 'all_but_by_index_number', | 350 "all_but_by_index_number", |
361 'by_header_name', 'all_but_by_header_name']: | 351 "by_header_name", |
362 c = (params['plotting_selection'] | 352 "all_but_by_header_name", |
363 ['column_selector_options']['col1']) | 353 ]: |
354 c = params["plotting_selection"]["column_selector_options"]["col1"] | |
364 else: | 355 else: |
365 c = None | 356 c = None |
366 | 357 |
367 _, input_df = read_columns(infile1, c=c, | 358 _, input_df = read_columns( |
368 c_option=column_option, | 359 infile1, |
369 return_df=True, | 360 c=c, |
370 sep='\t', header='infer', | 361 c_option=column_option, |
371 parse_dates=True) | 362 return_df=True, |
363 sep="\t", | |
364 header="infer", | |
365 parse_dates=True, | |
366 ) | |
372 | 367 |
373 feature_names = input_df.columns.values | 368 feature_names = input_df.columns.values |
374 | 369 |
375 if isinstance(estimator, Pipeline): | 370 if isinstance(estimator, Pipeline): |
376 for st in estimator.steps[:-1]: | 371 for st in estimator.steps[:-1]: |
377 if isinstance(st[-1], SelectorMixin): | 372 if isinstance(st[-1], SelectorMixin): |
378 mask = st[-1].get_support() | 373 mask = st[-1].get_support() |
379 feature_names = feature_names[mask] | 374 feature_names = feature_names[mask] |
380 estimator = estimator.steps[-1][-1] | 375 estimator = estimator.steps[-1][-1] |
381 | 376 |
382 if hasattr(estimator, 'coef_'): | 377 if hasattr(estimator, "coef_"): |
383 coefs = estimator.coef_ | 378 coefs = estimator.coef_ |
384 else: | 379 else: |
385 coefs = getattr(estimator, 'feature_importances_', None) | 380 coefs = getattr(estimator, "feature_importances_", None) |
386 if coefs is None: | 381 if coefs is None: |
387 raise RuntimeError('The classifier does not expose ' | 382 raise RuntimeError("The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes") |
388 '"coef_" or "feature_importances_" ' | 383 |
389 'attributes') | 384 threshold = params["plotting_selection"]["threshold"] |
390 | |
391 threshold = params['plotting_selection']['threshold'] | |
392 if threshold is not None: | 385 if threshold is not None: |
393 mask = (coefs > threshold) | (coefs < -threshold) | 386 mask = (coefs > threshold) | (coefs < -threshold) |
394 coefs = coefs[mask] | 387 coefs = coefs[mask] |
395 feature_names = feature_names[mask] | 388 feature_names = feature_names[mask] |
396 | 389 |
397 # sort | 390 # sort |
398 indices = np.argsort(coefs)[::-1] | 391 indices = np.argsort(coefs)[::-1] |
399 | 392 |
400 trace = go.Bar(x=feature_names[indices], | 393 trace = go.Bar(x=feature_names[indices], y=coefs[indices]) |
401 y=coefs[indices]) | |
402 layout = go.Layout(title=title or "Feature Importances") | 394 layout = go.Layout(title=title or "Feature Importances") |
403 fig = go.Figure(data=[trace], layout=layout) | 395 fig = go.Figure(data=[trace], layout=layout) |
404 | 396 |
405 plotly.offline.plot(fig, filename="output.html", | 397 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
406 auto_open=False) | |
407 # to be discovered by `from_work_dir` | 398 # to be discovered by `from_work_dir` |
408 os.rename('output.html', 'output') | 399 os.rename("output.html", "output") |
409 | 400 |
410 return 0 | 401 return 0 |
411 | 402 |
412 elif plot_type in ('pr_curve', 'roc_curve'): | 403 elif plot_type in ("pr_curve", "roc_curve"): |
413 df1 = pd.read_csv(infile1, sep='\t', header='infer') | 404 df1 = pd.read_csv(infile1, sep="\t", header="infer") |
414 df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32) | 405 df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32) |
415 | 406 |
416 minimum = params['plotting_selection']['report_minimum_n_positives'] | 407 minimum = params["plotting_selection"]["report_minimum_n_positives"] |
417 # filter out columns whose n_positives is beblow the threhold | 408 # filter out columns whose n_positives is beblow the threhold |
418 if minimum: | 409 if minimum: |
419 mask = df1.sum(axis=0) >= minimum | 410 mask = df1.sum(axis=0) >= minimum |
420 df1 = df1.loc[:, mask] | 411 df1 = df1.loc[:, mask] |
421 df2 = df2.loc[:, mask] | 412 df2 = df2.loc[:, mask] |
422 | 413 |
423 pos_label = params['plotting_selection']['pos_label'].strip() \ | 414 pos_label = params["plotting_selection"]["pos_label"].strip() or None |
424 or None | 415 |
425 | 416 if plot_type == "pr_curve": |
426 if plot_type == 'pr_curve': | 417 if plot_format == "plotly_html": |
427 if plot_format == 'plotly_html': | |
428 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) | 418 visualize_pr_curve_plotly(df1, df2, pos_label, title=title) |
429 else: | 419 else: |
430 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) | 420 visualize_pr_curve_matplotlib(df1, df2, pos_label, title) |
431 else: # 'roc_curve' | 421 else: # 'roc_curve' |
432 drop_intermediate = (params['plotting_selection'] | 422 drop_intermediate = params["plotting_selection"]["drop_intermediate"] |
433 ['drop_intermediate']) | 423 if plot_format == "plotly_html": |
434 if plot_format == 'plotly_html': | 424 visualize_roc_curve_plotly( |
435 visualize_roc_curve_plotly(df1, df2, pos_label, | 425 df1, |
436 drop_intermediate=drop_intermediate, | 426 df2, |
437 title=title) | 427 pos_label, |
428 drop_intermediate=drop_intermediate, | |
429 title=title, | |
430 ) | |
438 else: | 431 else: |
439 visualize_roc_curve_matplotlib( | 432 visualize_roc_curve_matplotlib( |
440 df1, df2, pos_label, | 433 df1, |
434 df2, | |
435 pos_label, | |
441 drop_intermediate=drop_intermediate, | 436 drop_intermediate=drop_intermediate, |
442 title=title) | 437 title=title, |
438 ) | |
443 | 439 |
444 return 0 | 440 return 0 |
445 | 441 |
446 elif plot_type == 'rfecv_gridscores': | 442 elif plot_type == "rfecv_gridscores": |
447 input_df = pd.read_csv(infile1, sep='\t', header='infer') | 443 input_df = pd.read_csv(infile1, sep="\t", header="infer") |
448 scores = input_df.iloc[:, 0] | 444 scores = input_df.iloc[:, 0] |
449 steps = params['plotting_selection']['steps'].strip() | 445 steps = params["plotting_selection"]["steps"].strip() |
450 steps = safe_eval(steps) | 446 steps = safe_eval(steps) |
451 | 447 |
452 data = go.Scatter( | 448 data = go.Scatter( |
453 x=list(range(len(scores))), | 449 x=list(range(len(scores))), |
454 y=scores, | 450 y=scores, |
455 text=[str(_) for _ in steps] if steps else None, | 451 text=[str(_) for _ in steps] if steps else None, |
456 mode='lines' | 452 mode="lines", |
457 ) | 453 ) |
458 layout = go.Layout( | 454 layout = go.Layout( |
459 xaxis=dict(title="Number of features selected"), | 455 xaxis=dict(title="Number of features selected"), |
460 yaxis=dict(title="Cross validation score"), | 456 yaxis=dict(title="Cross validation score"), |
461 title=dict( | 457 title=dict(text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"), |
462 text=title or None, | 458 font=dict(family="sans-serif", size=11), |
463 x=0.5, | |
464 y=0.92, | |
465 xanchor='center', | |
466 yanchor='top' | |
467 ), | |
468 font=dict( | |
469 family="sans-serif", | |
470 size=11 | |
471 ), | |
472 # control backgroud colors | 459 # control backgroud colors |
473 plot_bgcolor='rgba(255,255,255,0)' | 460 plot_bgcolor="rgba(255,255,255,0)", |
474 ) | 461 ) |
475 """ | 462 """ |
476 # legend=dict( | 463 # legend=dict( |
477 # x=0.95, | 464 # x=0.95, |
478 # y=0, | 465 # y=0, |
487 # borderwidth=2 | 474 # borderwidth=2 |
488 # ), | 475 # ), |
489 """ | 476 """ |
490 | 477 |
491 fig = go.Figure(data=[data], layout=layout) | 478 fig = go.Figure(data=[data], layout=layout) |
492 plotly.offline.plot(fig, filename="output.html", | 479 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
493 auto_open=False) | |
494 # to be discovered by `from_work_dir` | 480 # to be discovered by `from_work_dir` |
495 os.rename('output.html', 'output') | 481 os.rename("output.html", "output") |
496 | 482 |
497 return 0 | 483 return 0 |
498 | 484 |
499 elif plot_type == 'learning_curve': | 485 elif plot_type == "learning_curve": |
500 input_df = pd.read_csv(infile1, sep='\t', header='infer') | 486 input_df = pd.read_csv(infile1, sep="\t", header="infer") |
501 plot_std_err = params['plotting_selection']['plot_std_err'] | 487 plot_std_err = params["plotting_selection"]["plot_std_err"] |
502 data1 = go.Scatter( | 488 data1 = go.Scatter( |
503 x=input_df['train_sizes_abs'], | 489 x=input_df["train_sizes_abs"], |
504 y=input_df['mean_train_scores'], | 490 y=input_df["mean_train_scores"], |
505 error_y=dict( | 491 error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None, |
506 array=input_df['std_train_scores'] | 492 mode="lines", |
507 ) if plot_std_err else None, | |
508 mode='lines', | |
509 name="Train Scores", | 493 name="Train Scores", |
510 ) | 494 ) |
511 data2 = go.Scatter( | 495 data2 = go.Scatter( |
512 x=input_df['train_sizes_abs'], | 496 x=input_df["train_sizes_abs"], |
513 y=input_df['mean_test_scores'], | 497 y=input_df["mean_test_scores"], |
514 error_y=dict( | 498 error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None, |
515 array=input_df['std_test_scores'] | 499 mode="lines", |
516 ) if plot_std_err else None, | |
517 mode='lines', | |
518 name="Test Scores", | 500 name="Test Scores", |
519 ) | 501 ) |
520 layout = dict( | 502 layout = dict( |
521 xaxis=dict( | 503 xaxis=dict(title="No. of samples"), |
522 title='No. of samples' | 504 yaxis=dict(title="Performance Score"), |
523 ), | |
524 yaxis=dict( | |
525 title='Performance Score' | |
526 ), | |
527 # modify these configurations to customize image | 505 # modify these configurations to customize image |
528 title=dict( | 506 title=dict( |
529 text=title or 'Learning Curve', | 507 text=title or "Learning Curve", |
530 x=0.5, | 508 x=0.5, |
531 y=0.92, | 509 y=0.92, |
532 xanchor='center', | 510 xanchor="center", |
533 yanchor='top' | 511 yanchor="top", |
534 ), | 512 ), |
535 font=dict( | 513 font=dict(family="sans-serif", size=11), |
536 family="sans-serif", | |
537 size=11 | |
538 ), | |
539 # control backgroud colors | 514 # control backgroud colors |
540 plot_bgcolor='rgba(255,255,255,0)' | 515 plot_bgcolor="rgba(255,255,255,0)", |
541 ) | 516 ) |
542 """ | 517 """ |
543 # legend=dict( | 518 # legend=dict( |
544 # x=0.95, | 519 # x=0.95, |
545 # y=0, | 520 # y=0, |
554 # borderwidth=2 | 529 # borderwidth=2 |
555 # ), | 530 # ), |
556 """ | 531 """ |
557 | 532 |
558 fig = go.Figure(data=[data1, data2], layout=layout) | 533 fig = go.Figure(data=[data1, data2], layout=layout) |
559 plotly.offline.plot(fig, filename="output.html", | 534 plotly.offline.plot(fig, filename="output.html", auto_open=False) |
560 auto_open=False) | |
561 # to be discovered by `from_work_dir` | 535 # to be discovered by `from_work_dir` |
562 os.rename('output.html', 'output') | 536 os.rename("output.html", "output") |
563 | 537 |
564 return 0 | 538 return 0 |
565 | 539 |
566 elif plot_type == 'keras_plot_model': | 540 elif plot_type == "keras_plot_model": |
567 with open(model_config, 'r') as f: | 541 with open(model_config, "r") as f: |
568 model_str = f.read() | 542 model_str = f.read() |
569 model = model_from_json(model_str) | 543 model = model_from_json(model_str) |
570 plot_model(model, to_file="output.png") | 544 plot_model(model, to_file="output.png") |
571 os.rename('output.png', 'output') | 545 os.rename("output.png", "output") |
572 | 546 |
573 return 0 | 547 return 0 |
574 | 548 |
575 elif plot_type == 'classification_confusion_matrix': | 549 elif plot_type == "classification_confusion_matrix": |
576 plot_selection = params["plotting_selection"] | 550 plot_selection = params["plotting_selection"] |
577 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") | 551 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") |
578 header_predicted = 'infer' if plot_selection["header_predicted"] else None | 552 header_predicted = "infer" if plot_selection["header_predicted"] else None |
579 input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted) | 553 input_predicted = pd.read_csv(predicted_labels, sep="\t", parse_dates=True, header=header_predicted) |
580 true_classes = input_true.iloc[:, -1].copy() | 554 true_classes = input_true.iloc[:, -1].copy() |
581 predicted_classes = input_predicted.iloc[:, -1].copy() | 555 predicted_classes = input_predicted.iloc[:, -1].copy() |
582 axis_labels = list(set(true_classes)) | 556 axis_labels = list(set(true_classes)) |
583 c_matrix = confusion_matrix(true_classes, predicted_classes) | 557 c_matrix = confusion_matrix(true_classes, predicted_classes) |
584 fig, ax = plt.subplots(figsize=(7, 7)) | 558 fig, ax = plt.subplots(figsize=(7, 7)) |
585 im = plt.imshow(c_matrix, cmap=plot_color) | 559 im = plt.imshow(c_matrix, cmap=plot_color) |
586 for i in range(len(c_matrix)): | 560 for i in range(len(c_matrix)): |
587 for j in range(len(c_matrix)): | 561 for j in range(len(c_matrix)): |
588 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") | 562 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") |
589 ax.set_ylabel('True class labels') | 563 ax.set_ylabel("True class labels") |
590 ax.set_xlabel('Predicted class labels') | 564 ax.set_xlabel("Predicted class labels") |
591 ax.set_title(title) | 565 ax.set_title(title) |
592 ax.set_xticks(axis_labels) | 566 ax.set_xticks(axis_labels) |
593 ax.set_yticks(axis_labels) | 567 ax.set_yticks(axis_labels) |
594 fig.colorbar(im, ax=ax) | 568 fig.colorbar(im, ax=ax) |
595 fig.tight_layout() | 569 fig.tight_layout() |
596 plt.savefig("output.png", dpi=125) | 570 plt.savefig("output.png", dpi=125) |
597 os.rename('output.png', 'output') | 571 os.rename("output.png", "output") |
598 | 572 |
599 return 0 | 573 return 0 |
600 | 574 |
601 # save pdf file to disk | 575 # save pdf file to disk |
602 # fig.write_image("image.pdf", format='pdf') | 576 # fig.write_image("image.pdf", format='pdf') |
603 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | 577 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) |
604 | 578 |
605 | 579 |
606 if __name__ == '__main__': | 580 if __name__ == "__main__": |
607 aparser = argparse.ArgumentParser() | 581 aparser = argparse.ArgumentParser() |
608 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) | 582 aparser.add_argument("-i", "--inputs", dest="inputs", required=True) |
609 aparser.add_argument("-e", "--estimator", dest="infile_estimator") | 583 aparser.add_argument("-e", "--estimator", dest="infile_estimator") |
610 aparser.add_argument("-X", "--infile1", dest="infile1") | 584 aparser.add_argument("-X", "--infile1", dest="infile1") |
611 aparser.add_argument("-y", "--infile2", dest="infile2") | 585 aparser.add_argument("-y", "--infile2", dest="infile2") |
621 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") | 595 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") |
622 aparser.add_argument("-pc", "--plot_color", dest="plot_color") | 596 aparser.add_argument("-pc", "--plot_color", dest="plot_color") |
623 aparser.add_argument("-pt", "--title", dest="title") | 597 aparser.add_argument("-pt", "--title", dest="title") |
624 args = aparser.parse_args() | 598 args = aparser.parse_args() |
625 | 599 |
626 main(args.inputs, args.infile_estimator, args.infile1, args.infile2, | 600 main( |
627 args.outfile_result, outfile_object=args.outfile_object, | 601 args.inputs, |
628 groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, | 602 args.infile_estimator, |
629 targets=args.targets, fasta_path=args.fasta_path, | 603 args.infile1, |
630 model_config=args.model_config, true_labels=args.true_labels, | 604 args.infile2, |
631 predicted_labels=args.predicted_labels, | 605 args.outfile_result, |
632 plot_color=args.plot_color, | 606 outfile_object=args.outfile_object, |
633 title=args.title) | 607 groups=args.groups, |
608 ref_seq=args.ref_seq, | |
609 intervals=args.intervals, | |
610 targets=args.targets, | |
611 fasta_path=args.fasta_path, | |
612 model_config=args.model_config, | |
613 true_labels=args.true_labels, | |
614 predicted_labels=args.predicted_labels, | |
615 plot_color=args.plot_color, | |
616 title=args.title, | |
617 ) |