comparison base_model_trainer.py @ 10:e2a6fed32d54 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 47a5977e074223e92e216efa42969a4056516707
author goeckslab
date Fri, 01 Aug 2025 14:02:26 +0000
parents c6c1f8777aae
children
comparison
equal deleted inserted replaced
9:c6c1f8777aae 10:e2a6fed32d54
42 self.random_seed = random_seed 42 self.random_seed = random_seed
43 self.data = None 43 self.data = None
44 self.target = None 44 self.target = None
45 self.best_model = None 45 self.best_model = None
46 self.results = None 46 self.results = None
47 self.tuning_results = None
47 self.features_name = None 48 self.features_name = None
48 self.plots = {} 49 self.plots = {}
49 self.explainer_plots = {} 50 self.explainer_plots = {}
50 self.plots_explainer_html = None 51 self.plots_explainer_html = None
51 self.trees = [] 52 self.trees = []
55 self.setup_params = {} 56 self.setup_params = {}
56 self.test_file = test_file 57 self.test_file = test_file
57 self.test_data = None 58 self.test_data = None
58 59
59 if not self.output_dir: 60 if not self.output_dir:
60 raise ValueError("output_dir must be specified and not None") 61 raise ValueError(
62 "output_dir must be specified and not None"
63 )
64
65 # Warn about irrelevant kwargs for the task type
66 if self.task_type == "regression" and (
67 "probability_threshold" in self.user_kwargs
68 ):
69 LOG.warning(
70 "probability_threshold is ignored for regression tasks."
71 )
61 72
62 LOG.info(f"Model kwargs: {self.__dict__}") 73 LOG.info(f"Model kwargs: {self.__dict__}")
63 74
64 def load_data(self): 75 def load_data(self):
65 LOG.info(f"Loading data from {self.input_file}") 76 LOG.info(f"Loading data from {self.input_file}")
66 self.data = pd.read_csv(self.input_file, sep=None, engine="python") 77 self.data = pd.read_csv(
78 self.input_file, sep=None, engine="python"
79 )
67 self.data.columns = self.data.columns.str.replace(".", "_") 80 self.data.columns = self.data.columns.str.replace(".", "_")
68 if "prediction_label" in self.data.columns: 81
82 names = self.data.columns.to_list()
83 LOG.info(f"Original dataset columns: {names}")
84
85 target_index = int(self.target_col) - 1
86 num_cols = len(names)
87 if target_index < 0 or target_index >= num_cols:
88 raise ValueError(
89 f"Target column number {self.target_col} is invalid. "
90 f"Please select a number between 1 and {num_cols}."
91 )
92
93 self.target = names[target_index]
94
95 # Conditional drop: only if 'prediction_label' exists and is not
96 # the target
97 if "prediction_label" in self.data.columns and (
98 self.data.columns[target_index] != "prediction_label"
99 ):
100 LOG.info(
101 "Dropping 'prediction_label' column as it's not the target."
102 )
69 self.data = self.data.drop(columns=["prediction_label"]) 103 self.data = self.data.drop(columns=["prediction_label"])
70 104 else:
71 numeric_cols = self.data.select_dtypes(include=["number"]).columns 105 if self.target == "prediction_label":
72 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns 106 LOG.warning(
107 "Using 'prediction_label' as target column. "
108 "This may not be intended if it's a previous prediction."
109 )
110
111 numeric_cols = self.data.select_dtypes(
112 include=["number"]
113 ).columns
114 non_numeric_cols = self.data.select_dtypes(
115 exclude=["number"]
116 ).columns
73 self.data[numeric_cols] = self.data[numeric_cols].apply( 117 self.data[numeric_cols] = self.data[numeric_cols].apply(
74 pd.to_numeric, errors="coerce" 118 pd.to_numeric, errors="coerce"
75 ) 119 )
76 if len(non_numeric_cols) > 0: 120 if len(non_numeric_cols) > 0:
77 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") 121 LOG.info(
78 122 f"Non-numeric columns found: {non_numeric_cols.tolist()}"
123 )
124
125 # Update names after possible drop
79 names = self.data.columns.to_list() 126 names = self.data.columns.to_list()
80 target_index = int(self.target_col) - 1 127 LOG.info(f"Dataset columns after processing: {names}")
81 self.target = names[target_index] 128
82 self.features_name = [n for i, n in enumerate(names) if i != target_index] 129 self.features_name = [n for n in names if n != self.target]
83 130
84 if getattr(self, "missing_value_strategy", None): 131 if getattr(self, "missing_value_strategy", None):
85 strat = self.missing_value_strategy 132 strat = self.missing_value_strategy
86 if strat == "mean": 133 if strat == "mean":
87 self.data = self.data.fillna(self.data.mean(numeric_only=True)) 134 self.data = self.data.fillna(
135 self.data.mean(numeric_only=True)
136 )
88 elif strat == "median": 137 elif strat == "median":
89 self.data = self.data.fillna(self.data.median(numeric_only=True)) 138 self.data = self.data.fillna(
139 self.data.median(numeric_only=True)
140 )
90 elif strat == "drop": 141 elif strat == "drop":
91 self.data = self.data.dropna() 142 self.data = self.data.dropna()
92 else: 143 else:
93 self.data = self.data.fillna(self.data.median(numeric_only=True)) 144 self.data = self.data.fillna(
145 self.data.median(numeric_only=True)
146 )
94 147
95 if self.test_file: 148 if self.test_file:
96 LOG.info(f"Loading test data from {self.test_file}") 149 LOG.info(f"Loading test data from {self.test_file}")
97 df_test = pd.read_csv(self.test_file, sep=None, engine="python") 150 df_test = pd.read_csv(
151 self.test_file, sep=None, engine="python"
152 )
98 df_test.columns = df_test.columns.str.replace(".", "_") 153 df_test.columns = df_test.columns.str.replace(".", "_")
99 self.test_data = df_test 154 self.test_data = df_test
100 155
101 def setup_pycaret(self): 156 def setup_pycaret(self):
102 LOG.info("Initializing PyCaret") 157 LOG.info("Initializing PyCaret")
135 elif self.task_type == "regression": 190 elif self.task_type == "regression":
136 from pycaret.regression import RegressionExperiment 191 from pycaret.regression import RegressionExperiment
137 192
138 self.exp = RegressionExperiment() 193 self.exp = RegressionExperiment()
139 else: 194 else:
140 raise ValueError("task_type must be 'classification' or 'regression'") 195 raise ValueError(
196 "task_type must be 'classification' or 'regression'"
197 )
141 198
142 self.exp.setup(self.data, **self.setup_params) 199 self.exp.setup(self.data, **self.setup_params)
143 self.setup_params.update(self.user_kwargs) 200 self.setup_params.update(self.user_kwargs)
144 201
145 def train_model(self): 202 def train_model(self):
169 self.best_model = self.exp.compare_models(**compare_kwargs) 226 self.best_model = self.exp.compare_models(**compare_kwargs)
170 self.results = self.exp.pull() 227 self.results = self.exp.pull()
171 if getattr(self, "tune_model", False): 228 if getattr(self, "tune_model", False):
172 LOG.info("Tuning hyperparameters of the best model") 229 LOG.info("Tuning hyperparameters of the best model")
173 self.best_model = self.exp.tune_model(self.best_model) 230 self.best_model = self.exp.tune_model(self.best_model)
174 self.results = self.exp.pull() 231 self.tuning_results = self.exp.pull()
175 232
176 if self.task_type == "classification": 233 if self.task_type == "classification":
177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) 234 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
178 235
179 prob_thresh = getattr(self, "probability_threshold", None) 236 prob_thresh = getattr(self, "probability_threshold", None)
180 if self.task_type == "classification" and prob_thresh is not None: 237 if self.task_type == "classification" and (
181 _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh) 238 prob_thresh is not None
239 ):
240 _ = self.exp.predict_model(
241 self.best_model, probability_threshold=prob_thresh
242 )
182 else: 243 else:
183 _ = self.exp.predict_model(self.best_model) 244 _ = self.exp.predict_model(self.best_model)
184 245
185 self.test_result_df = self.exp.pull() 246 self.test_result_df = self.exp.pull()
186 if self.task_type == "classification": 247 if self.task_type == "classification":
187 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) 248 self.test_result_df.rename(
249 columns={"AUC": "ROC-AUC"}, inplace=True
250 )
188 251
189 def save_model(self): 252 def save_model(self):
190 hdf5_path = Path(self.output_dir) / "pycaret_model.h5" 253 hdf5_path = Path(self.output_dir) / "pycaret_model.h5"
191 with h5py.File(hdf5_path, "w") as f: 254 with h5py.File(hdf5_path, "w") as f:
192 with tempfile.NamedTemporaryFile(delete=False) as tmp: 255 with tempfile.NamedTemporaryFile(delete=False) as tmp:
196 f.create_dataset("model", data=np.void(model_bytes)) 259 f.create_dataset("model", data=np.void(model_bytes))
197 260
198 def generate_plots(self): 261 def generate_plots(self):
199 LOG.info("Generating PyCaret diagnostic pltos") 262 LOG.info("Generating PyCaret diagnostic pltos")
200 263
201 # choose the right plots based on task 264 # choose the right plots based on task type
202 if self.task_type == "classification": 265 if self.task_type == "classification":
203 plot_names = [ 266 plot_names = [
204 "learning", 267 "learning",
205 "vc", 268 "vc",
206 "calibration", 269 "calibration",
212 "class_report", 275 "class_report",
213 "pr_auc", 276 "pr_auc",
214 "roc_auc", 277 "roc_auc",
215 ] 278 ]
216 else: 279 else:
217 plot_names = ["residuals", "vc", "parameter", "error", "learning"] 280 plot_names = ["residuals", "vc", "parameter", "error",
281 "learning"]
218 for name in plot_names: 282 for name in plot_names:
219 try: 283 try:
220 ax = self.exp.plot_model(self.best_model, plot=name, save=False) 284 ax = self.exp.plot_model(
285 self.best_model, plot=name, save=False
286 )
221 out_path = Path(self.output_dir) / f"plot_{name}.png" 287 out_path = Path(self.output_dir) / f"plot_{name}.png"
222 fig = ax.get_figure() 288 fig = ax.get_figure()
223 fig.savefig(out_path, bbox_inches="tight") 289 fig.savefig(out_path, bbox_inches="tight")
224 self.plots[name] = str(out_path) 290 self.plots[name] = str(out_path)
225 except Exception as e: 291 except Exception as e:
237 best_model_name = str(self.results.iloc[0]["Model"]) 303 best_model_name = str(self.results.iloc[0]["Model"])
238 except Exception: 304 except Exception:
239 best_model_name = type(self.best_model).__name__ 305 best_model_name = type(self.best_model).__name__
240 LOG.info(f"Best model determined as: {best_model_name}") 306 LOG.info(f"Best model determined as: {best_model_name}")
241 307
242 # 2) Compute training sample count 308 # 2) Compute training sample count
243 try: 309 try:
244 n_train = self.exp.X_train.shape[0] 310 n_train = self.exp.X_train.shape[0]
245 except Exception: 311 except Exception:
246 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] 312 n_train = getattr(
313 self.exp, "X_train_transformed", pd.DataFrame()
314 ).shape[0]
247 total_rows = self.data.shape[0] 315 total_rows = self.data.shape[0]
248 316
249 # 3) Build setup parameters table 317 # 3) Build setup parameters table
250 all_params = self.setup_params.copy() 318 all_params = self.setup_params.copy()
251 if self.task_type == "classification" and hasattr(self, "probability_threshold"): 319 if self.task_type == "classification" and (
252 all_params["probability_threshold"] = self.probability_threshold 320 hasattr(self, "probability_threshold")
253 321 ):
322 all_params["probability_threshold"] = (
323 self.probability_threshold
324 )
254 display_keys = [ 325 display_keys = [
255 "Target", 326 "Target",
256 "Session ID", 327 "Session ID",
257 "Train Size", 328 "Train Size",
258 "Normalize", 329 "Normalize",
288 }: 359 }:
289 dv = bool(v) 360 dv = bool(v)
290 elif key == "Cross Validation Folds": 361 elif key == "Cross Validation Folds":
291 dv = v if v is not None else "None" 362 dv = v if v is not None else "None"
292 elif key == "Models": 363 elif key == "Models":
293 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" 364 dv = ", ".join(map(str, v)) if isinstance(
365 v, (list, tuple)
366 ) else "None"
294 elif key == "Probability Threshold": 367 elif key == "Probability Threshold":
295 dv = v if v is not None else "None" 368 dv = f"{v:.2f}" if v is not None else "0.5"
296 else: 369 else:
297 dv = v if v is not None else "None" 370 dv = v if v is not None else "None"
298 setup_rows.append([key, dv]) 371 setup_rows.append([key, dv])
299 if hasattr(self.exp, "_fold_metric"): 372 if hasattr(self.exp, "_fold_metric"):
300 setup_rows.append(["best_model_metric", self.exp._fold_metric]) 373 setup_rows.append(["best_model_metric", self.exp._fold_metric])
301 374
302 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) 375 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
303 df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False) 376 df_setup.to_csv(
377 Path(self.output_dir) / "setup_params.csv", index=False
378 )
304 379
305 # 4) Persist CSVs 380 # 4) Persist CSVs
306 self.results.to_csv( 381 self.results.to_csv(
307 Path(self.output_dir) / "comparison_results.csv", index=False 382 Path(self.output_dir) / "comparison_results.csv",
383 index=False
308 ) 384 )
309 self.test_result_df.to_csv( 385 self.test_result_df.to_csv(
310 Path(self.output_dir) / "test_results.csv", index=False 386 Path(self.output_dir) / "test_results.csv", index=False
311 ) 387 )
312 pd.DataFrame( 388 pd.DataFrame(
313 self.best_model.get_params().items(), columns=["Parameter", "Value"] 389 self.best_model.get_params().items(),
390 columns=["Parameter", "Value"]
314 ).to_csv(Path(self.output_dir) / "best_model.csv", index=False) 391 ).to_csv(Path(self.output_dir) / "best_model.csv", index=False)
392
393 if self.tuning_results is not None:
394 self.tuning_results.to_csv(
395 Path(self.output_dir) / "tuning_results.csv",
396 index=False
397 )
315 398
316 # 5) Header 399 # 5) Header
317 header = f"<h2>Best Model: {best_model_name}</h2>" 400 header = f"<h2>Best Model: {best_model_name}</h2>"
318 401
319 # — Validation Summary & Configuration — 402 # — Validation Summary & Configuration —
332 "pr_auc": "Precision-Recall AUC", 415 "pr_auc": "Precision-Recall AUC",
333 "roc_auc": "Receiver Operating Characteristic AUC", 416 "roc_auc": "Receiver Operating Characteristic AUC",
334 "residuals": "Residuals Distribution", 417 "residuals": "Residuals Distribution",
335 "error": "Prediction Error Distribution", 418 "error": "Prediction Error Distribution",
336 } 419 }
337 val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True) 420 val_df.drop(
421 columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True
422 )
338 summary_html = ( 423 summary_html = (
339 header 424 header
340 + "<h2>Train & Validation Summary</h2>" 425 + "<h2>Train & Validation Summary</h2>"
341 + '<div class="table-wrapper">' 426 + '<div class="table-wrapper">'
342 + val_df.to_html(index=False, classes="table sortable") 427 + val_df.to_html(index=False, classes="table sortable")
343 + "</div>" 428 + "</div>"
344 + "<h2>Setup Parameters</h2>" 429 )
430
431 if self.tuning_results is not None:
432 tuning_df = self.tuning_results.copy()
433 tuning_df.drop(
434 columns=["TT (Sec)"], errors="ignore", inplace=True
435 )
436 summary_html += (
437 f"<h2>{best_model_name}: Tuning Summary</h2>"
438 + '<div class="table-wrapper">'
439 + tuning_df.to_html(index=False, classes="table sortable")
440 + "</div>"
441 )
442
443 summary_html += (
444 "<h2>Setup Parameters</h2>"
345 + '<div class="table-wrapper">' 445 + '<div class="table-wrapper">'
346 + df_setup.to_html(index=False, classes="table sortable") 446 + df_setup.to_html(index=False, classes="table sortable")
347 + "</div>" 447 + "</div>"
348 # — Hyperparameters 448 # — Hyperparameters
349 + "<h2>Best Model Hyperparameters</h2>" 449 + "<h2>Best Model Hyperparameters</h2>"
350 + '<div class="table-wrapper">' 450 + '<div class="table-wrapper">'
351 + pd.DataFrame( 451 + pd.DataFrame(
352 self.best_model.get_params().items(), columns=["Parameter", "Value"] 452 self.best_model.get_params().items(),
453 columns=["Parameter", "Value"]
353 ).to_html(index=False, classes="table sortable") 454 ).to_html(index=False, classes="table sortable")
354 + "</div>" 455 + "</div>"
355 ) 456 )
356 457
357 # choose summary plots based on task type 458 # choose summary plots based on task type
371 472
372 for name in summary_plots: 473 for name in summary_plots:
373 if name in self.plots: 474 if name in self.plots:
374 summary_html += "<hr>" 475 summary_html += "<hr>"
375 b64 = encode_image_to_base64(self.plots[name]) 476 b64 = encode_image_to_base64(self.plots[name])
376 title = plot_title_map.get(name, name.replace("_", " ").title()) 477 title = plot_title_map.get(
478 name, name.replace("_", " ").title()
479 )
377 summary_html += ( 480 summary_html += (
378 '<div class="plot">' 481 '<div class="plot">'
379 f"<h2>{title}</h2>" 482 f"<h2>{title}</h2>"
380 f'<img src="data:image/png;base64,{b64}" ' 483 f'<img src="data:image/png;base64,{b64}" '
381 'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>' 484 'style="max-width:90%;max-height:600px;'
485 'border:1px solid #ddd;"/>'
382 "</div>" 486 "</div>"
383 ) 487 )
384 488
385 # — Test Summary — 489 # — Test Summary —
386 test_html = ( 490 test_html = (
387 header 491 header
388 + '<div class="table-wrapper">' 492 + '<div class="table-wrapper">'
389 + self.test_result_df.to_html(index=False, classes="table sortable") 493 + self.test_result_df.to_html(
494 index=False, classes="table sortable"
495 )
390 + "</div>" 496 + "</div>"
391 ) 497 )
392 if self.task_type == "regression": 498 if self.task_type == "regression":
393 try: 499 try:
394 y_true = ( 500 y_true = (
395 pd.Series(self.exp.y_test_transformed) 501 pd.Series(self.exp.y_test_transformed)
396 .reset_index(drop=True) 502 .reset_index(drop=True)
397 .rename("True") 503 .rename("True")
398 ) 504 )
399 y_pred = pd.Series( 505 y_pred = pd.Series(
400 self.best_model.predict(self.exp.X_test_transformed) 506 self.best_model.predict(
507 self.exp.X_test_transformed
508 )
401 ).rename("Predicted") 509 ).rename("Predicted")
402 df_tp = pd.concat([y_true, y_pred], axis=1) 510 df_tp = pd.concat([y_true, y_pred], axis=1)
403 test_html += "<h2>True vs Predicted Values</h2>" 511 test_html += "<h2>True vs Predicted Values</h2>"
404 test_html += ( 512 test_html += (
405 '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">' 513 '<div class="table-wrapper" '
406 + df_tp.head(50).to_html(index=False, classes="table sortable") 514 'style="max-height:400px; overflow-y:auto;">'
515 + df_tp.head(50).to_html(
516 index=False, classes="table sortable"
517 )
407 + "</div>" 518 + "</div>"
408 + add_hr_to_html() 519 + add_hr_to_html()
409 ) 520 )
410 except Exception as e: 521 except Exception as e:
411 LOG.warning(f"Could not generate True vs Predicted table: {e}") 522 LOG.warning(
523 f"Could not generate True vs Predicted table: {e}"
524 )
412 525
413 # 5a) Explainer-substituted plots in order 526 # 5a) Explainer-substituted plots in order
414 if self.task_type == "regression": 527 if self.task_type == "regression":
415 test_order = ["residuals"] 528 test_order = ["residuals"]
416 else: 529 else:
424 ] 537 ]
425 for key in test_order: 538 for key in test_order:
426 fig_or_fn = self.explainer_plots.pop(key, None) 539 fig_or_fn = self.explainer_plots.pop(key, None)
427 if fig_or_fn is not None: 540 if fig_or_fn is not None:
428 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn 541 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
429 title = plot_title_map.get(key, key.replace("_", " ").title()) 542 title = plot_title_map.get(
543 key, key.replace("_", " ").title()
544 )
430 test_html += ( 545 test_html += (
431 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() 546 f"<h2>{title}</h2>" + add_plot_to_html(fig)
547 + add_hr_to_html()
432 ) 548 )
433 # 5b) Remaining PyCaret test plots 549 # 5b) Remaining PyCaret test plots
434 for name, path in self.plots.items(): 550 for name, path in self.plots.items():
435 # classification: include only the small extras, before skipping anything 551 # classification: include only the small extras, before
436 if self.task_type == "classification" and name in { 552 # skipping anything
437 "threshold", 553 if self.task_type == "classification" and (
438 "pr_auc", 554 name in {
439 "class_report", 555 "threshold",
440 }: 556 "pr_auc",
441 title = plot_title_map.get(name, name.replace("_", " ").title()) 557 "class_report",
558 }
559 ):
560 title = plot_title_map.get(
561 name, name.replace("_", " ").title()
562 )
442 b64 = encode_image_to_base64(path) 563 b64 = encode_image_to_base64(path)
443 test_html += ( 564 test_html += (
444 f"<h2>{title}</h2>" 565 f"<h2>{title}</h2>"
445 "<div class='plot'>" 566 "<div class='plot'>"
446 f"<img src='data:image/png;base64,{b64}' " 567 f"<img src='data:image/png;base64,{b64}' "
447 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" 568 "style='max-width:90%;max-height:600px;"
569 "border:1px solid #ddd;'/>"
448 "</div>" + add_hr_to_html() 570 "</div>" + add_hr_to_html()
449 ) 571 )
450 continue 572 continue
451 573
452 # regression: explicitly include the 'error' plot, before skipping 574 # regression: explicitly include the 'error' plot,
453 if self.task_type == "regression" and name == "error": 575 # before skipping
454 title = plot_title_map.get("error", "Prediction Error Distribution") 576 if self.task_type == "regression" and (
577 name == "error"
578 ):
579 title = plot_title_map.get(
580 "error", "Prediction Error Distribution"
581 )
455 b64 = encode_image_to_base64(path) 582 b64 = encode_image_to_base64(path)
456 test_html += ( 583 test_html += (
457 f"<h2>{title}</h2>" 584 f"<h2>{title}</h2>"
458 "<div class='plot'>" 585 "<div class='plot'>"
459 f"<img src='data:image/png;base64,{b64}' " 586 f"<img src='data:image/png;base64,{b64}' "
460 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" 587 "style='max-width:90%;max-height:600px;"
588 "border:1px solid #ddd;'/>"
461 "</div>" + add_hr_to_html() 589 "</div>" + add_hr_to_html()
462 ) 590 )
463 continue 591 continue
464 592
465 # now skip any plots already rendered via test_order 593 # now skip any plots already rendered via test_order
489 "Mean Absolute SHAP Value Impact" 617 "Mean Absolute SHAP Value Impact"
490 if key == "shap_mean" 618 if key == "shap_mean"
491 else "Permutation Feature Importance" 619 else "Permutation Feature Importance"
492 ) 620 )
493 feature_html += ( 621 feature_html += (
494 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() 622 f"<h2>{title}</h2>" + add_plot_to_html(fig)
623 + add_hr_to_html()
495 ) 624 )
496 625
497 # 6c) PDPs last 626 # 6c) PDPs last
498 pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__")) 627 pdp_keys = sorted(
628 k for k in self.explainer_plots if k.startswith("pdp__")
629 )
499 for k in pdp_keys: 630 for k in pdp_keys:
500 fig_or_fn = self.explainer_plots[k] 631 fig_or_fn = self.explainer_plots[k]
501 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn 632 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
502 # extract feature name 633 # extract feature name
503 feature = k.split("__", 1)[1] 634 feature = k.split("__", 1)[1]
504 title = f"Partial Dependence for {feature}" 635 title = f"Partial Dependence for {feature}"
505 feature_html += ( 636 feature_html += (
506 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() 637 f"<h2>{title}</h2>" + add_plot_to_html(fig)
638 + add_hr_to_html()
507 ) 639 )
508 # 7) Assemble final HTML (three tabs) 640 # 7) Assemble final HTML (three tabs)
509 html = get_html_template() 641 html = get_html_template()
510 html += "<h1>Tabular Learner Model Report</h1>" 642 html += "<h1>Tabular Learner Model Report</h1>"
511 html += build_tabbed_html(summary_html, test_html, feature_html) 643 html += build_tabbed_html(summary_html, test_html, feature_html)
514 646
515 # 8) Write out 647 # 8) Write out
516 (Path(self.output_dir) / "comparison_result.html").write_text( 648 (Path(self.output_dir) / "comparison_result.html").write_text(
517 html, encoding="utf-8" 649 html, encoding="utf-8"
518 ) 650 )
519 LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html") 651 LOG.info(
652 f"HTML report generated at: "
653 f"{self.output_dir}/comparison_result.html"
654 )
520 655
521 def save_dashboard(self): 656 def save_dashboard(self):
522 raise NotImplementedError("Subclasses should implement this method") 657 raise NotImplementedError("Subclasses should implement this method")
523 658
524 def generate_plots_explainer(self): 659 def generate_plots_explainer(self):
525 raise NotImplementedError("Subclasses should implement this method") 660 raise NotImplementedError("Subclasses should implement this method")
526 661
527 def generate_tree_plots(self): 662 def generate_tree_plots(self):
528 from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor 663 from sklearn.ensemble import (
664 RandomForestClassifier, RandomForestRegressor
665 )
529 from xgboost import XGBClassifier, XGBRegressor 666 from xgboost import XGBClassifier, XGBRegressor
530 from explainerdashboard.explainers import RandomForestExplainer 667 from explainerdashboard.explainers import RandomForestExplainer
531 668
532 LOG.info("Generating tree plots") 669 LOG.info("Generating tree plots")
533 X_test = self.exp.X_test_transformed.copy() 670 X_test = self.exp.X_test_transformed.copy()
534 y_test = self.exp.y_test_transformed 671 y_test = self.exp.y_test_transformed
535 672
536 if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)): 673 if isinstance(
674 self.best_model, (RandomForestClassifier, RandomForestRegressor)
675 ):
537 n_trees = self.best_model.n_estimators 676 n_trees = self.best_model.n_estimators
538 elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)): 677 elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)):
539 n_trees = len(self.best_model.get_booster().get_dump()) 678 n_trees = len(self.best_model.get_booster().get_dump())
540 else: 679 else:
541 LOG.warning("Tree plots not supported for this model type.") 680 LOG.warning("Tree plots not supported for this model type.")