Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_predict.py @ 0:1f20fe57fdee draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author | goeckslab |
---|---|
date | Wed, 11 Dec 2024 04:59:43 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:1f20fe57fdee |
---|---|
1 import argparse | |
2 import logging | |
3 import tempfile | |
4 | |
5 import h5py | |
6 | |
7 import joblib | |
8 | |
9 import pandas as pd | |
10 | |
11 from pycaret.classification import ClassificationExperiment | |
12 from pycaret.regression import RegressionExperiment | |
13 | |
14 from sklearn.metrics import average_precision_score | |
15 | |
16 from utils import encode_image_to_base64, get_html_closing, get_html_template | |
17 | |
18 LOG = logging.getLogger(__name__) | |
19 | |
20 | |
21 class PyCaretModelEvaluator: | |
22 def __init__(self, model_path, task, target): | |
23 self.model_path = model_path | |
24 self.task = task.lower() | |
25 self.model = self.load_h5_model() | |
26 self.target = target if target != "None" else None | |
27 | |
28 def load_h5_model(self): | |
29 """Load a PyCaret model from an HDF5 file.""" | |
30 with h5py.File(self.model_path, 'r') as f: | |
31 model_bytes = bytes(f['model'][()]) | |
32 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
33 temp_file.write(model_bytes) | |
34 temp_file.seek(0) | |
35 loaded_model = joblib.load(temp_file.name) | |
36 return loaded_model | |
37 | |
38 def evaluate(self, data_path): | |
39 """Evaluate the model using the specified data.""" | |
40 raise NotImplementedError("Subclasses must implement this method") | |
41 | |
42 | |
43 class ClassificationEvaluator(PyCaretModelEvaluator): | |
44 def evaluate(self, data_path): | |
45 metrics = None | |
46 plot_paths = {} | |
47 data = pd.read_csv(data_path, engine='python', sep=None) | |
48 if self.target: | |
49 exp = ClassificationExperiment() | |
50 names = data.columns.to_list() | |
51 LOG.error(f"Column names: {names}") | |
52 target_index = int(self.target)-1 | |
53 target_name = names[target_index] | |
54 exp.setup(data, target=target_name, test_data=data, index=False) | |
55 exp.add_metric(id='PR-AUC-Weighted', | |
56 name='PR-AUC-Weighted', | |
57 target='pred_proba', | |
58 score_func=average_precision_score, | |
59 average='weighted') | |
60 predictions = exp.predict_model(self.model) | |
61 metrics = exp.pull() | |
62 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | |
63 'error', 'class_report', 'learning', 'calibration', | |
64 'vc', 'dimension', 'manifold', 'rfe', 'feature', | |
65 'feature_all'] | |
66 for plot_name in plots: | |
67 try: | |
68 if plot_name == 'auc' and not exp.is_multiclass: | |
69 plot_path = exp.plot_model(self.model, | |
70 plot=plot_name, | |
71 save=True, | |
72 plot_kwargs={ | |
73 'micro': False, | |
74 'macro': False, | |
75 'per_class': False, | |
76 'binary': True | |
77 }) | |
78 plot_paths[plot_name] = plot_path | |
79 continue | |
80 | |
81 plot_path = exp.plot_model(self.model, | |
82 plot=plot_name, save=True) | |
83 plot_paths[plot_name] = plot_path | |
84 except Exception as e: | |
85 LOG.error(f"Error generating plot {plot_name}: {e}") | |
86 continue | |
87 generate_html_report(plot_paths, metrics) | |
88 | |
89 else: | |
90 exp = ClassificationExperiment() | |
91 exp.setup(data, target=None, test_data=data, index=False) | |
92 predictions = exp.predict_model(self.model, data=data) | |
93 | |
94 return predictions, metrics, plot_paths | |
95 | |
96 | |
97 class RegressionEvaluator(PyCaretModelEvaluator): | |
98 def evaluate(self, data_path): | |
99 metrics = None | |
100 plot_paths = {} | |
101 data = pd.read_csv(data_path, engine='python', sep=None) | |
102 if self.target: | |
103 names = data.columns.to_list() | |
104 target_index = int(self.target)-1 | |
105 target_name = names[target_index] | |
106 exp = RegressionExperiment() | |
107 exp.setup(data, target=target_name, test_data=data, index=False) | |
108 predictions = exp.predict_model(self.model) | |
109 metrics = exp.pull() | |
110 plots = ['residuals', 'error', 'cooks', | |
111 'learning', 'vc', 'manifold', | |
112 'rfe', 'feature', 'feature_all'] | |
113 for plot_name in plots: | |
114 try: | |
115 plot_path = exp.plot_model(self.model, | |
116 plot=plot_name, save=True) | |
117 plot_paths[plot_name] = plot_path | |
118 except Exception as e: | |
119 LOG.error(f"Error generating plot {plot_name}: {e}") | |
120 continue | |
121 generate_html_report(plot_paths, metrics) | |
122 else: | |
123 exp = RegressionExperiment() | |
124 exp.setup(data, target=None, test_data=data, index=False) | |
125 predictions = exp.predict_model(self.model, data=data) | |
126 | |
127 return predictions, metrics, plot_paths | |
128 | |
129 | |
130 def generate_html_report(plots, metrics): | |
131 """Generate an HTML evaluation report.""" | |
132 plots_html = "" | |
133 for plot_name, plot_path in plots.items(): | |
134 encoded_image = encode_image_to_base64(plot_path) | |
135 plots_html += f""" | |
136 <div class="plot"> | |
137 <h3>{plot_name.capitalize()}</h3> | |
138 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> | |
139 </div> | |
140 <hr> | |
141 """ | |
142 | |
143 metrics_html = metrics.to_html(index=False, classes="table") | |
144 | |
145 html_content = f""" | |
146 {get_html_template()} | |
147 <h1>Model Evaluation Report</h1> | |
148 <div class="tabs"> | |
149 <div class="tab" onclick="openTab(event, 'metrics')">Metrics</div> | |
150 <div class="tab" onclick="openTab(event, 'plots')">Plots</div> | |
151 </div> | |
152 <div id="metrics" class="tab-content"> | |
153 <h2>Metrics</h2> | |
154 <table> | |
155 {metrics_html} | |
156 </table> | |
157 </div> | |
158 <div id="plots" class="tab-content"> | |
159 <h2>Plots</h2> | |
160 {plots_html} | |
161 </div> | |
162 {get_html_closing()} | |
163 """ | |
164 | |
165 # Save HTML report | |
166 with open("evaluation_report.html", "w") as html_file: | |
167 html_file.write(html_content) | |
168 | |
169 | |
170 if __name__ == "__main__": | |
171 parser = argparse.ArgumentParser( | |
172 description="Evaluate a PyCaret model stored in HDF5 format.") | |
173 parser.add_argument("--model_path", | |
174 type=str, | |
175 help="Path to the HDF5 model file.") | |
176 parser.add_argument("--data_path", | |
177 type=str, | |
178 help="Path to the evaluation data CSV file.") | |
179 parser.add_argument("--task", | |
180 type=str, | |
181 choices=["classification", "regression"], | |
182 help="Specify the task: classification or regression.") | |
183 parser.add_argument("--target", | |
184 default=None, | |
185 help="Column number of the target") | |
186 args = parser.parse_args() | |
187 | |
188 if args.task == "classification": | |
189 evaluator = ClassificationEvaluator( | |
190 args.model_path, args.task, args.target) | |
191 elif args.task == "regression": | |
192 evaluator = RegressionEvaluator( | |
193 args.model_path, args.task, args.target) | |
194 else: | |
195 raise ValueError( | |
196 "Unsupported task type. Use 'classification' or 'regression'.") | |
197 | |
198 predictions, metrics, plots = evaluator.evaluate(args.data_path) | |
199 | |
200 predictions.to_csv("predictions.csv", index=False) |