Mercurial > repos > bgruening > sklearn_discriminant_classifier
diff ml_visualization_ex.py @ 31:64b771b1471a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author | bgruening |
---|---|
date | Mon, 16 Dec 2019 05:17:00 -0500 |
parents | 7696d389675c |
children | 237ae72d58a3 |
line wrap: on
line diff
--- a/ml_visualization_ex.py Thu Nov 07 05:25:28 2019 -0500 +++ b/ml_visualization_ex.py Mon Dec 16 05:17:00 2019 -0500 @@ -1,6 +1,9 @@ import argparse import json +import matplotlib +import matplotlib.pyplot as plt import numpy as np +import os import pandas as pd import plotly import plotly.graph_objs as go @@ -17,6 +20,251 @@ safe_eval = SafeEval() +# plotly default colors +default_colors = [ + '#1f77b4', # muted blue + '#ff7f0e', # safety orange + '#2ca02c', # cooked asparagus green + '#d62728', # brick red + '#9467bd', # muted purple + '#8c564b', # chestnut brown + '#e377c2', # raspberry yogurt pink + '#7f7f7f', # middle gray + '#bcbd22', # curry yellow-green + '#17becf' # blue-teal +] + + +def visualize_pr_curve_plotly(df1, df2, pos_label, title=None): + """output pr-curve in html using plotly + + df1 : pandas.DataFrame + Containing y_true + df2 : pandas.DataFrame + Containing y_score + pos_label : None + The label of positive class + title : str + Plot title + """ + data = [] + for idx in range(df1.shape[1]): + y_true = df1.iloc[:, idx].values + y_score = df2.iloc[:, idx].values + + precision, recall, _ = precision_recall_curve( + y_true, y_score, pos_label=pos_label) + ap = average_precision_score( + y_true, y_score, pos_label=pos_label or 1) + + trace = go.Scatter( + x=recall, + y=precision, + mode='lines', + marker=dict( + color=default_colors[idx % len(default_colors)] + ), + name='%s (area = %.3f)' % (idx, ap) + ) + data.append(trace) + + layout = go.Layout( + xaxis=dict( + title='Recall', + linecolor='lightslategray', + linewidth=1 + ), + yaxis=dict( + title='Precision', + linecolor='lightslategray', + linewidth=1 + ), + title=dict( + text=title or 'Precision-Recall Curve', + x=0.5, + y=0.92, + xanchor='center', + yanchor='top' + ), + font=dict( + family="sans-serif", + size=11 + ), + # control backgroud colors + plot_bgcolor='rgba(255,255,255,0)' + ) + """ + legend=dict( + x=0.95, + y=0, + traceorder="normal", + font=dict( + family="sans-serif", + size=9, + color="black" + ), + bgcolor="LightSteelBlue", + bordercolor="Black", + borderwidth=2 + ),""" + + fig = go.Figure(data=data, layout=layout) + + plotly.offline.plot(fig, filename="output.html", auto_open=False) + # to be discovered by `from_work_dir` + os.rename('output.html', 'output') + + +def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): + """visualize pr-curve using matplotlib and output svg image + """ + backend = matplotlib.get_backend() + if "inline" not in backend: + matplotlib.use("SVG") + plt.style.use('seaborn-colorblind') + plt.figure() + + for idx in range(df1.shape[1]): + y_true = df1.iloc[:, idx].values + y_score = df2.iloc[:, idx].values + + precision, recall, _ = precision_recall_curve( + y_true, y_score, pos_label=pos_label) + ap = average_precision_score( + y_true, y_score, pos_label=pos_label or 1) + + plt.step(recall, precision, 'r-', color="black", alpha=0.3, + lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('Recall') + plt.ylabel('Precision') + title = title or 'Precision-Recall Curve' + plt.title(title) + folder = os.getcwd() + plt.savefig(os.path.join(folder, "output.svg"), format="svg") + os.rename(os.path.join(folder, "output.svg"), + os.path.join(folder, "output")) + + +def visualize_roc_curve_plotly(df1, df2, pos_label, + drop_intermediate=True, + title=None): + """output roc-curve in html using plotly + + df1 : pandas.DataFrame + Containing y_true + df2 : pandas.DataFrame + Containing y_score + pos_label : None + The label of positive class + drop_intermediate : bool + Whether to drop some suboptimal thresholds + title : str + Plot title + """ + data = [] + for idx in range(df1.shape[1]): + y_true = df1.iloc[:, idx].values + y_score = df2.iloc[:, idx].values + + fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, + drop_intermediate=drop_intermediate) + roc_auc = auc(fpr, tpr) + + trace = go.Scatter( + x=fpr, + y=tpr, + mode='lines', + marker=dict( + color=default_colors[idx % len(default_colors)] + ), + name='%s (area = %.3f)' % (idx, roc_auc) + ) + data.append(trace) + + layout = go.Layout( + xaxis=dict( + title='False Positive Rate', + linecolor='lightslategray', + linewidth=1 + ), + yaxis=dict( + title='True Positive Rate', + linecolor='lightslategray', + linewidth=1 + ), + title=dict( + text=title or 'Receiver Operating Characteristic (ROC) Curve', + x=0.5, + y=0.92, + xanchor='center', + yanchor='top' + ), + font=dict( + family="sans-serif", + size=11 + ), + # control backgroud colors + plot_bgcolor='rgba(255,255,255,0)' + ) + """ + # legend=dict( + # x=0.95, + # y=0, + # traceorder="normal", + # font=dict( + # family="sans-serif", + # size=9, + # color="black" + # ), + # bgcolor="LightSteelBlue", + # bordercolor="Black", + # borderwidth=2 + # ), + """ + + fig = go.Figure(data=data, layout=layout) + + plotly.offline.plot(fig, filename="output.html", auto_open=False) + # to be discovered by `from_work_dir` + os.rename('output.html', 'output') + + +def visualize_roc_curve_matplotlib(df1, df2, pos_label, + drop_intermediate=True, + title=None): + """visualize roc-curve using matplotlib and output svg image + """ + backend = matplotlib.get_backend() + if "inline" not in backend: + matplotlib.use("SVG") + plt.style.use('seaborn-colorblind') + plt.figure() + + for idx in range(df1.shape[1]): + y_true = df1.iloc[:, idx].values + y_score = df2.iloc[:, idx].values + + fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, + drop_intermediate=drop_intermediate) + roc_auc = auc(fpr, tpr) + + plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, + where="post", label='%s (area = %.3f)' % (idx, roc_auc)) + + plt.xlim([0.0, 1.0]) + plt.ylim([0.0, 1.05]) + plt.xlabel('False Positive Rate') + plt.ylabel('True Positive Rate') + title = title or 'Receiver Operating Characteristic (ROC) Curve' + plt.title(title) + folder = os.getcwd() + plt.savefig(os.path.join(folder, "output.svg"), format="svg") + os.rename(os.path.join(folder, "output.svg"), + os.path.join(folder, "output")) + def main(inputs, infile_estimator=None, infile1=None, infile2=None, outfile_result=None, @@ -71,6 +319,8 @@ title = params['plotting_selection']['title'].strip() plot_type = params['plotting_selection']['plot_type'] + plot_format = params['plotting_selection']['plot_format'] + if plot_type == 'feature_importances': with open(infile_estimator, 'rb') as estimator_handler: estimator = load_model(estimator_handler) @@ -123,98 +373,46 @@ layout = go.Layout(title=title or "Feature Importances") fig = go.Figure(data=[trace], layout=layout) - elif plot_type == 'pr_curve': - df1 = pd.read_csv(infile1, sep='\t', header=None) - df2 = pd.read_csv(infile2, sep='\t', header=None) + plotly.offline.plot(fig, filename="output.html", + auto_open=False) + # to be discovered by `from_work_dir` + os.rename('output.html', 'output') + + return 0 - precision = {} - recall = {} - ap = {} + elif plot_type in ('pr_curve', 'roc_curve'): + df1 = pd.read_csv(infile1, sep='\t', header='infer') + df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32) + + minimum = params['plotting_selection']['report_minimum_n_positives'] + # filter out columns whose n_positives is beblow the threhold + if minimum: + mask = df1.sum(axis=0) >= minimum + df1 = df1.loc[:, mask] + df2 = df2.loc[:, mask] pos_label = params['plotting_selection']['pos_label'].strip() \ or None - for col in df1.columns: - y_true = df1[col].values - y_score = df2[col].values - - precision[col], recall[col], _ = precision_recall_curve( - y_true, y_score, pos_label=pos_label) - ap[col] = average_precision_score( - y_true, y_score, pos_label=pos_label or 1) - - if len(df1.columns) > 1: - precision["micro"], recall["micro"], _ = precision_recall_curve( - df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) - ap['micro'] = average_precision_score( - df1.values, df2.values, average='micro', - pos_label=pos_label or 1) - - data = [] - for key in precision.keys(): - trace = go.Scatter( - x=recall[key], - y=precision[key], - mode='lines', - name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro' - else 'column %s (area = %.2f)' % (key, ap[key]) - ) - data.append(trace) - - layout = go.Layout( - title=title or "Precision-Recall curve", - xaxis=dict(title='Recall'), - yaxis=dict(title='Precision') - ) - - fig = go.Figure(data=data, layout=layout) - - elif plot_type == 'roc_curve': - df1 = pd.read_csv(infile1, sep='\t', header=None) - df2 = pd.read_csv(infile2, sep='\t', header=None) - fpr = {} - tpr = {} - roc_auc = {} - - pos_label = params['plotting_selection']['pos_label'].strip() \ - or None - for col in df1.columns: - y_true = df1[col].values - y_score = df2[col].values - - fpr[col], tpr[col], _ = roc_curve( - y_true, y_score, pos_label=pos_label) - roc_auc[col] = auc(fpr[col], tpr[col]) - - if len(df1.columns) > 1: - fpr["micro"], tpr["micro"], _ = roc_curve( - df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) - roc_auc['micro'] = auc(fpr["micro"], tpr["micro"]) + if plot_type == 'pr_curve': + if plot_format == 'plotly_html': + visualize_pr_curve_plotly(df1, df2, pos_label, title=title) + else: + visualize_pr_curve_matplotlib(df1, df2, pos_label, title) + else: # 'roc_curve' + drop_intermediate = (params['plotting_selection'] + ['drop_intermediate']) + if plot_format == 'plotly_html': + visualize_roc_curve_plotly(df1, df2, pos_label, + drop_intermediate=drop_intermediate, + title=title) + else: + visualize_roc_curve_matplotlib( + df1, df2, pos_label, + drop_intermediate=drop_intermediate, + title=title) - data = [] - for key in fpr.keys(): - trace = go.Scatter( - x=fpr[key], - y=tpr[key], - mode='lines', - name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro' - else 'column %s (area = %.2f)' % (key, roc_auc[key]) - ) - data.append(trace) - - trace = go.Scatter(x=[0, 1], y=[0, 1], - mode='lines', - line=dict(color='black', dash='dash'), - showlegend=False) - data.append(trace) - - layout = go.Layout( - title=title or "Receiver operating characteristic curve", - xaxis=dict(title='False Positive Rate'), - yaxis=dict(title='True Positive Rate') - ) - - fig = go.Figure(data=data, layout=layout) + return 0 elif plot_type == 'rfecv_gridscores': input_df = pd.read_csv(infile1, sep='\t', header='infer') @@ -231,10 +429,43 @@ layout = go.Layout( xaxis=dict(title="Number of features selected"), yaxis=dict(title="Cross validation score"), - title=title or None + title=dict( + text=title or None, + x=0.5, + y=0.92, + xanchor='center', + yanchor='top' + ), + font=dict( + family="sans-serif", + size=11 + ), + # control backgroud colors + plot_bgcolor='rgba(255,255,255,0)' ) + """ + # legend=dict( + # x=0.95, + # y=0, + # traceorder="normal", + # font=dict( + # family="sans-serif", + # size=9, + # color="black" + # ), + # bgcolor="LightSteelBlue", + # bordercolor="Black", + # borderwidth=2 + # ), + """ fig = go.Figure(data=[data], layout=layout) + plotly.offline.plot(fig, filename="output.html", + auto_open=False) + # to be discovered by `from_work_dir` + os.rename('output.html', 'output') + + return 0 elif plot_type == 'learning_curve': input_df = pd.read_csv(infile1, sep='\t', header='infer') @@ -264,23 +495,57 @@ yaxis=dict( title='Performance Score' ), - title=title or 'Learning Curve' + # modify these configurations to customize image + title=dict( + text=title or 'Learning Curve', + x=0.5, + y=0.92, + xanchor='center', + yanchor='top' + ), + font=dict( + family="sans-serif", + size=11 + ), + # control backgroud colors + plot_bgcolor='rgba(255,255,255,0)' ) + """ + # legend=dict( + # x=0.95, + # y=0, + # traceorder="normal", + # font=dict( + # family="sans-serif", + # size=9, + # color="black" + # ), + # bgcolor="LightSteelBlue", + # bordercolor="Black", + # borderwidth=2 + # ), + """ + fig = go.Figure(data=[data1, data2], layout=layout) + plotly.offline.plot(fig, filename="output.html", + auto_open=False) + # to be discovered by `from_work_dir` + os.rename('output.html', 'output') + + return 0 elif plot_type == 'keras_plot_model': with open(model_config, 'r') as f: model_str = f.read() model = model_from_json(model_str) plot_model(model, to_file="output.png") - __import__('os').rename('output.png', 'output') + os.rename('output.png', 'output') return 0 - plotly.offline.plot(fig, filename="output.html", - auto_open=False) - # to be discovered by `from_work_dir` - __import__('os').rename('output.html', 'output') + # save pdf file to disk + # fig.write_image("image.pdf", format='pdf') + # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) if __name__ == '__main__':