comparison base_model_trainer.py @ 0:915447b14520 draft

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author goeckslab
date Wed, 11 Dec 2024 05:00:00 +0000
parents
children 02f7746e7772
comparison
equal deleted inserted replaced
-1:000000000000 0:915447b14520
1 import base64
2 import logging
3 import os
4 import tempfile
5
6 from feature_importance import FeatureImportanceAnalyzer
7
8 import h5py
9
10 import joblib
11
12 import numpy as np
13
14 import pandas as pd
15
16 from sklearn.metrics import average_precision_score
17
18 from utils import get_html_closing, get_html_template
19
20 logging.basicConfig(level=logging.DEBUG)
21 LOG = logging.getLogger(__name__)
22
23
24 class BaseModelTrainer:
25
26 def __init__(
27 self,
28 input_file,
29 target_col,
30 output_dir,
31 task_type,
32 random_seed,
33 test_file=None,
34 **kwargs
35 ):
36 self.exp = None # This will be set in the subclass
37 self.input_file = input_file
38 self.target_col = target_col
39 self.output_dir = output_dir
40 self.task_type = task_type
41 self.random_seed = random_seed
42 self.data = None
43 self.target = None
44 self.best_model = None
45 self.results = None
46 self.features_name = None
47 self.plots = {}
48 self.expaliner = None
49 self.plots_explainer_html = None
50 self.trees = []
51 for key, value in kwargs.items():
52 setattr(self, key, value)
53 self.setup_params = {}
54 self.test_file = test_file
55 self.test_data = None
56
57 LOG.info(f"Model kwargs: {self.__dict__}")
58
59 def load_data(self):
60 LOG.info(f"Loading data from {self.input_file}")
61 self.data = pd.read_csv(self.input_file, sep=None, engine='python')
62 self.data.columns = self.data.columns.str.replace('.', '_')
63
64 numeric_cols = self.data.select_dtypes(include=['number']).columns
65 non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns
66
67 self.data[numeric_cols] = self.data[numeric_cols].apply(
68 pd.to_numeric, errors='coerce')
69
70 if len(non_numeric_cols) > 0:
71 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
72
73 names = self.data.columns.to_list()
74 target_index = int(self.target_col)-1
75 self.target = names[target_index]
76 self.features_name = [name
77 for i, name in enumerate(names)
78 if i != target_index]
79 if hasattr(self, 'missing_value_strategy'):
80 if self.missing_value_strategy == 'mean':
81 self.data = self.data.fillna(
82 self.data.mean(numeric_only=True))
83 elif self.missing_value_strategy == 'median':
84 self.data = self.data.fillna(
85 self.data.median(numeric_only=True))
86 elif self.missing_value_strategy == 'drop':
87 self.data = self.data.dropna()
88 else:
89 # Default strategy if not specified
90 self.data = self.data.fillna(self.data.median(numeric_only=True))
91
92 if self.test_file:
93 LOG.info(f"Loading test data from {self.test_file}")
94 self.test_data = pd.read_csv(
95 self.test_file, sep=None, engine='python')
96 self.test_data = self.test_data[numeric_cols].apply(
97 pd.to_numeric, errors='coerce')
98 self.test_data.columns = self.test_data.columns.str.replace(
99 '.', '_'
100 )
101
102 def setup_pycaret(self):
103 LOG.info("Initializing PyCaret")
104 self.setup_params = {
105 'target': self.target,
106 'session_id': self.random_seed,
107 'html': True,
108 'log_experiment': False,
109 'system_log': False,
110 'index': False,
111 }
112
113 if self.test_data is not None:
114 self.setup_params['test_data'] = self.test_data
115
116 if hasattr(self, 'train_size') and self.train_size is not None \
117 and self.test_data is None:
118 self.setup_params['train_size'] = self.train_size
119
120 if hasattr(self, 'normalize') and self.normalize is not None:
121 self.setup_params['normalize'] = self.normalize
122
123 if hasattr(self, 'feature_selection') and \
124 self.feature_selection is not None:
125 self.setup_params['feature_selection'] = self.feature_selection
126
127 if hasattr(self, 'cross_validation') and \
128 self.cross_validation is not None \
129 and self.cross_validation is False:
130 self.setup_params['cross_validation'] = self.cross_validation
131
132 if hasattr(self, 'cross_validation') and \
133 self.cross_validation is not None:
134 if hasattr(self, 'cross_validation_folds'):
135 self.setup_params['fold'] = self.cross_validation_folds
136
137 if hasattr(self, 'remove_outliers') and \
138 self.remove_outliers is not None:
139 self.setup_params['remove_outliers'] = self.remove_outliers
140
141 if hasattr(self, 'remove_multicollinearity') and \
142 self.remove_multicollinearity is not None:
143 self.setup_params['remove_multicollinearity'] = \
144 self.remove_multicollinearity
145
146 if hasattr(self, 'polynomial_features') and \
147 self.polynomial_features is not None:
148 self.setup_params['polynomial_features'] = self.polynomial_features
149
150 if hasattr(self, 'fix_imbalance') and \
151 self.fix_imbalance is not None:
152 self.setup_params['fix_imbalance'] = self.fix_imbalance
153
154 LOG.info(self.setup_params)
155 self.exp.setup(self.data, **self.setup_params)
156
157 def train_model(self):
158 LOG.info("Training and selecting the best model")
159 if self.task_type == "classification":
160 average_displayed = "Weighted"
161 self.exp.add_metric(id=f'PR-AUC-{average_displayed}',
162 name=f'PR-AUC-{average_displayed}',
163 target='pred_proba',
164 score_func=average_precision_score,
165 average='weighted'
166 )
167
168 if hasattr(self, 'models') and self.models is not None:
169 self.best_model = self.exp.compare_models(
170 include=self.models)
171 else:
172 self.best_model = self.exp.compare_models()
173 self.results = self.exp.pull()
174 if self.task_type == "classification":
175 self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True)
176
177 _ = self.exp.predict_model(self.best_model)
178 self.test_result_df = self.exp.pull()
179 if self.task_type == "classification":
180 self.test_result_df.rename(
181 columns={'AUC': 'ROC-AUC'}, inplace=True)
182
183 def save_model(self):
184 hdf5_model_path = "pycaret_model.h5"
185 with h5py.File(hdf5_model_path, 'w') as f:
186 with tempfile.NamedTemporaryFile(delete=False) as temp_file:
187 joblib.dump(self.best_model, temp_file.name)
188 temp_file.seek(0)
189 model_bytes = temp_file.read()
190 f.create_dataset('model', data=np.void(model_bytes))
191
192 def generate_plots(self):
193 raise NotImplementedError("Subclasses should implement this method")
194
195 def encode_image_to_base64(self, img_path):
196 with open(img_path, 'rb') as img_file:
197 return base64.b64encode(img_file.read()).decode('utf-8')
198
199 def save_html_report(self):
200 LOG.info("Saving HTML report")
201
202 model_name = type(self.best_model).__name__
203 excluded_params = ['html', 'log_experiment', 'system_log', 'test_data']
204 filtered_setup_params = {
205 k: v
206 for k, v in self.setup_params.items() if k not in excluded_params
207 }
208 setup_params_table = pd.DataFrame(
209 list(filtered_setup_params.items()),
210 columns=['Parameter', 'Value'])
211
212 best_model_params = pd.DataFrame(
213 self.best_model.get_params().items(),
214 columns=['Parameter', 'Value'])
215 best_model_params.to_csv(
216 os.path.join(self.output_dir, 'best_model.csv'),
217 index=False)
218 self.results.to_csv(os.path.join(
219 self.output_dir, "comparison_results.csv"))
220 self.test_result_df.to_csv(os.path.join(
221 self.output_dir, "test_results.csv"))
222
223 plots_html = ""
224 length = len(self.plots)
225 for i, (plot_name, plot_path) in enumerate(self.plots.items()):
226 encoded_image = self.encode_image_to_base64(plot_path)
227 plots_html += f"""
228 <div class="plot">
229 <h3>{plot_name.capitalize()}</h3>
230 <img src="data:image/png;base64,{encoded_image}"
231 alt="{plot_name}">
232 </div>
233 """
234 if i < length - 1:
235 plots_html += "<hr>"
236
237 tree_plots = ""
238 for i, tree in enumerate(self.trees):
239 if tree:
240 tree_plots += f"""
241 <div class="plot">
242 <h3>Tree {i+1}</h3>
243 <img src="data:image/png;base64,
244 {tree}"
245 alt="tree {i+1}">
246 </div>
247 """
248
249 analyzer = FeatureImportanceAnalyzer(
250 data=self.data,
251 target_col=self.target_col,
252 task_type=self.task_type,
253 output_dir=self.output_dir)
254 feature_importance_html = analyzer.run()
255
256 html_content = f"""
257 {get_html_template()}
258 <h1>PyCaret Model Training Report</h1>
259 <div class="tabs">
260 <div class="tab" onclick="openTab(event, 'summary')">
261 Setup & Best Model</div>
262 <div class="tab" onclick="openTab(event, 'plots')">
263 Best Model Plots</div>
264 <div class="tab" onclick="openTab(event, 'feature')">
265 Feature Importance</div>
266 <div class="tab" onclick="openTab(event, 'explainer')">
267 Explainer
268 </div>
269 </div>
270 <div id="summary" class="tab-content">
271 <h2>Setup Parameters</h2>
272 <table>
273 <tr><th>Parameter</th><th>Value</th></tr>
274 {setup_params_table.to_html(
275 index=False, header=False, classes='table')}
276 </table>
277 <h5>If you want to know all the experiment setup parameters,
278 please check the PyCaret documentation for
279 the classification/regression <code>exp</code> function.</h5>
280 <h2>Best Model: {model_name}</h2>
281 <table>
282 <tr><th>Parameter</th><th>Value</th></tr>
283 {best_model_params.to_html(
284 index=False, header=False, classes='table')}
285 </table>
286 <h2>Comparison Results on the Cross-Validation Set</h2>
287 <table>
288 {self.results.to_html(index=False, classes='table')}
289 </table>
290 <h2>Results on the Test Set for the best model</h2>
291 <table>
292 {self.test_result_df.to_html(index=False, classes='table')}
293 </table>
294 </div>
295 <div id="plots" class="tab-content">
296 <h2>Best Model Plots on the testing set</h2>
297 {plots_html}
298 </div>
299 <div id="feature" class="tab-content">
300 {feature_importance_html}
301 </div>
302 <div id="explainer" class="tab-content">
303 {self.plots_explainer_html}
304 {tree_plots}
305 </div>
306 {get_html_closing()}
307 """
308
309 with open(os.path.join(
310 self.output_dir, "comparison_result.html"), "w") as file:
311 file.write(html_content)
312
313 def save_dashboard(self):
314 raise NotImplementedError("Subclasses should implement this method")
315
316 def generate_plots_explainer(self):
317 raise NotImplementedError("Subclasses should implement this method")
318
319 # not working now
320 def generate_tree_plots(self):
321 from sklearn.ensemble import RandomForestClassifier, \
322 RandomForestRegressor
323 from xgboost import XGBClassifier, XGBRegressor
324 from explainerdashboard.explainers import RandomForestExplainer
325
326 LOG.info("Generating tree plots")
327 X_test = self.exp.X_test_transformed.copy()
328 y_test = self.exp.y_test_transformed
329
330 is_rf = isinstance(self.best_model, RandomForestClassifier) or \
331 isinstance(self.best_model, RandomForestRegressor)
332
333 is_xgb = isinstance(self.best_model, XGBClassifier) or \
334 isinstance(self.best_model, XGBRegressor)
335
336 try:
337 if is_rf:
338 num_trees = self.best_model.n_estimators
339 if is_xgb:
340 num_trees = len(self.best_model.get_booster().get_dump())
341 explainer = RandomForestExplainer(self.best_model, X_test, y_test)
342 for i in range(num_trees):
343 fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
344 LOG.info(f"Tree {i+1}")
345 LOG.info(fig)
346 self.trees.append(fig)
347 except Exception as e:
348 LOG.error(f"Error generating tree plots: {e}")
349
350 def run(self):
351 self.load_data()
352 self.setup_pycaret()
353 self.train_model()
354 self.save_model()
355 self.generate_plots()
356 self.generate_plots_explainer()
357 self.generate_tree_plots()
358 self.save_html_report()
359 # self.save_dashboard()