Mercurial > repos > bgruening > model_prediction
comparison ml_visualization_ex.py @ 17:980bf31faa05 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 3c1e6c72303cfd8a5fd014734f18402b97f8ecb5
author | bgruening |
---|---|
date | Fri, 22 Sep 2023 17:36:36 +0000 |
parents | 3bb1b688b0e4 |
children |
comparison
equal
deleted
inserted
replaced
16:948eca2af157 | 17:980bf31faa05 |
---|---|
13 from galaxy_ml.utils import read_columns, SafeEval | 13 from galaxy_ml.utils import read_columns, SafeEval |
14 from sklearn.feature_selection._base import SelectorMixin | 14 from sklearn.feature_selection._base import SelectorMixin |
15 from sklearn.metrics import ( | 15 from sklearn.metrics import ( |
16 auc, | 16 auc, |
17 average_precision_score, | 17 average_precision_score, |
18 confusion_matrix, | |
18 precision_recall_curve, | 19 precision_recall_curve, |
19 roc_curve, | 20 roc_curve, |
20 ) | 21 ) |
21 from sklearn.pipeline import Pipeline | 22 from sklearn.pipeline import Pipeline |
22 from tensorflow.keras.models import model_from_json | 23 from tensorflow.keras.models import model_from_json |
256 folder = os.getcwd() | 257 folder = os.getcwd() |
257 plt.savefig(os.path.join(folder, "output.svg"), format="svg") | 258 plt.savefig(os.path.join(folder, "output.svg"), format="svg") |
258 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) | 259 os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) |
259 | 260 |
260 | 261 |
262 def get_dataframe(file_path, plot_selection, header_name, column_name): | |
263 header = "infer" if plot_selection[header_name] else None | |
264 column_option = plot_selection[column_name]["selected_column_selector_option"] | |
265 if column_option in [ | |
266 "by_index_number", | |
267 "all_but_by_index_number", | |
268 "by_header_name", | |
269 "all_but_by_header_name", | |
270 ]: | |
271 col = plot_selection[column_name]["col1"] | |
272 else: | |
273 col = None | |
274 _, input_df = read_columns( | |
275 file_path, | |
276 c=col, | |
277 c_option=column_option, | |
278 return_df=True, | |
279 sep="\t", | |
280 header=header, | |
281 parse_dates=True, | |
282 ) | |
283 return input_df | |
284 | |
285 | |
261 def main( | 286 def main( |
262 inputs, | 287 inputs, |
263 infile_estimator=None, | 288 infile_estimator=None, |
264 infile1=None, | 289 infile1=None, |
265 infile2=None, | 290 infile2=None, |
269 ref_seq=None, | 294 ref_seq=None, |
270 intervals=None, | 295 intervals=None, |
271 targets=None, | 296 targets=None, |
272 fasta_path=None, | 297 fasta_path=None, |
273 model_config=None, | 298 model_config=None, |
299 true_labels=None, | |
300 predicted_labels=None, | |
301 plot_color=None, | |
302 title=None, | |
274 ): | 303 ): |
275 """ | 304 """ |
276 Parameter | 305 Parameter |
277 --------- | 306 --------- |
278 inputs : str | 307 inputs : str |
309 fasta_path : str, default is None | 338 fasta_path : str, default is None |
310 File path to dataset containing fasta file | 339 File path to dataset containing fasta file |
311 | 340 |
312 model_config : str, default is None | 341 model_config : str, default is None |
313 File path to dataset containing JSON config for neural networks | 342 File path to dataset containing JSON config for neural networks |
343 | |
344 true_labels : str, default is None | |
345 File path to dataset containing true labels | |
346 | |
347 predicted_labels : str, default is None | |
348 File path to dataset containing true predicted labels | |
349 | |
350 plot_color : str, default is None | |
351 Color of the confusion matrix heatmap | |
352 | |
353 title : str, default is None | |
354 Title of the confusion matrix heatmap | |
314 """ | 355 """ |
315 warnings.simplefilter("ignore") | 356 warnings.simplefilter("ignore") |
316 | 357 |
317 with open(inputs, "r") as param_handler: | 358 with open(inputs, "r") as param_handler: |
318 params = json.load(param_handler) | 359 params = json.load(param_handler) |
532 plot_model(model, to_file="output.png") | 573 plot_model(model, to_file="output.png") |
533 os.rename("output.png", "output") | 574 os.rename("output.png", "output") |
534 | 575 |
535 return 0 | 576 return 0 |
536 | 577 |
578 elif plot_type == "classification_confusion_matrix": | |
579 plot_selection = params["plotting_selection"] | |
580 input_true = get_dataframe( | |
581 true_labels, plot_selection, "header_true", "column_selector_options_true" | |
582 ) | |
583 header_predicted = "infer" if plot_selection["header_predicted"] else None | |
584 input_predicted = pd.read_csv( | |
585 predicted_labels, sep="\t", parse_dates=True, header=header_predicted | |
586 ) | |
587 true_classes = input_true.iloc[:, -1].copy() | |
588 predicted_classes = input_predicted.iloc[:, -1].copy() | |
589 axis_labels = list(set(true_classes)) | |
590 c_matrix = confusion_matrix(true_classes, predicted_classes) | |
591 fig, ax = plt.subplots(figsize=(7, 7)) | |
592 im = plt.imshow(c_matrix, cmap=plot_color) | |
593 for i in range(len(c_matrix)): | |
594 for j in range(len(c_matrix)): | |
595 ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") | |
596 ax.set_ylabel("True class labels") | |
597 ax.set_xlabel("Predicted class labels") | |
598 ax.set_title(title) | |
599 ax.set_xticks(axis_labels) | |
600 ax.set_yticks(axis_labels) | |
601 fig.colorbar(im, ax=ax) | |
602 fig.tight_layout() | |
603 plt.savefig("output.png", dpi=125) | |
604 os.rename("output.png", "output") | |
605 | |
606 return 0 | |
607 | |
537 # save pdf file to disk | 608 # save pdf file to disk |
538 # fig.write_image("image.pdf", format='pdf') | 609 # fig.write_image("image.pdf", format='pdf') |
539 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) | 610 # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) |
540 | 611 |
541 | 612 |
551 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") | 622 aparser.add_argument("-r", "--ref_seq", dest="ref_seq") |
552 aparser.add_argument("-b", "--intervals", dest="intervals") | 623 aparser.add_argument("-b", "--intervals", dest="intervals") |
553 aparser.add_argument("-t", "--targets", dest="targets") | 624 aparser.add_argument("-t", "--targets", dest="targets") |
554 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") | 625 aparser.add_argument("-f", "--fasta_path", dest="fasta_path") |
555 aparser.add_argument("-c", "--model_config", dest="model_config") | 626 aparser.add_argument("-c", "--model_config", dest="model_config") |
627 aparser.add_argument("-tl", "--true_labels", dest="true_labels") | |
628 aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") | |
629 aparser.add_argument("-pc", "--plot_color", dest="plot_color") | |
630 aparser.add_argument("-pt", "--title", dest="title") | |
556 args = aparser.parse_args() | 631 args = aparser.parse_args() |
557 | 632 |
558 main( | 633 main( |
559 args.inputs, | 634 args.inputs, |
560 args.infile_estimator, | 635 args.infile_estimator, |
566 ref_seq=args.ref_seq, | 641 ref_seq=args.ref_seq, |
567 intervals=args.intervals, | 642 intervals=args.intervals, |
568 targets=args.targets, | 643 targets=args.targets, |
569 fasta_path=args.fasta_path, | 644 fasta_path=args.fasta_path, |
570 model_config=args.model_config, | 645 model_config=args.model_config, |
646 true_labels=args.true_labels, | |
647 predicted_labels=args.predicted_labels, | |
648 plot_color=args.plot_color, | |
649 title=args.title, | |
571 ) | 650 ) |