Mercurial > repos > goeckslab > pycaret_compare
comparison base_model_trainer.py @ 0:915447b14520 draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author | goeckslab |
---|---|
date | Wed, 11 Dec 2024 05:00:00 +0000 |
parents | |
children | 02f7746e7772 |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:915447b14520 |
---|---|
1 import base64 | |
2 import logging | |
3 import os | |
4 import tempfile | |
5 | |
6 from feature_importance import FeatureImportanceAnalyzer | |
7 | |
8 import h5py | |
9 | |
10 import joblib | |
11 | |
12 import numpy as np | |
13 | |
14 import pandas as pd | |
15 | |
16 from sklearn.metrics import average_precision_score | |
17 | |
18 from utils import get_html_closing, get_html_template | |
19 | |
20 logging.basicConfig(level=logging.DEBUG) | |
21 LOG = logging.getLogger(__name__) | |
22 | |
23 | |
24 class BaseModelTrainer: | |
25 | |
26 def __init__( | |
27 self, | |
28 input_file, | |
29 target_col, | |
30 output_dir, | |
31 task_type, | |
32 random_seed, | |
33 test_file=None, | |
34 **kwargs | |
35 ): | |
36 self.exp = None # This will be set in the subclass | |
37 self.input_file = input_file | |
38 self.target_col = target_col | |
39 self.output_dir = output_dir | |
40 self.task_type = task_type | |
41 self.random_seed = random_seed | |
42 self.data = None | |
43 self.target = None | |
44 self.best_model = None | |
45 self.results = None | |
46 self.features_name = None | |
47 self.plots = {} | |
48 self.expaliner = None | |
49 self.plots_explainer_html = None | |
50 self.trees = [] | |
51 for key, value in kwargs.items(): | |
52 setattr(self, key, value) | |
53 self.setup_params = {} | |
54 self.test_file = test_file | |
55 self.test_data = None | |
56 | |
57 LOG.info(f"Model kwargs: {self.__dict__}") | |
58 | |
59 def load_data(self): | |
60 LOG.info(f"Loading data from {self.input_file}") | |
61 self.data = pd.read_csv(self.input_file, sep=None, engine='python') | |
62 self.data.columns = self.data.columns.str.replace('.', '_') | |
63 | |
64 numeric_cols = self.data.select_dtypes(include=['number']).columns | |
65 non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns | |
66 | |
67 self.data[numeric_cols] = self.data[numeric_cols].apply( | |
68 pd.to_numeric, errors='coerce') | |
69 | |
70 if len(non_numeric_cols) > 0: | |
71 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | |
72 | |
73 names = self.data.columns.to_list() | |
74 target_index = int(self.target_col)-1 | |
75 self.target = names[target_index] | |
76 self.features_name = [name | |
77 for i, name in enumerate(names) | |
78 if i != target_index] | |
79 if hasattr(self, 'missing_value_strategy'): | |
80 if self.missing_value_strategy == 'mean': | |
81 self.data = self.data.fillna( | |
82 self.data.mean(numeric_only=True)) | |
83 elif self.missing_value_strategy == 'median': | |
84 self.data = self.data.fillna( | |
85 self.data.median(numeric_only=True)) | |
86 elif self.missing_value_strategy == 'drop': | |
87 self.data = self.data.dropna() | |
88 else: | |
89 # Default strategy if not specified | |
90 self.data = self.data.fillna(self.data.median(numeric_only=True)) | |
91 | |
92 if self.test_file: | |
93 LOG.info(f"Loading test data from {self.test_file}") | |
94 self.test_data = pd.read_csv( | |
95 self.test_file, sep=None, engine='python') | |
96 self.test_data = self.test_data[numeric_cols].apply( | |
97 pd.to_numeric, errors='coerce') | |
98 self.test_data.columns = self.test_data.columns.str.replace( | |
99 '.', '_' | |
100 ) | |
101 | |
102 def setup_pycaret(self): | |
103 LOG.info("Initializing PyCaret") | |
104 self.setup_params = { | |
105 'target': self.target, | |
106 'session_id': self.random_seed, | |
107 'html': True, | |
108 'log_experiment': False, | |
109 'system_log': False, | |
110 'index': False, | |
111 } | |
112 | |
113 if self.test_data is not None: | |
114 self.setup_params['test_data'] = self.test_data | |
115 | |
116 if hasattr(self, 'train_size') and self.train_size is not None \ | |
117 and self.test_data is None: | |
118 self.setup_params['train_size'] = self.train_size | |
119 | |
120 if hasattr(self, 'normalize') and self.normalize is not None: | |
121 self.setup_params['normalize'] = self.normalize | |
122 | |
123 if hasattr(self, 'feature_selection') and \ | |
124 self.feature_selection is not None: | |
125 self.setup_params['feature_selection'] = self.feature_selection | |
126 | |
127 if hasattr(self, 'cross_validation') and \ | |
128 self.cross_validation is not None \ | |
129 and self.cross_validation is False: | |
130 self.setup_params['cross_validation'] = self.cross_validation | |
131 | |
132 if hasattr(self, 'cross_validation') and \ | |
133 self.cross_validation is not None: | |
134 if hasattr(self, 'cross_validation_folds'): | |
135 self.setup_params['fold'] = self.cross_validation_folds | |
136 | |
137 if hasattr(self, 'remove_outliers') and \ | |
138 self.remove_outliers is not None: | |
139 self.setup_params['remove_outliers'] = self.remove_outliers | |
140 | |
141 if hasattr(self, 'remove_multicollinearity') and \ | |
142 self.remove_multicollinearity is not None: | |
143 self.setup_params['remove_multicollinearity'] = \ | |
144 self.remove_multicollinearity | |
145 | |
146 if hasattr(self, 'polynomial_features') and \ | |
147 self.polynomial_features is not None: | |
148 self.setup_params['polynomial_features'] = self.polynomial_features | |
149 | |
150 if hasattr(self, 'fix_imbalance') and \ | |
151 self.fix_imbalance is not None: | |
152 self.setup_params['fix_imbalance'] = self.fix_imbalance | |
153 | |
154 LOG.info(self.setup_params) | |
155 self.exp.setup(self.data, **self.setup_params) | |
156 | |
157 def train_model(self): | |
158 LOG.info("Training and selecting the best model") | |
159 if self.task_type == "classification": | |
160 average_displayed = "Weighted" | |
161 self.exp.add_metric(id=f'PR-AUC-{average_displayed}', | |
162 name=f'PR-AUC-{average_displayed}', | |
163 target='pred_proba', | |
164 score_func=average_precision_score, | |
165 average='weighted' | |
166 ) | |
167 | |
168 if hasattr(self, 'models') and self.models is not None: | |
169 self.best_model = self.exp.compare_models( | |
170 include=self.models) | |
171 else: | |
172 self.best_model = self.exp.compare_models() | |
173 self.results = self.exp.pull() | |
174 if self.task_type == "classification": | |
175 self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True) | |
176 | |
177 _ = self.exp.predict_model(self.best_model) | |
178 self.test_result_df = self.exp.pull() | |
179 if self.task_type == "classification": | |
180 self.test_result_df.rename( | |
181 columns={'AUC': 'ROC-AUC'}, inplace=True) | |
182 | |
183 def save_model(self): | |
184 hdf5_model_path = "pycaret_model.h5" | |
185 with h5py.File(hdf5_model_path, 'w') as f: | |
186 with tempfile.NamedTemporaryFile(delete=False) as temp_file: | |
187 joblib.dump(self.best_model, temp_file.name) | |
188 temp_file.seek(0) | |
189 model_bytes = temp_file.read() | |
190 f.create_dataset('model', data=np.void(model_bytes)) | |
191 | |
192 def generate_plots(self): | |
193 raise NotImplementedError("Subclasses should implement this method") | |
194 | |
195 def encode_image_to_base64(self, img_path): | |
196 with open(img_path, 'rb') as img_file: | |
197 return base64.b64encode(img_file.read()).decode('utf-8') | |
198 | |
199 def save_html_report(self): | |
200 LOG.info("Saving HTML report") | |
201 | |
202 model_name = type(self.best_model).__name__ | |
203 excluded_params = ['html', 'log_experiment', 'system_log', 'test_data'] | |
204 filtered_setup_params = { | |
205 k: v | |
206 for k, v in self.setup_params.items() if k not in excluded_params | |
207 } | |
208 setup_params_table = pd.DataFrame( | |
209 list(filtered_setup_params.items()), | |
210 columns=['Parameter', 'Value']) | |
211 | |
212 best_model_params = pd.DataFrame( | |
213 self.best_model.get_params().items(), | |
214 columns=['Parameter', 'Value']) | |
215 best_model_params.to_csv( | |
216 os.path.join(self.output_dir, 'best_model.csv'), | |
217 index=False) | |
218 self.results.to_csv(os.path.join( | |
219 self.output_dir, "comparison_results.csv")) | |
220 self.test_result_df.to_csv(os.path.join( | |
221 self.output_dir, "test_results.csv")) | |
222 | |
223 plots_html = "" | |
224 length = len(self.plots) | |
225 for i, (plot_name, plot_path) in enumerate(self.plots.items()): | |
226 encoded_image = self.encode_image_to_base64(plot_path) | |
227 plots_html += f""" | |
228 <div class="plot"> | |
229 <h3>{plot_name.capitalize()}</h3> | |
230 <img src="data:image/png;base64,{encoded_image}" | |
231 alt="{plot_name}"> | |
232 </div> | |
233 """ | |
234 if i < length - 1: | |
235 plots_html += "<hr>" | |
236 | |
237 tree_plots = "" | |
238 for i, tree in enumerate(self.trees): | |
239 if tree: | |
240 tree_plots += f""" | |
241 <div class="plot"> | |
242 <h3>Tree {i+1}</h3> | |
243 <img src="data:image/png;base64, | |
244 {tree}" | |
245 alt="tree {i+1}"> | |
246 </div> | |
247 """ | |
248 | |
249 analyzer = FeatureImportanceAnalyzer( | |
250 data=self.data, | |
251 target_col=self.target_col, | |
252 task_type=self.task_type, | |
253 output_dir=self.output_dir) | |
254 feature_importance_html = analyzer.run() | |
255 | |
256 html_content = f""" | |
257 {get_html_template()} | |
258 <h1>PyCaret Model Training Report</h1> | |
259 <div class="tabs"> | |
260 <div class="tab" onclick="openTab(event, 'summary')"> | |
261 Setup & Best Model</div> | |
262 <div class="tab" onclick="openTab(event, 'plots')"> | |
263 Best Model Plots</div> | |
264 <div class="tab" onclick="openTab(event, 'feature')"> | |
265 Feature Importance</div> | |
266 <div class="tab" onclick="openTab(event, 'explainer')"> | |
267 Explainer | |
268 </div> | |
269 </div> | |
270 <div id="summary" class="tab-content"> | |
271 <h2>Setup Parameters</h2> | |
272 <table> | |
273 <tr><th>Parameter</th><th>Value</th></tr> | |
274 {setup_params_table.to_html( | |
275 index=False, header=False, classes='table')} | |
276 </table> | |
277 <h5>If you want to know all the experiment setup parameters, | |
278 please check the PyCaret documentation for | |
279 the classification/regression <code>exp</code> function.</h5> | |
280 <h2>Best Model: {model_name}</h2> | |
281 <table> | |
282 <tr><th>Parameter</th><th>Value</th></tr> | |
283 {best_model_params.to_html( | |
284 index=False, header=False, classes='table')} | |
285 </table> | |
286 <h2>Comparison Results on the Cross-Validation Set</h2> | |
287 <table> | |
288 {self.results.to_html(index=False, classes='table')} | |
289 </table> | |
290 <h2>Results on the Test Set for the best model</h2> | |
291 <table> | |
292 {self.test_result_df.to_html(index=False, classes='table')} | |
293 </table> | |
294 </div> | |
295 <div id="plots" class="tab-content"> | |
296 <h2>Best Model Plots on the testing set</h2> | |
297 {plots_html} | |
298 </div> | |
299 <div id="feature" class="tab-content"> | |
300 {feature_importance_html} | |
301 </div> | |
302 <div id="explainer" class="tab-content"> | |
303 {self.plots_explainer_html} | |
304 {tree_plots} | |
305 </div> | |
306 {get_html_closing()} | |
307 """ | |
308 | |
309 with open(os.path.join( | |
310 self.output_dir, "comparison_result.html"), "w") as file: | |
311 file.write(html_content) | |
312 | |
313 def save_dashboard(self): | |
314 raise NotImplementedError("Subclasses should implement this method") | |
315 | |
316 def generate_plots_explainer(self): | |
317 raise NotImplementedError("Subclasses should implement this method") | |
318 | |
319 # not working now | |
320 def generate_tree_plots(self): | |
321 from sklearn.ensemble import RandomForestClassifier, \ | |
322 RandomForestRegressor | |
323 from xgboost import XGBClassifier, XGBRegressor | |
324 from explainerdashboard.explainers import RandomForestExplainer | |
325 | |
326 LOG.info("Generating tree plots") | |
327 X_test = self.exp.X_test_transformed.copy() | |
328 y_test = self.exp.y_test_transformed | |
329 | |
330 is_rf = isinstance(self.best_model, RandomForestClassifier) or \ | |
331 isinstance(self.best_model, RandomForestRegressor) | |
332 | |
333 is_xgb = isinstance(self.best_model, XGBClassifier) or \ | |
334 isinstance(self.best_model, XGBRegressor) | |
335 | |
336 try: | |
337 if is_rf: | |
338 num_trees = self.best_model.n_estimators | |
339 if is_xgb: | |
340 num_trees = len(self.best_model.get_booster().get_dump()) | |
341 explainer = RandomForestExplainer(self.best_model, X_test, y_test) | |
342 for i in range(num_trees): | |
343 fig = explainer.decisiontree_encoded(tree_idx=i, index=0) | |
344 LOG.info(f"Tree {i+1}") | |
345 LOG.info(fig) | |
346 self.trees.append(fig) | |
347 except Exception as e: | |
348 LOG.error(f"Error generating tree plots: {e}") | |
349 | |
350 def run(self): | |
351 self.load_data() | |
352 self.setup_pycaret() | |
353 self.train_model() | |
354 self.save_model() | |
355 self.generate_plots() | |
356 self.generate_plots_explainer() | |
357 self.generate_tree_plots() | |
358 self.save_html_report() | |
359 # self.save_dashboard() |