comparison ml_visualization_ex.py @ 2:e23cfe4be9d4 draft

"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 02087ce2966cf8b4aac9197a41171e7f986c11d1-dirty"
author bgruening
date Wed, 02 Oct 2019 03:46:45 -0400
parents cc49634df38f
children 2b8406e74f9e
comparison
equal deleted inserted replaced
1:cc49634df38f 2:e23cfe4be9d4
144 144
145 if len(df1.columns) > 1: 145 if len(df1.columns) > 1:
146 precision["micro"], recall["micro"], _ = precision_recall_curve( 146 precision["micro"], recall["micro"], _ = precision_recall_curve(
147 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label) 147 df1.values.ravel(), df2.values.ravel(), pos_label=pos_label)
148 ap['micro'] = average_precision_score( 148 ap['micro'] = average_precision_score(
149 df1.values, df2.values, average='micro', pos_label=pos_label or 1) 149 df1.values, df2.values, average='micro',
150 pos_label=pos_label or 1)
150 151
151 data = [] 152 data = []
152 for key in precision.keys(): 153 for key in precision.keys():
153 trace = go.Scatter( 154 trace = go.Scatter(
154 x=recall[key], 155 x=recall[key],
199 name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro' 200 name='%s (area = %.2f)' % (key, roc_auc[key]) if key == 'micro'
200 else 'column %s (area = %.2f)' % (key, roc_auc[key]) 201 else 'column %s (area = %.2f)' % (key, roc_auc[key])
201 ) 202 )
202 data.append(trace) 203 data.append(trace)
203 204
204 trace = go.Scatter(x=[0, 1], y=[0, 1], 205 trace = go.Scatter(x=[0, 1], y=[0, 1],
205 mode='lines', 206 mode='lines',
206 line=dict(color='black', dash='dash'), 207 line=dict(color='black', dash='dash'),
207 showlegend=False) 208 showlegend=False)
208 data.append(trace) 209 data.append(trace)
209 210