Mercurial > repos > bgruening > sklearn_train_test_eval
diff ml_visualization_ex.py @ 9:ead7adad8d0e draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit e2a5eade6d0e5ddf3a47630381a0ad90d80e8a04"
author | bgruening |
---|---|
date | Tue, 13 Apr 2021 18:45:35 +0000 |
parents | 1b68acd5ac08 |
children | a9e0b963b7bb |
line wrap: on
line diff
--- a/ml_visualization_ex.py Fri Oct 02 08:43:15 2020 +0000 +++ b/ml_visualization_ex.py Tue Apr 13 18:45:35 2021 +0000 @@ -22,16 +22,16 @@ # plotly default colors default_colors = [ - '#1f77b4', # muted blue - '#ff7f0e', # safety orange - '#2ca02c', # cooked asparagus green - '#d62728', # brick red - '#9467bd', # muted purple - '#8c564b', # chestnut brown - '#e377c2', # raspberry yogurt pink - '#7f7f7f', # middle gray - '#bcbd22', # curry yellow-green - '#17becf' # blue-teal + "#1f77b4", # muted blue + "#ff7f0e", # safety orange + "#2ca02c", # cooked asparagus green + "#d62728", # brick red + "#9467bd", # muted purple + "#8c564b", # chestnut brown + "#e377c2", # raspberry yogurt pink + "#7f7f7f", # middle gray + "#bcbd22", # curry yellow-green + "#17becf", # blue-teal ] @@ -52,46 +52,31 @@ y_true = df1.iloc[:, idx].values y_score = df2.iloc[:, idx].values - precision, recall, _ = precision_recall_curve( - y_true, y_score, pos_label=pos_label) - ap = average_precision_score( - y_true, y_score, pos_label=pos_label or 1) + precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) + ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) trace = go.Scatter( x=recall, y=precision, - mode='lines', - marker=dict( - color=default_colors[idx % len(default_colors)] - ), - name='%s (area = %.3f)' % (idx, ap) + mode="lines", + marker=dict(color=default_colors[idx % len(default_colors)]), + name="%s (area = %.3f)" % (idx, ap), ) data.append(trace) layout = go.Layout( - xaxis=dict( - title='Recall', - linecolor='lightslategray', - linewidth=1 - ), - yaxis=dict( - title='Precision', - linecolor='lightslategray', - linewidth=1 - ), + xaxis=dict(title="Recall", linecolor="lightslategray", linewidth=1), + yaxis=dict(title="Precision", linecolor="lightslategray", linewidth=1), title=dict( - text=title or 'Precision-Recall Curve', + text=title or "Precision-Recall Curve", x=0.5, y=0.92, - xanchor='center', - yanchor='top' + xanchor="center", + yanchor="top", ), - font=dict( - family="sans-serif", - size=11 - ), + font=dict(family="sans-serif", size=11), # control backgroud colors - plot_bgcolor='rgba(255,255,255,0)' + plot_bgcolor="rgba(255,255,255,0)", ) """ legend=dict( @@ -112,45 +97,47 @@ plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` - os.rename('output.html', 'output') + os.rename("output.html", "output") def visualize_pr_curve_matplotlib(df1, df2, pos_label, title=None): - """visualize pr-curve using matplotlib and output svg image - """ + """visualize pr-curve using matplotlib and output svg image""" backend = matplotlib.get_backend() if "inline" not in backend: matplotlib.use("SVG") - plt.style.use('seaborn-colorblind') + plt.style.use("seaborn-colorblind") plt.figure() for idx in range(df1.shape[1]): y_true = df1.iloc[:, idx].values y_score = df2.iloc[:, idx].values - precision, recall, _ = precision_recall_curve( - y_true, y_score, pos_label=pos_label) - ap = average_precision_score( - y_true, y_score, pos_label=pos_label or 1) + precision, recall, _ = precision_recall_curve(y_true, y_score, pos_label=pos_label) + ap = average_precision_score(y_true, y_score, pos_label=pos_label or 1) - plt.step(recall, precision, 'r-', color="black", alpha=0.3, - lw=1, where="post", label='%s (area = %.3f)' % (idx, ap)) + plt.step( + recall, + precision, + "r-", + color="black", + alpha=0.3, + lw=1, + where="post", + label="%s (area = %.3f)" % (idx, ap), + ) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) - plt.xlabel('Recall') - plt.ylabel('Precision') - title = title or 'Precision-Recall Curve' + plt.xlabel("Recall") + plt.ylabel("Precision") + title = title or "Precision-Recall Curve" plt.title(title) folder = os.getcwd() plt.savefig(os.path.join(folder, "output.svg"), format="svg") - os.rename(os.path.join(folder, "output.svg"), - os.path.join(folder, "output")) + os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) -def visualize_roc_curve_plotly(df1, df2, pos_label, - drop_intermediate=True, - title=None): +def visualize_roc_curve_plotly(df1, df2, pos_label, drop_intermediate=True, title=None): """output roc-curve in html using plotly df1 : pandas.DataFrame @@ -169,45 +156,31 @@ y_true = df1.iloc[:, idx].values y_score = df2.iloc[:, idx].values - fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, - drop_intermediate=drop_intermediate) + fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) roc_auc = auc(fpr, tpr) trace = go.Scatter( x=fpr, y=tpr, - mode='lines', - marker=dict( - color=default_colors[idx % len(default_colors)] - ), - name='%s (area = %.3f)' % (idx, roc_auc) + mode="lines", + marker=dict(color=default_colors[idx % len(default_colors)]), + name="%s (area = %.3f)" % (idx, roc_auc), ) data.append(trace) layout = go.Layout( - xaxis=dict( - title='False Positive Rate', - linecolor='lightslategray', - linewidth=1 - ), - yaxis=dict( - title='True Positive Rate', - linecolor='lightslategray', - linewidth=1 - ), + xaxis=dict(title="False Positive Rate", linecolor="lightslategray", linewidth=1), + yaxis=dict(title="True Positive Rate", linecolor="lightslategray", linewidth=1), title=dict( - text=title or 'Receiver Operating Characteristic (ROC) Curve', + text=title or "Receiver Operating Characteristic (ROC) Curve", x=0.5, y=0.92, - xanchor='center', - yanchor='top' + xanchor="center", + yanchor="top", ), - font=dict( - family="sans-serif", - size=11 - ), + font=dict(family="sans-serif", size=11), # control backgroud colors - plot_bgcolor='rgba(255,255,255,0)' + plot_bgcolor="rgba(255,255,255,0)", ) """ # legend=dict( @@ -229,66 +202,84 @@ plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` - os.rename('output.html', 'output') + os.rename("output.html", "output") -def visualize_roc_curve_matplotlib(df1, df2, pos_label, - drop_intermediate=True, - title=None): - """visualize roc-curve using matplotlib and output svg image - """ +def visualize_roc_curve_matplotlib(df1, df2, pos_label, drop_intermediate=True, title=None): + """visualize roc-curve using matplotlib and output svg image""" backend = matplotlib.get_backend() if "inline" not in backend: matplotlib.use("SVG") - plt.style.use('seaborn-colorblind') + plt.style.use("seaborn-colorblind") plt.figure() for idx in range(df1.shape[1]): y_true = df1.iloc[:, idx].values y_score = df2.iloc[:, idx].values - fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, - drop_intermediate=drop_intermediate) + fpr, tpr, _ = roc_curve(y_true, y_score, pos_label=pos_label, drop_intermediate=drop_intermediate) roc_auc = auc(fpr, tpr) - plt.step(fpr, tpr, 'r-', color="black", alpha=0.3, lw=1, - where="post", label='%s (area = %.3f)' % (idx, roc_auc)) + plt.step( + fpr, + tpr, + "r-", + color="black", + alpha=0.3, + lw=1, + where="post", + label="%s (area = %.3f)" % (idx, roc_auc), + ) plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) - plt.xlabel('False Positive Rate') - plt.ylabel('True Positive Rate') - title = title or 'Receiver Operating Characteristic (ROC) Curve' + plt.xlabel("False Positive Rate") + plt.ylabel("True Positive Rate") + title = title or "Receiver Operating Characteristic (ROC) Curve" plt.title(title) folder = os.getcwd() plt.savefig(os.path.join(folder, "output.svg"), format="svg") - os.rename(os.path.join(folder, "output.svg"), - os.path.join(folder, "output")) + os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) def get_dataframe(file_path, plot_selection, header_name, column_name): - header = 'infer' if plot_selection[header_name] else None + header = "infer" if plot_selection[header_name] else None column_option = plot_selection[column_name]["selected_column_selector_option"] - if column_option in ["by_index_number", "all_but_by_index_number", "by_header_name", "all_but_by_header_name"]: + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: col = plot_selection[column_name]["col1"] else: col = None _, input_df = read_columns(file_path, c=col, - c_option=column_option, - return_df=True, - sep='\t', header=header, - parse_dates=True) + c_option=column_option, + return_df=True, + sep='\t', header=header, + parse_dates=True) return input_df -def main(inputs, infile_estimator=None, infile1=None, - infile2=None, outfile_result=None, - outfile_object=None, groups=None, - ref_seq=None, intervals=None, - targets=None, fasta_path=None, - model_config=None, true_labels=None, - predicted_labels=None, plot_color=None, - title=None): +def main( + inputs, + infile_estimator=None, + infile1=None, + infile2=None, + outfile_result=None, + outfile_object=None, + groups=None, + ref_seq=None, + intervals=None, + targets=None, + fasta_path=None, + model_config=None, + true_labels=None, + predicted_labels=None, + plot_color=None, + title=None, +): """ Parameter --------- @@ -341,34 +332,39 @@ title : str, default is None Title of the confusion matrix heatmap """ - warnings.simplefilter('ignore') + warnings.simplefilter("ignore") - with open(inputs, 'r') as param_handler: + with open(inputs, "r") as param_handler: params = json.load(param_handler) - title = params['plotting_selection']['title'].strip() - plot_type = params['plotting_selection']['plot_type'] - plot_format = params['plotting_selection']['plot_format'] + title = params["plotting_selection"]["title"].strip() + plot_type = params["plotting_selection"]["plot_type"] + plot_format = params["plotting_selection"]["plot_format"] - if plot_type == 'feature_importances': - with open(infile_estimator, 'rb') as estimator_handler: + if plot_type == "feature_importances": + with open(infile_estimator, "rb") as estimator_handler: estimator = load_model(estimator_handler) - column_option = (params['plotting_selection'] - ['column_selector_options'] - ['selected_column_selector_option']) - if column_option in ['by_index_number', 'all_but_by_index_number', - 'by_header_name', 'all_but_by_header_name']: - c = (params['plotting_selection'] - ['column_selector_options']['col1']) + column_option = params["plotting_selection"]["column_selector_options"]["selected_column_selector_option"] + if column_option in [ + "by_index_number", + "all_but_by_index_number", + "by_header_name", + "all_but_by_header_name", + ]: + c = params["plotting_selection"]["column_selector_options"]["col1"] else: c = None - _, input_df = read_columns(infile1, c=c, - c_option=column_option, - return_df=True, - sep='\t', header='infer', - parse_dates=True) + _, input_df = read_columns( + infile1, + c=c, + c_option=column_option, + return_df=True, + sep="\t", + header="infer", + parse_dates=True, + ) feature_names = input_df.columns.values @@ -379,16 +375,14 @@ feature_names = feature_names[mask] estimator = estimator.steps[-1][-1] - if hasattr(estimator, 'coef_'): + if hasattr(estimator, "coef_"): coefs = estimator.coef_ else: - coefs = getattr(estimator, 'feature_importances_', None) + coefs = getattr(estimator, "feature_importances_", None) if coefs is None: - raise RuntimeError('The classifier does not expose ' - '"coef_" or "feature_importances_" ' - 'attributes') + raise RuntimeError("The classifier does not expose " '"coef_" or "feature_importances_" ' "attributes") - threshold = params['plotting_selection']['threshold'] + threshold = params["plotting_selection"]["threshold"] if threshold is not None: mask = (coefs > threshold) | (coefs < -threshold) coefs = coefs[mask] @@ -397,80 +391,74 @@ # sort indices = np.argsort(coefs)[::-1] - trace = go.Bar(x=feature_names[indices], - y=coefs[indices]) + trace = go.Bar(x=feature_names[indices], y=coefs[indices]) layout = go.Layout(title=title or "Feature Importances") fig = go.Figure(data=[trace], layout=layout) - plotly.offline.plot(fig, filename="output.html", - auto_open=False) + plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` - os.rename('output.html', 'output') + os.rename("output.html", "output") return 0 - elif plot_type in ('pr_curve', 'roc_curve'): - df1 = pd.read_csv(infile1, sep='\t', header='infer') - df2 = pd.read_csv(infile2, sep='\t', header='infer').astype(np.float32) + elif plot_type in ("pr_curve", "roc_curve"): + df1 = pd.read_csv(infile1, sep="\t", header="infer") + df2 = pd.read_csv(infile2, sep="\t", header="infer").astype(np.float32) - minimum = params['plotting_selection']['report_minimum_n_positives'] + minimum = params["plotting_selection"]["report_minimum_n_positives"] # filter out columns whose n_positives is beblow the threhold if minimum: mask = df1.sum(axis=0) >= minimum df1 = df1.loc[:, mask] df2 = df2.loc[:, mask] - pos_label = params['plotting_selection']['pos_label'].strip() \ - or None + pos_label = params["plotting_selection"]["pos_label"].strip() or None - if plot_type == 'pr_curve': - if plot_format == 'plotly_html': + if plot_type == "pr_curve": + if plot_format == "plotly_html": visualize_pr_curve_plotly(df1, df2, pos_label, title=title) else: visualize_pr_curve_matplotlib(df1, df2, pos_label, title) - else: # 'roc_curve' - drop_intermediate = (params['plotting_selection'] - ['drop_intermediate']) - if plot_format == 'plotly_html': - visualize_roc_curve_plotly(df1, df2, pos_label, - drop_intermediate=drop_intermediate, - title=title) + else: # 'roc_curve' + drop_intermediate = params["plotting_selection"]["drop_intermediate"] + if plot_format == "plotly_html": + visualize_roc_curve_plotly( + df1, + df2, + pos_label, + drop_intermediate=drop_intermediate, + title=title, + ) else: visualize_roc_curve_matplotlib( - df1, df2, pos_label, + df1, + df2, + pos_label, drop_intermediate=drop_intermediate, - title=title) + title=title, + ) return 0 - elif plot_type == 'rfecv_gridscores': - input_df = pd.read_csv(infile1, sep='\t', header='infer') + elif plot_type == "rfecv_gridscores": + input_df = pd.read_csv(infile1, sep="\t", header="infer") scores = input_df.iloc[:, 0] - steps = params['plotting_selection']['steps'].strip() + steps = params["plotting_selection"]["steps"].strip() steps = safe_eval(steps) data = go.Scatter( x=list(range(len(scores))), y=scores, text=[str(_) for _ in steps] if steps else None, - mode='lines' + mode="lines", ) layout = go.Layout( xaxis=dict(title="Number of features selected"), yaxis=dict(title="Cross validation score"), - title=dict( - text=title or None, - x=0.5, - y=0.92, - xanchor='center', - yanchor='top' - ), - font=dict( - family="sans-serif", - size=11 - ), + title=dict(text=title or None, x=0.5, y=0.92, xanchor="center", yanchor="top"), + font=dict(family="sans-serif", size=11), # control backgroud colors - plot_bgcolor='rgba(255,255,255,0)' + plot_bgcolor="rgba(255,255,255,0)", ) """ # legend=dict( @@ -489,55 +477,43 @@ """ fig = go.Figure(data=[data], layout=layout) - plotly.offline.plot(fig, filename="output.html", - auto_open=False) + plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` - os.rename('output.html', 'output') + os.rename("output.html", "output") return 0 - elif plot_type == 'learning_curve': - input_df = pd.read_csv(infile1, sep='\t', header='infer') - plot_std_err = params['plotting_selection']['plot_std_err'] + elif plot_type == "learning_curve": + input_df = pd.read_csv(infile1, sep="\t", header="infer") + plot_std_err = params["plotting_selection"]["plot_std_err"] data1 = go.Scatter( - x=input_df['train_sizes_abs'], - y=input_df['mean_train_scores'], - error_y=dict( - array=input_df['std_train_scores'] - ) if plot_std_err else None, - mode='lines', + x=input_df["train_sizes_abs"], + y=input_df["mean_train_scores"], + error_y=dict(array=input_df["std_train_scores"]) if plot_std_err else None, + mode="lines", name="Train Scores", ) data2 = go.Scatter( - x=input_df['train_sizes_abs'], - y=input_df['mean_test_scores'], - error_y=dict( - array=input_df['std_test_scores'] - ) if plot_std_err else None, - mode='lines', + x=input_df["train_sizes_abs"], + y=input_df["mean_test_scores"], + error_y=dict(array=input_df["std_test_scores"]) if plot_std_err else None, + mode="lines", name="Test Scores", ) layout = dict( - xaxis=dict( - title='No. of samples' - ), - yaxis=dict( - title='Performance Score' - ), + xaxis=dict(title="No. of samples"), + yaxis=dict(title="Performance Score"), # modify these configurations to customize image title=dict( - text=title or 'Learning Curve', + text=title or "Learning Curve", x=0.5, y=0.92, - xanchor='center', - yanchor='top' + xanchor="center", + yanchor="top", ), - font=dict( - family="sans-serif", - size=11 - ), + font=dict(family="sans-serif", size=11), # control backgroud colors - plot_bgcolor='rgba(255,255,255,0)' + plot_bgcolor="rgba(255,255,255,0)", ) """ # legend=dict( @@ -556,27 +532,26 @@ """ fig = go.Figure(data=[data1, data2], layout=layout) - plotly.offline.plot(fig, filename="output.html", - auto_open=False) + plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` - os.rename('output.html', 'output') + os.rename("output.html", "output") return 0 - elif plot_type == 'keras_plot_model': - with open(model_config, 'r') as f: + elif plot_type == "keras_plot_model": + with open(model_config, "r") as f: model_str = f.read() model = model_from_json(model_str) plot_model(model, to_file="output.png") - os.rename('output.png', 'output') + os.rename("output.png", "output") return 0 - elif plot_type == 'classification_confusion_matrix': + elif plot_type == "classification_confusion_matrix": plot_selection = params["plotting_selection"] input_true = get_dataframe(true_labels, plot_selection, "header_true", "column_selector_options_true") - header_predicted = 'infer' if plot_selection["header_predicted"] else None - input_predicted = pd.read_csv(predicted_labels, sep='\t', parse_dates=True, header=header_predicted) + header_predicted = "infer" if plot_selection["header_predicted"] else None + input_predicted = pd.read_csv(predicted_labels, sep="\t", parse_dates=True, header=header_predicted) true_classes = input_true.iloc[:, -1].copy() predicted_classes = input_predicted.iloc[:, -1].copy() axis_labels = list(set(true_classes)) @@ -586,15 +561,15 @@ for i in range(len(c_matrix)): for j in range(len(c_matrix)): ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") - ax.set_ylabel('True class labels') - ax.set_xlabel('Predicted class labels') + ax.set_ylabel("True class labels") + ax.set_xlabel("Predicted class labels") ax.set_title(title) ax.set_xticks(axis_labels) ax.set_yticks(axis_labels) fig.colorbar(im, ax=ax) fig.tight_layout() plt.savefig("output.png", dpi=125) - os.rename('output.png', 'output') + os.rename("output.png", "output") return 0 @@ -603,7 +578,7 @@ # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) -if __name__ == '__main__': +if __name__ == "__main__": aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-e", "--estimator", dest="infile_estimator") @@ -623,11 +598,21 @@ aparser.add_argument("-pt", "--title", dest="title") args = aparser.parse_args() - main(args.inputs, args.infile_estimator, args.infile1, args.infile2, - args.outfile_result, outfile_object=args.outfile_object, - groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, - targets=args.targets, fasta_path=args.fasta_path, - model_config=args.model_config, true_labels=args.true_labels, - predicted_labels=args.predicted_labels, - plot_color=args.plot_color, - title=args.title) + main( + args.inputs, + args.infile_estimator, + args.infile1, + args.infile2, + args.outfile_result, + outfile_object=args.outfile_object, + groups=args.groups, + ref_seq=args.ref_seq, + intervals=args.intervals, + targets=args.targets, + fasta_path=args.fasta_path, + model_config=args.model_config, + true_labels=args.true_labels, + predicted_labels=args.predicted_labels, + plot_color=args.plot_color, + title=args.title, + )