comparison ml_visualization_ex.py @ 15:2eb5c017958d draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author bgruening
date Wed, 09 Aug 2023 13:15:27 +0000
parents caf7d2b71a48
children adf9f8d5ab9a
comparison
equal deleted inserted replaced
14:4d1637cac794 15:2eb5c017958d
7 import matplotlib.pyplot as plt 7 import matplotlib.pyplot as plt
8 import numpy as np 8 import numpy as np
9 import pandas as pd 9 import pandas as pd
10 import plotly 10 import plotly
11 import plotly.graph_objs as go 11 import plotly.graph_objs as go
12 from galaxy_ml.utils import load_model, read_columns, SafeEval 12 from galaxy_ml.model_persist import load_model_from_h5
13 from keras.models import model_from_json 13 from galaxy_ml.utils import read_columns, SafeEval
14 from keras.utils import plot_model 14 from sklearn.feature_selection._base import SelectorMixin
15 from sklearn.feature_selection.base import SelectorMixin 15 from sklearn.metrics import (
16 from sklearn.metrics import (auc, average_precision_score, confusion_matrix, 16 auc,
17 precision_recall_curve, roc_curve) 17 average_precision_score,
18 precision_recall_curve,
19 roc_curve,
20 )
18 from sklearn.pipeline import Pipeline 21 from sklearn.pipeline import Pipeline
22 from tensorflow.keras.models import model_from_json
23 from tensorflow.keras.utils import plot_model
19 24
20 safe_eval = SafeEval() 25 safe_eval = SafeEval()
21 26
22 # plotly default colors 27 # plotly default colors
23 default_colors = [ 28 default_colors = [
251 folder = os.getcwd() 256 folder = os.getcwd()
252 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 257 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
253 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) 258 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
254 259
255 260
256 def get_dataframe(file_path, plot_selection, header_name, column_name):
257 header = "infer" if plot_selection[header_name] else None
258 column_option = plot_selection[column_name]["selected_column_selector_option"]
259 if column_option in [
260 "by_index_number",
261 "all_but_by_index_number",
262 "by_header_name",
263 "all_but_by_header_name",
264 ]:
265 col = plot_selection[column_name]["col1"]
266 else:
267 col = None
268 _, input_df = read_columns(
269 file_path,
270 c=col,
271 c_option=column_option,
272 return_df=True,
273 sep="\t",
274 header=header,
275 parse_dates=True,
276 )
277 return input_df
278
279
280 def main( 261 def main(
281 inputs, 262 inputs,
282 infile_estimator=None, 263 infile_estimator=None,
283 infile1=None, 264 infile1=None,
284 infile2=None, 265 infile2=None,
288 ref_seq=None, 269 ref_seq=None,
289 intervals=None, 270 intervals=None,
290 targets=None, 271 targets=None,
291 fasta_path=None, 272 fasta_path=None,
292 model_config=None, 273 model_config=None,
293 true_labels=None,
294 predicted_labels=None,
295 plot_color=None,
296 title=None,
297 ): 274 ):
298 """ 275 """
299 Parameter 276 Parameter
300 --------- 277 ---------
301 inputs : str 278 inputs : str
332 fasta_path : str, default is None 309 fasta_path : str, default is None
333 File path to dataset containing fasta file 310 File path to dataset containing fasta file
334 311
335 model_config : str, default is None 312 model_config : str, default is None
336 File path to dataset containing JSON config for neural networks 313 File path to dataset containing JSON config for neural networks
337
338 true_labels : str, default is None
339 File path to dataset containing true labels
340
341 predicted_labels : str, default is None
342 File path to dataset containing true predicted labels
343
344 plot_color : str, default is None
345 Color of the confusion matrix heatmap
346
347 title : str, default is None
348 Title of the confusion matrix heatmap
349 """ 314 """
350 warnings.simplefilter("ignore") 315 warnings.simplefilter("ignore")
351 316
352 with open(inputs, "r") as param_handler: 317 with open(inputs, "r") as param_handler:
353 params = json.load(param_handler) 318 params = json.load(param_handler)
355 title = params["plotting_selection"]["title"].strip() 320 title = params["plotting_selection"]["title"].strip()
356 plot_type = params["plotting_selection"]["plot_type"] 321 plot_type = params["plotting_selection"]["plot_type"]
357 plot_format = params["plotting_selection"]["plot_format"] 322 plot_format = params["plotting_selection"]["plot_format"]
358 323
359 if plot_type == "feature_importances": 324 if plot_type == "feature_importances":
360 with open(infile_estimator, "rb") as estimator_handler: 325 estimator = load_model_from_h5(infile_estimator)
361 estimator = load_model(estimator_handler)
362 326
363 column_option = params["plotting_selection"]["column_selector_options"][ 327 column_option = params["plotting_selection"]["column_selector_options"][
364 "selected_column_selector_option" 328 "selected_column_selector_option"
365 ] 329 ]
366 if column_option in [ 330 if column_option in [
568 plot_model(model, to_file="output.png") 532 plot_model(model, to_file="output.png")
569 os.rename("output.png", "output") 533 os.rename("output.png", "output")
570 534
571 return 0 535 return 0
572 536
573 elif plot_type == "classification_confusion_matrix":
574 plot_selection = params["plotting_selection"]
575 input_true = get_dataframe(
576 true_labels, plot_selection, "header_true", "column_selector_options_true"
577 )
578 header_predicted = "infer" if plot_selection["header_predicted"] else None
579 input_predicted = pd.read_csv(
580 predicted_labels, sep="\t", parse_dates=True, header=header_predicted
581 )
582 true_classes = input_true.iloc[:, -1].copy()
583 predicted_classes = input_predicted.iloc[:, -1].copy()
584 axis_labels = list(set(true_classes))
585 c_matrix = confusion_matrix(true_classes, predicted_classes)
586 fig, ax = plt.subplots(figsize=(7, 7))
587 im = plt.imshow(c_matrix, cmap=plot_color)
588 for i in range(len(c_matrix)):
589 for j in range(len(c_matrix)):
590 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
591 ax.set_ylabel("True class labels")
592 ax.set_xlabel("Predicted class labels")
593 ax.set_title(title)
594 ax.set_xticks(axis_labels)
595 ax.set_yticks(axis_labels)
596 fig.colorbar(im, ax=ax)
597 fig.tight_layout()
598 plt.savefig("output.png", dpi=125)
599 os.rename("output.png", "output")
600
601 return 0
602
603 # save pdf file to disk 537 # save pdf file to disk
604 # fig.write_image("image.pdf", format='pdf') 538 # fig.write_image("image.pdf", format='pdf')
605 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) 539 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
606 540
607 541
617 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 551 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
618 aparser.add_argument("-b", "--intervals", dest="intervals") 552 aparser.add_argument("-b", "--intervals", dest="intervals")
619 aparser.add_argument("-t", "--targets", dest="targets") 553 aparser.add_argument("-t", "--targets", dest="targets")
620 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 554 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
621 aparser.add_argument("-c", "--model_config", dest="model_config") 555 aparser.add_argument("-c", "--model_config", dest="model_config")
622 aparser.add_argument("-tl", "--true_labels", dest="true_labels")
623 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
624 aparser.add_argument("-pc", "--plot_color", dest="plot_color")
625 aparser.add_argument("-pt", "--title", dest="title")
626 args = aparser.parse_args() 556 args = aparser.parse_args()
627 557
628 main( 558 main(
629 args.inputs, 559 args.inputs,
630 args.infile_estimator, 560 args.infile_estimator,
636 ref_seq=args.ref_seq, 566 ref_seq=args.ref_seq,
637 intervals=args.intervals, 567 intervals=args.intervals,
638 targets=args.targets, 568 targets=args.targets,
639 fasta_path=args.fasta_path, 569 fasta_path=args.fasta_path,
640 model_config=args.model_config, 570 model_config=args.model_config,
641 true_labels=args.true_labels,
642 predicted_labels=args.predicted_labels,
643 plot_color=args.plot_color,
644 title=args.title,
645 ) 571 )