Mercurial > repos > bgruening > sklearn_train_test_eval
comparison ml_visualization_ex.py @ 11:caf7d2b71a48 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
| author | bgruening | 
|---|---|
| date | Sat, 01 May 2021 01:47:26 +0000 | 
| parents | a9e0b963b7bb | 
| children | 2eb5c017958d | 
   comparison
  equal
  deleted
  inserted
  replaced
| 10:a9e0b963b7bb | 11:caf7d2b71a48 | 
|---|---|
| 11 import plotly.graph_objs as go | 11 import plotly.graph_objs as go | 
| 12 from galaxy_ml.utils import load_model, read_columns, SafeEval | 12 from galaxy_ml.utils import load_model, read_columns, SafeEval | 
| 13 from keras.models import model_from_json | 13 from keras.models import model_from_json | 
| 14 from keras.utils import plot_model | 14 from keras.utils import plot_model | 
| 15 from sklearn.feature_selection.base import SelectorMixin | 15 from sklearn.feature_selection.base import SelectorMixin | 
| 16 from sklearn.metrics import auc, average_precision_score, confusion_matrix, precision_recall_curve, roc_curve | 16 from sklearn.metrics import (auc, average_precision_score, confusion_matrix, | 
| 17 precision_recall_curve, roc_curve) | |
| 17 from sklearn.pipeline import Pipeline | 18 from sklearn.pipeline import Pipeline | 
| 18 | |
| 19 | 19 | 
| 20 safe_eval = SafeEval() | 20 safe_eval = SafeEval() | 
| 21 | 21 | 
| 22 # plotly default colors | 22 # plotly default colors | 
| 23 default_colors = [ | 23 default_colors = [ | 
| 49 data = [] | 49 data = [] | 
| 50 for idx in range(df1.shape[1]): | 50 for idx in range(df1.shape[1]): | 
| 51 y_true = df1.iloc[:, idx].values | 51 y_true = df1.iloc[:, idx].values | 
| 52 y_score = df2.iloc[:, idx].values | 52 y_score = df2.iloc[:, idx].values | 
| 53 | 53 | 
| 54 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) | 54 precision, recall, _ = precision_recall_curve( | 
| 55 y_true, y_score, pos_label=pos_label | |
| 56 ) | |
| 55 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) | 57 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) | 
| 56 | 58 | 
| 57 trace = go.Scatter( | 59 trace = go.Scatter( | 
| 58 x=recall, | 60 x=recall, | 
| 59 y=precision, | 61 y=precision, | 
| 109 | 111 | 
| 110 for idx in range(df1.shape[1]): | 112 for idx in range(df1.shape[1]): | 
| 111 y_true = df1.iloc[:, idx].values | 113 y_true = df1.iloc[:, idx].values | 
| 112 y_score = df2.iloc[:, idx].values | 114 y_score = df2.iloc[:, idx].values | 
| 113 | 115 | 
| 114 precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) | 116 precision, recall, _ = precision_recall_curve( | 
| 117 y_true, y_score, pos_label=pos_label | |
| 118 ) | |
| 115 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) | 119 ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) | 
| 116 | 120 | 
| 117 plt.step( | 121 plt.step( | 
| 118 recall, | 122 recall, | 
| 119 precision, | 123 precision, | 
| 153 data = [] | 157 data = [] | 
| 154 for idx in range(df1.shape[1]): | 158 for idx in range(df1.shape[1]): | 
| 155 y_true = df1.iloc[:, idx].values | 159 y_true = df1.iloc[:, idx].values | 
| 156 y_score = df2.iloc[:, idx].values | 160 y_score = df2.iloc[:, idx].values | 
| 157 | 161 | 
| 158 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) | 162 fpr, tpr, _ = roc_curve( | 
| 163 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate | |
| 164 ) | |
| 159 roc_auc = auc(fpr, tpr) | 165 roc_auc = auc(fpr, tpr) | 
| 160 | 166 | 
| 161 trace = go.Scatter( | 167 trace = go.Scatter( | 
| 162 x=fpr, | 168 x=fpr, | 
| 163 y=tpr, | 169 y=tpr, | 
| 166 name="%s (area = %.3f)" % (idx, roc_auc), | 172 name="%s (area = %.3f)" % (idx, roc_auc), | 
| 167 ) | 173 ) | 
| 168 data.append(trace) | 174 data.append(trace) | 
| 169 | 175 | 
| 170 layout = go.Layout( | 176 layout = go.Layout( | 
| 171 xaxis=dict(title="False Positive Rate", linecolor="lightslategray", linewidth=1), | 177 xaxis=dict( | 
| 178 title="False Positive Rate", linecolor="lightslategray", linewidth=1 | |
| 179 ), | |
| 172 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), | 180 yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), | 
| 173 title=dict( | 181 title=dict( | 
| 174 text=title or "Receiver Operating Characteristic (ROC) Curve", | 182 text=title or "Receiver Operating Characteristic (ROC) Curve", | 
| 175 x=0.5, | 183 x=0.5, | 
| 176 y=0.92, | 184 y=0.92, | 
| 202 plotly.offline.plot(fig, filename="output.html", auto_open=False) | 210 plotly.offline.plot(fig, filename="output.html", auto_open=False) | 
| 203 # to be discovered by `from_work_dir` | 211 # to be discovered by `from_work_dir` | 
| 204 os.rename("output.html", "output") | 212 os.rename("output.html", "output") | 
| 205 | 213 | 
| 206 | 214 | 
| 207 def visualize_roc_curve_matplotlib(df1, df2, pos_label, drop_intermediate=True, title=None): | 215 def visualize_roc_curve_matplotlib( | 
| 216 df1, df2, pos_label, drop_intermediate=True, title=None | |
| 217 ): | |
| 208 """visualize roc-curve using matplotlib and output svg image""" | 218 """visualize roc-curve using matplotlib and output svg image""" | 
| 209 backend = matplotlib.get_backend() | 219 backend = matplotlib.get_backend() | 
| 210 if "inline" not in backend: | 220 if "inline" not in backend: | 
| 211 matplotlib.use("SVG") | 221 matplotlib.use("SVG") | 
| 212 plt.style.use("seaborn-colorblind") | 222 plt.style.use("seaborn-colorblind") | 
| 214 | 224 | 
| 215 for idx in range(df1.shape[1]): | 225 for idx in range(df1.shape[1]): | 
| 216 y_true = df1.iloc[:, idx].values | 226 y_true = df1.iloc[:, idx].values | 
| 217 y_score = df2.iloc[:, idx].values | 227 y_score = df2.iloc[:, idx].values | 
| 218 | 228 | 
| 219 fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) | 229 fpr, tpr, _ = roc_curve( | 
| 230 y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate | |
| 231 ) | |
| 220 roc_auc = auc(fpr, tpr) | 232 roc_auc = auc(fpr, tpr) | 
| 221 | 233 | 
| 222 plt.step( | 234 plt.step( | 
| 223 fpr, | 235 fpr, | 
| 224 tpr, | 236 tpr, | 
| 251 "all_but_by_header_name", | 263 "all_but_by_header_name", | 
| 252 ]: | 264 ]: | 
| 253 col = plot_selection[column_name]["col1"] | 265 col = plot_selection[column_name]["col1"] | 
| 254 else: | 266 else: | 
| 255 col = None | 267 col = None | 
| 256 _, input_df = read_columns(file_path, c=col, | 268 _, input_df = read_columns( | 
| 257 c_option=column_option, | 269 file_path, | 
| 258 return_df=True, | 270 c=col, | 
| 259 sep='\t', header=header, | 271 c_option=column_option, | 
| 260 parse_dates=True) | 272 return_df=True, | 
| 273 sep="\t", | |
| 274 header=header, | |
| 275 parse_dates=True, | |
| 276 ) | |
| 261 return input_df | 277 return input_df | 
| 262 | 278 | 
| 263 | 279 | 
| 264 def main( | 280 def main( | 
| 265 inputs, | 281 inputs, | 
| 342 | 358 | 
| 343 if plot_type == "feature_importances": | 359 if plot_type == "feature_importances": | 
| 344 with open(infile_estimator, "rb") as estimator_handler: | 360 with open(infile_estimator, "rb") as estimator_handler: | 
| 345 estimator = load_model(estimator_handler) | 361 estimator = load_model(estimator_handler) | 
| 346 | 362 | 
| 347 column_option = params["plotting_selection"]["column_selector_options"]["selected_column_selector_option"] | 363 column_option = params["plotting_selection"]["column_selector_options"][ | 
| 364 "selected_column_selector_option" | |
| 365 ] | |
| 348 if column_option in [ | 366 if column_option in [ | 
| 349 "by_index_number", | 367 "by_index_number", | 
| 350 "all_but_by_index_number", | 368 "all_but_by_index_number", | 
| 351 "by_header_name", | 369 "by_header_name", | 
| 352 "all_but_by_header_name", | 370 "all_but_by_header_name", | 
| 377 if hasattr(estimator, "coef_"): | 395 if hasattr(estimator, "coef_"): | 
| 378 coefs = estimator.coef_ | 396 coefs = estimator.coef_ | 
| 379 else: | 397 else: | 
| 380 coefs = getattr(estimator, "feature_importances_", None) | 398 coefs = getattr(estimator, "feature_importances_", None) | 
| 381 if coefs is None: | 399 if coefs is None: | 
| 382 raise RuntimeError("The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes") | 400 raise RuntimeError( | 
| 401 "The classifier does not expose " | |
| 402 '"coef_" or "feature_importances_" ' | |
| 403 "attributes" | |
| 404 ) | |
| 383 | 405 | 
| 384 threshold = params["plotting_selection"]["threshold"] | 406 threshold = params["plotting_selection"]["threshold"] | 
| 385 if threshold is not None: | 407 if threshold is not None: | 
| 386 mask = (coefs > threshold) | (coefs < -threshold) | 408 mask = (coefs > threshold) | (coefs < -threshold) | 
| 387 coefs = coefs[mask] | 409 coefs = coefs[mask] | 
| 452 mode="lines", | 474 mode="lines", | 
| 453 ) | 475 ) | 
| 454 layout = go.Layout( | 476 layout = go.Layout( | 
| 455 xaxis=dict(title="Number of features selected"), | 477 xaxis=dict(title="Number of features selected"), | 
| 456 yaxis=dict(title="Cross validation score"), | 478 yaxis=dict(title="Cross validation score"), | 
| 457 title=dict(text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"), | 479 title=dict( | 
| 480 text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top" | |
| 481 ), | |
| 458 font=dict(family="sans-serif", size=11), | 482 font=dict(family="sans-serif", size=11), | 
| 459 # control backgroud colors | 483 # control backgroud colors | 
| 460 plot_bgcolor="rgba(255,255,255,0)", | 484 plot_bgcolor="rgba(255,255,255,0)", | 
| 461 ) | 485 ) | 
| 462 """ | 486 """ | 
| 546 | 570 | 
| 547 return 0 | 571 return 0 | 
| 548 | 572 | 
| 549 elif plot_type == "classification_confusion_matrix": | 573 elif plot_type == "classification_confusion_matrix": | 
| 550 plot_selection = params["plotting_selection"] | 574 plot_selection = params["plotting_selection"] | 
| 551 input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") | 575 input_true = get_dataframe( | 
| 576 true_labels, plot_selection, "header_true", "column_selector_options_true" | |
| 577 ) | |
| 552 header_predicted = "infer" if plot_selection["header_predicted"] else None | 578 header_predicted = "infer" if plot_selection["header_predicted"] else None | 
| 553 input_predicted = pd.read_csv(predicted_labels, sep="\t", parse_dates=True, header=header_predicted) | 579 input_predicted = pd.read_csv( | 
| 580 predicted_labels, sep="\t", parse_dates=True, header=header_predicted | |
| 581 ) | |
| 554 true_classes = input_true.iloc[:, -1].copy() | 582 true_classes = input_true.iloc[:, -1].copy() | 
| 555 predicted_classes = input_predicted.iloc[:, -1].copy() | 583 predicted_classes = input_predicted.iloc[:, -1].copy() | 
| 556 axis_labels = list(set(true_classes)) | 584 axis_labels = list(set(true_classes)) | 
| 557 c_matrix = confusion_matrix(true_classes, predicted_classes) | 585 c_matrix = confusion_matrix(true_classes, predicted_classes) | 
| 558 fig, ax = plt.subplots(figsize=(7, 7)) | 586 fig, ax = plt.subplots(figsize=(7, 7)) | 
