comparison base_model_trainer.py @ 4:4aa511539199 draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit cf47efb521b91a9cb44ae5c5ade860627f9b9030
author goeckslab
date Tue, 03 Jun 2025 19:31:16 +0000
parents 02f7746e7772
children
comparison
equal deleted inserted replaced
3:02f7746e7772 4:4aa511539199
1 import base64 1 import base64
2 import logging 2 import logging
3 import os 3 import os
4 import tempfile 4 import tempfile
5 5
6 import h5py
7 import joblib
8 import numpy as np
9 import pandas as pd
6 from feature_importance import FeatureImportanceAnalyzer 10 from feature_importance import FeatureImportanceAnalyzer
7
8 import h5py
9
10 import joblib
11
12 import numpy as np
13
14 import pandas as pd
15
16 from sklearn.metrics import average_precision_score 11 from sklearn.metrics import average_precision_score
17
18 from utils import get_html_closing, get_html_template 12 from utils import get_html_closing, get_html_template
19 13
20 logging.basicConfig(level=logging.DEBUG) 14 logging.basicConfig(level=logging.DEBUG)
21 LOG = logging.getLogger(__name__) 15 LOG = logging.getLogger(__name__)
22 16
29 target_col, 23 target_col,
30 output_dir, 24 output_dir,
31 task_type, 25 task_type,
32 random_seed, 26 random_seed,
33 test_file=None, 27 test_file=None,
34 **kwargs 28 **kwargs):
35 ):
36 self.exp = None # This will be set in the subclass 29 self.exp = None # This will be set in the subclass
37 self.input_file = input_file 30 self.input_file = input_file
38 self.target_col = target_col 31 self.target_col = target_col
39 self.output_dir = output_dir 32 self.output_dir = output_dir
40 self.task_type = task_type 33 self.task_type = task_type
69 62
70 if len(non_numeric_cols) > 0: 63 if len(non_numeric_cols) > 0:
71 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") 64 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
72 65
73 names = self.data.columns.to_list() 66 names = self.data.columns.to_list()
74 target_index = int(self.target_col)-1 67 target_index = int(self.target_col) - 1
75 self.target = names[target_index] 68 self.target = names[target_index]
76 self.features_name = [name 69 self.features_name = [name
77 for i, name in enumerate(names) 70 for i, name in enumerate(names)
78 if i != target_index] 71 if i != target_index]
79 if hasattr(self, 'missing_value_strategy'): 72 if hasattr(self, 'missing_value_strategy'):
95 self.test_file, sep=None, engine='python') 88 self.test_file, sep=None, engine='python')
96 self.test_data = self.test_data[numeric_cols].apply( 89 self.test_data = self.test_data[numeric_cols].apply(
97 pd.to_numeric, errors='coerce') 90 pd.to_numeric, errors='coerce')
98 self.test_data.columns = self.test_data.columns.str.replace( 91 self.test_data.columns = self.test_data.columns.str.replace(
99 '.', '_' 92 '.', '_'
100 ) 93 )
101 94
102 def setup_pycaret(self): 95 def setup_pycaret(self):
103 LOG.info("Initializing PyCaret") 96 LOG.info("Initializing PyCaret")
104 self.setup_params = { 97 self.setup_params = {
105 'target': self.target, 98 'target': self.target,
204 filtered_setup_params = { 197 filtered_setup_params = {
205 k: v 198 k: v
206 for k, v in self.setup_params.items() if k not in excluded_params 199 for k, v in self.setup_params.items() if k not in excluded_params
207 } 200 }
208 setup_params_table = pd.DataFrame( 201 setup_params_table = pd.DataFrame(
209 list(filtered_setup_params.items()), 202 list(filtered_setup_params.items()), columns=['Parameter', 'Value']
210 columns=['Parameter', 'Value']) 203 )
211 204
212 best_model_params = pd.DataFrame( 205 best_model_params = pd.DataFrame(
213 self.best_model.get_params().items(), 206 self.best_model.get_params().items(),
214 columns=['Parameter', 'Value']) 207 columns=['Parameter', 'Value']
208 )
215 best_model_params.to_csv( 209 best_model_params.to_csv(
216 os.path.join(self.output_dir, 'best_model.csv'), 210 os.path.join(self.output_dir, "best_model.csv"), index=False
217 index=False) 211 )
218 self.results.to_csv(os.path.join( 212 self.results.to_csv(
219 self.output_dir, "comparison_results.csv")) 213 os.path.join(self.output_dir, "comparison_results.csv")
220 self.test_result_df.to_csv(os.path.join( 214 )
221 self.output_dir, "test_results.csv")) 215 self.test_result_df.to_csv(
216 os.path.join(self.output_dir, "test_results.csv")
217 )
222 218
223 plots_html = "" 219 plots_html = ""
224 length = len(self.plots) 220 length = len(self.plots)
225 for i, (plot_name, plot_path) in enumerate(self.plots.items()): 221 for i, (plot_name, plot_path) in enumerate(self.plots.items()):
226 encoded_image = self.encode_image_to_base64(plot_path) 222 encoded_image = self.encode_image_to_base64(plot_path)
248 244
249 analyzer = FeatureImportanceAnalyzer( 245 analyzer = FeatureImportanceAnalyzer(
250 data=self.data, 246 data=self.data,
251 target_col=self.target_col, 247 target_col=self.target_col,
252 task_type=self.task_type, 248 task_type=self.task_type,
253 output_dir=self.output_dir) 249 output_dir=self.output_dir,
250 )
254 feature_importance_html = analyzer.run() 251 feature_importance_html = analyzer.run()
255 252
256 html_content = f""" 253 html_content = f"""
257 {get_html_template()} 254 {get_html_template()}
258 <h1>PyCaret Model Training Report</h1> 255 <h1>PyCaret Model Training Report</h1>
261 Setup & Best Model</div> 258 Setup & Best Model</div>
262 <div class="tab" onclick="openTab(event, 'plots')"> 259 <div class="tab" onclick="openTab(event, 'plots')">
263 Best Model Plots</div> 260 Best Model Plots</div>
264 <div class="tab" onclick="openTab(event, 'feature')"> 261 <div class="tab" onclick="openTab(event, 'feature')">
265 Feature Importance</div> 262 Feature Importance</div>
266 """ 263 """
267 if self.plots_explainer_html: 264 if self.plots_explainer_html:
268 html_content += """ 265 html_content += """
269 "<div class="tab" onclick="openTab(event, 'explainer')">" 266 <div class="tab" onclick="openTab(event, 'explainer')">
270 Explainer Plots</div> 267 Explainer Plots</div>
271 """ 268 """
272 html_content += f""" 269 html_content += f"""
273 </div> 270 </div>
274 <div id="summary" class="tab-content"> 271 <div id="summary" class="tab-content">
275 <h2>Setup Parameters</h2> 272 <h2>Setup Parameters</h2>
276 <table> 273 {setup_params_table.to_html(
277 <tr><th>Parameter</th><th>Value</th></tr> 274 index=False,
278 {setup_params_table.to_html( 275 header=True,
279 index=False, header=False, classes='table')} 276 classes='table sortable'
280 </table> 277 )}
281 <h5>If you want to know all the experiment setup parameters, 278 <h5>If you want to know all the experiment setup parameters,
282 please check the PyCaret documentation for 279 please check the PyCaret documentation for
283 the classification/regression <code>exp</code> function.</h5> 280 the classification/regression <code>exp</code> function.</h5>
284 <h2>Best Model: {model_name}</h2> 281 <h2>Best Model: {model_name}</h2>
285 <table> 282 {best_model_params.to_html(
286 <tr><th>Parameter</th><th>Value</th></tr> 283 index=False,
287 {best_model_params.to_html( 284 header=True,
288 index=False, header=False, classes='table')} 285 classes='table sortable'
289 </table> 286 )}
290 <h2>Comparison Results on the Cross-Validation Set</h2> 287 <h2>Comparison Results on the Cross-Validation Set</h2>
291 <table> 288 {self.results.to_html(index=False, classes='table sortable')}
292 {self.results.to_html(index=False, classes='table')}
293 </table>
294 <h2>Results on the Test Set for the best model</h2> 289 <h2>Results on the Test Set for the best model</h2>
295 <table> 290 {self.test_result_df.to_html(
296 {self.test_result_df.to_html(index=False, classes='table')} 291 index=False,
297 </table> 292 classes='table sortable'
293 )}
298 </div> 294 </div>
299 <div id="plots" class="tab-content"> 295 <div id="plots" class="tab-content">
300 <h2>Best Model Plots on the testing set</h2> 296 <h2>Best Model Plots on the testing set</h2>
301 {plots_html} 297 {plots_html}
302 </div> 298 </div>
308 html_content += f""" 304 html_content += f"""
309 <div id="explainer" class="tab-content"> 305 <div id="explainer" class="tab-content">
310 {self.plots_explainer_html} 306 {self.plots_explainer_html}
311 {tree_plots} 307 {tree_plots}
312 </div> 308 </div>
313 {get_html_closing()}
314 """ 309 """
315 else: 310 html_content += """
316 html_content += f""" 311 <script>
317 {get_html_closing()} 312 document.addEventListener("DOMContentLoaded", function() {
318 """ 313 var tables = document.querySelectorAll("table.sortable");
319 with open(os.path.join( 314 tables.forEach(function(table) {
320 self.output_dir, "comparison_result.html"), "w") as file: 315 var headers = table.querySelectorAll("th");
316 headers.forEach(function(header, index) {
317 header.style.cursor = "pointer";
318 // Add initial arrow (up) to indicate sortability
319 header.innerHTML += '<span class="sort-arrow"> ↑</span>';
320 header.addEventListener("click", function() {
321 var direction = this.getAttribute(
322 "data-sort-direction"
323 ) || "asc";
324 // Reset arrows in all headers of this table
325 headers.forEach(function(h) {
326 var arrow = h.querySelector(".sort-arrow");
327 if (arrow) arrow.textContent = " ↑";
328 });
329 // Set arrow for clicked header
330 var arrow = this.querySelector(".sort-arrow");
331 arrow.textContent = direction === "asc" ? " ↓" : " ↑";
332 sortTable(table, index, direction);
333 this.setAttribute("data-sort-direction",
334 direction === "asc" ? "desc" : "asc");
335 });
336 });
337 });
338 });
339
340 function sortTable(table, colNum, direction) {
341 var tb = table.tBodies[0];
342 var tr = Array.prototype.slice.call(tb.rows, 0);
343 var multiplier = direction === "asc" ? 1 : -1;
344 tr = tr.sort(function(a, b) {
345 var aText = a.cells[colNum].textContent.trim();
346 var bText = b.cells[colNum].textContent.trim();
347 // Remove arrow from text comparison
348 aText = aText.replace(/[↑↓]/g, '').trim();
349 bText = bText.replace(/[↑↓]/g, '').trim();
350 if (!isNaN(aText) && !isNaN(bText)) {
351 return multiplier * (
352 parseFloat(aText) - parseFloat(bText)
353 );
354 } else {
355 return multiplier * aText.localeCompare(bText);
356 }
357 });
358 for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);
359 }
360 </script>
361 """
362 html_content += f"""
363 {get_html_closing()}
364 """
365 with open(
366 os.path.join(self.output_dir, "comparison_result.html"),
367 "w"
368 ) as file:
321 file.write(html_content) 369 file.write(html_content)
322 370
323 def save_dashboard(self): 371 def save_dashboard(self):
324 raise NotImplementedError("Subclasses should implement this method") 372 raise NotImplementedError("Subclasses should implement this method")
325 373