comparison base_model_trainer.py @ 16:4fee4504646e draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
author goeckslab
date Fri, 28 Nov 2025 22:28:26 +0000
parents 7d78a6afc958
children
comparison
equal deleted inserted replaced
15:a2aeeb754d76 16:4fee4504646e
44 self.target = None 44 self.target = None
45 self.best_model = None 45 self.best_model = None
46 self.results = None 46 self.results = None
47 self.tuning_results = None 47 self.tuning_results = None
48 self.features_name = None 48 self.features_name = None
49 self.plot_feature_names = None
49 self.plots = {} 50 self.plots = {}
50 self.explainer_plots = {} 51 self.explainer_plots = {}
51 self.plots_explainer_html = None 52 self.plots_explainer_html = None
52 self.trees = [] 53 self.trees = []
53 self.user_kwargs = kwargs.copy() 54 self.user_kwargs = kwargs.copy()
54 for key, value in self.user_kwargs.items(): 55 for key, value in self.user_kwargs.items():
55 setattr(self, key, value) 56 setattr(self, key, value)
57 if not hasattr(self, "plot_feature_limit"):
58 self.plot_feature_limit = 30
59 self._shap_row_cap = None
60 if getattr(self, "polynomial_features", False):
61 # Keep feature importance responsive by trimming plots/SHAP rows
62 try:
63 limit_val = int(self.plot_feature_limit)
64 except (TypeError, ValueError):
65 limit_val = 30
66 self.plot_feature_limit = min(limit_val, 15)
67 self._shap_row_cap = 200
68 LOG.info(
69 "Polynomial features enabled; limiting feature plots to %s and SHAP rows to %s",
70 self.plot_feature_limit,
71 self._shap_row_cap,
72 )
73 self.imputed_training_data = None
74 self._best_model_metric_used = None
56 self.setup_params = {} 75 self.setup_params = {}
57 self.test_file = test_file 76 self.test_file = test_file
58 self.test_data = None 77 self.test_data = None
59 78
60 if not self.output_dir: 79 if not self.output_dir:
125 # Update names after possible drop 144 # Update names after possible drop
126 names = self.data.columns.to_list() 145 names = self.data.columns.to_list()
127 LOG.info(f"Dataset columns after processing: {names}") 146 LOG.info(f"Dataset columns after processing: {names}")
128 147
129 self.features_name = [n for n in names if n != self.target] 148 self.features_name = [n for n in names if n != self.target]
130 149 self.plot_feature_names = self._select_plot_features(self.features_name)
131 if getattr(self, "missing_value_strategy", None):
132 strat = self.missing_value_strategy
133 if strat == "mean":
134 self.data = self.data.fillna(
135 self.data.mean(numeric_only=True)
136 )
137 elif strat == "median":
138 self.data = self.data.fillna(
139 self.data.median(numeric_only=True)
140 )
141 elif strat == "drop":
142 self.data = self.data.dropna()
143 else:
144 self.data = self.data.fillna(
145 self.data.median(numeric_only=True)
146 )
147 150
148 if self.test_file: 151 if self.test_file:
149 LOG.info(f"Loading test data from {self.test_file}") 152 LOG.info(f"Loading test data from {self.test_file}")
150 df_test = pd.read_csv( 153 df_test = pd.read_csv(
151 self.test_file, sep=None, engine="python" 154 self.test_file, sep=None, engine="python"
152 ) 155 )
153 df_test.columns = df_test.columns.str.replace(".", "_") 156 df_test.columns = df_test.columns.str.replace(".", "_")
154 self.test_data = df_test 157 self.test_data = df_test
158
159 def _select_plot_features(self, all_features):
160 limit = getattr(self, "plot_feature_limit", 30)
161 if not isinstance(limit, int) or limit <= 0:
162 LOG.info(
163 "Feature plotting limit disabled (plot_feature_limit=%s).", limit
164 )
165 return all_features
166 if len(all_features) <= limit:
167 LOG.info(
168 "Feature plotting limit not needed (%s features <= limit %s).",
169 len(all_features),
170 limit,
171 )
172 return all_features
173 df = self.data[all_features].copy()
174 numeric_cols = df.select_dtypes(include=["number"]).columns
175 ranked = []
176 if len(numeric_cols) > 0:
177 variances = (
178 df[numeric_cols]
179 .var()
180 .fillna(0)
181 .abs()
182 .sort_values(ascending=False)
183 )
184 ranked = variances.index.tolist()
185 selected = []
186 for col in ranked:
187 if len(selected) >= limit:
188 break
189 selected.append(col)
190 if len(selected) < limit:
191 for col in all_features:
192 if col in selected:
193 continue
194 selected.append(col)
195 if len(selected) >= limit:
196 break
197 LOG.info(
198 "Limiting feature-level plots to %s of %s available features (limit=%s).",
199 len(selected),
200 len(all_features),
201 limit,
202 )
203 return selected
155 204
156 def setup_pycaret(self): 205 def setup_pycaret(self):
157 LOG.info("Initializing PyCaret") 206 LOG.info("Initializing PyCaret")
158 self.setup_params = { 207 self.setup_params = {
159 "target": self.target, 208 "target": self.target,
196 raise ValueError( 245 raise ValueError(
197 "task_type must be 'classification' or 'regression'" 246 "task_type must be 'classification' or 'regression'"
198 ) 247 )
199 248
200 self.exp.setup(self.data, **self.setup_params) 249 self.exp.setup(self.data, **self.setup_params)
250 self._capture_imputed_training_data()
201 self.setup_params.update(self.user_kwargs) 251 self.setup_params.update(self.user_kwargs)
202 252
203 def _normalize_metric(self, m: str) -> str: 253 def _capture_imputed_training_data(self):
204 if not m: 254 """
205 return "R2" if self.task_type == "regression" else "Accuracy" 255 Cache the dataset as transformed/imputed by PyCaret so downstream
206 m_low = str(m).strip().lower() 256 components (e.g., feature importance) can operate on the exact data
207 alias = { 257 used for training.
208 "auc": "AUC", "roc_auc": "AUC", "roc-auc": "AUC", 258 """
209 "accuracy": "Accuracy", 259 if self.exp is None:
210 "precision": "Precision", 260 return
211 "recall": "Recall", 261 try:
212 "f1": "F1", 262 X_processed = self.exp.get_config("X_transformed").copy()
213 "kappa": "Kappa", 263 y_processed = self.exp.get_config("y")
214 "logloss": "Log Loss", "log_loss": "Log Loss", 264 if isinstance(y_processed, pd.Series):
215 "pr_auc": "PR-AUC-Weighted", "prauc": "PR-AUC-Weighted", 265 y_series = y_processed.reset_index(drop=True)
216 "r2": "R2", 266 else:
217 "mae": "MAE", 267 y_series = pd.Series(y_processed)
218 "mse": "MSE", 268 y_series.name = self.target
219 "rmse": "RMSE", 269 X_processed = X_processed.reset_index(drop=True)
220 "rmsle": "RMSLE", 270 self.imputed_training_data = pd.concat(
221 "mape": "MAPE", 271 [X_processed, y_series], axis=1
222 } 272 )
223 return alias.get(m_low, m) 273 LOG.info(
274 "Captured imputed training dataset from PyCaret "
275 "(%s rows, %s features).",
276 self.imputed_training_data.shape[0],
277 self.imputed_training_data.shape[1] - 1,
278 )
279 except Exception as exc:
280 LOG.warning(
281 "Unable to capture processed training data from PyCaret: %s",
282 exc,
283 )
284 self.imputed_training_data = None
224 285
225 def train_model(self): 286 def train_model(self):
226 LOG.info("Training and selecting the best model") 287 LOG.info("Training and selecting the best model")
227 if self.task_type == "classification": 288 if self.task_type == "classification":
228 self.exp.add_metric( 289 self.exp.add_metric(
243 304
244 # Respect explicit fold count 305 # Respect explicit fold count
245 if getattr(self, "cross_validation_folds", None) is not None: 306 if getattr(self, "cross_validation_folds", None) is not None:
246 compare_kwargs["fold"] = self.cross_validation_folds 307 compare_kwargs["fold"] = self.cross_validation_folds
247 308
248 chosen_metric = self._normalize_metric(getattr(self, "best_model_metric", None)) 309 best_metric = getattr(self, "best_model_metric", None)
249 if chosen_metric: 310 if best_metric:
250 compare_kwargs["sort"] = chosen_metric 311 compare_kwargs["sort"] = best_metric
251 self.chosen_metric_label = chosen_metric 312 self._best_model_metric_used = best_metric
252 try: 313 LOG.info(f"Ranking models using metric: {best_metric}")
253 setattr(self.exp, "_fold_metric", chosen_metric)
254 except Exception as e:
255 LOG.warning(f"Failed to set '_fold_metric' to '{chosen_metric}': {e}", exc_info=True)
256 314
257 LOG.info(f"compare_models kwargs: {compare_kwargs}") 315 LOG.info(f"compare_models kwargs: {compare_kwargs}")
258 self.best_model = self.exp.compare_models(**compare_kwargs) 316 self.best_model = self.exp.compare_models(**compare_kwargs)
317 if self._best_model_metric_used is None:
318 self._best_model_metric_used = getattr(self.exp, "_fold_metric", None)
259 self.results = self.exp.pull() 319 self.results = self.exp.pull()
260 if getattr(self, "tune_model", False): 320 if getattr(self, "tune_model", False):
261 LOG.info("Tuning hyperparameters of the best model") 321 LOG.info("Tuning hyperparameters of the best model")
262 self.best_model = self.exp.tune_model(self.best_model) 322 self.best_model = self.exp.tune_model(self.best_model)
263 self.tuning_results = self.exp.pull() 323 self.tuning_results = self.exp.pull()
324 LOG.warning(f"Could not generate {name} plot: {e}") 384 LOG.warning(f"Could not generate {name} plot: {e}")
325 385
326 def encode_image_to_base64(self, img_path: str) -> str: 386 def encode_image_to_base64(self, img_path: str) -> str:
327 with open(img_path, "rb") as img_file: 387 with open(img_path, "rb") as img_file:
328 return base64.b64encode(img_file.read()).decode("utf-8") 388 return base64.b64encode(img_file.read()).decode("utf-8")
389
390 def _resolve_plot_callable(self, key, fig_or_fn, section):
391 """
392 Safely execute stored plot callables so a single failure does not
393 abort the entire HTML report generation.
394 """
395 if fig_or_fn is None:
396 return None
397 try:
398 return fig_or_fn() if callable(fig_or_fn) else fig_or_fn
399 except Exception as exc:
400 extra = ""
401 if isinstance(exc, ValueError) and "Input contains NaN" in str(exc):
402 extra = (
403 " (model returned NaN probabilities; "
404 "consider checking data preprocessing)"
405 )
406 LOG.warning(
407 "Skipping %s plot '%s' due to error: %s%s",
408 section,
409 key,
410 exc,
411 extra,
412 )
413 return None
329 414
330 def save_html_report(self): 415 def save_html_report(self):
331 LOG.info("Saving HTML report") 416 LOG.info("Saving HTML report")
332 417
333 # 1) Determine best model name 418 # 1) Determine best model name
399 elif key == "Probability Threshold": 484 elif key == "Probability Threshold":
400 dv = f"{v:.2f}" if v is not None else "0.5" 485 dv = f"{v:.2f}" if v is not None else "0.5"
401 else: 486 else:
402 dv = v if v is not None else "None" 487 dv = v if v is not None else "None"
403 setup_rows.append([key, dv]) 488 setup_rows.append([key, dv])
404 if getattr(self, "chosen_metric_label", None): 489 metric_label = self._best_model_metric_used or getattr(
405 setup_rows.append(["Best Model Metric", self.chosen_metric_label]) 490 self.exp, "_fold_metric", None
491 )
492 if metric_label:
493 setup_rows.append(["Best Model Metric", metric_label])
406 494
407 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) 495 df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
408 df_setup.to_csv( 496 df_setup.to_csv(
409 Path(self.output_dir) / "setup_params.csv", index=False 497 Path(self.output_dir) / "setup_params.csv", index=False
410 ) 498 )
562 test_order = [ 650 test_order = [
563 "confusion_matrix", 651 "confusion_matrix",
564 "roc_auc", 652 "roc_auc",
565 "pr_auc", 653 "pr_auc",
566 "lift_curve", 654 "lift_curve",
567 "threshold",
568 "cumulative_precision", 655 "cumulative_precision",
569 ] 656 ]
570 for key in test_order: 657 for key in test_order:
571 fig_or_fn = self.explainer_plots.pop(key, None) 658 fig_or_fn = self.explainer_plots.pop(key, None)
572 if fig_or_fn is not None: 659 if fig_or_fn is not None:
573 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn 660 fig = self._resolve_plot_callable(
661 key, fig_or_fn, section="test/explainer"
662 )
663 if fig is None:
664 continue
574 title = plot_title_map.get( 665 title = plot_title_map.get(
575 key, key.replace("_", " ").title() 666 key, key.replace("_", " ").title()
576 ) 667 )
577 test_html += ( 668 test_html += (
578 f"<h2>{title}</h2>" + add_plot_to_html(fig) 669 f"<h2>{title}</h2>" + add_plot_to_html(fig)
582 for name, path in self.plots.items(): 673 for name, path in self.plots.items():
583 # classification: include only the small extras, before 674 # classification: include only the small extras, before
584 # skipping anything 675 # skipping anything
585 if self.task_type == "classification" and ( 676 if self.task_type == "classification" and (
586 name in { 677 name in {
587 "threshold",
588 "pr_auc", 678 "pr_auc",
589 "class_report", 679 "class_report",
590 } 680 }
591 ): 681 ):
592 title = plot_title_map.get( 682 title = plot_title_map.get(
628 718
629 # — Feature Importance — 719 # — Feature Importance —
630 feature_html = header 720 feature_html = header
631 721
632 # 6a) PyCaret’s default feature importances 722 # 6a) PyCaret’s default feature importances
633 feature_html += FeatureImportanceAnalyzer( 723 imputed_data = (
634 data=self.data, 724 self.imputed_training_data
725 if self.imputed_training_data is not None
726 else self.data
727 )
728 fi_analyzer = FeatureImportanceAnalyzer(
729 data=imputed_data,
635 target_col=self.target_col, 730 target_col=self.target_col,
636 task_type=self.task_type, 731 task_type=self.task_type,
637 output_dir=self.output_dir, 732 output_dir=self.output_dir,
638 exp=self.exp, 733 exp=self.exp,
639 best_model=self.best_model, 734 best_model=self.best_model,
640 ).run() 735 max_plot_features=self.plot_feature_limit,
736 processed_data=self.imputed_training_data,
737 max_shap_rows=self._shap_row_cap,
738 )
739 fi_html = fi_analyzer.run()
740 # Add a small table to show SHAP feature caps near the Best Model header.
741 cap_rows = []
742 if fi_analyzer.shap_total_features is not None:
743 cap_rows.append(
744 ("Total transformed features", fi_analyzer.shap_total_features)
745 )
746 if fi_analyzer.shap_used_features is not None:
747 cap_rows.append(
748 ("Features used in SHAP", fi_analyzer.shap_used_features)
749 )
750 if cap_rows:
751 cap_table = (
752 "<div class='table-wrapper'>"
753 "<table class='table sortable'>"
754 "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>"
755 "<tbody>"
756 + "".join(
757 f"<tr><td>{label}</td><td>{value}</td></tr>"
758 for label, value in cap_rows
759 )
760 + "</tbody></table></div>"
761 )
762 feature_html += cap_table
763 feature_html += fi_html
641 764
642 # 6b) Explainer SHAP importances 765 # 6b) Explainer SHAP importances
643 for key in ["shap_mean", "shap_perm"]: 766 for key in ["shap_mean", "shap_perm"]:
644 fig_or_fn = self.explainer_plots.pop(key, None) 767 fig_or_fn = self.explainer_plots.pop(key, None)
645 if fig_or_fn is not None: 768 if fig_or_fn is not None:
646 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn 769 fig = self._resolve_plot_callable(
770 key, fig_or_fn, section="feature importance"
771 )
772 if fig is None:
773 continue
647 # give SHAP plots explicit titles 774 # give SHAP plots explicit titles
648 title = ( 775 title = (
649 "Mean Absolute SHAP Value Impact" 776 "Mean Absolute SHAP Value Impact"
650 if key == "shap_mean" 777 if key == "shap_mean"
651 else "Permutation Feature Importance" 778 else "Permutation Feature Importance"
659 pdp_keys = sorted( 786 pdp_keys = sorted(
660 k for k in self.explainer_plots if k.startswith("pdp__") 787 k for k in self.explainer_plots if k.startswith("pdp__")
661 ) 788 )
662 for k in pdp_keys: 789 for k in pdp_keys:
663 fig_or_fn = self.explainer_plots[k] 790 fig_or_fn = self.explainer_plots[k]
664 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn 791 fig = self._resolve_plot_callable(
792 k, fig_or_fn, section="pdp"
793 )
794 if fig is None:
795 continue
665 # extract feature name 796 # extract feature name
666 feature = k.split("__", 1)[1] 797 feature = k.split("__", 1)[1]
667 title = f"Partial Dependence for {feature}" 798 title = f"Partial Dependence for {feature}"
668 feature_html += ( 799 feature_html += (
669 f"<h2>{title}</h2>" + add_plot_to_html(fig) 800 f"<h2>{title}</h2>" + add_plot_to_html(fig)