Mercurial > repos > bgruening > sklearn_nn_classifier
view ml_visualization_ex.py @ 26:2c58b83c9bad draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit f031d8ddfb73cec24572648666ac44ee47f08aad
author | bgruening |
---|---|
date | Thu, 11 Aug 2022 09:54:23 +0000 |
parents | 823ecc0bce45 |
children | 22f0b9db4ea1 |
line wrap: on
line source
import argparse import json import os import warnings import matplotlib import matplotlib.pyplot as plt import numpy as np import pandas as pd import plotly import plotly.graph_objs as go from galaxy_ml.utils import load_model, read_columns, SafeEval from keras.models import model_from_json from keras.utils import plot_model from sklearn.feature_selection.base import SelectorMixin from sklearn.metrics import (auc, average_precision_score, confusion_matrix, precision_recall_curve, roc_curve) from sklearn.pipeline import Pipeline safe_eval = SafeEval() # plotly default colors default_colors = [ "#1f77b4", # muted blue "#ff7f0e", # safety orange "#2ca02c", # cooked asparagus green "#d62728", # brick red "#9467bd", # muted purple "#8c564b", # chestnut brown "#e377c2", # raspberry yogurt pink "#7f7f7f", # middle gray "#bcbd22", # curry yellow-green "#17becf", # blue-teal ] def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): """output pr-curve in html using plotly df1 : pandas.DataFrame Containing y_true df2 : pandas.DataFrame Containing y_score pos_label : None The label of positive class title : str Plot title """ data = [] for idx in range(df1.shape[1]): y_true = df1.iloc[:, idx].values y_score = df2.iloc[:, idx].values precision, recall, _ = precision_recall_curve( y_true, y_score, pos_label=pos_label ) ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) trace = go.Scatter( x=recall, y=precision, mode="lines", marker=dict(color=default_colors[idx % len(default_colors)]), name="%s (area = %.3f)" % (idx, ap), ) data.append(trace) layout = go.Layout( xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1), yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1), title=dict( text=title or "Precision-Recall Curve", x=0.5, y=0.92, xanchor="center", yanchor="top", ), font=dict(family="sans-serif", size=11), # control backgroud colors plot_bgcolor="rgba(255,255,255,0)", ) """ legend=dict( x=0.95, y=0, traceorder="normal", font=dict( family="sans-serif", size=9, color="black" ), bgcolor="LightSteelBlue", bordercolor="Black", borderwidth=2 ),""" fig = go.Figure(data=data, layout=layout) plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` os.rename("output.html", "output") def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): """visualize pr-curve using matplotlib and output svg image""" backend = matplotlib.get_backend() if "inline" not in backend: matplotlib.use("SVG") plt.style.use("seaborn-colorblind") plt.figure() for idx in range(df1.shape[1]): y_true = df1.iloc[:, idx].values y_score = df2.iloc[:, idx].values precision, recall, _ = precision_recall_curve( y_true, y_score, pos_label=pos_label ) ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) plt.step( recall, precision, "r-", color="black", alpha=0.3, lw=1, where="post", label="%s (area = %.3f)" % (idx, ap), ) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("Recall") plt.ylabel("Precision") title = title or "Precision-Recall Curve" plt.title(title) folder = os.getcwd() plt.savefig(os.path.join(folder, "output.svg"), format="svg") os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None): """output roc-curve in html using plotly df1 : pandas.DataFrame Containing y_true df2 : pandas.DataFrame Containing y_score pos_label : None The label of positive class drop_intermediate : bool Whether to drop some suboptimal thresholds title : str Plot title """ data = [] for idx in range(df1.shape[1]): y_true = df1.iloc[:, idx].values y_score = df2.iloc[:, idx].values fpr, tpr, _ = roc_curve( y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate ) roc_auc = auc(fpr, tpr) trace = go.Scatter( x=fpr, y=tpr, mode="lines", marker=dict(color=default_colors[idx % len(default_colors)]), name="%s (area = %.3f)" % (idx, roc_auc), ) data.append(trace) layout = go.Layout( xaxis=dict( title="False Positive Rate", linecolor="lightslategray", linewidth=1 ), yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), title=dict( text=title or "Receiver Operating Characteristic (ROC) Curve", x=0.5, y=0.92, xanchor="center", yanchor="top", ), font=dict(family="sans-serif", size=11), # control backgroud colors plot_bgcolor="rgba(255,255,255,0)", ) """ # legend=dict( # x=0.95, # y=0, # traceorder="normal", # font=dict( # family="sans-serif", # size=9, # color="black" # ), # bgcolor="LightSteelBlue", # bordercolor="Black", # borderwidth=2 # ), """ fig = go.Figure(data=data, layout=layout) plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` os.rename("output.html", "output") def visualize_roc_curve_matplotlib( df1, df2, pos_label, drop_intermediate=True, title=None ): """visualize roc-curve using matplotlib and output svg image""" backend = matplotlib.get_backend() if "inline" not in backend: matplotlib.use("SVG") plt.style.use("seaborn-colorblind") plt.figure() for idx in range(df1.shape[1]): y_true = df1.iloc[:, idx].values y_score = df2.iloc[:, idx].values fpr, tpr, _ = roc_curve( y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate ) roc_auc = auc(fpr, tpr) plt.step( fpr, tpr, "r-", color="black", alpha=0.3, lw=1, where="post", label="%s (area = %.3f)" % (idx, roc_auc), ) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel("False Positive Rate") plt.ylabel("True Positive Rate") title = title or "Receiver Operating Characteristic (ROC) Curve" plt.title(title) folder = os.getcwd() plt.savefig(os.path.join(folder, "output.svg"), format="svg") os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) def get_dataframe(file_path, plot_selection, header_name, column_name): header = "infer" if plot_selection[header_name] else None column_option = plot_selection[column_name]["selected_column_selector_option"] if column_option in [ "by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name", ]: col = plot_selection[column_name]["col1"] else: col = None _, input_df = read_columns( file_path, c=col, c_option=column_option, return_df=True, sep="\t", header=header, parse_dates=True, ) return input_df def main( inputs, infile_estimator=None, infile1=None, infile2=None, outfile_result=None, outfile_object=None, groups=None, ref_seq=None, intervals=None, targets=None, fasta_path=None, model_config=None, true_labels=None, predicted_labels=None, plot_color=None, title=None, ): """ Parameter --------- inputs : str File path to galaxy tool parameter infile_estimator : str, default is None File path to estimator infile1 : str, default is None File path to dataset containing features or true labels. infile2 : str, default is None File path to dataset containing target values or predicted probabilities. outfile_result : str, default is None File path to save the results, either cv_results or test result outfile_object : str, default is None File path to save searchCV object groups : str, default is None File path to dataset containing groups labels ref_seq : str, default is None File path to dataset containing genome sequence file intervals : str, default is None File path to dataset containing interval file targets : str, default is None File path to dataset compressed target bed file fasta_path : str, default is None File path to dataset containing fasta file model_config : str, default is None File path to dataset containing JSON config for neural networks true_labels : str, default is None File path to dataset containing true labels predicted_labels : str, default is None File path to dataset containing true predicted labels plot_color : str, default is None Color of the confusion matrix heatmap title : str, default is None Title of the confusion matrix heatmap """ warnings.simplefilter("ignore") with open(inputs, "r") as param_handler: params = json.load(param_handler) title = params["plotting_selection"]["title"].strip() plot_type = params["plotting_selection"]["plot_type"] plot_format = params["plotting_selection"]["plot_format"] if plot_type == "feature_importances": with open(infile_estimator, "rb") as estimator_handler: estimator = load_model(estimator_handler) column_option = params["plotting_selection"]["column_selector_options"][ "selected_column_selector_option" ] if column_option in [ "by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name", ]: c = params["plotting_selection"]["column_selector_options"]["col1"] else: c = None _, input_df = read_columns( infile1, c=c, c_option=column_option, return_df=True, sep="\t", header="infer", parse_dates=True, ) feature_names = input_df.columns.values if isinstance(estimator, Pipeline): for st in estimator.steps[:-1]: if isinstance(st[-1], SelectorMixin): mask = st[-1].get_support() feature_names = feature_names[mask] estimator = estimator.steps[-1][-1] if hasattr(estimator, "coef_"): coefs = estimator.coef_ else: coefs = getattr(estimator, "feature_importances_", None) if coefs is None: raise RuntimeError( "The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes" ) threshold = params["plotting_selection"]["threshold"] if threshold is not None: mask = (coefs > threshold) | (coefs < -threshold) coefs = coefs[mask] feature_names = feature_names[mask] # sort indices = np.argsort(coefs)[::-1] trace = go.Bar(x=feature_names[indices], y=coefs[indices]) layout = go.Layout(title=title or "Feature Importances") fig = go.Figure(data=[trace], layout=layout) plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` os.rename("output.html", "output") return 0 elif plot_type in ("pr_curve", "roc_curve"): df1 = pd.read_csv(infile1, sep="\t", header="infer") df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32) minimum = params["plotting_selection"]["report_minimum_n_positives"] # filter out columns whose n_positives is beblow the threhold if minimum: mask = df1.sum(axis=0) >= minimum df1 = df1.loc[:, mask] df2 = df2.loc[:, mask] pos_label = params["plotting_selection"]["pos_label"].strip() or None if plot_type == "pr_curve": if plot_format == "plotly_html": visualize_pr_curve_plotly(df1, df2, pos_label, title=title) else: visualize_pr_curve_matplotlib(df1, df2, pos_label, title) else: # 'roc_curve' drop_intermediate = params["plotting_selection"]["drop_intermediate"] if plot_format == "plotly_html": visualize_roc_curve_plotly( df1, df2, pos_label, drop_intermediate=drop_intermediate, title=title, ) else: visualize_roc_curve_matplotlib( df1, df2, pos_label, drop_intermediate=drop_intermediate, title=title, ) return 0 elif plot_type == "rfecv_gridscores": input_df = pd.read_csv(infile1, sep="\t", header="infer") scores = input_df.iloc[:, 0] steps = params["plotting_selection"]["steps"].strip() steps = safe_eval(steps) data = go.Scatter( x=list(range(len(scores))), y=scores, text=[str(_) for _ in steps] if steps else None, mode="lines", ) layout = go.Layout( xaxis=dict(title="Number of features selected"), yaxis=dict(title="Cross validation score"), title=dict( text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top" ), font=dict(family="sans-serif", size=11), # control backgroud colors plot_bgcolor="rgba(255,255,255,0)", ) """ # legend=dict( # x=0.95, # y=0, # traceorder="normal", # font=dict( # family="sans-serif", # size=9, # color="black" # ), # bgcolor="LightSteelBlue", # bordercolor="Black", # borderwidth=2 # ), """ fig = go.Figure(data=[data], layout=layout) plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` os.rename("output.html", "output") return 0 elif plot_type == "learning_curve": input_df = pd.read_csv(infile1, sep="\t", header="infer") plot_std_err = params["plotting_selection"]["plot_std_err"] data1 = go.Scatter( x=input_df["train_sizes_abs"], y=input_df["mean_train_scores"], error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None, mode="lines", name="Train Scores", ) data2 = go.Scatter( x=input_df["train_sizes_abs"], y=input_df["mean_test_scores"], error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None, mode="lines", name="Test Scores", ) layout = dict( xaxis=dict(title="No. of samples"), yaxis=dict(title="Performance Score"), # modify these configurations to customize image title=dict( text=title or "Learning Curve", x=0.5, y=0.92, xanchor="center", yanchor="top", ), font=dict(family="sans-serif", size=11), # control backgroud colors plot_bgcolor="rgba(255,255,255,0)", ) """ # legend=dict( # x=0.95, # y=0, # traceorder="normal", # font=dict( # family="sans-serif", # size=9, # color="black" # ), # bgcolor="LightSteelBlue", # bordercolor="Black", # borderwidth=2 # ), """ fig = go.Figure(data=[data1, data2], layout=layout) plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` os.rename("output.html", "output") return 0 elif plot_type == "keras_plot_model": with open(model_config, "r") as f: model_str = f.read() model = model_from_json(model_str) plot_model(model, to_file="output.png") os.rename("output.png", "output") return 0 elif plot_type == "classification_confusion_matrix": plot_selection = params["plotting_selection"] input_true = get_dataframe( true_labels, plot_selection, "header_true", "column_selector_options_true" ) header_predicted = "infer" if plot_selection["header_predicted"] else None input_predicted = pd.read_csv( predicted_labels, sep="\t", parse_dates=True, header=header_predicted ) true_classes = input_true.iloc[:, -1].copy() predicted_classes = input_predicted.iloc[:, -1].copy() axis_labels = list(set(true_classes)) c_matrix = confusion_matrix(true_classes, predicted_classes) fig, ax = plt.subplots(figsize=(7, 7)) im = plt.imshow(c_matrix, cmap=plot_color) for i in range(len(c_matrix)): for j in range(len(c_matrix)): ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") ax.set_ylabel("True class labels") ax.set_xlabel("Predicted class labels") ax.set_title(title) ax.set_xticks(axis_labels) ax.set_yticks(axis_labels) fig.colorbar(im, ax=ax) fig.tight_layout() plt.savefig("output.png", dpi=125) os.rename("output.png", "output") return 0 # save pdf file to disk # fig.write_image("image.pdf", format='pdf') # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) if __name__ == "__main__": aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-e", "--estimator", dest="infile_estimator") aparser.add_argument("-X", "--infile1", dest="infile1") aparser.add_argument("-y", "--infile2", dest="infile2") aparser.add_argument("-O", "--outfile_result", dest="outfile_result") aparser.add_argument("-o", "--outfile_object", dest="outfile_object") aparser.add_argument("-g", "--groups", dest="groups") aparser.add_argument("-r", "--ref_seq", dest="ref_seq") aparser.add_argument("-b", "--intervals", dest="intervals") aparser.add_argument("-t", "--targets", dest="targets") aparser.add_argument("-f", "--fasta_path", dest="fasta_path") aparser.add_argument("-c", "--model_config", dest="model_config") aparser.add_argument("-tl", "--true_labels", dest="true_labels") aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") aparser.add_argument("-pc", "--plot_color", dest="plot_color") aparser.add_argument("-pt", "--title", dest="title") args = aparser.parse_args() main( args.inputs, args.infile_estimator, args.infile1, args.infile2, args.outfile_result, outfile_object=args.outfile_object, groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, targets=args.targets, fasta_path=args.fasta_path, model_config=args.model_config, true_labels=args.true_labels, predicted_labels=args.predicted_labels, plot_color=args.plot_color, title=args.title, )