Mercurial > repos > goeckslab > pycaret_predict
comparison base_model_trainer.py @ 10:e2a6fed32d54 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 47a5977e074223e92e216efa42969a4056516707
author | goeckslab |
---|---|
date | Fri, 01 Aug 2025 14:02:26 +0000 |
parents | c6c1f8777aae |
children |
comparison
equal
deleted
inserted
replaced
9:c6c1f8777aae | 10:e2a6fed32d54 |
---|---|
42 self.random_seed = random_seed | 42 self.random_seed = random_seed |
43 self.data = None | 43 self.data = None |
44 self.target = None | 44 self.target = None |
45 self.best_model = None | 45 self.best_model = None |
46 self.results = None | 46 self.results = None |
47 self.tuning_results = None | |
47 self.features_name = None | 48 self.features_name = None |
48 self.plots = {} | 49 self.plots = {} |
49 self.explainer_plots = {} | 50 self.explainer_plots = {} |
50 self.plots_explainer_html = None | 51 self.plots_explainer_html = None |
51 self.trees = [] | 52 self.trees = [] |
55 self.setup_params = {} | 56 self.setup_params = {} |
56 self.test_file = test_file | 57 self.test_file = test_file |
57 self.test_data = None | 58 self.test_data = None |
58 | 59 |
59 if not self.output_dir: | 60 if not self.output_dir: |
60 raise ValueError("output_dir must be specified and not None") | 61 raise ValueError( |
62 "output_dir must be specified and not None" | |
63 ) | |
64 | |
65 # Warn about irrelevant kwargs for the task type | |
66 if self.task_type == "regression" and ( | |
67 "probability_threshold" in self.user_kwargs | |
68 ): | |
69 LOG.warning( | |
70 "probability_threshold is ignored for regression tasks." | |
71 ) | |
61 | 72 |
62 LOG.info(f"Model kwargs: {self.__dict__}") | 73 LOG.info(f"Model kwargs: {self.__dict__}") |
63 | 74 |
64 def load_data(self): | 75 def load_data(self): |
65 LOG.info(f"Loading data from {self.input_file}") | 76 LOG.info(f"Loading data from {self.input_file}") |
66 self.data = pd.read_csv(self.input_file, sep=None, engine="python") | 77 self.data = pd.read_csv( |
78 self.input_file, sep=None, engine="python" | |
79 ) | |
67 self.data.columns = self.data.columns.str.replace(".", "_") | 80 self.data.columns = self.data.columns.str.replace(".", "_") |
68 if "prediction_label" in self.data.columns: | 81 |
82 names = self.data.columns.to_list() | |
83 LOG.info(f"Original dataset columns: {names}") | |
84 | |
85 target_index = int(self.target_col) - 1 | |
86 num_cols = len(names) | |
87 if target_index < 0 or target_index >= num_cols: | |
88 raise ValueError( | |
89 f"Target column number {self.target_col} is invalid. " | |
90 f"Please select a number between 1 and {num_cols}." | |
91 ) | |
92 | |
93 self.target = names[target_index] | |
94 | |
95 # Conditional drop: only if 'prediction_label' exists and is not | |
96 # the target | |
97 if "prediction_label" in self.data.columns and ( | |
98 self.data.columns[target_index] != "prediction_label" | |
99 ): | |
100 LOG.info( | |
101 "Dropping 'prediction_label' column as it's not the target." | |
102 ) | |
69 self.data = self.data.drop(columns=["prediction_label"]) | 103 self.data = self.data.drop(columns=["prediction_label"]) |
70 | 104 else: |
71 numeric_cols = self.data.select_dtypes(include=["number"]).columns | 105 if self.target == "prediction_label": |
72 non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns | 106 LOG.warning( |
107 "Using 'prediction_label' as target column. " | |
108 "This may not be intended if it's a previous prediction." | |
109 ) | |
110 | |
111 numeric_cols = self.data.select_dtypes( | |
112 include=["number"] | |
113 ).columns | |
114 non_numeric_cols = self.data.select_dtypes( | |
115 exclude=["number"] | |
116 ).columns | |
73 self.data[numeric_cols] = self.data[numeric_cols].apply( | 117 self.data[numeric_cols] = self.data[numeric_cols].apply( |
74 pd.to_numeric, errors="coerce" | 118 pd.to_numeric, errors="coerce" |
75 ) | 119 ) |
76 if len(non_numeric_cols) > 0: | 120 if len(non_numeric_cols) > 0: |
77 LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") | 121 LOG.info( |
78 | 122 f"Non-numeric columns found: {non_numeric_cols.tolist()}" |
123 ) | |
124 | |
125 # Update names after possible drop | |
79 names = self.data.columns.to_list() | 126 names = self.data.columns.to_list() |
80 target_index = int(self.target_col) - 1 | 127 LOG.info(f"Dataset columns after processing: {names}") |
81 self.target = names[target_index] | 128 |
82 self.features_name = [n for i, n in enumerate(names) if i != target_index] | 129 self.features_name = [n for n in names if n != self.target] |
83 | 130 |
84 if getattr(self, "missing_value_strategy", None): | 131 if getattr(self, "missing_value_strategy", None): |
85 strat = self.missing_value_strategy | 132 strat = self.missing_value_strategy |
86 if strat == "mean": | 133 if strat == "mean": |
87 self.data = self.data.fillna(self.data.mean(numeric_only=True)) | 134 self.data = self.data.fillna( |
135 self.data.mean(numeric_only=True) | |
136 ) | |
88 elif strat == "median": | 137 elif strat == "median": |
89 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 138 self.data = self.data.fillna( |
139 self.data.median(numeric_only=True) | |
140 ) | |
90 elif strat == "drop": | 141 elif strat == "drop": |
91 self.data = self.data.dropna() | 142 self.data = self.data.dropna() |
92 else: | 143 else: |
93 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 144 self.data = self.data.fillna( |
145 self.data.median(numeric_only=True) | |
146 ) | |
94 | 147 |
95 if self.test_file: | 148 if self.test_file: |
96 LOG.info(f"Loading test data from {self.test_file}") | 149 LOG.info(f"Loading test data from {self.test_file}") |
97 df_test = pd.read_csv(self.test_file, sep=None, engine="python") | 150 df_test = pd.read_csv( |
151 self.test_file, sep=None, engine="python" | |
152 ) | |
98 df_test.columns = df_test.columns.str.replace(".", "_") | 153 df_test.columns = df_test.columns.str.replace(".", "_") |
99 self.test_data = df_test | 154 self.test_data = df_test |
100 | 155 |
101 def setup_pycaret(self): | 156 def setup_pycaret(self): |
102 LOG.info("Initializing PyCaret") | 157 LOG.info("Initializing PyCaret") |
135 elif self.task_type == "regression": | 190 elif self.task_type == "regression": |
136 from pycaret.regression import RegressionExperiment | 191 from pycaret.regression import RegressionExperiment |
137 | 192 |
138 self.exp = RegressionExperiment() | 193 self.exp = RegressionExperiment() |
139 else: | 194 else: |
140 raise ValueError("task_type must be 'classification' or 'regression'") | 195 raise ValueError( |
196 "task_type must be 'classification' or 'regression'" | |
197 ) | |
141 | 198 |
142 self.exp.setup(self.data, **self.setup_params) | 199 self.exp.setup(self.data, **self.setup_params) |
143 self.setup_params.update(self.user_kwargs) | 200 self.setup_params.update(self.user_kwargs) |
144 | 201 |
145 def train_model(self): | 202 def train_model(self): |
169 self.best_model = self.exp.compare_models(**compare_kwargs) | 226 self.best_model = self.exp.compare_models(**compare_kwargs) |
170 self.results = self.exp.pull() | 227 self.results = self.exp.pull() |
171 if getattr(self, "tune_model", False): | 228 if getattr(self, "tune_model", False): |
172 LOG.info("Tuning hyperparameters of the best model") | 229 LOG.info("Tuning hyperparameters of the best model") |
173 self.best_model = self.exp.tune_model(self.best_model) | 230 self.best_model = self.exp.tune_model(self.best_model) |
174 self.results = self.exp.pull() | 231 self.tuning_results = self.exp.pull() |
175 | 232 |
176 if self.task_type == "classification": | 233 if self.task_type == "classification": |
177 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 234 self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) |
178 | 235 |
179 prob_thresh = getattr(self, "probability_threshold", None) | 236 prob_thresh = getattr(self, "probability_threshold", None) |
180 if self.task_type == "classification" and prob_thresh is not None: | 237 if self.task_type == "classification" and ( |
181 _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh) | 238 prob_thresh is not None |
239 ): | |
240 _ = self.exp.predict_model( | |
241 self.best_model, probability_threshold=prob_thresh | |
242 ) | |
182 else: | 243 else: |
183 _ = self.exp.predict_model(self.best_model) | 244 _ = self.exp.predict_model(self.best_model) |
184 | 245 |
185 self.test_result_df = self.exp.pull() | 246 self.test_result_df = self.exp.pull() |
186 if self.task_type == "classification": | 247 if self.task_type == "classification": |
187 self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) | 248 self.test_result_df.rename( |
249 columns={"AUC": "ROC-AUC"}, inplace=True | |
250 ) | |
188 | 251 |
189 def save_model(self): | 252 def save_model(self): |
190 hdf5_path = Path(self.output_dir) / "pycaret_model.h5" | 253 hdf5_path = Path(self.output_dir) / "pycaret_model.h5" |
191 with h5py.File(hdf5_path, "w") as f: | 254 with h5py.File(hdf5_path, "w") as f: |
192 with tempfile.NamedTemporaryFile(delete=False) as tmp: | 255 with tempfile.NamedTemporaryFile(delete=False) as tmp: |
196 f.create_dataset("model", data=np.void(model_bytes)) | 259 f.create_dataset("model", data=np.void(model_bytes)) |
197 | 260 |
198 def generate_plots(self): | 261 def generate_plots(self): |
199 LOG.info("Generating PyCaret diagnostic pltos") | 262 LOG.info("Generating PyCaret diagnostic pltos") |
200 | 263 |
201 # choose the right plots based on task | 264 # choose the right plots based on task type |
202 if self.task_type == "classification": | 265 if self.task_type == "classification": |
203 plot_names = [ | 266 plot_names = [ |
204 "learning", | 267 "learning", |
205 "vc", | 268 "vc", |
206 "calibration", | 269 "calibration", |
212 "class_report", | 275 "class_report", |
213 "pr_auc", | 276 "pr_auc", |
214 "roc_auc", | 277 "roc_auc", |
215 ] | 278 ] |
216 else: | 279 else: |
217 plot_names = ["residuals", "vc", "parameter", "error", "learning"] | 280 plot_names = ["residuals", "vc", "parameter", "error", |
281 "learning"] | |
218 for name in plot_names: | 282 for name in plot_names: |
219 try: | 283 try: |
220 ax = self.exp.plot_model(self.best_model, plot=name, save=False) | 284 ax = self.exp.plot_model( |
285 self.best_model, plot=name, save=False | |
286 ) | |
221 out_path = Path(self.output_dir) / f"plot_{name}.png" | 287 out_path = Path(self.output_dir) / f"plot_{name}.png" |
222 fig = ax.get_figure() | 288 fig = ax.get_figure() |
223 fig.savefig(out_path, bbox_inches="tight") | 289 fig.savefig(out_path, bbox_inches="tight") |
224 self.plots[name] = str(out_path) | 290 self.plots[name] = str(out_path) |
225 except Exception as e: | 291 except Exception as e: |
237 best_model_name = str(self.results.iloc[0]["Model"]) | 303 best_model_name = str(self.results.iloc[0]["Model"]) |
238 except Exception: | 304 except Exception: |
239 best_model_name = type(self.best_model).__name__ | 305 best_model_name = type(self.best_model).__name__ |
240 LOG.info(f"Best model determined as: {best_model_name}") | 306 LOG.info(f"Best model determined as: {best_model_name}") |
241 | 307 |
242 # 2) Compute training sample count | 308 # 2) Compute training sample count |
243 try: | 309 try: |
244 n_train = self.exp.X_train.shape[0] | 310 n_train = self.exp.X_train.shape[0] |
245 except Exception: | 311 except Exception: |
246 n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] | 312 n_train = getattr( |
313 self.exp, "X_train_transformed", pd.DataFrame() | |
314 ).shape[0] | |
247 total_rows = self.data.shape[0] | 315 total_rows = self.data.shape[0] |
248 | 316 |
249 # 3) Build setup parameters table | 317 # 3) Build setup parameters table |
250 all_params = self.setup_params.copy() | 318 all_params = self.setup_params.copy() |
251 if self.task_type == "classification" and hasattr(self, "probability_threshold"): | 319 if self.task_type == "classification" and ( |
252 all_params["probability_threshold"] = self.probability_threshold | 320 hasattr(self, "probability_threshold") |
253 | 321 ): |
322 all_params["probability_threshold"] = ( | |
323 self.probability_threshold | |
324 ) | |
254 display_keys = [ | 325 display_keys = [ |
255 "Target", | 326 "Target", |
256 "Session ID", | 327 "Session ID", |
257 "Train Size", | 328 "Train Size", |
258 "Normalize", | 329 "Normalize", |
288 }: | 359 }: |
289 dv = bool(v) | 360 dv = bool(v) |
290 elif key == "Cross Validation Folds": | 361 elif key == "Cross Validation Folds": |
291 dv = v if v is not None else "None" | 362 dv = v if v is not None else "None" |
292 elif key == "Models": | 363 elif key == "Models": |
293 dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" | 364 dv = ", ".join(map(str, v)) if isinstance( |
365 v, (list, tuple) | |
366 ) else "None" | |
294 elif key == "Probability Threshold": | 367 elif key == "Probability Threshold": |
295 dv = v if v is not None else "None" | 368 dv = f"{v:.2f}" if v is not None else "0.5" |
296 else: | 369 else: |
297 dv = v if v is not None else "None" | 370 dv = v if v is not None else "None" |
298 setup_rows.append([key, dv]) | 371 setup_rows.append([key, dv]) |
299 if hasattr(self.exp, "_fold_metric"): | 372 if hasattr(self.exp, "_fold_metric"): |
300 setup_rows.append(["best_model_metric", self.exp._fold_metric]) | 373 setup_rows.append(["best_model_metric", self.exp._fold_metric]) |
301 | 374 |
302 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) | 375 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) |
303 df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False) | 376 df_setup.to_csv( |
377 Path(self.output_dir) / "setup_params.csv", index=False | |
378 ) | |
304 | 379 |
305 # 4) Persist CSVs | 380 # 4) Persist CSVs |
306 self.results.to_csv( | 381 self.results.to_csv( |
307 Path(self.output_dir) / "comparison_results.csv", index=False | 382 Path(self.output_dir) / "comparison_results.csv", |
383 index=False | |
308 ) | 384 ) |
309 self.test_result_df.to_csv( | 385 self.test_result_df.to_csv( |
310 Path(self.output_dir) / "test_results.csv", index=False | 386 Path(self.output_dir) / "test_results.csv", index=False |
311 ) | 387 ) |
312 pd.DataFrame( | 388 pd.DataFrame( |
313 self.best_model.get_params().items(), columns=["Parameter", "Value"] | 389 self.best_model.get_params().items(), |
390 columns=["Parameter", "Value"] | |
314 ).to_csv(Path(self.output_dir) / "best_model.csv", index=False) | 391 ).to_csv(Path(self.output_dir) / "best_model.csv", index=False) |
392 | |
393 if self.tuning_results is not None: | |
394 self.tuning_results.to_csv( | |
395 Path(self.output_dir) / "tuning_results.csv", | |
396 index=False | |
397 ) | |
315 | 398 |
316 # 5) Header | 399 # 5) Header |
317 header = f"<h2>Best Model: {best_model_name}</h2>" | 400 header = f"<h2>Best Model: {best_model_name}</h2>" |
318 | 401 |
319 # — Validation Summary & Configuration — | 402 # — Validation Summary & Configuration — |
332 "pr_auc": "Precision-Recall AUC", | 415 "pr_auc": "Precision-Recall AUC", |
333 "roc_auc": "Receiver Operating Characteristic AUC", | 416 "roc_auc": "Receiver Operating Characteristic AUC", |
334 "residuals": "Residuals Distribution", | 417 "residuals": "Residuals Distribution", |
335 "error": "Prediction Error Distribution", | 418 "error": "Prediction Error Distribution", |
336 } | 419 } |
337 val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True) | 420 val_df.drop( |
421 columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True | |
422 ) | |
338 summary_html = ( | 423 summary_html = ( |
339 header | 424 header |
340 + "<h2>Train & Validation Summary</h2>" | 425 + "<h2>Train & Validation Summary</h2>" |
341 + '<div class="table-wrapper">' | 426 + '<div class="table-wrapper">' |
342 + val_df.to_html(index=False, classes="table sortable") | 427 + val_df.to_html(index=False, classes="table sortable") |
343 + "</div>" | 428 + "</div>" |
344 + "<h2>Setup Parameters</h2>" | 429 ) |
430 | |
431 if self.tuning_results is not None: | |
432 tuning_df = self.tuning_results.copy() | |
433 tuning_df.drop( | |
434 columns=["TT (Sec)"], errors="ignore", inplace=True | |
435 ) | |
436 summary_html += ( | |
437 f"<h2>{best_model_name}: Tuning Summary</h2>" | |
438 + '<div class="table-wrapper">' | |
439 + tuning_df.to_html(index=False, classes="table sortable") | |
440 + "</div>" | |
441 ) | |
442 | |
443 summary_html += ( | |
444 "<h2>Setup Parameters</h2>" | |
345 + '<div class="table-wrapper">' | 445 + '<div class="table-wrapper">' |
346 + df_setup.to_html(index=False, classes="table sortable") | 446 + df_setup.to_html(index=False, classes="table sortable") |
347 + "</div>" | 447 + "</div>" |
348 # — Hyperparameters | 448 # — Hyperparameters |
349 + "<h2>Best Model Hyperparameters</h2>" | 449 + "<h2>Best Model Hyperparameters</h2>" |
350 + '<div class="table-wrapper">' | 450 + '<div class="table-wrapper">' |
351 + pd.DataFrame( | 451 + pd.DataFrame( |
352 self.best_model.get_params().items(), columns=["Parameter", "Value"] | 452 self.best_model.get_params().items(), |
453 columns=["Parameter", "Value"] | |
353 ).to_html(index=False, classes="table sortable") | 454 ).to_html(index=False, classes="table sortable") |
354 + "</div>" | 455 + "</div>" |
355 ) | 456 ) |
356 | 457 |
357 # choose summary plots based on task type | 458 # choose summary plots based on task type |
371 | 472 |
372 for name in summary_plots: | 473 for name in summary_plots: |
373 if name in self.plots: | 474 if name in self.plots: |
374 summary_html += "<hr>" | 475 summary_html += "<hr>" |
375 b64 = encode_image_to_base64(self.plots[name]) | 476 b64 = encode_image_to_base64(self.plots[name]) |
376 title = plot_title_map.get(name, name.replace("_", " ").title()) | 477 title = plot_title_map.get( |
478 name, name.replace("_", " ").title() | |
479 ) | |
377 summary_html += ( | 480 summary_html += ( |
378 '<div class="plot">' | 481 '<div class="plot">' |
379 f"<h2>{title}</h2>" | 482 f"<h2>{title}</h2>" |
380 f'<img src="data:image/png;base64,{b64}" ' | 483 f'<img src="data:image/png;base64,{b64}" ' |
381 'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>' | 484 'style="max-width:90%;max-height:600px;' |
485 'border:1px solid #ddd;"/>' | |
382 "</div>" | 486 "</div>" |
383 ) | 487 ) |
384 | 488 |
385 # — Test Summary — | 489 # — Test Summary — |
386 test_html = ( | 490 test_html = ( |
387 header | 491 header |
388 + '<div class="table-wrapper">' | 492 + '<div class="table-wrapper">' |
389 + self.test_result_df.to_html(index=False, classes="table sortable") | 493 + self.test_result_df.to_html( |
494 index=False, classes="table sortable" | |
495 ) | |
390 + "</div>" | 496 + "</div>" |
391 ) | 497 ) |
392 if self.task_type == "regression": | 498 if self.task_type == "regression": |
393 try: | 499 try: |
394 y_true = ( | 500 y_true = ( |
395 pd.Series(self.exp.y_test_transformed) | 501 pd.Series(self.exp.y_test_transformed) |
396 .reset_index(drop=True) | 502 .reset_index(drop=True) |
397 .rename("True") | 503 .rename("True") |
398 ) | 504 ) |
399 y_pred = pd.Series( | 505 y_pred = pd.Series( |
400 self.best_model.predict(self.exp.X_test_transformed) | 506 self.best_model.predict( |
507 self.exp.X_test_transformed | |
508 ) | |
401 ).rename("Predicted") | 509 ).rename("Predicted") |
402 df_tp = pd.concat([y_true, y_pred], axis=1) | 510 df_tp = pd.concat([y_true, y_pred], axis=1) |
403 test_html += "<h2>True vs Predicted Values</h2>" | 511 test_html += "<h2>True vs Predicted Values</h2>" |
404 test_html += ( | 512 test_html += ( |
405 '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">' | 513 '<div class="table-wrapper" ' |
406 + df_tp.head(50).to_html(index=False, classes="table sortable") | 514 'style="max-height:400px; overflow-y:auto;">' |
515 + df_tp.head(50).to_html( | |
516 index=False, classes="table sortable" | |
517 ) | |
407 + "</div>" | 518 + "</div>" |
408 + add_hr_to_html() | 519 + add_hr_to_html() |
409 ) | 520 ) |
410 except Exception as e: | 521 except Exception as e: |
411 LOG.warning(f"Could not generate True vs Predicted table: {e}") | 522 LOG.warning( |
523 f"Could not generate True vs Predicted table: {e}" | |
524 ) | |
412 | 525 |
413 # 5a) Explainer-substituted plots in order | 526 # 5a) Explainer-substituted plots in order |
414 if self.task_type == "regression": | 527 if self.task_type == "regression": |
415 test_order = ["residuals"] | 528 test_order = ["residuals"] |
416 else: | 529 else: |
424 ] | 537 ] |
425 for key in test_order: | 538 for key in test_order: |
426 fig_or_fn = self.explainer_plots.pop(key, None) | 539 fig_or_fn = self.explainer_plots.pop(key, None) |
427 if fig_or_fn is not None: | 540 if fig_or_fn is not None: |
428 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | 541 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn |
429 title = plot_title_map.get(key, key.replace("_", " ").title()) | 542 title = plot_title_map.get( |
543 key, key.replace("_", " ").title() | |
544 ) | |
430 test_html += ( | 545 test_html += ( |
431 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | 546 f"<h2>{title}</h2>" + add_plot_to_html(fig) |
547 + add_hr_to_html() | |
432 ) | 548 ) |
433 # 5b) Remaining PyCaret test plots | 549 # 5b) Remaining PyCaret test plots |
434 for name, path in self.plots.items(): | 550 for name, path in self.plots.items(): |
435 # classification: include only the small extras, before skipping anything | 551 # classification: include only the small extras, before |
436 if self.task_type == "classification" and name in { | 552 # skipping anything |
437 "threshold", | 553 if self.task_type == "classification" and ( |
438 "pr_auc", | 554 name in { |
439 "class_report", | 555 "threshold", |
440 }: | 556 "pr_auc", |
441 title = plot_title_map.get(name, name.replace("_", " ").title()) | 557 "class_report", |
558 } | |
559 ): | |
560 title = plot_title_map.get( | |
561 name, name.replace("_", " ").title() | |
562 ) | |
442 b64 = encode_image_to_base64(path) | 563 b64 = encode_image_to_base64(path) |
443 test_html += ( | 564 test_html += ( |
444 f"<h2>{title}</h2>" | 565 f"<h2>{title}</h2>" |
445 "<div class='plot'>" | 566 "<div class='plot'>" |
446 f"<img src='data:image/png;base64,{b64}' " | 567 f"<img src='data:image/png;base64,{b64}' " |
447 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" | 568 "style='max-width:90%;max-height:600px;" |
569 "border:1px solid #ddd;'/>" | |
448 "</div>" + add_hr_to_html() | 570 "</div>" + add_hr_to_html() |
449 ) | 571 ) |
450 continue | 572 continue |
451 | 573 |
452 # regression: explicitly include the 'error' plot, before skipping | 574 # regression: explicitly include the 'error' plot, |
453 if self.task_type == "regression" and name == "error": | 575 # before skipping |
454 title = plot_title_map.get("error", "Prediction Error Distribution") | 576 if self.task_type == "regression" and ( |
577 name == "error" | |
578 ): | |
579 title = plot_title_map.get( | |
580 "error", "Prediction Error Distribution" | |
581 ) | |
455 b64 = encode_image_to_base64(path) | 582 b64 = encode_image_to_base64(path) |
456 test_html += ( | 583 test_html += ( |
457 f"<h2>{title}</h2>" | 584 f"<h2>{title}</h2>" |
458 "<div class='plot'>" | 585 "<div class='plot'>" |
459 f"<img src='data:image/png;base64,{b64}' " | 586 f"<img src='data:image/png;base64,{b64}' " |
460 "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" | 587 "style='max-width:90%;max-height:600px;" |
588 "border:1px solid #ddd;'/>" | |
461 "</div>" + add_hr_to_html() | 589 "</div>" + add_hr_to_html() |
462 ) | 590 ) |
463 continue | 591 continue |
464 | 592 |
465 # now skip any plots already rendered via test_order | 593 # now skip any plots already rendered via test_order |
489 "Mean Absolute SHAP Value Impact" | 617 "Mean Absolute SHAP Value Impact" |
490 if key == "shap_mean" | 618 if key == "shap_mean" |
491 else "Permutation Feature Importance" | 619 else "Permutation Feature Importance" |
492 ) | 620 ) |
493 feature_html += ( | 621 feature_html += ( |
494 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | 622 f"<h2>{title}</h2>" + add_plot_to_html(fig) |
623 + add_hr_to_html() | |
495 ) | 624 ) |
496 | 625 |
497 # 6c) PDPs last | 626 # 6c) PDPs last |
498 pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__")) | 627 pdp_keys = sorted( |
628 k for k in self.explainer_plots if k.startswith("pdp__") | |
629 ) | |
499 for k in pdp_keys: | 630 for k in pdp_keys: |
500 fig_or_fn = self.explainer_plots[k] | 631 fig_or_fn = self.explainer_plots[k] |
501 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn | 632 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn |
502 # extract feature name | 633 # extract feature name |
503 feature = k.split("__", 1)[1] | 634 feature = k.split("__", 1)[1] |
504 title = f"Partial Dependence for {feature}" | 635 title = f"Partial Dependence for {feature}" |
505 feature_html += ( | 636 feature_html += ( |
506 f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() | 637 f"<h2>{title}</h2>" + add_plot_to_html(fig) |
638 + add_hr_to_html() | |
507 ) | 639 ) |
508 # 7) Assemble final HTML (three tabs) | 640 # 7) Assemble final HTML (three tabs) |
509 html = get_html_template() | 641 html = get_html_template() |
510 html += "<h1>Tabular Learner Model Report</h1>" | 642 html += "<h1>Tabular Learner Model Report</h1>" |
511 html += build_tabbed_html(summary_html, test_html, feature_html) | 643 html += build_tabbed_html(summary_html, test_html, feature_html) |
514 | 646 |
515 # 8) Write out | 647 # 8) Write out |
516 (Path(self.output_dir) / "comparison_result.html").write_text( | 648 (Path(self.output_dir) / "comparison_result.html").write_text( |
517 html, encoding="utf-8" | 649 html, encoding="utf-8" |
518 ) | 650 ) |
519 LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html") | 651 LOG.info( |
652 f"HTML report generated at: " | |
653 f"{self.output_dir}/comparison_result.html" | |
654 ) | |
520 | 655 |
521 def save_dashboard(self): | 656 def save_dashboard(self): |
522 raise NotImplementedError("Subclasses should implement this method") | 657 raise NotImplementedError("Subclasses should implement this method") |
523 | 658 |
524 def generate_plots_explainer(self): | 659 def generate_plots_explainer(self): |
525 raise NotImplementedError("Subclasses should implement this method") | 660 raise NotImplementedError("Subclasses should implement this method") |
526 | 661 |
527 def generate_tree_plots(self): | 662 def generate_tree_plots(self): |
528 from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor | 663 from sklearn.ensemble import ( |
664 RandomForestClassifier, RandomForestRegressor | |
665 ) | |
529 from xgboost import XGBClassifier, XGBRegressor | 666 from xgboost import XGBClassifier, XGBRegressor |
530 from explainerdashboard.explainers import RandomForestExplainer | 667 from explainerdashboard.explainers import RandomForestExplainer |
531 | 668 |
532 LOG.info("Generating tree plots") | 669 LOG.info("Generating tree plots") |
533 X_test = self.exp.X_test_transformed.copy() | 670 X_test = self.exp.X_test_transformed.copy() |
534 y_test = self.exp.y_test_transformed | 671 y_test = self.exp.y_test_transformed |
535 | 672 |
536 if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)): | 673 if isinstance( |
674 self.best_model, (RandomForestClassifier, RandomForestRegressor) | |
675 ): | |
537 n_trees = self.best_model.n_estimators | 676 n_trees = self.best_model.n_estimators |
538 elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)): | 677 elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)): |
539 n_trees = len(self.best_model.get_booster().get_dump()) | 678 n_trees = len(self.best_model.get_booster().get_dump()) |
540 else: | 679 else: |
541 LOG.warning("Tree plots not supported for this model type.") | 680 LOG.warning("Tree plots not supported for this model type.") |