Mercurial > repos > bgruening > sklearn_nn_classifier
diff ml_visualization_ex.py @ 27:22f0b9db4ea1 draft
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit 9981e25b00de29ed881b2229a173a8c812ded9bb
author | bgruening |
---|---|
date | Wed, 09 Aug 2023 12:57:05 +0000 |
parents | 823ecc0bce45 |
children | 4db563e8d453 |
line wrap: on
line diff
--- a/ml_visualization_ex.py Thu Aug 11 09:54:23 2022 +0000 +++ b/ml_visualization_ex.py Wed Aug 09 12:57:05 2023 +0000 @@ -9,13 +9,18 @@ import pandas as pd import plotly import plotly.graph_objs as go -from galaxy_ml.utils import load_model, read_columns, SafeEval -from keras.models import model_from_json -from keras.utils import plot_model -from sklearn.feature_selection.base import SelectorMixin -from sklearn.metrics import (auc, average_precision_score, confusion_matrix, - precision_recall_curve, roc_curve) +from galaxy_ml.model_persist import load_model_from_h5 +from galaxy_ml.utils import read_columns, SafeEval +from sklearn.feature_selection._base import SelectorMixin +from sklearn.metrics import ( + auc, + average_precision_score, + precision_recall_curve, + roc_curve, +) from sklearn.pipeline import Pipeline +from tensorflow.keras.models import model_from_json +from tensorflow.keras.utils import plot_model safe_eval = SafeEval() @@ -253,30 +258,6 @@ os.rename(os.path.join(folder, "output.svg"), os.path.join(folder, "output")) -def get_dataframe(file_path, plot_selection, header_name, column_name): - header = "infer" if plot_selection[header_name] else None - column_option = plot_selection[column_name]["selected_column_selector_option"] - if column_option in [ - "by_index_number", - "all_but_by_index_number", - "by_header_name", - "all_but_by_header_name", - ]: - col = plot_selection[column_name]["col1"] - else: - col = None - _, input_df = read_columns( - file_path, - c=col, - c_option=column_option, - return_df=True, - sep="\t", - header=header, - parse_dates=True, - ) - return input_df - - def main( inputs, infile_estimator=None, @@ -290,10 +271,6 @@ targets=None, fasta_path=None, model_config=None, - true_labels=None, - predicted_labels=None, - plot_color=None, - title=None, ): """ Parameter @@ -334,18 +311,6 @@ model_config : str, default is None File path to dataset containing JSON config for neural networks - - true_labels : str, default is None - File path to dataset containing true labels - - predicted_labels : str, default is None - File path to dataset containing true predicted labels - - plot_color : str, default is None - Color of the confusion matrix heatmap - - title : str, default is None - Title of the confusion matrix heatmap """ warnings.simplefilter("ignore") @@ -357,8 +322,7 @@ plot_format = params["plotting_selection"]["plot_format"] if plot_type == "feature_importances": - with open(infile_estimator, "rb") as estimator_handler: - estimator = load_model(estimator_handler) + estimator = load_model_from_h5(infile_estimator) column_option = params["plotting_selection"]["column_selector_options"][ "selected_column_selector_option" @@ -570,36 +534,6 @@ return 0 - elif plot_type == "classification_confusion_matrix": - plot_selection = params["plotting_selection"] - input_true = get_dataframe( - true_labels, plot_selection, "header_true", "column_selector_options_true" - ) - header_predicted = "infer" if plot_selection["header_predicted"] else None - input_predicted = pd.read_csv( - predicted_labels, sep="\t", parse_dates=True, header=header_predicted - ) - true_classes = input_true.iloc[:, -1].copy() - predicted_classes = input_predicted.iloc[:, -1].copy() - axis_labels = list(set(true_classes)) - c_matrix = confusion_matrix(true_classes, predicted_classes) - fig, ax = plt.subplots(figsize=(7, 7)) - im = plt.imshow(c_matrix, cmap=plot_color) - for i in range(len(c_matrix)): - for j in range(len(c_matrix)): - ax.text(j, i, c_matrix[i, j], ha="center", va="center", color="k") - ax.set_ylabel("True class labels") - ax.set_xlabel("Predicted class labels") - ax.set_title(title) - ax.set_xticks(axis_labels) - ax.set_yticks(axis_labels) - fig.colorbar(im, ax=ax) - fig.tight_layout() - plt.savefig("output.png", dpi=125) - os.rename("output.png", "output") - - return 0 - # save pdf file to disk # fig.write_image("image.pdf", format='pdf') # fig.write_image("image.pdf", format='pdf', width=340*2, height=226*2) @@ -619,10 +553,6 @@ aparser.add_argument("-t", "--targets", dest="targets") aparser.add_argument("-f", "--fasta_path", dest="fasta_path") aparser.add_argument("-c", "--model_config", dest="model_config") - aparser.add_argument("-tl", "--true_labels", dest="true_labels") - aparser.add_argument("-pl", "--predicted_labels", dest="predicted_labels") - aparser.add_argument("-pc", "--plot_color", dest="plot_color") - aparser.add_argument("-pt", "--title", dest="title") args = aparser.parse_args() main( @@ -638,8 +568,4 @@ targets=args.targets, fasta_path=args.fasta_path, model_config=args.model_config, - true_labels=args.true_labels, - predicted_labels=args.predicted_labels, - plot_color=args.plot_color, - title=args.title, )