diff ml_visualization_ex.py @ 31:a2da4cebc584 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 5b2ac730ec6d3b762faa9034eddd19ad1b347476"
author bgruening
date Mon, 16 Dec 2019 05:43:45 -0500
parents 2a32d1d9c8da
children ecd247e1ea9c
line wrap: on
line diff
--- a/ml_visualization_ex.py	Thu Nov 07 05:45:58 2019 -0500
+++ b/ml_visualization_ex.py	Mon Dec 16 05:43:45 2019 -0500
@@ -1,6 +1,9 @@
 import argparse
 import json
+import matplotlib
+import matplotlib.pyplot as plt
 import numpy as np
+import os
 import pandas as pd
 import plotly
 import plotly.graph_objs as go
@@ -17,6 +20,251 @@
 
 safe_eval = SafeEval()
 
+# plotly default colors
+default_colors = [
+    '#1f77b4',  # muted blue
+    '#ff7f0e',  # safety orange
+    '#2ca02c',  # cooked asparagus green
+    '#d62728',  # brick red
+    '#9467bd',  # muted purple
+    '#8c564b',  # chestnut brown
+    '#e377c2',  # raspberry yogurt pink
+    '#7f7f7f',  # middle gray
+    '#bcbd22',  # curry yellow-green
+    '#17becf'   # blue-teal
+]
+
+
+def visualize_pr_curve_plotly(df1, df2, pos_label, title=None):
+    """output pr-curve in html using plotly
+
+    df1 : pandas.DataFrame
+        Containing y_true
+    df2 : pandas.DataFrame
+        Containing y_score
+    pos_label : None
+        The label of positive class
+    title : str
+        Plot title
+    """
+    data = []
+    for idx in range(df1.shape[1]):
+        y_true = df1.iloc[:, idx].values
+        y_score = df2.iloc[:, idx].values
+
+        precision, recall, _ = precision_recall_curve(
+            y_true, y_score, pos_label=pos_label)
+        ap = average_precision_score(
+            y_true, y_score, pos_label=pos_label or 1)
+
+        trace = go.Scatter(
+            x=recall,
+            y=precision,
+            mode='lines',
+            marker=dict(
+                color=default_colors[idx % len(default_colors)]
+            ),
+            name='%s (area = %.3f)' % (idx, ap)
+        )
+        data.append(trace)
+
+    layout = go.Layout(
+        xaxis=dict(
+            title='Recall',
+            linecolor='lightslategray',
+            linewidth=1
+        ),
+        yaxis=dict(
+            title='Precision',
+            linecolor='lightslategray',
+            linewidth=1
+        ),
+        title=dict(
+            text=title or 'Precision-Recall Curve',
+            x=0.5,
+            y=0.92,
+            xanchor='center',
+            yanchor='top'
+        ),
+        font=dict(
+            family="sans-serif",
+            size=11
+        ),
+        # control backgroud colors
+        plot_bgcolor='rgba(255,255,255,0)'
+    )
+    """
+    legend=dict(
+        x=0.95,
+        y=0,
+        traceorder="normal",
+        font=dict(
+            family="sans-serif",
+            size=9,
+            color="black"
+        ),
+        bgcolor="LightSteelBlue",
+        bordercolor="Black",
+        borderwidth=2
+    ),"""
+
+    fig = go.Figure(data=data, layout=layout)
+
+    plotly.offline.plot(fig, filename="output.html", auto_open=False)
+    # to be discovered by `from_work_dir`
+    os.rename('output.html', 'output')
+
+
+def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None):
+    """visualize pr-curve using matplotlib and output svg image
+    """
+    backend = matplotlib.get_backend()
+    if "inline" not in backend:
+        matplotlib.use("SVG")
+    plt.style.use('seaborn-colorblind')
+    plt.figure()
+
+    for idx in range(df1.shape[1]):
+        y_true = df1.iloc[:, idx].values
+        y_score = df2.iloc[:, idx].values
+
+        precision, recall, _ = precision_recall_curve(
+            y_true, y_score, pos_label=pos_label)
+        ap = average_precision_score(
+            y_true, y_score, pos_label=pos_label or 1)
+
+        plt.step(recall, precision, 'r-', color="black", alpha=0.3,
+                 lw=1, where="post", label='%s (area = %.3f)' % (idx, ap))
+
+    plt.xlim([0.0, 1.0])
+    plt.ylim([0.0, 1.05])
+    plt.xlabel('Recall')
+    plt.ylabel('Precision')
+    title = title or 'Precision-Recall Curve'
+    plt.title(title)
+    folder = os.getcwd()
+    plt.savefig(os.path.join(folder, "output.svg"), format="svg")
+    os.rename(os.path.join(folder, "output.svg"),
+              os.path.join(folder, "output"))
+
+
+def visualize_roc_curve_plotly(df1, df2, pos_label,
+                               drop_intermediate=True,
+                               title=None):
+    """output roc-curve in html using plotly
+
+    df1 : pandas.DataFrame
+        Containing y_true
+    df2 : pandas.DataFrame
+        Containing y_score
+    pos_label : None
+        The label of positive class
+    drop_intermediate : bool
+        Whether to drop some suboptimal thresholds
+    title : str
+        Plot title
+    """
+    data = []
+    for idx in range(df1.shape[1]):
+        y_true = df1.iloc[:, idx].values
+        y_score = df2.iloc[:, idx].values
+
+        fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label,
+                                drop_intermediate=drop_intermediate)
+        roc_auc = auc(fpr, tpr)
+
+        trace = go.Scatter(
+            x=fpr,
+            y=tpr,
+            mode='lines',
+            marker=dict(
+                color=default_colors[idx % len(default_colors)]
+            ),
+            name='%s (area = %.3f)' % (idx, roc_auc)
+        )
+        data.append(trace)
+
+    layout = go.Layout(
+        xaxis=dict(
+            title='False Positive Rate',
+            linecolor='lightslategray',
+            linewidth=1
+        ),
+        yaxis=dict(
+            title='True Positive Rate',
+            linecolor='lightslategray',
+            linewidth=1
+        ),
+        title=dict(
+            text=title or 'Receiver Operating Characteristic (ROC) Curve',
+            x=0.5,
+            y=0.92,
+            xanchor='center',
+            yanchor='top'
+        ),
+        font=dict(
+            family="sans-serif",
+            size=11
+        ),
+        # control backgroud colors
+        plot_bgcolor='rgba(255,255,255,0)'
+    )
+    """
+    # legend=dict(
+            # x=0.95,
+            # y=0,
+            # traceorder="normal",
+            # font=dict(
+            #    family="sans-serif",
+            #    size=9,
+            #    color="black"
+            # ),
+            # bgcolor="LightSteelBlue",
+            # bordercolor="Black",
+            # borderwidth=2
+        # ),
+    """
+
+    fig = go.Figure(data=data, layout=layout)
+
+    plotly.offline.plot(fig, filename="output.html", auto_open=False)
+    # to be discovered by `from_work_dir`
+    os.rename('output.html', 'output')
+
+
+def visualize_roc_curve_matplotlib(df1, df2, pos_label,
+                                   drop_intermediate=True,
+                                   title=None):
+    """visualize roc-curve using matplotlib and output svg image
+    """
+    backend = matplotlib.get_backend()
+    if "inline" not in backend:
+        matplotlib.use("SVG")
+    plt.style.use('seaborn-colorblind')
+    plt.figure()
+
+    for idx in range(df1.shape[1]):
+        y_true = df1.iloc[:, idx].values
+        y_score = df2.iloc[:, idx].values
+
+        fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label,
+                                drop_intermediate=drop_intermediate)
+        roc_auc = auc(fpr, tpr)
+
+        plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1,
+                 where="post", label='%s (area = %.3f)' % (idx, roc_auc))
+
+    plt.xlim([0.0, 1.0])
+    plt.ylim([0.0, 1.05])
+    plt.xlabel('False Positive Rate')
+    plt.ylabel('True Positive Rate')
+    title = title or 'Receiver Operating Characteristic (ROC) Curve'
+    plt.title(title)
+    folder = os.getcwd()
+    plt.savefig(os.path.join(folder, "output.svg"), format="svg")
+    os.rename(os.path.join(folder, "output.svg"),
+              os.path.join(folder, "output"))
+
 
 def main(inputs, infile_estimator=None, infile1=None,
          infile2=None, outfile_result=None,
@@ -71,6 +319,8 @@
 
     title = params['plotting_selection']['title'].strip()
     plot_type = params['plotting_selection']['plot_type']
+    plot_format = params['plotting_selection']['plot_format']
+
     if plot_type == 'feature_importances':
         with open(infile_estimator, 'rb') as estimator_handler:
             estimator = load_model(estimator_handler)
@@ -123,98 +373,46 @@
         layout = go.Layout(title=title or "Feature Importances")
         fig = go.Figure(data=[trace], layout=layout)
 
-    elif plot_type == 'pr_curve':
-        df1 = pd.read_csv(infile1, sep='\t', header=None)
-        df2 = pd.read_csv(infile2, sep='\t', header=None)
+        plotly.offline.plot(fig, filename="output.html",
+                            auto_open=False)
+        # to be discovered by `from_work_dir`
+        os.rename('output.html', 'output')
+
+        return 0
 
-        precision = {}
-        recall = {}
-        ap = {}
+    elif plot_type in ('pr_curve', 'roc_curve'):
+        df1 = pd.read_csv(infile1, sep='\t', header='infer')
+        df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32)
+
+        minimum = params['plotting_selection']['report_minimum_n_positives']
+        # filter out columns whose n_positives is beblow the threhold
+        if minimum:
+            mask = df1.sum(axis=0) >= minimum
+            df1 = df1.loc[:, mask]
+            df2 = df2.loc[:, mask]
 
         pos_label = params['plotting_selection']['pos_label'].strip() \
             or None
