Mercurial > repos > bgruening > sklearn_nn_classifier
view ml_visualization_ex.py @ 14:9871a634540f draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 02087ce2966cf8b4aac9197a41171e7f986c11d1-dirty"
author | bgruening |
---|---|
date | Wed, 02 Oct 2019 03:58:00 -0400 |
parents | fbc38059bb8f |
children | 699024d5c451 |
line wrap: on
line source
import argparse import json import numpy as np import pandas as pd import plotly import plotly.graph_objs as go import warnings from keras.models import model_from_json from keras.utils import plot_model from sklearn.feature_selection.base import SelectorMixin from sklearn.metrics import precision_recall_curve, average_precision_score from sklearn.metrics import roc_curve, auc from sklearn.pipeline import Pipeline from galaxy_ml.utils import load_model, read_columns, SafeEval safe_eval = SafeEval() def main(inputs, infile_estimator=None, infile1=None, infile2=None, outfile_result=None, outfile_object=None, groups=None, ref_seq=None, intervals=None, targets=None, fasta_path=None, model_config=None): """ Parameter --------- inputs : str File path to galaxy tool parameter infile_estimator : str, default is None File path to estimator infile1 : str, default is None File path to dataset containing features or true labels. infile2 : str, default is None File path to dataset containing target values or predicted probabilities. outfile_result : str, default is None File path to save the results, either cv_results or test result outfile_object : str, default is None File path to save searchCV object groups : str, default is None File path to dataset containing groups labels ref_seq : str, default is None File path to dataset containing genome sequence file intervals : str, default is None File path to dataset containing interval file targets : str, default is None File path to dataset compressed target bed file fasta_path : str, default is None File path to dataset containing fasta file model_config : str, default is None File path to dataset containing JSON config for neural networks """ warnings.simplefilter('ignore') with open(inputs, 'r') as param_handler: params = json.load(param_handler) title = params['plotting_selection']['title'].strip() plot_type = params['plotting_selection']['plot_type'] if plot_type == 'feature_importances': with open(infile_estimator, 'rb') as estimator_handler: estimator = load_model(estimator_handler) column_option = (params['plotting_selection'] ['column_selector_options'] ['selected_column_selector_option']) if column_option in ['by_index_number', 'all_but_by_index_number', 'by_header_name', 'all_but_by_header_name']: c = (params['plotting_selection'] ['column_selector_options']['col1']) else: c = None _, input_df = read_columns(infile1, c=c, c_option=column_option, return_df=True, sep='\t', header='infer', parse_dates=True) feature_names = input_df.columns.values if isinstance(estimator, Pipeline): for st in estimator.steps[:-1]: if isinstance(st[-1], SelectorMixin): mask = st[-1].get_support() feature_names = feature_names[mask] estimator = estimator.steps[-1][-1] if hasattr(estimator, 'coef_'): coefs = estimator.coef_ else: coefs = getattr(estimator, 'feature_importances_', None) if coefs is None: raise RuntimeError('The classifier does not expose ' '"coef_" or "feature_importances_" ' 'attributes') threshold = params['plotting_selection']['threshold'] if threshold is not None: mask = (coefs > threshold) | (coefs < -threshold) coefs = coefs[mask] feature_names = feature_names[mask] # sort indices = np.argsort(coefs)[::-1] trace = go.Bar(x=feature_names[indices], y=coefs[indices]) layout = go.Layout(title=title or "Feature Importances") fig = go.Figure(data=[trace], layout=layout) elif plot_type == 'pr_curve': df1 = pd.read_csv(infile1, sep='\t', header=None) df2 = pd.read_csv(infile2, sep='\t', header=None) precision = {} recall = {} ap = {} pos_label = params['plotting_selection']['pos_label'].strip() \ or None for col in df1.columns: y_true = df1[col].values y_score = df2[col].values precision[col], recall[col], _ = precision_recall_curve( y_true, y_score, pos_label=pos_label) ap[col] = average_precision_score( y_true, y_score, pos_label=pos_label or 1) if len(df1.columns) > 1: precision["micro"], recall["micro"], _ = precision_recall_curve( df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) ap['micro'] = average_precision_score( df1.values, df2.values, average='micro', pos_label=pos_label or 1) data = [] for key in precision.keys(): trace = go.Scatter( x=recall[key], y=precision[key], mode='lines', name='%s (area = %.2f)' % (key, ap[key]) if key == 'micro' else 'column %s (area = %.2f)' % (key, ap[key]) ) data.append(trace) layout = go.Layout( title=title or "Precision-Recall curve", xaxis=dict(title='Recall'), yaxis=dict(title='Precision') ) fig = go.Figure(data=data, layout=layout) elif plot_type == 'roc_curve': df1 = pd.read_csv(infile1, sep='\t', header=None) df2 = pd.read_csv(infile2, sep='\t', header=None) fpr = {} tpr = {} roc_auc = {} pos_label = params['plotting_selection']['pos_label'].strip() \ or None for col in df1.columns: y_true = df1[col].values y_score = df2[col].values fpr[col], tpr[col], _ = roc_curve( y_true, y_score, pos_label=pos_label) roc_auc[col] = auc(fpr[col], tpr[col]) if len(df1.columns) > 1: fpr["micro"], tpr["micro"], _ = roc_curve( df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) roc_auc['micro'] = auc(fpr["micro"], tpr["micro"]) data = [] for key in fpr.keys(): trace = go.Scatter( x=fpr[key], y=tpr[key], mode='lines', name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro' else 'column %s (area = %.2f)' % (key, roc_auc[key]) ) data.append(trace) trace = go.Scatter(x=[0, 1], y=[0, 1], mode='lines', line=dict(color='black', dash='dash'), showlegend=False) data.append(trace) layout = go.Layout( title=title or "Receiver operating characteristic curve", xaxis=dict(title='False Positive Rate'), yaxis=dict(title='True Positive Rate') ) fig = go.Figure(data=data, layout=layout) elif plot_type == 'rfecv_gridscores': input_df = pd.read_csv(infile1, sep='\t', header='infer') scores = input_df.iloc[:, 0] steps = params['plotting_selection']['steps'].strip() steps = safe_eval(steps) data = go.Scatter( x=list(range(len(scores))), y=scores, text=[str(_) for _ in steps] if steps else None, mode='lines' ) layout = go.Layout( xaxis=dict(title="Number of features selected"), yaxis=dict(title="Cross validation score"), title=title or None ) fig = go.Figure(data=[data], layout=layout) elif plot_type == 'learning_curve': input_df = pd.read_csv(infile1, sep='\t', header='infer') plot_std_err = params['plotting_selection']['plot_std_err'] data1 = go.Scatter( x=input_df['train_sizes_abs'], y=input_df['mean_train_scores'], error_y=dict( array=input_df['std_train_scores'] ) if plot_std_err else None, mode='lines', name="Train Scores", ) data2 = go.Scatter( x=input_df['train_sizes_abs'], y=input_df['mean_test_scores'], error_y=dict( array=input_df['std_test_scores'] ) if plot_std_err else None, mode='lines', name="Test Scores", ) layout = dict( xaxis=dict( title='No. of samples' ), yaxis=dict( title='Performance Score' ), title=title or 'Learning Curve' ) fig = go.Figure(data=[data1, data2], layout=layout) elif plot_type == 'keras_plot_model': with open(model_config, 'r') as f: model_str = f.read() model = model_from_json(model_str) plot_model(model, to_file="output.png") __import__('os').rename('output.png', 'output') return 0 plotly.offline.plot(fig, filename="output.html", auto_open=False) # to be discovered by `from_work_dir` __import__('os').rename('output.html', 'output') if __name__ == '__main__': aparser = argparse.ArgumentParser() aparser.add_argument("-i", "--inputs", dest="inputs", required=True) aparser.add_argument("-e", "--estimator", dest="infile_estimator") aparser.add_argument("-X", "--infile1", dest="infile1") aparser.add_argument("-y", "--infile2", dest="infile2") aparser.add_argument("-O", "--outfile_result", dest="outfile_result") aparser.add_argument("-o", "--outfile_object", dest="outfile_object") aparser.add_argument("-g", "--groups", dest="groups") aparser.add_argument("-r", "--ref_seq", dest="ref_seq") aparser.add_argument("-b", "--intervals", dest="intervals") aparser.add_argument("-t", "--targets", dest="targets") aparser.add_argument("-f", "--fasta_path", dest="fasta_path") aparser.add_argument("-c", "--model_config", dest="model_config") args = aparser.parse_args() main(args.inputs, args.infile_estimator, args.infile1, args.infile2, args.outfile_result, outfile_object=args.outfile_object, groups=args.groups, ref_seq=args.ref_seq, intervals=args.intervals, targets=args.targets, fasta_path=args.fasta_path, model_config=args.model_config)