comparison pycaret_predict.py @ 0:1f20fe57fdee draft

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author goeckslab
date Wed, 11 Dec 2024 04:59:43 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:1f20fe57fdee
1 import argparse
2 import logging
3 import tempfile
4
5 import h5py
6
7 import joblib
8
9 import pandas as pd
10
11 from pycaret.classification import ClassificationExperiment
12 from pycaret.regression import RegressionExperiment
13
14 from sklearn.metrics import average_precision_score
15
16 from utils import encode_image_to_base64, get_html_closing, get_html_template
17
18 LOG = logging.getLogger(__name__)
19
20
21 class PyCaretModelEvaluator:
22 def __init__(self, model_path, task, target):
23 self.model_path = model_path
24 self.task = task.lower()
25 self.model = self.load_h5_model()
26 self.target = target if target != "None" else None
27
28 def load_h5_model(self):
29 """Load a PyCaret model from an HDF5 file."""
30 with h5py.File(self.model_path, 'r') as f:
31 model_bytes = bytes(f['model'][()])
32 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
33 temp_file.write(model_bytes)
34 temp_file.seek(0)
35 loaded_model = joblib.load(temp_file.name)
36 return loaded_model
37
38 def evaluate(self, data_path):
39 """Evaluate the model using the specified data."""
40 raise NotImplementedError("Subclasses must implement this method")
41
42
43 class ClassificationEvaluator(PyCaretModelEvaluator):
44 def evaluate(self, data_path):
45 metrics = None
46 plot_paths = {}
47 data = pd.read_csv(data_path, engine='python', sep=None)
48 if self.target:
49 exp = ClassificationExperiment()
50 names = data.columns.to_list()
51 LOG.error(f"Column names: {names}")
52 target_index = int(self.target)-1
53 target_name = names[target_index]
54 exp.setup(data, target=target_name, test_data=data, index=False)
55 exp.add_metric(id='PR-AUC-Weighted',
56 name='PR-AUC-Weighted',
57 target='pred_proba',
58 score_func=average_precision_score,
59 average='weighted')
60 predictions = exp.predict_model(self.model)
61 metrics = exp.pull()
62 plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
63 'error', 'class_report', 'learning', 'calibration',
64 'vc', 'dimension', 'manifold', 'rfe', 'feature',
65 'feature_all']
66 for plot_name in plots:
67 try:
68 if plot_name == 'auc' and not exp.is_multiclass:
69 plot_path = exp.plot_model(self.model,
70 plot=plot_name,
71 save=True,
72 plot_kwargs={
73 'micro': False,
74 'macro': False,
75 'per_class': False,
76 'binary': True
77 })
78 plot_paths[plot_name] = plot_path
79 continue
80
81 plot_path = exp.plot_model(self.model,
82 plot=plot_name, save=True)
83 plot_paths[plot_name] = plot_path
84 except Exception as e:
85 LOG.error(f"Error generating plot {plot_name}: {e}")
86 continue
87 generate_html_report(plot_paths, metrics)
88
89 else:
90 exp = ClassificationExperiment()
91 exp.setup(data, target=None, test_data=data, index=False)
92 predictions = exp.predict_model(self.model, data=data)
93
94 return predictions, metrics, plot_paths
95
96
97 class RegressionEvaluator(PyCaretModelEvaluator):
98 def evaluate(self, data_path):
99 metrics = None
100 plot_paths = {}
101 data = pd.read_csv(data_path, engine='python', sep=None)
102 if self.target:
103 names = data.columns.to_list()
104 target_index = int(self.target)-1
105 target_name = names[target_index]
106 exp = RegressionExperiment()
107 exp.setup(data, target=target_name, test_data=data, index=False)
108 predictions = exp.predict_model(self.model)
109 metrics = exp.pull()
110 plots = ['residuals', 'error', 'cooks',
111 'learning', 'vc', 'manifold',
112 'rfe', 'feature', 'feature_all']
113 for plot_name in plots:
114 try:
115 plot_path = exp.plot_model(self.model,
116 plot=plot_name, save=True)
117 plot_paths[plot_name] = plot_path
118 except Exception as e:
119 LOG.error(f"Error generating plot {plot_name}: {e}")
120 continue
121 generate_html_report(plot_paths, metrics)
122 else:
123 exp = RegressionExperiment()
124 exp.setup(data, target=None, test_data=data, index=False)
125 predictions = exp.predict_model(self.model, data=data)
126
127 return predictions, metrics, plot_paths
128
129
130 def generate_html_report(plots, metrics):
131 """Generate an HTML evaluation report."""
132 plots_html = ""
133 for plot_name, plot_path in plots.items():
134 encoded_image = encode_image_to_base64(plot_path)
135 plots_html += f"""
136 <div class="plot">
137 <h3>{plot_name.capitalize()}</h3>
138 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
139 </div>
140 <hr>
141 """
142
143 metrics_html = metrics.to_html(index=False, classes="table")
144
145 html_content = f"""
146 {get_html_template()}
147 <h1>Model Evaluation Report</h1>
148 <div class="tabs">
149 <div class="tab" onclick="openTab(event, 'metrics')">Metrics</div>
150 <div class="tab" onclick="openTab(event, 'plots')">Plots</div>
151 </div>
152 <div id="metrics" class="tab-content">
153 <h2>Metrics</h2>
154 <table>
155 {metrics_html}
156 </table>
157 </div>
158 <div id="plots" class="tab-content">
159 <h2>Plots</h2>
160 {plots_html}
161 </div>
162 {get_html_closing()}
163 """
164
165 # Save HTML report
166 with open("evaluation_report.html", "w") as html_file:
167 html_file.write(html_content)
168
169
170 if __name__ == "__main__":
171 parser = argparse.ArgumentParser(
172 description="Evaluate a PyCaret model stored in HDF5 format.")
173 parser.add_argument("--model_path",
174 type=str,
175 help="Path to the HDF5 model file.")
176 parser.add_argument("--data_path",
177 type=str,
178 help="Path to the evaluation data CSV file.")
179 parser.add_argument("--task",
180 type=str,
181 choices=["classification", "regression"],
182 help="Specify the task: classification or regression.")
183 parser.add_argument("--target",
184 default=None,
185 help="Column number of the target")
186 args = parser.parse_args()
187
188 if args.task == "classification":
189 evaluator = ClassificationEvaluator(
190 args.model_path, args.task, args.target)
191 elif args.task == "regression":
192 evaluator = RegressionEvaluator(
193 args.model_path, args.task, args.target)
194 else:
195 raise ValueError(
196 "Unsupported task type. Use 'classification' or 'regression'.")
197
198 predictions, metrics, plots = evaluator.evaluate(args.data_path)
199
200 predictions.to_csv("predictions.csv", index=False)