Mercurial > repos > bgruening > model_prediction
comparison ml_visualization_ex.py @ 15:3bb1b688b0e4 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 13:06:25 +0000 |
parents | 60c726f5cc5e |
children | 980bf31faa05 |
comparison
equal
deleted
inserted
replaced
14:29e4122c90de | 15:3bb1b688b0e4 |
---|---|
7 import matplotlib.pyplot as plt | 7 import matplotlib.pyplot as plt |
8 import numpy as np | 8 import numpy as np |
9 import pandas as pd | 9 import pandas as pd |
10 import plotly | 10 import plotly |
11 import plotly.graph_objs as go | 11 import plotly.graph_objs as go |
12 from galaxy_ml.utils import load_model, read_columns, SafeEval | 12 from galaxy_ml.model_persist import load_model_from_h5 |
13 from keras.models import model_from_json | 13 from galaxy_ml.utils import read_columns, SafeEval |
14 from keras.utils import plot_model | 14 from sklearn.feature_selection._base import SelectorMixin |
15 from sklearn.feature_selection.base import SelectorMixin | 15 from sklearn.metrics import ( |
16 from sklearn.metrics import (auc, average_precision_score, confusion_matrix, | 16 auc, |
17 precision_recall_curve, roc_curve) | 17 average_precision_score, |
18 precision_recall_curve, | |
19 roc_curve, | |
20 ) | |
18 from sklearn.pipeline import Pipeline | 21 from sklearn.pipeline import Pipeline |
22 from tensorflow.keras.models import model_from_json | |
23 from tensorflow.keras.utils import plot_model | |
19 | 24 |
20 safe_eval = SafeEval() | 25 safe_eval = SafeEval() |
21 | 26 |
22 # plotly default colors | 27 # plotly default colors |
23 default_colors = [ | 28 default_colors = [ |
251 folder = os.getcwd() | 256 folder = os.getcwd() |
252 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 257 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
253 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) | 258 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) |
254 | 259 |
255 | 260 |
256 def get_dataframe(file_path, plot_selection, header_name, column_name): | |
257 header = "infer" if plot_selection[header_name] else None | |
258 column_option = plot_selection[column_name]["selected_column_selector_option"] | |
259 if column_option in [ | |
260 "by_index_number", | |
261 "all_but_by_index_number", | |
262 "by_header_name", | |
263 "all_but_by_header_name", | |
264 ]: | |
265 col = plot_selection[column_name]["col1"] | |
266 else: | |
267 col = None | |
268 _, input_df = read_columns( | |
269 file_path, | |
270 c=col, | |
271 c_option=column_option, | |
272 return_df=True, | |
273 sep="\t", | |
274 header=header, | |
275 parse_dates=True, | |
276 ) | |
277 return input_df | |
278 | |
279 | |
280 def main( | 261 def main( |
281 inputs, | 262 inputs, |
282 infile_estimator=None, | 263 infile_estimator=None, |
283 infile1=None, | 264 infile1=None, |
284 infile2=None, | 265 infile2=None, |
288 ref_seq=None, | 269 ref_seq=None, |
289 intervals=None, | 270 intervals=None, |
290 targets=None, | 271 targets=None, |
291 fasta_path=None, | 272 fasta_path=None, |
292 model_config=None, | 273 model_config=None, |
293 true_labels=None, | |
294 predicted_labels=None, | |
295 plot_color=None, | |
296 title=None, | |
297 ): | 274 ): |
298 """ | 275 """ |
299 Parameter | 276 Parameter |
300 --------- | 277 --------- |
301 inputs : str | 278 inputs : str |
332 fasta_path : str, default is None | 309 fasta_path : str, default is None |
333 File path to dataset containing fasta file | 310 File path to dataset containing fasta file |
334 | 311 |
335 model_config : str, default is None | 312 model_config : str, default is None |
336 File path to dataset containing JSON config for neural networks | 313 File path to dataset containing JSON config for neural networks |
337 | |
338 true_labels : str, default is None | |
339 File path to dataset containing true labels | |
340 | |
341 predicted_labels : str, default is None | |
342 File path to dataset containing true predicted labels | |
343 | |
344 plot_color : str, default is None | |
345 Color of the confusion matrix heatmap | |
346 | |
347 title : str, default is None | |
348 Title of the confusion matrix heatmap | |
349 """ | 314 """ |
350 warnings.simplefilter("ignore") | 315 warnings.simplefilter("ignore") |
351 | 316 |
352 with open(inputs, "r") as param_handler: | 317 with open(inputs, "r") as param_handler: |
353 params = json.load(param_handler) | 318 params = json.load(param_handler) |
355 title = params["plotting_selection"]["title"].strip() | 320 title = params["plotting_selection"]["title"].strip() |
356 plot_type = params["plotting_selection"]["plot_type"] | 321 plot_type = params["plotting_selection"]["plot_type"] |
357 plot_format = params["plotting_selection"]["plot_format"] | 322 plot_format = params["plotting_selection"]["plot_format"] |
358 | 323 |
359 if plot_type == "feature_importances": | 324 if plot_type == "feature_importances": |
360 with open(infile_estimator, "rb") as estimator_handler: | 325 estimator = load_model_from_h5(infile_estimator) |
361 estimator = load_model(estimator_handler) | |
362 | 326 |
363 column_option = params["plotting_selection"]["column_selector_options"][ | 327 column_option = params["plotting_selection"]["column_selector_options"][ |
364 "selected_column_selector_option" | 328 "selected_column_selector_option" |
365 ] | 329 ] |
366 if column_option in [ | 330 if column_option in [ |
568 plot_model(model, to_file="output.png") | 532 plot_model(model, to_file="output.png") |
569 os.rename("output.png", "output") | 533 os.rename("output.png", "output") |
570 | 534 |
571 return 0 | 535 return 0 |
572 | 536 |
573 elif plot_type == "classification_confusion_matrix": | |
574 plot_selection = params["plotting_selection"] | |
575 input_true = get_dataframe( | |
576 true_labels, plot_selection, "header_true", "column_selector_options_true" | |
577 ) | |
578 header_predicted = "infer" if plot_selection["header_predicted"] else None | |
579 input_predicted = pd.read_csv( | |
580 predicted_labels, sep="\t", parse_dates=True, header=header_predicted | |
581 ) | |
582 true_classes = input_true.iloc[:, -1].copy() | |
583 predicted_classes = input_predicted.iloc[:, -1].copy() | |
584 axis_labels = list(set(true_classes)) | |
585 c_matrix = confusion_matrix(true_classes, predicted_classes) | |
586 fig, ax = plt.subplots(figsize=(7, 7)) | |
587 im = plt.imshow(c_matrix, cmap=plot_color) | |
588 for i in range(len(c_matrix)): | |
589 for j in range(len(c_matrix)): | |
590 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") | |
591 ax.set_ylabel("True class labels") | |
592 ax.set_xlabel("Predicted class labels") | |
593 ax.set_title(title) | |
594 ax.set_xticks(axis_labels) | |
595 ax.set_yticks(axis_labels) | |
596 fig.colorbar(im, ax=ax) | |
597 fig.tight_layout() | |
598 plt.savefig("output.png", dpi=125) | |
599 os.rename("output.png", "output") | |
600 | |
601 return 0 | |
602 | |
603 # save pdf file to disk | 537 # save pdf file to disk |
604 # fig.write_image("image.pdf", format='pdf') | 538 # fig.write_image("image.pdf", format='pdf') |
605 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | 539 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) |
606 | 540 |
607 | 541 |
617 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 551 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
618 aparser.add_argument("-b", "--intervals", dest="intervals") | 552 aparser.add_argument("-b", "--intervals", dest="intervals") |
619 aparser.add_argument("-t", "--targets", dest="targets") | 553 aparser.add_argument("-t", "--targets", dest="targets") |
620 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 554 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
621 aparser.add_argument("-c", "--model_config", dest="model_config") | 555 aparser.add_argument("-c", "--model_config", dest="model_config") |
622 aparser.add_argument("-tl", "--true_labels", dest="true_labels") | |
623 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") | |
624 aparser.add_argument("-pc", "--plot_color", dest="plot_color") | |
625 aparser.add_argument("-pt", "--title", dest="title") | |
626 args = aparser.parse_args() | 556 args = aparser.parse_args() |
627 | 557 |
628 main( | 558 main( |
629 args.inputs, | 559 args.inputs, |
630 args.infile_estimator, | 560 args.infile_estimator, |
636 ref_seq=args.ref_seq, | 566 ref_seq=args.ref_seq, |
637 intervals=args.intervals, | 567 intervals=args.intervals, |
638 targets=args.targets, | 568 targets=args.targets, |
639 fasta_path=args.fasta_path, | 569 fasta_path=args.fasta_path, |
640 model_config=args.model_config, | 570 model_config=args.model_config, |
641 true_labels=args.true_labels, | |
642 predicted_labels=args.predicted_labels, | |
643 plot_color=args.plot_color, | |
644 title=args.title, | |
645 ) | 571 ) |