-        for col in df1.columns:
-            y_true = df1[col].values
-            y_score = df2[col].values
-
-            precision[col], recall[col], _ = precision_recall_curve(
-                y_true, y_score, pos_label=pos_label)
-            ap[col] = average_precision_score(
-                y_true, y_score, pos_label=pos_label or 1)
-
-        if len(df1.columns) > 1:
-            precision["micro"], recall["micro"], _ = precision_recall_curve(
-                df1.values.ravel(), df2.values.ravel(), pos_label=pos_label)
-            ap['micro'] = average_precision_score(
-                df1.values, df2.values, average='micro',
-                pos_label=pos_label or 1)
-
-        data = []
-        for key in precision.keys():
-            trace = go.Scatter(
-                x=recall[key],
-                y=precision[key],
-                mode='lines',
-                name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro'
-                     else 'column %s (area = %.2f)' % (key, ap[key])
-            )
-            data.append(trace)
-
-        layout = go.Layout(
-            title=title or "Precision-Recall curve",
-            xaxis=dict(title='Recall'),
-            yaxis=dict(title='Precision')
-        )
-
-        fig = go.Figure(data=data, layout=layout)
-
-    elif plot_type == 'roc_curve':
-        df1 = pd.read_csv(infile1, sep='\t', header=None)
-        df2 = pd.read_csv(infile2, sep='\t', header=None)
 
