Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_predict.py @ 0:1f20fe57fdee draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
| author | goeckslab |
|---|---|
| date | Wed, 11 Dec 2024 04:59:43 +0000 |
| parents | |
| children | ccd798db5abb |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:1f20fe57fdee |
|---|---|
| 1 import argparse | |
| 2 import logging | |
| 3 import tempfile | |
| 4 | |
| 5 import h5py | |
| 6 | |
| 7 import joblib | |
| 8 | |
| 9 import pandas as pd | |
| 10 | |
| 11 from pycaret.classification import ClassificationExperiment | |
| 12 from pycaret.regression import RegressionExperiment | |
| 13 | |
| 14 from sklearn.metrics import average_precision_score | |
| 15 | |
| 16 from utils import encode_image_to_base64, get_html_closing, get_html_template | |
| 17 | |
| 18 LOG = logging.getLogger(__name__) | |
| 19 | |
| 20 | |
| 21 class PyCaretModelEvaluator: | |
| 22 def __init__(self, model_path, task, target): | |
| 23 self.model_path = model_path | |
| 24 self.task = task.lower() | |
| 25 self.model = self.load_h5_model() | |
| 26 self.target = target if target != "None" else None | |
| 27 | |
| 28 def load_h5_model(self): | |
| 29 """Load a PyCaret model from an HDF5 file.""" | |
| 30 with h5py.File(self.model_path, 'r') as f: | |
| 31 model_bytes = bytes(f['model'][()]) | |
| 32 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
| 33 temp_file.write(model_bytes) | |
| 34 temp_file.seek(0) | |
| 35 loaded_model = joblib.load(temp_file.name) | |
| 36 return loaded_model | |
| 37 | |
| 38 def evaluate(self, data_path): | |
| 39 """Evaluate the model using the specified data.""" | |
| 40 raise NotImplementedError("Subclasses must implement this method") | |
| 41 | |
| 42 | |
| 43 class ClassificationEvaluator(PyCaretModelEvaluator): | |
| 44 def evaluate(self, data_path): | |
| 45 metrics = None | |
| 46 plot_paths = {} | |
| 47 data = pd.read_csv(data_path, engine='python', sep=None) | |
| 48 if self.target: | |
| 49 exp = ClassificationExperiment() | |
| 50 names = data.columns.to_list() | |
| 51 LOG.error(f"Column names: {names}") | |
| 52 target_index = int(self.target)-1 | |
| 53 target_name = names[target_index] | |
| 54 exp.setup(data, target=target_name, test_data=data, index=False) | |
| 55 exp.add_metric(id='PR-AUC-Weighted', | |
| 56 name='PR-AUC-Weighted', | |
| 57 target='pred_proba', | |
| 58 score_func=average_precision_score, | |
| 59 average='weighted') | |
| 60 predictions = exp.predict_model(self.model) | |
| 61 metrics = exp.pull() | |
| 62 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | |
| 63 'error', 'class_report', 'learning', 'calibration', | |
| 64 'vc', 'dimension', 'manifold', 'rfe', 'feature', | |
| 65 'feature_all'] | |
| 66 for plot_name in plots: | |
| 67 try: | |
| 68 if plot_name == 'auc' and not exp.is_multiclass: | |
| 69 plot_path = exp.plot_model(self.model, | |
| 70 plot=plot_name, | |
| 71 save=True, | |
| 72 plot_kwargs={ | |
| 73 'micro': False, | |
| 74 'macro': False, | |
| 75 'per_class': False, | |
| 76 'binary': True | |
| 77 }) | |
| 78 plot_paths[plot_name] = plot_path | |
| 79 continue | |
| 80 | |
| 81 plot_path = exp.plot_model(self.model, | |
| 82 plot=plot_name, save=True) | |
| 83 plot_paths[plot_name] = plot_path | |
| 84 except Exception as e: | |
| 85 LOG.error(f"Error generating plot {plot_name}: {e}") | |
| 86 continue | |
| 87 generate_html_report(plot_paths, metrics) | |
| 88 | |
| 89 else: | |
| 90 exp = ClassificationExperiment() | |
| 91 exp.setup(data, target=None, test_data=data, index=False) | |
| 92 predictions = exp.predict_model(self.model, data=data) | |
| 93 | |
| 94 return predictions, metrics, plot_paths | |
| 95 | |
| 96 | |
| 97 class RegressionEvaluator(PyCaretModelEvaluator): | |
| 98 def evaluate(self, data_path): | |
| 99 metrics = None | |
| 100 plot_paths = {} | |
| 101 data = pd.read_csv(data_path, engine='python', sep=None) | |
| 102 if self.target: | |
| 103 names = data.columns.to_list() | |
| 104 target_index = int(self.target)-1 | |
| 105 target_name = names[target_index] | |
| 106 exp = RegressionExperiment() | |
| 107 exp.setup(data, target=target_name, test_data=data, index=False) | |
| 108 predictions = exp.predict_model(self.model) | |
| 109 metrics = exp.pull() | |
| 110 plots = ['residuals', 'error', 'cooks', | |
| 111 'learning', 'vc', 'manifold', | |
| 112 'rfe', 'feature', 'feature_all'] | |
| 113 for plot_name in plots: | |
| 114 try: | |
| 115 plot_path = exp.plot_model(self.model, | |
| 116 plot=plot_name, save=True) | |
| 117 plot_paths[plot_name] = plot_path | |
| 118 except Exception as e: | |
| 119 LOG.error(f"Error generating plot {plot_name}: {e}") | |
| 120 continue | |
| 121 generate_html_report(plot_paths, metrics) | |
| 122 else: | |
| 123 exp = RegressionExperiment() | |
| 124 exp.setup(data, target=None, test_data=data, index=False) | |
| 125 predictions = exp.predict_model(self.model, data=data) | |
| 126 | |
| 127 return predictions, metrics, plot_paths | |
| 128 | |
| 129 | |
| 130 def generate_html_report(plots, metrics): | |
| 131 """Generate an HTML evaluation report.""" | |
| 132 plots_html = "" | |
| 133 for plot_name, plot_path in plots.items(): | |
| 134 encoded_image = encode_image_to_base64(plot_path) | |
| 135 plots_html += f""" | |
| 136 <div class="plot"> | |
| 137 <h3>{plot_name.capitalize()}</h3> | |
| 138 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> | |
| 139 </div> | |
| 140 <hr> | |
| 141 """ | |
| 142 | |
| 143 metrics_html = metrics.to_html(index=False, classes="table") | |
| 144 | |
| 145 html_content = f""" | |
| 146 {get_html_template()} | |
| 147 <h1>Model Evaluation Report</h1> | |
| 148 <div class="tabs"> | |
| 149 <div class="tab" onclick="openTab(event, 'metrics')">Metrics</div> | |
| 150 <div class="tab" onclick="openTab(event, 'plots')">Plots</div> | |
| 151 </div> | |
| 152 <div id="metrics" class="tab-content"> | |
| 153 <h2>Metrics</h2> | |
| 154 <table> | |
| 155 {metrics_html} | |
| 156 </table> | |
| 157 </div> | |
| 158 <div id="plots" class="tab-content"> | |
| 159 <h2>Plots</h2> | |
| 160 {plots_html} | |
| 161 </div> | |
| 162 {get_html_closing()} | |
| 163 """ | |
| 164 | |
| 165 # Save HTML report | |
| 166 with open("evaluation_report.html", "w") as html_file: | |
| 167 html_file.write(html_content) | |
| 168 | |
| 169 | |
| 170 if __name__ == "__main__": | |
| 171 parser = argparse.ArgumentParser( | |
| 172 description="Evaluate a PyCaret model stored in HDF5 format.") | |
| 173 parser.add_argument("--model_path", | |
| 174 type=str, | |
| 175 help="Path to the HDF5 model file.") | |
| 176 parser.add_argument("--data_path", | |
| 177 type=str, | |
| 178 help="Path to the evaluation data CSV file.") | |
| 179 parser.add_argument("--task", | |
| 180 type=str, | |
| 181 choices=["classification", "regression"], | |
| 182 help="Specify the task: classification or regression.") | |
| 183 parser.add_argument("--target", | |
| 184 default=None, | |
| 185 help="Column number of the target") | |
| 186 args = parser.parse_args() | |
| 187 | |
| 188 if args.task == "classification": | |
| 189 evaluator = ClassificationEvaluator( | |
| 190 args.model_path, args.task, args.target) | |
| 191 elif args.task == "regression": | |
| 192 evaluator = RegressionEvaluator( | |
| 193 args.model_path, args.task, args.target) | |
| 194 else: | |
| 195 raise ValueError( | |
| 196 "Unsupported task type. Use 'classification' or 'regression'.") | |
| 197 | |
| 198 predictions, metrics, plots = evaluator.evaluate(args.data_path) | |
| 199 | |
| 200 predictions.to_csv("predictions.csv", index=False) |
