Mercurial > repos > goeckslab > pycaret_predict
view pycaret_predict.py @ 1:4a7df9abe4c4 draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit 9497c4faca7063bcbb6b201ab6d0dd1570f22acb
author | goeckslab |
---|---|
date | Sat, 14 Dec 2024 23:17:48 +0000 |
parents | 1f20fe57fdee |
children |
line wrap: on
line source
import argparse import logging import tempfile import h5py import joblib import pandas as pd from pycaret.classification import ClassificationExperiment from pycaret.regression import RegressionExperiment from sklearn.metrics import average_precision_score from utils import encode_image_to_base64, get_html_closing, get_html_template LOG = logging.getLogger(__name__) class PyCaretModelEvaluator: def __init__(self, model_path, task, target): self.model_path = model_path self.task = task.lower() self.model = self.load_h5_model() self.target = target if target != "None" else None def load_h5_model(self): """Load a PyCaret model from an HDF5 file.""" with h5py.File(self.model_path, 'r') as f: model_bytes = bytes(f['model'][()]) with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(model_bytes) temp_file.seek(0) loaded_model = joblib.load(temp_file.name) return loaded_model def evaluate(self, data_path): """Evaluate the model using the specified data.""" raise NotImplementedError("Subclasses must implement this method") class ClassificationEvaluator(PyCaretModelEvaluator): def evaluate(self, data_path): metrics = None plot_paths = {} data = pd.read_csv(data_path, engine='python', sep=None) if self.target: exp = ClassificationExperiment() names = data.columns.to_list() LOG.error(f"Column names: {names}") target_index = int(self.target)-1 target_name = names[target_index] exp.setup(data, target=target_name, test_data=data, index=False) exp.add_metric(id='PR-AUC-Weighted', name='PR-AUC-Weighted', target='pred_proba', score_func=average_precision_score, average='weighted') predictions = exp.predict_model(self.model) metrics = exp.pull() plots = ['confusion_matrix', 'auc', 'threshold', 'pr', 'error', 'class_report', 'learning', 'calibration', 'vc', 'dimension', 'manifold', 'rfe', 'feature', 'feature_all'] for plot_name in plots: try: if plot_name == 'auc' and not exp.is_multiclass: plot_path = exp.plot_model(self.model, plot=plot_name, save=True, plot_kwargs={ 'micro': False, 'macro': False, 'per_class': False, 'binary': True }) plot_paths[plot_name] = plot_path continue plot_path = exp.plot_model(self.model, plot=plot_name, save=True) plot_paths[plot_name] = plot_path except Exception as e: LOG.error(f"Error generating plot {plot_name}: {e}") continue generate_html_report(plot_paths, metrics) else: exp = ClassificationExperiment() exp.setup(data, target=None, test_data=data, index=False) predictions = exp.predict_model(self.model, data=data) return predictions, metrics, plot_paths class RegressionEvaluator(PyCaretModelEvaluator): def evaluate(self, data_path): metrics = None plot_paths = {} data = pd.read_csv(data_path, engine='python', sep=None) if self.target: names = data.columns.to_list() target_index = int(self.target)-1 target_name = names[target_index] exp = RegressionExperiment() exp.setup(data, target=target_name, test_data=data, index=False) predictions = exp.predict_model(self.model) metrics = exp.pull() plots = ['residuals', 'error', 'cooks', 'learning', 'vc', 'manifold', 'rfe', 'feature', 'feature_all'] for plot_name in plots: try: plot_path = exp.plot_model(self.model, plot=plot_name, save=True) plot_paths[plot_name] = plot_path except Exception as e: LOG.error(f"Error generating plot {plot_name}: {e}") continue generate_html_report(plot_paths, metrics) else: exp = RegressionExperiment() exp.setup(data, target=None, test_data=data, index=False) predictions = exp.predict_model(self.model, data=data) return predictions, metrics, plot_paths def generate_html_report(plots, metrics): """Generate an HTML evaluation report.""" plots_html = "" for plot_name, plot_path in plots.items(): encoded_image = encode_image_to_base64(plot_path) plots_html += f""" <div class="plot"> <h3>{plot_name.capitalize()}</h3> <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> </div> <hr> """ metrics_html = metrics.to_html(index=False, classes="table") html_content = f""" {get_html_template()} <h1>Model Evaluation Report</h1> <div class="tabs"> <div class="tab" onclick="openTab(event, 'metrics')">Metrics</div> <div class="tab" onclick="openTab(event, 'plots')">Plots</div> </div> <div id="metrics" class="tab-content"> <h2>Metrics</h2> <table> {metrics_html} </table> </div> <div id="plots" class="tab-content"> <h2>Plots</h2> {plots_html} </div> {get_html_closing()} """ # Save HTML report with open("evaluation_report.html", "w") as html_file: html_file.write(html_content) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Evaluate a PyCaret model stored in HDF5 format.") parser.add_argument("--model_path", type=str, help="Path to the HDF5 model file.") parser.add_argument("--data_path", type=str, help="Path to the evaluation data CSV file.") parser.add_argument("--task", type=str, choices=["classification", "regression"], help="Specify the task: classification or regression.") parser.add_argument("--target", default=None, help="Column number of the target") args = parser.parse_args() if args.task == "classification": evaluator = ClassificationEvaluator( args.model_path, args.task, args.target) elif args.task == "regression": evaluator = RegressionEvaluator( args.model_path, args.task, args.target) else: raise ValueError( "Unsupported task type. Use 'classification' or 'regression'.") predictions, metrics, plots = evaluator.evaluate(args.data_path) predictions.to_csv("predictions.csv", index=False)