-        fpr = {}
-        tpr = {}
-        roc_auc = {}
-
-        pos_label = params['plotting_selection']['pos_label'].strip() \
-            or None
-        for col in df1.columns:
-            y_true = df1[col].values
-            y_score = df2[col].values
-
-            fpr[col], tpr[col], _ = roc_curve(
-                y_true, y_score, pos_label=pos_label)
-            roc_auc[col] = auc(fpr[col], tpr[col])
-
-        if len(df1.columns) > 1:
-            fpr["micro"], tpr["micro"], _ = roc_curve(
-                df1.values.ravel(), df2.values.ravel(), pos_label=pos_label)
-            roc_auc['micro'] = auc(fpr["micro"], tpr["micro"])
+        if plot_type == 'pr_curve':
+            if plot_format == 'plotly_html':
+                visualize_pr_curve_plotly(df1, df2, pos_label, title=title)
+            else:
+                visualize_pr_curve_matplotlib(df1, df2, pos_label, title)
+        else:          # 'roc_curve'
+            drop_intermediate = (params['plotting_selection']
+                                       ['drop_intermediate'])
+            if plot_format == 'plotly_html':
+                visualize_roc_curve_plotly(df1, df2, pos_label,
+                                           drop_intermediate=drop_intermediate,
+                                           title=title)
+            else:
+                visualize_roc_curve_matplotlib(
+                    df1, df2, pos_label,
+                    drop_intermediate=drop_intermediate,
+                    title=title)
 
