Mercurial > repos > bgruening > sklearn_mlxtend_association_rules
diff ml_visualization_ex.py @ 0:af2624d5ab32 draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ea12f973df4b97a2691d9e4ce6bf6fae59d57717"
author | bgruening |
---|---|
date | Sat, 01 May 2021 01:24:32 +0000 |
parents | |
children | 9349ed2749c6 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/ml_visualization_ex.py Sat May 01 01:24:32 2021 +0000 @@ -0,0 +1,645 @@ +import argparse +import json +import os +import warnings + +import matplotlib +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import plotly +import plotly.graph_objs as go +from galaxy_ml.utils import load_model, read_columns, SafeEval +from keras.models import model_from_json +from keras.utils import plot_model +from sklearn.feature_selection.base import SelectorMixin +from sklearn.metrics import (auc, average_precision_score, confusion_matrix, + precision_recall_curve, roc_curve) +from sklearn.pipeline import Pipeline + +safe_eval = SafeEval() + +# plotly default colors +default_colors = [ + "#1f77b4", # muted blue + "#ff7f0e", # safety orange + "#2ca02c", # cooked asparagus green + "#d62728", # brick red + "#9467bd", # muted purple + "#8c564b", # chestnut brown + "#e377c2", # raspberry yogurt pink + "#7f7f7f", # middle gray + "#bcbd22", # curry yellow-green + "#17becf", # blue-teal +] + + +def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): + """output pr-curve in html using plotly + + df1 : pandas.DataFrame + Containing y_true + df2 : pandas.DataFrame + Containing y_score + pos_label : None + The label of positive class + title : str + Plot title + """ + data = [] + for idx in range(df1.shape[1]): + y_true = df1.iloc[:, idx].values + y_score = df2.iloc[:, idx].values + + precision, recall, _ = precision_recall_curve( + y_true, y_score, pos_label=pos_label + ) + ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) + + trace = go.Scatter( + x=recall, + y=precision, + mode="lines", + marker=dict(color=default_colors[idx % len(default_colors)]), + name="%s (area = %.3f)" % (idx, ap), + ) + data.append(trace) + + layout = go.Layout( + xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1), + yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1), + title=dict( + text=title or "Precision-Recall Curve", + x=0.5, + y=0.92, + xanchor="center", + yanchor="top", + ), + font=dict(family="sans-serif", size=11), + # control backgroud colors + plot_bgcolor="rgba(255,255,255,0)", + ) + """ + legend=dict( + x=0.95, + y=0, + traceorder="normal", + font=dict( + family="sans-serif", + size=9, + color="black" + ), + bgcolor="LightSteelBlue", + bordercolor="Black", + borderwidth=2 + ),""" + + fig = go.Figure(data=data, layout=layout) + + plotly.offline.plot(fig, filename="output.html", auto_open=False) + # to be discovered by `from_work_dir` + os.rename("output.html", "output") + + +def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): + """visualize pr-curve using matplotlib and output svg image""" + backend = matplotlib.get_backend() + if "inline" not in backend: + matplotlib.use("SVG") + plt.style.use("seaborn-colorblind") + plt.figure() + + for idx in range(df1.shape[1]): + y_true = df1.iloc[:, idx].values + y_score = df2.iloc[:, idx].values + + precision, recall, _ = precision_recall_curve( + y_true, y_score, pos_label=pos_label + ) + ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) + + plt.step( + recall, + precision, + "r-", + color="black", + alpha=0.3, + lw=1, + where="post", + label="%s (area = %.3f)" % (idx, ap), + ) + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel("Recall") + plt.ylabel("Precision") + title = title or "Precision-Recall Curve" + plt.title(title) + folder = os.getcwd() + plt.savefig(os.path.join(folder, "output.svg"), format="svg") + os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) + + +def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None): + """output roc-curve in html using plotly + + df1 : pandas.DataFrame + Containing y_true + df2 : pandas.DataFrame + Containing y_score + pos_label : None + The label of positive class + drop_intermediate : bool + Whether to drop some suboptimal thresholds + title : str + Plot title + """ + data = [] + for idx in range(df1.shape[1]): + y_true = df1.iloc[:, idx].values + y_score = df2.iloc[:, idx].values + + fpr, tpr, _ = roc_curve( + y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate + ) + roc_auc = auc(fpr, tpr) + + trace = go.Scatter( + x=fpr, + y=tpr, + mode="lines", + marker=dict(color=default_colors[idx % len(default_colors)]), + name="%s (area = %.3f)" % (idx, roc_auc), + ) + data.append(trace) + + layout = go.Layout( + xaxis=dict( + title="False Positive Rate", linecolor="lightslategray", linewidth=1 + ), + yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), + title=dict( + text=title or "Receiver Operating Characteristic (ROC) Curve", + x=0.5, + y=0.92, + xanchor="center", + yanchor="top", + ), + font=dict(family="sans-serif", size=11), + # control backgroud colors + plot_bgcolor="rgba(255,255,255,0)", + ) + """ + # legend=dict( + # x=0.95, + # y=0, + # traceorder="normal", + # font=dict( + # family="sans-serif", + # size=9, + # color="black" + # ), + # bgcolor="LightSteelBlue", + # bordercolor="Black", + # borderwidth=2 + # ), + """ + + fig = go.Figure(data=data, layout=layout) + + plotly.offline.plot(fig, filename="output.html", auto_open=False) + # to be discovered by `from_work_dir` + os.rename("output.html", "output") + + +def visualize_roc_curve_matplotlib( + df1, df2, pos_label, drop_intermediate=True, title=None +): + """visualize roc-curve using matplotlib and output svg image""" + backend = matplotlib.get_backend() + if "inline" not in backend: + matplotlib.use("SVG") + plt.style.use("seaborn-colorblind") + plt.figure() + + for idx in range(df1.shape[1]): + y_true = df1.iloc[:, idx].values + y_score = df2.iloc[:, idx].values + + fpr, tpr, _ = roc_curve( + y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate + ) + roc_auc = auc(fpr, tpr) + + plt.step( + fpr, + tpr, + "r-", + color="black", + alpha=0.3, + lw=1, + where="post", + label="%s (area = %.3f)" % (idx, roc_auc), + ) + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + title = title or "Receiver Operating Characteristic (ROC) Curve" + plt.title(title) + folder = os.getcwd() + plt.savefig(os.path.join(folder, "output.svg"), format="svg") + os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) + + +def get_dataframe(file_path, plot_selection, header_name, column_name): + header = "infer" if plot_selection[header_name] else None + column_option = plot_selection[column_name]["selected_column_selector_option"] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + col = plot_selection[column_name]["col1"] + else: + col = None + _, input_df = read_columns( + file_path, + c=col, + c_option=column_option, + return_df=True, + sep="\t", + header=header, + parse_dates=True, + ) + return input_df + + +def main( + inputs, + infile_estimator=None, + infile1=None, + infile2=None, + outfile_result=None, + outfile_object=None, + groups=None, + ref_seq=None, + intervals=None, + targets=None, + fasta_path=None, + model_config=None, + true_labels=None, + predicted_labels=None, + plot_color=None, + title=None, +): + """ + Parameter + --------- + inputs : str + File path to galaxy tool parameter + + infile_estimator : str, default is None + File path to estimator + + infile1 : str, default is None + File path to dataset containing features or true labels. + + infile2 : str, default is None + File path to dataset containing target values or predicted + probabilities. + + outfile_result : str, default is None + File path to save the results, either cv_results or test result + + outfile_object : str, default is None + File path to save searchCV object + + groups : str, default is None + File path to dataset containing groups labels + + ref_seq : str, default is None + File path to dataset containing genome sequence file + + intervals : str, default is None + File path to dataset containing interval file + + targets : str, default is None + File path to dataset compressed target bed file + + fasta_path : str, default is None + File path to dataset containing fasta file + + model_config : str, default is None + File path to dataset containing JSON config for neural networks + + true_labels : str, default is None + File path to dataset containing true labels + + predicted_labels : str, default is None + File path to dataset containing true predicted labels + + plot_color : str, default is None + Color of the confusion matrix heatmap + + title : str, default is None + Title of the confusion matrix heatmap + """ + warnings.simplefilter("ignore") + + with open(inputs, "r") as param_handler: + params = json.load(param_handler) + + title = params["plotting_selection"]["title"].strip() + plot_type = params["plotting_selection"]["plot_type"] + plot_format = params["plotting_selection"]["plot_format"] + + if plot_type == "feature_importances": + with open(infile_estimator, "rb") as estimator_handler: + estimator = load_model(estimator_handler) + + column_option = params["plotting_selection"]["column_selector_options"][ + "selected_column_selector_option" + ] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + c = params["plotting_selection"]["column_selector_options"]["col1"] + else: + c = None + + _, input_df = read_columns( + infile1, + c=c, + c_option=column_option, + return_df=True, + sep="\t", + header="infer", + parse_dates=True, + ) + + feature_names = input_df.columns.values + + if isinstance(estimator, Pipeline): + for st in estimator.steps[:-1]: + if isinstance(st[-1], SelectorMixin): + mask = st[-1].get_support() + feature_names = feature_names[mask] + estimator = estimator.steps[-1][-1] + + if hasattr(estimator, "coef_"): + coefs = estimator.coef_ + else: + coefs = getattr(estimator, "feature_importances_", None) + if coefs is None: + raise RuntimeError( + "The classifier does not expose " + '"coef_" or "feature_importances_" ' + "attributes" + ) + + threshold = params["plotting_selection"]["threshold"] + if threshold is not None: + mask = (coefs > threshold) | (coefs < -threshold) + coefs = coefs[mask] + feature_names = feature_names[mask] + + # sort + indices = np.argsort(coefs)[::-1] + + trace = go.Bar(x=feature_names[indices], y=coefs[indices]) + layout = go.Layout(title=title or "Feature Importances") + fig = go.Figure(data=[trace], layout=layout) + + plotly.offline.plot(fig, filename="output.html", auto_open=False) + # to be discovered by `from_work_dir` + os.rename("output.html", "output") + + return 0 + + elif plot_type in ("pr_curve", "roc_curve"): + df1 = pd.read_csv(infile1, sep="\t", header="infer") + df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32) + + minimum = params["plotting_selection"]["report_minimum_n_positives"] + # filter out columns whose n_positives is beblow the threhold + if minimum: + mask = df1.sum(axis=0) >= minimum + df1 = df1.loc[:, mask] + df2 = df2.loc[:, mask] + + pos_label = params["plotting_selection"]["pos_label"].strip() or None + + if plot_type == "pr_curve": + if plot_format == "plotly_html": + visualize_pr_curve_plotly(df1, df2, pos_label, title=title) + else: + visualize_pr_curve_matplotlib(df1, df2, pos_label, title) + else: # 'roc_curve' + drop_intermediate = params["plotting_selection"]["drop_intermediate"] + if plot_format == "plotly_html": + visualize_roc_curve_plotly( + df1, + df2, + pos_label, + drop_intermediate=drop_intermediate, + title=title, + ) + else: + visualize_roc_curve_matplotlib( + df1, + df2, + pos_label, + drop_intermediate=drop_intermediate, + title=title, + ) + + return 0 + + elif plot_type == "rfecv_gridscores": + input_df = pd.read_csv(infile1, sep="\t", header="infer") + scores = input_df.iloc[:, 0] + steps = params["plotting_selection"]["steps"].strip() + steps = safe_eval(steps) + + data = go.Scatter( + x=list(range(len(scores))), + y=scores, + text=[str(_) for _ in steps] if steps else None, + mode="lines", + ) + layout = go.Layout( + xaxis=dict(title="Number of features selected"), + yaxis=dict(title="Cross validation score"), + title=dict( + text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top" + ), + font=dict(family="sans-serif", size=11), + # control backgroud colors + plot_bgcolor="rgba(255,255,255,0)", + ) + """ + # legend=dict( + # x=0.95, + # y=0, + # traceorder="normal", + # font=dict( + # family="sans-serif", + # size=9, + # color="black" + # ), + # bgcolor="LightSteelBlue", + # bordercolor="Black", + # borderwidth=2 + # ), + """ + + fig = go.Figure(data=[data], layout=layout) + plotly.offline.plot(fig, filename="output.html", auto_open=False) + # to be discovered by `from_work_dir` + os.rename("output.html", "output") + + return 0 + + elif plot_type == "learning_curve": + input_df = pd.read_csv(infile1, sep="\t", header="infer") + plot_std_err = params["plotting_selection"]["plot_std_err"] + data1 = go.Scatter( + x=input_df["train_sizes_abs"], + y=input_df["mean_train_scores"], + error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None, + mode="lines", + name="Train Scores", + ) + data2 = go.Scatter( + x=input_df["train_sizes_abs"], + y=input_df["mean_test_scores"], + error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None, + mode="lines", + name="Test Scores", + ) + layout = dict( + xaxis=dict(title="No. of samples"), + yaxis=dict(title="Performance Score"), + # modify these configurations to customize image + title=dict( + text=title or "Learning Curve", + x=0.5, + y=0.92, + xanchor="center", + yanchor="top", + ), + font=dict(family="sans-serif", size=11), + # control backgroud colors + plot_bgcolor="rgba(255,255,255,0)", + ) + """ + # legend=dict( + # x=0.95, + # y=0, + # traceorder="normal", + # font=dict( + # family="sans-serif", + # size=9, + # color="black" + # ), + # bgcolor="LightSteelBlue", + # bordercolor="Black", + # borderwidth=2 + # ), + """ + + fig = go.Figure(data=[data1, data2], layout=layout) + plotly.offline.plot(fig, filename="output.html", auto_open=False) + # to be discovered by `from_work_dir` + os.rename("output.html", "output") + + return 0 + + elif plot_type == "keras_plot_model": + with open(model_config, "r") as f: + model_str = f.read() + model = model_from_json(model_str) + plot_model(model, to_file="output.png") + os.rename("output.png", "output") + + return 0 + + elif plot_type == "classification_confusion_matrix": + plot_selection = params["plotting_selection"] + input_true = get_dataframe( + true_labels, plot_selection, "header_true", "column_selector_options_true" + ) + header_predicted = "infer" if plot_selection["header_predicted"] else None + input_predicted = pd.read_csv( + predicted_labels, sep="\t", parse_dates=True, header=header_predicted + ) + true_classes = input_true.iloc[:, -1].copy() + predicted_classes = input_predicted.iloc[:, -1].copy() + axis_labels = list(set(true_classes)) + c_matrix = confusion_matrix(true_classes, predicted_classes) + fig, ax = plt.subplots(figsize=(7, 7)) + im = plt.imshow(c_matrix, cmap=plot_color) + for i in range(len(c_matrix)): + for j in range(len(c_matrix)): + ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") + ax.set_ylabel("True class labels") + ax.set_xlabel("Predicted class labels") + ax.set_title(title) + ax.set_xticks(axis_labels) + ax.set_yticks(axis_labels) + fig.colorbar(im, ax=ax) + fig.tight_layout() + plt.savefig("output.png", dpi=125) + os.rename("output.png", "output") + + return 0 + + # save pdf file to disk + # fig.write_image("image.pdf", format='pdf') + # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) + + +if __name__ == "__main__": + aparser = argparse.ArgumentParser() + aparser.add_argument("-i", "--inputs", dest="inputs", required=True) + aparser.add_argument("-e", "--estimator", dest="infile_estimator") + aparser.add_argument("-X", "--infile1", dest="infile1") + aparser.add_argument("-y", "--infile2", dest="infile2") + aparser.add_argument("-O", "--outfile_result", dest="outfile_result") + aparser.add_argument("-o", "--outfile_object", dest="outfile_object") + aparser.add_argument("-g", "--groups", dest="groups") + aparser.add_argument("-r", "--ref_seq", dest="ref_seq") + aparser.add_argument("-b", "--intervals", dest="intervals") + aparser.add_argument("-t", "--targets", dest="targets") + aparser.add_argument("-f", "--fasta_path", dest="fasta_path") + aparser.add_argument("-c", "--model_config", dest="model_config") + aparser.add_argument("-tl", "--true_labels", dest="true_labels") + aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") + aparser.add_argument("-pc", "--plot_color", dest="plot_color") + aparser.add_argument("-pt", "--title", dest="title") + args = aparser.parse_args() + + main( + args.inputs, + args.infile_estimator, + args.infile1, + args.infile2, + args.outfile_result, + outfile_object=args.outfile_object, + groups=args.groups, + ref_seq=args.ref_seq, + intervals=args.intervals, + targets=args.targets, + fasta_path=args.fasta_path, + model_config=args.model_config, + true_labels=args.true_labels, + predicted_labels=args.predicted_labels, + plot_color=args.plot_color, + title=args.title, + )