comparison ml_visualization_ex.py @ 17:980bf31faa05 draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 3c1e6c72303cfd8a5fd014734f18402b97f8ecb5
author bgruening
date Fri, 22 Sep 2023 17:36:36 +0000
parents 3bb1b688b0e4
children
comparison
equal deleted inserted replaced
16:948eca2af157 17:980bf31faa05
13 from galaxy_ml.utils import read_columns, SafeEval 13 from galaxy_ml.utils import read_columns, SafeEval
14 from sklearn.feature_selection._base import SelectorMixin 14 from sklearn.feature_selection._base import SelectorMixin
15 from sklearn.metrics import ( 15 from sklearn.metrics import (
16 auc, 16 auc,
17 average_precision_score, 17 average_precision_score,
18 confusion_matrix,
18 precision_recall_curve, 19 precision_recall_curve,
19 roc_curve, 20 roc_curve,
20 ) 21 )
21 from sklearn.pipeline import Pipeline 22 from sklearn.pipeline import Pipeline
22 from tensorflow.keras.models import model_from_json 23 from tensorflow.keras.models import model_from_json
256 folder = os.getcwd() 257 folder = os.getcwd()
257 plt.savefig(os.path.join(folder, "output.svg"), format="svg") 258 plt.savefig(os.path.join(folder, "output.svg"), format="svg")
258 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) 259 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output"))
259 260
260 261
262 def get_dataframe(file_path, plot_selection, header_name, column_name):
263 header = "infer" if plot_selection[header_name] else None
264 column_option = plot_selection[column_name]["selected_column_selector_option"]
265 if column_option in [
266 "by_index_number",
267 "all_but_by_index_number",
268 "by_header_name",
269 "all_but_by_header_name",
270 ]:
271 col = plot_selection[column_name]["col1"]
272 else:
273 col = None
274 _, input_df = read_columns(
275 file_path,
276 c=col,
277 c_option=column_option,
278 return_df=True,
279 sep="\t",
280 header=header,
281 parse_dates=True,
282 )
283 return input_df
284
285
261 def main( 286 def main(
262 inputs, 287 inputs,
263 infile_estimator=None, 288 infile_estimator=None,
264 infile1=None, 289 infile1=None,
265 infile2=None, 290 infile2=None,
269 ref_seq=None, 294 ref_seq=None,
270 intervals=None, 295 intervals=None,
271 targets=None, 296 targets=None,
272 fasta_path=None, 297 fasta_path=None,
273 model_config=None, 298 model_config=None,
299 true_labels=None,
300 predicted_labels=None,
301 plot_color=None,
302 title=None,
274 ): 303 ):
275 """ 304 """
276 Parameter 305 Parameter
277 --------- 306 ---------
278 inputs : str 307 inputs : str
309 fasta_path : str, default is None 338 fasta_path : str, default is None
310 File path to dataset containing fasta file 339 File path to dataset containing fasta file
311 340
312 model_config : str, default is None 341 model_config : str, default is None
313 File path to dataset containing JSON config for neural networks 342 File path to dataset containing JSON config for neural networks
343
344 true_labels : str, default is None
345 File path to dataset containing true labels
346
347 predicted_labels : str, default is None
348 File path to dataset containing true predicted labels
349
350 plot_color : str, default is None
351 Color of the confusion matrix heatmap
352
353 title : str, default is None
354 Title of the confusion matrix heatmap
314 """ 355 """
315 warnings.simplefilter("ignore") 356 warnings.simplefilter("ignore")
316 357
317 with open(inputs, "r") as param_handler: 358 with open(inputs, "r") as param_handler:
318 params = json.load(param_handler) 359 params = json.load(param_handler)
532 plot_model(model, to_file="output.png") 573 plot_model(model, to_file="output.png")
533 os.rename("output.png", "output") 574 os.rename("output.png", "output")
534 575
535 return 0 576 return 0
536 577
578 elif plot_type == "classification_confusion_matrix":
579 plot_selection = params["plotting_selection"]
580 input_true = get_dataframe(
581 true_labels, plot_selection, "header_true", "column_selector_options_true"
582 )
583 header_predicted = "infer" if plot_selection["header_predicted"] else None
584 input_predicted = pd.read_csv(
585 predicted_labels, sep="\t", parse_dates=True, header=header_predicted
586 )
587 true_classes = input_true.iloc[:, -1].copy()
588 predicted_classes = input_predicted.iloc[:, -1].copy()
589 axis_labels = list(set(true_classes))
590 c_matrix = confusion_matrix(true_classes, predicted_classes)
591 fig, ax = plt.subplots(figsize=(7, 7))
592 im = plt.imshow(c_matrix, cmap=plot_color)
593 for i in range(len(c_matrix)):
594 for j in range(len(c_matrix)):
595 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k")
596 ax.set_ylabel("True class labels")
597 ax.set_xlabel("Predicted class labels")
598 ax.set_title(title)
599 ax.set_xticks(axis_labels)
600 ax.set_yticks(axis_labels)
601 fig.colorbar(im, ax=ax)
602 fig.tight_layout()
603 plt.savefig("output.png", dpi=125)
604 os.rename("output.png", "output")
605
606 return 0
607
537 # save pdf file to disk 608 # save pdf file to disk
538 # fig.write_image("image.pdf", format='pdf') 609 # fig.write_image("image.pdf", format='pdf')
539 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) 610 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2)
540 611
541 612
551 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") 622 aparser.add_argument("-r", "--ref_seq", dest="ref_seq")
552 aparser.add_argument("-b", "--intervals", dest="intervals") 623 aparser.add_argument("-b", "--intervals", dest="intervals")
553 aparser.add_argument("-t", "--targets", dest="targets") 624 aparser.add_argument("-t", "--targets", dest="targets")
554 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") 625 aparser.add_argument("-f", "--fasta_path", dest="fasta_path")
555 aparser.add_argument("-c", "--model_config", dest="model_config") 626 aparser.add_argument("-c", "--model_config", dest="model_config")
627 aparser.add_argument("-tl", "--true_labels", dest="true_labels")
628 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels")
629 aparser.add_argument("-pc", "--plot_color", dest="plot_color")
630 aparser.add_argument("-pt", "--title", dest="title")
556 args = aparser.parse_args() 631 args = aparser.parse_args()
557 632
558 main( 633 main(
559 args.inputs, 634 args.inputs,
560 args.infile_estimator, 635 args.infile_estimator,
566 ref_seq=args.ref_seq, 641 ref_seq=args.ref_seq,
567 intervals=args.intervals, 642 intervals=args.intervals,
568 targets=args.targets, 643 targets=args.targets,
569 fasta_path=args.fasta_path, 644 fasta_path=args.fasta_path,
570 model_config=args.model_config, 645 model_config=args.model_config,
646 true_labels=args.true_labels,
647 predicted_labels=args.predicted_labels,
648 plot_color=args.plot_color,
649 title=args.title,
571 ) 650 )