-        data = []
-        for key in fpr.keys():
-            trace = go.Scatter(
-                x=fpr[key],
-                y=tpr[key],
-                mode='lines',
-                name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro'
-                     else 'column %s (area = %.2f)' % (key, roc_auc[key])
-            )
-            data.append(trace)
-
-        trace = go.Scatter(x=[0, 1], y=[0, 1],
-                           mode='lines', 
-                           line=dict(color='black', dash='dash'),
-                           showlegend=False)
-        data.append(trace)
-
-        layout = go.Layout(
-            title=title or "Receiver operating characteristic curve",
-            xaxis=dict(title='False Positive Rate'),
-            yaxis=dict(title='True Positive Rate')
-        )
-
-        fig = go.Figure(data=data, layout=layout)
+        return 0
 
     elif plot_type == 'rfecv_gridscores':
         input_df = pd.read_csv(infile1, sep='\t', header='infer')
@@ -231,10 +429,43 @@
         layout = go.Layout(
             xaxis=dict(title="Number of features selected"),
             yaxis=dict(title="Cross validation score"),
-            title=title or None
+            title=dict(
+                text=title or None,
+                x=0.5,
+                y=0.92,
+                xanchor='center',
+                yanchor='top'
+            ),
+            font=dict(
+                family="sans-serif",
+                size=11
+            ),
+            # control backgroud colors
+            plot_bgcolor='rgba(255,255,255,0)'
         )
+        """
+        # legend=dict(
+                # x=0.95,
+                # y=0,
+                # traceorder="normal",
+                # font=dict(
+                #    family="sans-serif",
+                #    size=9,
+                #    color="black"
+                # ),
+                # bgcolor="LightSteelBlue",
+                # bordercolor="Black",
+                # borderwidth=2
+            # ),
+        """
 
         fig = go.Figure(data=[data], layout=layout)
+        plotly.offline.plot(fig, filename="output.html",
+                            auto_open=False)
+        # to be discovered by `from_work_dir`
+        os.rename('output.html', 'output')
+
+        return 0
 
     elif plot_type == 'learning_curve':
         input_df = pd.read_csv(infile1, sep='\t', header='infer')
@@ -264,23 +495,57 @@
             yaxis=dict(
                 title='Performance Score'
             ),
-            title=title or 'Learning Curve'
+            # modify these configurations to customize image
+            title=dict(
+                text=title or 'Learning Curve',
+                x=0.5,
+                y=0.92,
+                xanchor='center',
+                yanchor='top'
+            ),
+            font=dict(
+                family="sans-serif",
+                size=11
+            ),
+            # control backgroud colors
+            plot_bgcolor='rgba(255,255,255,0)'
         )
+        """
+        # legend=dict(
+                # x=0.95,
+                # y=0,
+                # traceorder="normal",
+                # font=dict(
+                #    family="sans-serif",
+                #    size=9,
+                #    color="black"
+                # ),
+                # bgcolor="LightSteelBlue",
+                # bordercolor="Black",
+                # borderwidth=2
+            # ),
+        """
+
         fig = go.Figure(data=[data1, data2], layout=layout)
+        plotly.offline.plot(fig, filename="output.html",
+                            auto_open=False)
+        # to be discovered by `from_work_dir`
+        os.rename('output.html', 'output')
+
+        return 0
 
     elif plot_type == 'keras_plot_model':
         with open(model_config, 'r') as f:
             model_str = f.read()
         model = model_from_json(model_str)
         plot_model(model, to_file="output.png")
-        __import__('os').rename('output.png', 'output')
+        os.rename('output.png', 'output')
 
         return 0
 
-    plotly.offline.plot(fig, filename="output.html",
-                        auto_open=False)
-    # to be discovered by `from_work_dir`
-    __import__('os').rename('output.html', 'output')
+    # save pdf file to disk
+    # fig.write_image("image.pdf", format='pdf')
+    # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
 
 
 if __name__ == '__main__':