changeset 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
files base_model_trainer.py feature_help_modal.py feature_importance.py pycaret_classification.py pycaret_predict.py pycaret_regression.py pycaret_train.py utils.py
diffstat 8 files changed, 1035 insertions(+), 777 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/base_model_trainer.py	Fri Jul 25 19:02:32 2025 +0000
@@ -1,7 +1,7 @@
 import base64
 import logging
-import os
 import tempfile
+from pathlib import Path
 
 import h5py
 import joblib
@@ -10,7 +10,14 @@
 from feature_help_modal import get_feature_metrics_help_modal
 from feature_importance import FeatureImportanceAnalyzer
 from sklearn.metrics import average_precision_score
-from utils import get_html_closing, get_html_template
+from utils import (
+    add_hr_to_html,
+    add_plot_to_html,
+    build_tabbed_html,
+    encode_image_to_base64,
+    get_html_closing,
+    get_html_template,
+)
 
 logging.basicConfig(level=logging.DEBUG)
 LOG = logging.getLogger(__name__)
@@ -27,7 +34,7 @@
         test_file=None,
         **kwargs,
     ):
-        self.exp = None  # This will be set in the subclass
+        self.exp = None
         self.input_file = input_file
         self.target_col = target_col
         self.output_dir = output_dir
@@ -39,10 +46,11 @@
         self.results = None
         self.features_name = None
         self.plots = {}
-        self.expaliner = None
+        self.explainer_plots = {}
         self.plots_explainer_html = None
         self.trees = []
-        for key, value in kwargs.items():
+        self.user_kwargs = kwargs.copy()
+        for key, value in self.user_kwargs.items():
             setattr(self, key, value)
         self.setup_params = {}
         self.test_file = test_file
@@ -57,43 +65,38 @@
         LOG.info(f"Loading data from {self.input_file}")
         self.data = pd.read_csv(self.input_file, sep=None, engine="python")
         self.data.columns = self.data.columns.str.replace(".", "_")
-
-        # Remove prediction_label if present
         if "prediction_label" in self.data.columns:
             self.data = self.data.drop(columns=["prediction_label"])
 
         numeric_cols = self.data.select_dtypes(include=["number"]).columns
         non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns
-
         self.data[numeric_cols] = self.data[numeric_cols].apply(
             pd.to_numeric, errors="coerce"
         )
-
         if len(non_numeric_cols) > 0:
             LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
 
         names = self.data.columns.to_list()
         target_index = int(self.target_col) - 1
         self.target = names[target_index]
-        self.features_name = [name for i, name in enumerate(names) if i != target_index]
-        if hasattr(self, "missing_value_strategy"):
-            if self.missing_value_strategy == "mean":
+        self.features_name = [n for i, n in enumerate(names) if i != target_index]
+
+        if getattr(self, "missing_value_strategy", None):
+            strat = self.missing_value_strategy
+            if strat == "mean":
                 self.data = self.data.fillna(self.data.mean(numeric_only=True))
-            elif self.missing_value_strategy == "median":
+            elif strat == "median":
                 self.data = self.data.fillna(self.data.median(numeric_only=True))
-            elif self.missing_value_strategy == "drop":
+            elif strat == "drop":
                 self.data = self.data.dropna()
         else:
-            # Default strategy if not specified
             self.data = self.data.fillna(self.data.median(numeric_only=True))
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
-            self.test_data = pd.read_csv(self.test_file, sep=None, engine="python")
-            self.test_data = self.test_data[numeric_cols].apply(
-                pd.to_numeric, errors="coerce"
-            )
-            self.test_data.columns = self.test_data.columns.str.replace(".", "_")
+            df_test = pd.read_csv(self.test_file, sep=None, engine="python")
+            df_test.columns = df_test.columns.str.replace(".", "_")
+            self.test_data = df_test
 
     def setup_pycaret(self):
         LOG.info("Initializing PyCaret")
@@ -105,59 +108,26 @@
             "system_log": False,
             "index": False,
         }
-
         if self.test_data is not None:
             self.setup_params["test_data"] = self.test_data
-
-        if (
-            hasattr(self, "train_size")
-            and self.train_size is not None
-            and self.test_data is None
-        ):
-            self.setup_params["train_size"] = self.train_size
-
-        if hasattr(self, "normalize") and self.normalize is not None:
-            self.setup_params["normalize"] = self.normalize
-
-        if hasattr(self, "feature_selection") and self.feature_selection is not None:
-            self.setup_params["feature_selection"] = self.feature_selection
-
-        if (
-            hasattr(self, "cross_validation")
-            and self.cross_validation is not None
-            and self.cross_validation is False
-        ):
-            logging.info(
-                "cross_validation is set to False. This will disable cross-validation."
-            )
-
-        if hasattr(self, "cross_validation") and self.cross_validation:
-            if hasattr(self, "cross_validation_folds"):
-                self.setup_params["fold"] = self.cross_validation_folds
-
-        if hasattr(self, "remove_outliers") and self.remove_outliers is not None:
-            self.setup_params["remove_outliers"] = self.remove_outliers
-
-        if (
-            hasattr(self, "remove_multicollinearity")
-            and self.remove_multicollinearity is not None
-        ):
-            self.setup_params["remove_multicollinearity"] = (
-                self.remove_multicollinearity
-            )
-
-        if (
-            hasattr(self, "polynomial_features")
-            and self.polynomial_features is not None
-        ):
-            self.setup_params["polynomial_features"] = self.polynomial_features
-
-        if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None:
-            self.setup_params["fix_imbalance"] = self.fix_imbalance
-
+        for attr in [
+            "train_size",
+            "normalize",
+            "feature_selection",
+            "remove_outliers",
+            "remove_multicollinearity",
+            "polynomial_features",
+            "feature_interaction",
+            "feature_ratio",
+            "fix_imbalance",
+        ]:
+            val = getattr(self, attr, None)
+            if val is not None:
+                self.setup_params[attr] = val
+        if getattr(self, "cross_validation_folds", None) is not None:
+            self.setup_params["fold"] = self.cross_validation_folds
         LOG.info(self.setup_params)
 
-        # Solution: instantiate the correct PyCaret experiment based on task_type
         if self.task_type == "classification":
             from pycaret.classification import ClassificationExperiment
 
@@ -170,246 +140,371 @@
             raise ValueError("task_type must be 'classification' or 'regression'")
 
         self.exp.setup(self.data, **self.setup_params)
+        self.setup_params.update(self.user_kwargs)
 
     def train_model(self):
         LOG.info("Training and selecting the best model")
         if self.task_type == "classification":
-            average_displayed = "Weighted"
             self.exp.add_metric(
-                id=f"PR-AUC-{average_displayed}",
-                name=f"PR-AUC-{average_displayed}",
+                id="PR-AUC-Weighted",
+                name="PR-AUC-Weighted",
                 target="pred_proba",
                 score_func=average_precision_score,
                 average="weighted",
             )
+        # Build arguments for compare_models()
+        compare_kwargs = {}
+        if getattr(self, "models", None):
+            compare_kwargs["include"] = self.models
 
-        if hasattr(self, "models") and self.models is not None:
-            self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation)
-        else:
-            self.best_model = self.exp.compare_models(cross_validation=self.cross_validation)
+        # Respect explicit cross-validation flag
+        if getattr(self, "cross_validation", None) is not None:
+            compare_kwargs["cross_validation"] = self.cross_validation
+
+        # Respect explicit fold count
+        if getattr(self, "cross_validation_folds", None) is not None:
+            compare_kwargs["fold"] = self.cross_validation_folds
+
+        LOG.info(f"compare_models kwargs: {compare_kwargs}")
+        self.best_model = self.exp.compare_models(**compare_kwargs)
         self.results = self.exp.pull()
+        if getattr(self, "tune_model", False):
+            LOG.info("Tuning hyperparameters of the best model")
+            self.best_model = self.exp.tune_model(self.best_model)
+            self.results = self.exp.pull()
 
         if self.task_type == "classification":
             self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
-
         _ = self.exp.predict_model(self.best_model)
         self.test_result_df = self.exp.pull()
         if self.task_type == "classification":
             self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
     def save_model(self):
-        hdf5_model_path = "pycaret_model.h5"
-        with h5py.File(hdf5_model_path, "w") as f:
-            with tempfile.NamedTemporaryFile(delete=False) as temp_file:
-                joblib.dump(self.best_model, temp_file.name)
-                temp_file.seek(0)
-                model_bytes = temp_file.read()
+        hdf5_path = Path(self.output_dir) / "pycaret_model.h5"
+        with h5py.File(hdf5_path, "w") as f:
+            with tempfile.NamedTemporaryFile(delete=False) as tmp:
+                joblib.dump(self.best_model, tmp.name)
+                tmp.seek(0)
+                model_bytes = tmp.read()
             f.create_dataset("model", data=np.void(model_bytes))
 
     def generate_plots(self):
-        raise NotImplementedError("Subclasses should implement this method")
+        LOG.info("Generating PyCaret diagnostic pltos")
 
-    def encode_image_to_base64(self, img_path):
+        # choose the right plots based on task
+        if self.task_type == "classification":
+            plot_names = [
+                "learning",
+                "vc",
+                "calibration",
+                "dimension",
+                "manifold",
+                "rfe",
+                "threshold",
+                "percentage_above_below",
+                "class_report",
+                "pr_auc",
+                "roc_auc",
+            ]
+        else:
+            plot_names = ["residuals", "vc", "parameter", "error", "learning"]
+        for name in plot_names:
+            try:
+                ax = self.exp.plot_model(self.best_model, plot=name, save=False)
+                out_path = Path(self.output_dir) / f"plot_{name}.png"
+                fig = ax.get_figure()
+                fig.savefig(out_path, bbox_inches="tight")
+                self.plots[name] = str(out_path)
+            except Exception as e:
+                LOG.warning(f"Could not generate {name} plot: {e}")
+
+    def encode_image_to_base64(self, img_path: str) -> str:
         with open(img_path, "rb") as img_file:
             return base64.b64encode(img_file.read()).decode("utf-8")
 
     def save_html_report(self):
         LOG.info("Saving HTML report")
 
-        if not self.output_dir:
-            raise ValueError("output_dir must be specified and not None")
+        # 1) Determine best model name
+        try:
+            best_model_name = str(self.results.iloc[0]["Model"])
+        except Exception:
+            best_model_name = type(self.best_model).__name__
+        LOG.info(f"Best model determined as: {best_model_name}")
+
+        # 2) Compute training sample count
+        try:
+            n_train = self.exp.X_train.shape[0]
+        except Exception:
+            n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0]
+        total_rows = self.data.shape[0]
 
-        model_name = type(self.best_model).__name__
-        excluded_params = ["html", "log_experiment", "system_log", "test_data"]
-        filtered_setup_params = {
-            k: v for k, v in self.setup_params.items() if k not in excluded_params
+        # 3) Build setup parameters table
+        all_params = self.setup_params
+        display_keys = [
+            "Target",
+            "Session ID",
+            "Train Size",
+            "Normalize",
+            "Feature Selection",
+            "Cross Validation",
+            "Cross Validation Folds",
+            "Remove Outliers",
+            "Remove Multicollinearity",
+            "Polynomial Features",
+            "Fix Imbalance",
+            "Models",
+        ]
+        setup_rows = []
+        for key in display_keys:
+            pk = key.lower().replace(" ", "_")
+            v = all_params.get(pk)
+            if key == "Train Size":
+                frac = (
+                    float(v)
+                    if v is not None
+                    else (n_train / total_rows if total_rows else 0)
+                )
+                dv = f"{frac:.2f} ({n_train} rows)"
+            elif key in {
+                "Normalize",
+                "Feature Selection",
+                "Cross Validation",
+                "Remove Outliers",
+                "Remove Multicollinearity",
+                "Polynomial Features",
+                "Fix Imbalance",
+            }:
+                dv = bool(v)
+            elif key == "Cross Validation Folds":
+                dv = v if v is not None else "None"
+            elif key == "Models":
+                dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
+            else:
+                dv = v if v is not None else "None"
+            setup_rows.append([key, dv])
+        if hasattr(self.exp, "_fold_metric"):
+            setup_rows.append(["best_model_metric", self.exp._fold_metric])
+
+        df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
+        df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False)
+
+        # 4) Persist CSVs
+        self.results.to_csv(
+            Path(self.output_dir) / "comparison_results.csv", index=False
+        )
+        self.test_result_df.to_csv(
+            Path(self.output_dir) / "test_results.csv", index=False
+        )
+        pd.DataFrame(
+            self.best_model.get_params().items(), columns=["Parameter", "Value"]
+        ).to_csv(Path(self.output_dir) / "best_model.csv", index=False)
+
+        # 5) Header
+        header = f"<h2>Best Model: {best_model_name}</h2>"
+
+        # — Validation Summary & Configuration —
+        val_df = self.results.copy()
+        # mapping raw plot keys to user-friendly titles
+        plot_title_map = {
+            "learning": "Learning Curve",
+            "vc": "Validation Curve",
+            "calibration": "Calibration Curve",
+            "dimension": "Dimensionality Reduction",
+            "manifold": "Manifold Learning",
+            "rfe": "Recursive Feature Elimination",
+            "threshold": "Threshold Plot",
+            "percentage_above_below": "Percentage Above vs. Below Cutoff",
+            "class_report": "Classification Report",
+            "pr_auc": "Precision-Recall AUC",
+            "roc_auc": "Receiver Operating Characteristic AUC",
+            "residuals": "Residuals Distribution",
+            "error": "Prediction Error Distribution",
         }
-        setup_params_table = pd.DataFrame(
-            list(filtered_setup_params.items()), columns=["Parameter", "Value"]
+        val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True)
+        summary_html = (
+            header
+            + "<h2>Train & Validation Summary</h2>"
+            + '<div class="table-wrapper">'
+            + val_df.to_html(index=False, classes="table sortable")
+            + "</div>"
+            + "<h2>Setup Parameters</h2>"
+            + '<div class="table-wrapper">'
+            + df_setup.to_html(index=False, classes="table sortable")
+            + "</div>"
+            # — Hyperparameters
+            + "<h2>Best Model Hyperparameters</h2>"
+            + '<div class="table-wrapper">'
+            + pd.DataFrame(
+                self.best_model.get_params().items(), columns=["Parameter", "Value"]
+            ).to_html(index=False, classes="table sortable")
+            + "</div>"
         )
 
-        best_model_params = pd.DataFrame(
-            self.best_model.get_params().items(), columns=["Parameter", "Value"]
-        )
-        best_model_params.to_csv(
-            os.path.join(self.output_dir, "best_model.csv"), index=False
-        )
-        self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
-        self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv"))
+        # choose summary plots based on task type
+        if self.task_type == "classification":
+            summary_plots = [
+                "learning",
+                "vc",
+                "calibration",
+                "dimension",
+                "manifold",
+                "rfe",
+                "threshold",
+                "percentage_above_below",
+            ]
+        else:
+            summary_plots = ["learning", "vc", "parameter", "residuals"]
 
-        plots_html = ""
-        length = len(self.plots)
-        for i, (plot_name, plot_path) in enumerate(self.plots.items()):
-            encoded_image = self.encode_image_to_base64(plot_path)
-            plots_html += (
-                f'<div class="plot">'
-                f"<h3>{plot_name.capitalize()}</h3>"
-                f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">'
-                f"</div>"
-            )
-            if i < length - 1:
-                plots_html += "<hr>"
-
-        tree_plots = ""
-        for i, tree in enumerate(self.trees):
-            if tree:
-                tree_plots += (
-                    f'<div class="plot">'
-                    f"<h3>Tree {i + 1}</h3>"
-                    f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">'
-                    f"</div>"
+        for name in summary_plots:
+            if name in self.plots:
+                summary_html += "<hr>"
+                b64 = encode_image_to_base64(self.plots[name])
+                title = plot_title_map.get(name, name.replace("_", " ").title())
+                summary_html += (
+                    '<div class="plot">'
+                    f"<h2>{title}</h2>"
+                    f'<img src="data:image/png;base64,{b64}" '
+                    'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>'
+                    "</div>"
                 )
 
-        analyzer = FeatureImportanceAnalyzer(
+        # — Test Summary —
+        test_html = (
+            header
+            + '<div class="table-wrapper">'
+            + self.test_result_df.to_html(index=False, classes="table sortable")
+            + "</div>"
+        )
+        if self.task_type == "regression":
+            try:
+                y_true = (
+                    pd.Series(self.exp.y_test_transformed)
+                    .reset_index(drop=True)
+                    .rename("True")
+                )
+                y_pred = pd.Series(
+                    self.best_model.predict(self.exp.X_test_transformed)
+                ).rename("Predicted")
+                df_tp = pd.concat([y_true, y_pred], axis=1)
+                test_html += "<h2>True vs Predicted Values</h2>"
+                test_html += (
+                    '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">'
+                    + df_tp.head(50).to_html(index=False, classes="table sortable")
+                    + "</div>"
+                    + add_hr_to_html()
+                )
+            except Exception as e:
+                LOG.warning(f"Could not generate True vs Predicted table: {e}")
+
+        # 5a) Explainer-substituted plots in order
+        if self.task_type == "regression":
+            test_order = ["residuals"]
+        else:
+            test_order = [
+                "confusion_matrix",
+                "roc_auc",
+                "pr_auc",
+                "lift_curve",
+                "threshold",
+                "cumulative_precision",
+            ]
+        for key in test_order:
+            fig_or_fn = self.explainer_plots.pop(key, None)
+            if fig_or_fn is not None:
+                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+                title = plot_title_map.get(key, key.replace("_", " ").title())
+                test_html += (
+                    f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
+                )
+        # 5b) Remaining PyCaret test plots
+        for name, path in self.plots.items():
+            # classification: include only the small extras, before skipping anything
+            if self.task_type == "classification" and name in {
+                "threshold",
+                "pr_auc",
+                "class_report",
+            }:
+                title = plot_title_map.get(name, name.replace("_", " ").title())
+                b64 = encode_image_to_base64(path)
+                test_html += (
+                    f"<h2>{title}</h2>"
+                    "<div class='plot'>"
+                    f"<img src='data:image/png;base64,{b64}' "
+                    "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
+                    "</div>" + add_hr_to_html()
+                )
+                continue
+
+            # regression: explicitly include the 'error' plot, before skipping
+            if self.task_type == "regression" and name == "error":
+                title = plot_title_map.get("error", "Prediction Error Distribution")
+                b64 = encode_image_to_base64(path)
+                test_html += (
+                    f"<h2>{title}</h2>"
+                    "<div class='plot'>"
+                    f"<img src='data:image/png;base64,{b64}' "
+                    "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
+                    "</div>" + add_hr_to_html()
+                )
+                continue
+
+            # now skip any plots already rendered via test_order
+            if name in test_order:
+                continue
+
+        # — Feature Importance —
+        feature_html = header
+
+        # 6a) PyCaret’s default feature importances
+        feature_html += FeatureImportanceAnalyzer(
             data=self.data,
             target_col=self.target_col,
             task_type=self.task_type,
             output_dir=self.output_dir,
             exp=self.exp,
             best_model=self.best_model,
-        )
-        feature_importance_html = analyzer.run()
+        ).run()
 
-        # --- Feature Metrics Help Button ---
-        feature_metrics_button_html = (
-            '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">'
-            "Help: Metrics Guide"
-            "</button>"
-            "<style>"
-            ".help-modal-btn {"
-            "background-color: #17623b;"
-            "color: #fff;"
-            "border: none;"
-            "border-radius: 24px;"
-            "padding: 10px 28px;"
-            "font-size: 1.1rem;"
-            "font-weight: bold;"
-            "letter-spacing: 0.03em;"
-            "cursor: pointer;"
-            "transition: background 0.2s, box-shadow 0.2s;"
-            "box-shadow: 0 2px 8px rgba(23,98,59,0.07);"
-            "}"
-            ".help-modal-btn:hover, .help-modal-btn:focus {"
-            "background-color: #21895e;"
-            "outline: none;"
-            "box-shadow: 0 4px 16px rgba(23,98,59,0.14);"
-            "}"
-            "</style>"
-        )
+        # 6b) Explainer SHAP importances
+        for key in ["shap_mean", "shap_perm"]:
+            fig_or_fn = self.explainer_plots.pop(key, None)
+            if fig_or_fn is not None:
+                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+                # give SHAP plots explicit titles
+                title = (
+                    "Mean Absolute SHAP Value Impact"
+                    if key == "shap_mean"
+                    else "Permutation Feature Importance"
+                )
+                feature_html += (
+                    f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
+                )
 
-        html_content = (
-            f"{get_html_template()}"
-            "<h1>Tabular Learner Model Report</h1>"
-            f"{feature_metrics_button_html}"
-            '<div class="tabs">'
-            '<div class="tab" onclick="openTab(event, \'summary\')">'
-            "Validation Result Summary & Config</div>"
-            '<div class="tab" onclick="openTab(event, \'plots\')">'
-            "Test Results</div>"
-            '<div class="tab" onclick="openTab(event, \'feature\')">'
-            "Feature Importance</div>"
-        )
-        if self.plots_explainer_html:
-            html_content += (
-                '<div class="tab" onclick="openTab(event, \'explainer\')">'
-                "Explainer Plots</div>"
+        # 6c) PDPs last
+        pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__"))
+        for k in pdp_keys:
+            fig_or_fn = self.explainer_plots[k]
+            fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+            # extract feature name
+            feature = k.split("__", 1)[1]
+            title = f"Partial Dependence for {feature}"
+            feature_html += (
+                f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
             )
-        html_content += (
-            "</div>"
-            '<div id="summary" class="tab-content">'
-            f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>"
-            f"<h2>Best Model: {model_name}</h2>"
-            "<h5>The best model is selected by: Accuracy (Classification)"
-            " or R2 (Regression).</h5>"
-            f"{self.results.to_html(index=False, classes='table sortable')}"
-            "<h2>Best Model's Hyperparameters</h2>"
-            f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}"
-            "<h2>Setup Parameters</h2>"
-            f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}"
-            "<h5>If you want to know all the experiment setup parameters,"
-            " please check the PyCaret documentation for"
-            " the classification/regression <code>exp</code> function.</h5>"
-            "</div>"
-            '<div id="plots" class="tab-content">'
-            f"<h2>Best Model: {model_name}</h2>"
-            "<h5>The best model is selected by: Accuracy (Classification)"
-            " or R2 (Regression).</h5>"
-            "<h2>Test Metrics</h2>"
-            f"{self.test_result_df.to_html(index=False)}"
-            "<h2>Test Results</h2>"
-            f"{plots_html}"
-            "</div>"
-            '<div id="feature" class="tab-content">'
-            f"{feature_importance_html}"
-            "</div>"
+        # 7) Assemble final HTML (three tabs)
+        html = get_html_template()
+        html += "<h1>Tabular Learner Model Report</h1>"
+        html += build_tabbed_html(summary_html, test_html, feature_html)
+        html += get_feature_metrics_help_modal()
+        html += get_html_closing()
+
+        # 8) Write out
+        (Path(self.output_dir) / "comparison_result.html").write_text(
+            html, encoding="utf-8"
         )
-        if self.plots_explainer_html:
-            html_content += (
-                '<div id="explainer" class="tab-content">'
-                f"{self.plots_explainer_html}"
-                f"{tree_plots}"
-                "</div>"
-            )
-        html_content += (
-            "<script>"
-            "document.addEventListener(\"DOMContentLoaded\", function() {"
-            "var tables = document.querySelectorAll(\"table.sortable\");"
-            "tables.forEach(function(table) {"
-            "var headers = table.querySelectorAll(\"th\");"
-            "headers.forEach(function(header, index) {"
-            "header.style.cursor = \"pointer\";"
-            "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)"
-            "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';"
-            "header.addEventListener(\"click\", function() {"
-            "var direction = this.getAttribute("
-            "\"data-sort-direction\""
-            ") || \"asc\";"
-            "// Reset arrows in all headers of this table"
-            "headers.forEach(function(h) {"
-            "var arrow = h.querySelector(\".sort-arrow\");"
-            "if (arrow) arrow.textContent = \" ↑\";"
-            "});"
-            "// Set arrow for clicked header"
-            "var arrow = this.querySelector(\".sort-arrow\");"
-            "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";"
-            "sortTable(table, index, direction);"
-            "this.setAttribute(\"data-sort-direction\","
-            "direction === \"asc\" ? \"desc\" : \"asc\");"
-            "});"
-            "});"
-            "});"
-            "});"
-            "function sortTable(table, colNum, direction) {"
-            "var tb = table.tBodies[0];"
-            "var tr = Array.prototype.slice.call(tb.rows, 0);"
-            "var multiplier = direction === \"asc\" ? 1 : -1;"
-            "tr = tr.sort(function(a, b) {"
-            "var aText = a.cells[colNum].textContent.trim();"
-            "var bText = b.cells[colNum].textContent.trim();"
-            "// Remove arrow from text comparison"
-            "aText = aText.replace(/[↑↓]/g, '').trim();"
-            "bText = bText.replace(/[↑↓]/g, '').trim();"
-            "if (!isNaN(aText) && !isNaN(bText)) {"
-            "return multiplier * ("
-            "parseFloat(aText) - parseFloat(bText)"
-            ");"
-            "} else {"
-            "return multiplier * aText.localeCompare(bText);"
-            "}"
-            "});"
-            "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);"
-            "}"
-            "</script>"
-        )
-        # --- Add the Feature Metrics Help Modal ---
-        html_content += get_feature_metrics_help_modal()
-        html_content += f"{get_html_closing()}"
-        with open(
-            os.path.join(self.output_dir, "comparison_result.html"),
-            "w",
-            encoding="utf-8",
-        ) as file:
-            file.write(html_content)
+        LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html")
 
     def save_dashboard(self):
         raise NotImplementedError("Subclasses should implement this method")
@@ -426,29 +521,18 @@
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
 
-        is_rf = isinstance(
-            self.best_model, (RandomForestClassifier, RandomForestRegressor)
-        )
-        is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor))
-
-        num_trees = None
-        if is_rf:
-            num_trees = self.best_model.n_estimators
-        elif is_xgb:
-            num_trees = len(self.best_model.get_booster().get_dump())
+        if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)):
+            n_trees = self.best_model.n_estimators
+        elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)):
+            n_trees = len(self.best_model.get_booster().get_dump())
         else:
             LOG.warning("Tree plots not supported for this model type.")
             return
 
-        try:
-            explainer = RandomForestExplainer(self.best_model, X_test, y_test)
-            for i in range(num_trees):
-                fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
-                LOG.info(f"Tree {i + 1}")
-                LOG.info(fig)
-                self.trees.append(fig)
-        except Exception as e:
-            LOG.error(f"Error generating tree plots: {e}")
+        explainer = RandomForestExplainer(self.best_model, X_test, y_test)
+        for i in range(n_trees):
+            fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
+            self.trees.append(fig)
 
     def run(self):
         self.load_data()
--- a/feature_help_modal.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/feature_help_modal.py	Fri Jul 25 19:02:32 2025 +0000
@@ -1,104 +1,146 @@
 def get_feature_metrics_help_modal() -> str:
     modal_html = """
-<div id="featureMetricsHelpModal" class="modal">
+<div id="metricsHelpModal" class="modal">
   <div class="modal-content">
-    <span class="close-feature-metrics">&times;</span>
+    <span class="close">&times;</span>
     <h2>Help Guide: Common Model Metrics</h2>
-    <div class="metrics-guide" style="max-height:65vh;overflow-y:auto;font-size:1.04em;">
-      <h3>1) General Metrics</h3>
-      <h4>Classification</h4>
-      <p><strong>Accuracy:</strong> The proportion of correct predictions among all predictions. It is calculated as (TP + TN) / (TP + TN + FP + FN). While intuitive, Accuracy can be misleading for imbalanced datasets where one class dominates. For example, in a dataset with 95% negative cases, a model predicting all negatives achieves 95% Accuracy but fails to identify positives.</p>
-      <p><strong>AUC (Area Under the Curve):</strong> Specifically, the Area Under the Receiver Operating Characteristic Curve (ROC-AUC) measures a model’s ability to distinguish between classes. It ranges from 0 to 1, where 1 indicates perfect separation and 0.5 suggests random guessing. ROC-AUC is robust for binary and multiclass problems but may be less informative for highly imbalanced datasets.</p>
-      <h4>Regression</h4>
-      <p><strong>R2 (Coefficient of Determination):</strong> Measures the proportion of variance in the dependent variable explained by the independent variables. It ranges from 0 to 1, with 1 indicating perfect prediction and 0 indicating no explanatory power. Negative values are possible if the model performs worse than a mean-based baseline. R2 is widely used but sensitive to outliers.</p>
-      <p><strong>RMSE (Root Mean Squared Error):</strong> The square root of the average squared differences between predicted and actual values. It penalizes larger errors more heavily and is expressed in the same units as the target variable, making it interpretable. Lower RMSE indicates better model performance.</p>
-      <p><strong>MAE (Mean Absolute Error):</strong> The average of absolute differences between predicted and actual values. It is less sensitive to outliers than RMSE and provides a straightforward measure of average error magnitude. Lower MAE is better.</p>
+    <div class="metrics-guide">
+
+      <!-- Classification Metrics -->
+      <h3>1) Classification Metrics</h3>
+
+      <p><strong>Accuracy:</strong>
+      The proportion of correct predictions over all predictions:<br>
+      <code>(TP + TN) / (TP + TN + FP + FN)</code>.
+      <em>Use when</em> classes are balanced and you want a single easy‐to‐interpret number.</p>
+
+      <p><strong>Precision:</strong>
+      The fraction of positive predictions that are actually positive:<br>
+      <code>TP / (TP + FP)</code>.
+      <em>Use when</em> false positives are costly (e.g. spam filter—better to miss some spam than flag good mail).</p>
 
-      <h3>2) Precision, Recall & Specificity</h3>
-      <h4>Classification</h4>
-      <p><strong>Precision:</strong> The proportion of positive predictions that are correct, calculated as TP / (TP + FP). High Precision is crucial when false positives are costly, such as in spam email detection, where misclassifying legitimate emails as spam disrupts user experience.</p>
-      <p><strong>Recall (Sensitivity):</strong> The proportion of actual positives correctly predicted, calculated as TP / (TP + FN). High Recall is vital when missing positives is risky, such as in disease diagnosis, where failing to identify a sick patient could have severe consequences.</p>
-      <p><strong>Specificity:</strong> The true negative rate, calculated as TN / (TN + FP). It measures how well a model identifies negatives, making it valuable in medical testing to minimize false alarms (e.g., incorrectly diagnosing healthy patients as sick).</p>
+      <p><strong>Recall (Sensitivity):</strong>
+      The fraction of actual positives correctly identified:<br>
+      <code>TP / (TP + FN)</code>
+      <em>Use when</em> false negatives are costly (e.g. disease screening—don’t miss sick patients).</p>
+
+      <p><strong>F1 Score:</strong>
+      The harmonic mean of Precision and Recall:<br>
+      <code>2·(Precision·Recall)/(Precision+Recall)</code>
+      <em>Use when</em> you need a balance between Precision & Recall on an imbalanced dataset.</p>
+
+      <p><strong>ROC-AUC (Area Under ROC Curve):</strong>
+      Measures ability to distinguish classes across all thresholds.
+      Ranges from 0.5 (random) to 1 (perfect).
+      <em>Use when</em> you care about ranking positives above negatives.</p>
+
+      <p><strong>PR-AUC (Area Under Precision-Recall Curve):</strong>
+      Summarizes Precision vs. Recall trade-off.
+      More informative than ROC-AUC when positives are rare.
+      <em>Use when</em> dealing with highly imbalanced data.</p>
 
-      <h3>3) Macro, Micro, and Weighted Averages</h3>
-      <h4>Classification</h4>
-      <p><strong>Macro Precision / Recall / F1:</strong> Computes the metric for each class independently and averages them, treating all classes equally. This is ideal for balanced datasets or when all classes are equally important, such as in multiclass image classification with similar class frequencies.</p>
-      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives (TP), false positives (FP), and false negatives (FN) across all classes before computing the metric. It provides a global perspective and is suitable for imbalanced datasets or multilabel problems, as it accounts for class frequency.</p>
-      <p><strong>Weighted Precision / Recall / F1:</strong> Averages the metric across classes, weighted by the number of true instances per class. This balances the importance of classes based on their frequency, making it useful for imbalanced datasets where larger classes should have more influence but smaller classes are still considered.</p>
+      <p><strong>Log Loss:</strong>
+      Penalizes confident wrong predictions via negative log-likelihood.
+      Lower is better.
+      <em>Use when</em> you need well-calibrated probability estimates.</p>
+
+      <p><strong>Cohen’s Kappa:</strong>
+      Measures agreement between predictions and true labels accounting for chance.
+      1 is perfect, 0 is random.
+      <em>Use when</em> you want to factor out chance agreement.</p>
 
-      <h3>4) Average Precision (PR-AUC Variants)</h3>
-      <h4>Classification</h4>
-      <p><strong>Average Precision:</strong> The Area Under the Precision-Recall Curve (PR-AUC) summarizes the trade-off between Precision and Recall. It is particularly useful for imbalanced datasets, where ROC-AUC may overestimate performance. Average Precision is computed by averaging Precision values at different Recall thresholds, providing a robust measure for ranking tasks or rare class detection.</p>
+      <hr>
+
+      <!-- Regression Metrics -->
+      <h3>2) Regression Metrics</h3>
+
+      <p><strong>R² (Coefficient of Determination):</strong>
+      Proportion of variance in the target explained by features:<br>
+      1 is perfect, 0 means no better than predicting the mean, negative is worse than mean.
+      <em>Use when</em> you want a normalized measure of fit.</p>
 
-      <h3>5) ROC-AUC Variants</h3>
-      <h4>Classification</h4>
-      <p><strong>ROC-AUC:</strong> The Area Under the Receiver Operating Characteristic Curve plots the true positive rate (Recall) against the false positive rate (1 - Specificity) at various thresholds. It quantifies the model’s ability to separate classes, with higher values indicating better performance.</p>
-      <p><strong>Macro ROC-AUC:</strong> Averages the ROC-AUC scores across all classes, treating each class equally. This is suitable for balanced multiclass problems where all classes are of equal importance.</p>
-      <p><strong>Micro ROC-AUC:</strong> Computes a single ROC-AUC by aggregating predictions and true labels across all classes. It is effective for multiclass or multilabel problems with class imbalance, as it accounts for the overall prediction distribution.</p>
+      <p><strong>MAE (Mean Absolute Error):</strong>
+      Average absolute difference between predictions and actual values:<br>
+      <code>mean(|y_pred − y_true|)</code>
+      <em>Use when</em> you need an interpretable “average” error and want to downweight outliers.</p>
+
+      <p><strong>RMSE (Root Mean Squared Error):</strong>
+      Square root of the average squared errors:<br>
+      <code>√mean((y_pred − y_true)²)</code>.
+      Penalizes large errors more heavily.
+      <em>Use when</em> large deviations are especially undesirable.</p>
 
-      <h3>6) Confusion Matrix Stats (Per Class)</h3>
-      <h4>Classification</h4>
-      <p><strong>True Positives (TP):</strong> The number of correct positive predictions for a given class.</p>
-      <p><strong>True Negatives (TN):</strong> The number of correct negative predictions for a given class.</p>
-      <p><strong>False Positives (FP):</strong> The number of incorrect positive predictions for a given class (false alarms).</p>
-      <p><strong>False Negatives (FN):</strong> The number of incorrect negative predictions for a given class (missed detections). These stats are visualized in PyCaret’s confusion matrix plots, aiding class-wise performance analysis.</p>
+      <p><strong>MSE (Mean Squared Error):</strong>
+      The average squared error:<br>
+      <code>mean((y_pred − y_true)²)</code>.
+      Similar to RMSE but in squared units; often used in optimization.</p>
 
-      <h3>7) Other Useful Metrics</h3>
-      <h4>Classification</h4>
-      <p><strong>Cohen’s Kappa:</strong> Measures the agreement between predicted and actual labels, adjusted for chance. It ranges from -1 to 1, where 1 indicates perfect agreement, 0 indicates chance-level agreement, and negative values suggest worse-than-chance performance. Kappa is useful for multiclass problems with imbalanced labels.</p>
-      <p><strong>Matthews Correlation Coefficient (MCC):</strong> A balanced measure that considers TP, TN, FP, and FN, calculated as (TP * TN - FP * FN) / sqrt((TP + FP)(TP + FN)(TN + FP)(TN + FN)). It ranges from -1 to 1, with 1 being perfect prediction. MCC is particularly effective for imbalanced datasets due to its symmetry across classes.</p>
-      <h4>Regression</h4>
-      <p><strong>MSE (Mean Squared Error):</strong> The average of squared differences between predicted and actual values. It amplifies larger errors, making it sensitive to outliers. Lower MSE indicates better performance.</p>
-      <p><strong>MAPE (Mean Absolute Percentage Error):</strong> The average of absolute percentage differences between predicted and actual values, calculated as (1/n) * Σ(|actual - predicted| / |actual|) * 100. It is useful when relative errors are important but can be unstable if actual values are near zero.</p>
+      <p><strong>RMSLE (Root Mean Squared Log Error):</strong>
+      <code>√mean((log(1+y_pred) − log(1+y_true))²)</code>.
+      Less sensitive to large differences when both true and predicted are large.
+      <em>Use when</em> target spans several orders of magnitude.</p>
+
+      <p><strong>MAPE (Mean Absolute Percentage Error):</strong>
+      <code>mean(|(y_true − y_pred)/y_true|)·100</code>.
+      Expresses error as a percentage.
+      <em>Use when</em> relative error matters—but avoid if y_true≈0.</p>
+
     </div>
   </div>
 </div>
 """
+
     modal_css = """
 <style>
-/* Modal Background & Content */
-#featureMetricsHelpModal.modal {
+.modal {
   display: none;
   position: fixed;
-  z-index: 9999;
-  left: 0; top: 0;
-  width: 100%; height: 100%;
+  z-index: 1;
+  left: 0;
+  top: 0;
+  width: 100%;
+  height: 100%;
   overflow: auto;
-  background-color: rgba(0,0,0,0.45);
+  background-color: rgba(0,0,0,0.4);
 }
-#featureMetricsHelpModal .modal-content {
+.modal-content {
   background-color: #fefefe;
-  margin: 5% auto;
-  padding: 24px 28px 20px 28px;
-  border: 1.5px solid #17623b;
-  width: 90%;
+  margin: 15% auto;
+  padding: 20px;
+  border: 1px solid #888;
+  width: 80%;
   max-width: 800px;
-  border-radius: 18px;
-  box-shadow: 0 8px 32px rgba(23,98,59,0.20);
 }
-#featureMetricsHelpModal .close-feature-metrics {
-  color: #17623b;
+.close {
+  color: #aaa;
   float: right;
   font-size: 28px;
   font-weight: bold;
+}
+.close:hover,
+.close:focus {
+  color: black;
+  text-decoration: none;
   cursor: pointer;
-  transition: color 0.2s;
 }
-#featureMetricsHelpModal .close-feature-metrics:hover {
-  color: #21895e;
+.metrics-guide h3 {
+  margin-top: 20px;
 }
-.metrics-guide h3 { margin-top: 20px; }
-.metrics-guide h4 { margin-top: 12px; color: #17623b; }
-.metrics-guide p { margin: 5px 0 10px 0; }
-.metrics-guide ul { margin: 10px 0 10px 24px; }
+.metrics-guide p {
+  margin: 5px 0;
+}
+.metrics-guide ul {
+  margin: 10px 0;
+  padding-left: 20px;
+}
 </style>
 """
     modal_js = """
 <script>
 document.addEventListener("DOMContentLoaded", function() {
-  var modal = document.getElementById("featureMetricsHelpModal");
-  var openBtn = document.getElementById("openFeatureMetricsHelp");
-  var span = document.getElementsByClassName("close-feature-metrics")[0];
+  var modal = document.getElementById("metricsHelpModal");
+  var openBtn = document.getElementById("openMetricsHelp");
+  var span = document.getElementsByClassName("close")[0];
   if (openBtn && modal) {
     openBtn.onclick = function() {
       modal.style.display = "block";
--- a/feature_importance.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/feature_importance.py	Fri Jul 25 19:02:32 2025 +0000
@@ -14,14 +14,15 @@
 
 class FeatureImportanceAnalyzer:
     def __init__(
-            self,
-            task_type,
-            output_dir,
-            data_path=None,
-            data=None,
-            target_col=None,
-            exp=None,
-            best_model=None):
+        self,
+        task_type,
+        output_dir,
+        data_path=None,
+        data=None,
+        target_col=None,
+        exp=None,
+        best_model=None,
+    ):
 
         self.task_type = task_type
         self.output_dir = output_dir
@@ -43,7 +44,11 @@
                 self.data.columns = self.data.columns.str.replace('.', '_')
                 self.data = self.data.fillna(self.data.median(numeric_only=True))
             self.target = self.data.columns[int(target_col) - 1]
-            self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment()
+            self.exp = (
+                ClassificationExperiment()
+                if task_type == "classification"
+                else RegressionExperiment()
+            )
 
         self.plots = {}
 
@@ -57,7 +62,7 @@
             'session_id': 123,
             'html': True,
             'log_experiment': False,
-            'system_log': False
+            'system_log': False,
         }
         self.exp.setup(self.data, **setup_params)
 
@@ -70,14 +75,16 @@
         model_type = model.__class__.__name__
         self.tree_model_name = model_type  # Store the model name for reporting
 
-        if hasattr(model, "feature_importances_"):
+        if hasattr(model, 'feature_importances_'):
             importances = model.feature_importances_
-        elif hasattr(model, "coef_"):
+        elif hasattr(model, 'coef_'):
             # For linear models, flatten coef_ and take abs (importance as magnitude)
             importances = abs(model.coef_).flatten()
         else:
             # Neither attribute exists; skip the plot
-            LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.")
+            LOG.warning(
+                f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot."
+            )
             self.tree_model_name = None  # No plot generated
             return
 
@@ -89,69 +96,77 @@
             self.tree_model_name = None
             return
 
-        feature_importances = pd.DataFrame({
-            'Feature': processed_features,
-            'Importance': importances
-        }).sort_values(by='Importance', ascending=False)
+        feature_importances = pd.DataFrame(
+            {'Feature': processed_features, 'Importance': importances}
+        ).sort_values(by='Importance', ascending=False)
         plt.figure(figsize=(10, 6))
-        plt.barh(
-            feature_importances['Feature'],
-            feature_importances['Importance'])
+        plt.barh(feature_importances['Feature'], feature_importances['Importance'])
         plt.xlabel('Importance')
         plt.title(f'Feature Importance ({model_type})')
-        plot_path = os.path.join(
-            self.output_dir,
-            'tree_importance.png')
+        plot_path = os.path.join(self.output_dir, 'tree_importance.png')
         plt.savefig(plot_path)
         plt.close()
         self.plots['tree_importance'] = plot_path
 
     def save_shap_values(self):
-        model = self.best_model or self.exp.get_config('best_model')
-        X_transformed = self.exp.get_config('X_transformed')
-        tree_classes = (
-            "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting"
-        )
-        model_class_name = model.__class__.__name__
-        self.shap_model_name = model_class_name
+
+        model = self.best_model or self.exp.get_config("best_model")
 
-        # Ensure feature alignment
-        if hasattr(model, "feature_name_"):
-            used_features = model.feature_name_
-        elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"):
+        X_data = None
+        for key in ("X_test_transformed", "X_train_transformed"):
+            try:
+                X_data = self.exp.get_config(key)
+                break
+            except KeyError:
+                continue
+        if X_data is None:
+            raise RuntimeError(
+                "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. "
+                "Make sure PyCaret setup/compare_models was run with feature_selection=True."
+            )
+
+        try:
             used_features = model.booster_.feature_name()
-        elif hasattr(model, "feature_names_in_"):
-            # scikit‐learn's standard attribute for the names of features used during fit
-            used_features = list(model.feature_names_in_)
-        else:
-            used_features = X_transformed.columns
+        except Exception:
+            used_features = getattr(model, "feature_names_in_", X_data.columns.tolist())
+        X_data = X_data[used_features]
+
+        max_bg = min(len(X_data), 100)
+        bg = X_data.sample(max_bg, random_state=42)
+
+        predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict
 
-        if any(tc in model_class_name for tc in tree_classes):
-            explainer = shap.TreeExplainer(model)
-            X_shap = X_transformed[used_features]
-            shap_values = explainer.shap_values(X_shap)
-            plot_X = X_shap
-            plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)"
+        explainer = shap.Explainer(predict_fn, bg)
+        self.shap_model_name = explainer.__class__.__name__
+
+        shap_values = explainer(X_data)
+
+        output_names = getattr(shap_values, "output_names", None)
+        if output_names is None and hasattr(model, "classes_"):
+            output_names = list(model.classes_)
+        if output_names is None:
+            n_out = shap_values.values.shape[-1]
+            output_names = list(map(str, range(n_out)))
+
+        values = shap_values.values
+        if values.ndim == 3:
+            for j, name in enumerate(output_names):
+                safe = name.replace(" ", "_").replace("/", "_")
+                out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png")
+                plt.figure()
+                shap.plots.beeswarm(shap_values[..., j], show=False)
+                plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}")
+                plt.savefig(out_path)
+                plt.close()
+                self.plots[f"shap_summary_{safe}"] = out_path
         else:
-            logging.warning(f"len(X_transformed) = {len(X_transformed)}")
-            max_samples = 100
-            n_samples = min(max_samples, len(X_transformed))
-            sampled_X = X_transformed[used_features].sample(
-                n=n_samples,
-                replace=False,
-                random_state=42
-            )
-            explainer = shap.KernelExplainer(model.predict, sampled_X)
-            shap_values = explainer.shap_values(sampled_X)
-            plot_X = sampled_X
-            plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)"
-
-        shap.summary_plot(shap_values, plot_X, show=False)
-        plt.title(plot_title)
-        plot_path = os.path.join(self.output_dir, "shap_summary.png")
-        plt.savefig(plot_path)
-        plt.close()
-        self.plots["shap_summary"] = plot_path
+            plt.figure()
+            shap.plots.beeswarm(shap_values, show=False)
+            plt.title(f"SHAP Summary for {model.__class__.__name__}")
+            out_path = os.path.join(self.output_dir, "shap_summary.png")
+            plt.savefig(out_path)
+            plt.close()
+            self.plots["shap_summary"] = out_path
 
     def generate_html_report(self):
         LOG.info("Generating HTML report")
@@ -159,11 +174,17 @@
         plots_html = ""
         for plot_name, plot_path in self.plots.items():
             # Special handling for tree importance: skip if no model name (not generated)
-            if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None):
+            if plot_name == 'tree_importance' and not getattr(
+                self, 'tree_model_name', None
+            ):
                 continue
             encoded_image = self.encode_image_to_base64(plot_path)
-            if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None):
-                section_title = f"Feature importance analysis from a trained {self.tree_model_name}"
+            if plot_name == 'tree_importance' and getattr(
+                self, 'tree_model_name', None
+            ):
+                section_title = (
+                    f"Feature importance analysis from a trained {self.tree_model_name}"
+                )
             elif plot_name == 'shap_summary':
                 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
             else:
@@ -176,7 +197,6 @@
             """
 
         html_content = f"""
-            <h1>PyCaret Feature Importance Report</h1>
             {plots_html}
         """
 
@@ -187,7 +207,11 @@
             return base64.b64encode(img_file.read()).decode('utf-8')
 
     def run(self):
-        if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup:
+        if (
+            self.exp is None
+            or not hasattr(self.exp, 'is_setup')
+            or not self.exp.is_setup
+        ):
             self.setup_pycaret()
         self.save_tree_importance()
         self.save_shap_values()
--- a/pycaret_classification.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/pycaret_classification.py	Fri Jul 25 19:02:32 2025 +0000
@@ -1,23 +1,27 @@
 import logging
+import types
+from typing import Dict
 
 from base_model_trainer import BaseModelTrainer
 from dashboard import generate_classifier_explainer_dashboard
+from plotly.graph_objects import Figure
 from pycaret.classification import ClassificationExperiment
-from utils import add_hr_to_html, add_plot_to_html, predict_proba
+from utils import predict_proba
 
 LOG = logging.getLogger(__name__)
 
 
 class ClassificationModelTrainer(BaseModelTrainer):
     def __init__(
-            self,
-            input_file,
-            target_col,
-            output_dir,
-            task_type,
-            random_seed,
-            test_file=None,
-            **kwargs):
+        self,
+        input_file,
+        target_col,
+        output_dir,
+        task_type,
+        random_seed,
+        test_file=None,
+        **kwargs,
+    ):
         super().__init__(
             input_file,
             target_col,
@@ -25,191 +29,134 @@
             task_type,
             random_seed,
             test_file,
-            **kwargs)
+            **kwargs,
+        )
         self.exp = ClassificationExperiment()
 
     def save_dashboard(self):
         LOG.info("Saving explainer dashboard")
-        dashboard = generate_classifier_explainer_dashboard(self.exp,
-                                                            self.best_model)
+        dashboard = generate_classifier_explainer_dashboard(self.exp, self.best_model)
         dashboard.save_html("dashboard.html")
 
     def generate_plots(self):
         LOG.info("Generating and saving plots")
 
         if not hasattr(self.best_model, "predict_proba"):
-            import types
             self.best_model.predict_proba = types.MethodType(
-                predict_proba, self.best_model)
+                predict_proba, self.best_model
+            )
             LOG.warning(
-                f"The model {type(self.best_model).__name__}\
-                    does not support `predict_proba`. \
-                    Applying monkey patch.")
+                f"The model {type(self.best_model).__name__} does not support `predict_proba`. Applying monkey patch."
+            )
 
-        plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
-                 'error', 'class_report', 'learning', 'calibration',
-                 'vc', 'dimension', 'manifold', 'rfe', 'feature',
-                 'feature_all']
+        plots = [
+            'confusion_matrix',
+            'auc',
+            'threshold',
+            'pr',
+            'error',
+            'class_report',
+            'learning',
+            'calibration',
+            'vc',
+            'dimension',
+            'manifold',
+            'rfe',
+            'feature',
+            'feature_all',
+        ]
         for plot_name in plots:
             try:
-                if plot_name == 'auc' and not self.exp.is_multiclass:
-                    plot_path = self.exp.plot_model(self.best_model,
-                                                    plot=plot_name,
-                                                    save=True,
-                                                    plot_kwargs={
-                                                        'micro': False,
-                                                        'macro': False,
-                                                        'per_class': False,
-                                                        'binary': True
-                                                    })
+                if plot_name == "threshold":
+                    plot_path = self.exp.plot_model(
+                        self.best_model,
+                        plot=plot_name,
+                        save=True,
+                        plot_kwargs={"binary": True, "percentage": True},
+                    )
                     self.plots[plot_name] = plot_path
-                    continue
-
-                plot_path = self.exp.plot_model(self.best_model,
-                                                plot=plot_name, save=True)
-                self.plots[plot_name] = plot_path
+                elif plot_name == "auc" and not self.exp.is_multiclass:
+                    plot_path = self.exp.plot_model(
+                        self.best_model,
+                        plot=plot_name,
+                        save=True,
+                        plot_kwargs={
+                            "micro": False,
+                            "macro": False,
+                            "per_class": False,
+                            "binary": True,
+                        },
+                    )
+                    self.plots[plot_name] = plot_path
+                else:
+                    plot_path = self.exp.plot_model(
+                        self.best_model, plot=plot_name, save=True
+                    )
+                    self.plots[plot_name] = plot_path
             except Exception as e:
                 LOG.error(f"Error generating plot {plot_name}: {e}")
                 continue
 
     def generate_plots_explainer(self):
-        LOG.info("Generating and saving plots from explainer")
+        from explainerdashboard import ClassifierExplainer
 
-        from explainerdashboard import ClassifierExplainer
+        LOG.info("Generating explainer plots")
 
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
-
-        try:
-            explainer = ClassifierExplainer(self.best_model, X_test, y_test)
-            self.expaliner = explainer
-            plots_explainer_html = ""
-        except Exception as e:
-            LOG.error(f"Error creating explainer: {e}")
-            self.plots_explainer_html = None
-            return
+        explainer = ClassifierExplainer(self.best_model, X_test, y_test)
 
-        try:
-            fig_importance = explainer.plot_importances()
-            plots_explainer_html += add_plot_to_html(fig_importance)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot importance(mean shap): {e}")
-
-        try:
-            fig_importance_perm = explainer.plot_importances(
-                kind="permutation")
-            plots_explainer_html += add_plot_to_html(fig_importance_perm)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot importance(permutation): {e}")
-
-        # try:
-        #     fig_shap = explainer.plot_shap_summary()
-        #     plots_explainer_html += add_plot_to_html(fig_shap,
-        #       include_plotlyjs=False)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot shap: {e}")
+        # a dict to hold the raw Figure objects or callables
+        self.explainer_plots: Dict[str, Figure] = {}
 
-        # try:
-        #     fig_contributions = explainer.plot_contributions(
-        #       index=0)
-        #     plots_explainer_html += add_plot_to_html(
-        #       fig_contributions, include_plotlyjs=False)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot contributions: {e}")
+        # these go into the Test tab
+        for key, fn in [
+            ("roc_auc", explainer.plot_roc_auc),
+            ("pr_auc", explainer.plot_pr_auc),
+            ("lift_curve", explainer.plot_lift_curve),
+            ("confusion_matrix", explainer.plot_confusion_matrix),
+            ("threshold", explainer.plot_precision),  # Percentage vs probability
+            ("cumulative_precision", explainer.plot_cumulative_precision),
+        ]:
+            try:
+                self.explainer_plots[key] = fn()
+            except Exception as e:
+                LOG.error(f"Error generating explainer plot {key}: {e}")
 
-        # try:
-        #     for feature in self.features_name:
-        #         fig_dependence = explainer.plot_dependence(col=feature)
-        #         plots_explainer_html += add_plot_to_html(fig_dependence)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot dependencies: {e}")
-
+        # mean SHAP importances
         try:
-            for feature in self.features_name:
-                fig_pdp = explainer.plot_pdp(feature)
-                plots_explainer_html += add_plot_to_html(fig_pdp)
-                plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["shap_mean"] = explainer.plot_importances()
         except Exception as e:
-            LOG.error(f"Error generating plot pdp: {e}")
+            LOG.warning(f"Could not generate shap_mean: {e}")
 
-        try:
-            for feature in self.features_name:
-                fig_interaction = explainer.plot_interaction(
-                    col=feature, interact_col=feature)
-                plots_explainer_html += add_plot_to_html(fig_interaction)
-        except Exception as e:
-            LOG.error(f"Error generating plot interactions: {e}")
-
+        # permutation importances
         try:
-            for feature in self.features_name:
-                fig_interactions_importance = \
-                    explainer.plot_interactions_importance(
-                        col=feature)
-                plots_explainer_html += add_plot_to_html(
-                    fig_interactions_importance)
-                plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["shap_perm"] = lambda: explainer.plot_importances(
+                kind="permutation"
+            )
         except Exception as e:
-            LOG.error(f"Error generating plot interactions importance: {e}")
+            LOG.warning(f"Could not generate shap_perm: {e}")
 
-        # try:
-        #     for feature in self.features_name:
-        #         fig_interactions_detailed = \
-        #           explainer.plot_interactions_detailed(
-        #               col=feature)
-        #         plots_explainer_html += add_plot_to_html(
-        #           fig_interactions_detailed)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot interactions detailed: {e}")
-
-        try:
-            fig_precision = explainer.plot_precision()
-            plots_explainer_html += add_plot_to_html(fig_precision)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot precision: {e}")
-
-        try:
-            fig_cumulative_precision = explainer.plot_cumulative_precision()
-            plots_explainer_html += add_plot_to_html(fig_cumulative_precision)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot cumulative precision: {e}")
+        # PDPs for each feature (appended last)
+        valid_feats = []
+        for feat in self.features_name:
+            if feat in explainer.X.columns or feat in explainer.onehot_cols:
+                valid_feats.append(feat)
+            else:
+                LOG.warning(f"Skipping PDP for feature {feat!r}: not found in explainer data")
 
-        try:
-            fig_classification = explainer.plot_classification()
-            plots_explainer_html += add_plot_to_html(fig_classification)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot classification: {e}")
-
-        try:
-            fig_confusion_matrix = explainer.plot_confusion_matrix()
-            plots_explainer_html += add_plot_to_html(fig_confusion_matrix)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot confusion matrix: {e}")
+        for feat in valid_feats:
+            # wrap each PDP call to catch any unexpected AssertionErrors
+            def make_pdp_plotter(f):
+                def _plot():
+                    try:
+                        return explainer.plot_pdp(f)
+                    except AssertionError as ae:
+                        LOG.warning(f"PDP AssertionError for {f!r}: {ae}")
+                        return None
+                    except Exception as e:
+                        LOG.error(f"Unexpected error plotting PDP for {f!r}: {e}")
+                        return None
+                return _plot
 
-        try:
-            fig_lift_curve = explainer.plot_lift_curve()
-            plots_explainer_html += add_plot_to_html(fig_lift_curve)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot lift curve: {e}")
-
-        try:
-            fig_roc_auc = explainer.plot_roc_auc()
-            plots_explainer_html += add_plot_to_html(fig_roc_auc)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot roc auc: {e}")
-
-        try:
-            fig_pr_auc = explainer.plot_pr_auc()
-            plots_explainer_html += add_plot_to_html(fig_pr_auc)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot pr auc: {e}")
-
-        self.plots_explainer_html = plots_explainer_html
+            self.explainer_plots[f"pdp__{feat}"] = make_pdp_plotter(feat)
--- a/pycaret_predict.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/pycaret_predict.py	Fri Jul 25 19:02:32 2025 +0000
@@ -10,6 +10,7 @@
 from sklearn.metrics import average_precision_score
 from utils import encode_image_to_base64, get_html_closing, get_html_template
 
+
 LOG = logging.getLogger(__name__)
 
 
--- a/pycaret_regression.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/pycaret_regression.py	Fri Jul 25 19:02:32 2025 +0000
@@ -3,21 +3,21 @@
 from base_model_trainer import BaseModelTrainer
 from dashboard import generate_regression_explainer_dashboard
 from pycaret.regression import RegressionExperiment
-from utils import add_hr_to_html, add_plot_to_html
 
 LOG = logging.getLogger(__name__)
 
 
 class RegressionModelTrainer(BaseModelTrainer):
     def __init__(
-            self,
-            input_file,
-            target_col,
-            output_dir,
-            task_type,
-            random_seed,
-            test_file=None,
-            **kwargs):
+        self,
+        input_file,
+        target_col,
+        output_dir,
+        task_type,
+        random_seed,
+        test_file=None,
+        **kwargs,
+    ):
         super().__init__(
             input_file,
             target_col,
@@ -25,24 +25,35 @@
             task_type,
             random_seed,
             test_file,
-            **kwargs)
+            **kwargs,
+        )
+        # The BaseModelTrainer.setup_pycaret will set self.exp appropriately
+        # But we reassign here for clarity
         self.exp = RegressionExperiment()
 
     def save_dashboard(self):
         LOG.info("Saving explainer dashboard")
-        dashboard = generate_regression_explainer_dashboard(self.exp,
-                                                            self.best_model)
+        dashboard = generate_regression_explainer_dashboard(self.exp, self.best_model)
         dashboard.save_html("dashboard.html")
 
     def generate_plots(self):
         LOG.info("Generating and saving plots")
-        plots = ['residuals', 'error', 'cooks',
-                 'learning', 'vc', 'manifold',
-                 'rfe', 'feature', 'feature_all']
+        plots = [
+            "residuals",
+            "error",
+            "cooks",
+            "learning",
+            "vc",
+            "manifold",
+            "rfe",
+            "feature",
+            "feature_all",
+        ]
         for plot_name in plots:
             try:
-                plot_path = self.exp.plot_model(self.best_model,
-                                                plot=plot_name, save=True)
+                plot_path = self.exp.plot_model(
+                    self.best_model, plot=plot_name, save=True
+                )
                 self.plots[plot_name] = plot_path
             except Exception as e:
                 LOG.error(f"Error generating plot {plot_name}: {e}")
@@ -58,79 +69,60 @@
 
         try:
             explainer = RegressionExplainer(self.best_model, X_test, y_test)
-            self.expaliner = explainer
-            plots_explainer_html = ""
         except Exception as e:
             LOG.error(f"Error creating explainer: {e}")
-            self.plots_explainer_html = None
             return
 
+        # --- 1) SHAP mean impact (average absolute SHAP values) ---
         try:
-            fig_importance = explainer.plot_importances()
-            plots_explainer_html += add_plot_to_html(fig_importance)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot importance: {e}")
-
-        try:
-            fig_importance_permutation = \
-                explainer.plot_importances_permutation(
-                    kind="permutation")
-            plots_explainer_html += add_plot_to_html(
-                fig_importance_permutation)
-            plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["shap_mean"] = explainer.plot_importances()
         except Exception as e:
-            LOG.error(f"Error generating plot importance permutation: {e}")
+            LOG.error(f"Error generating SHAP mean importance: {e}")
 
-        try:
-            for feature in self.features_name:
-                fig_shap = explainer.plot_pdp(feature)
-                plots_explainer_html += add_plot_to_html(fig_shap)
-                plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot shap dependence: {e}")
-
-        # try:
-        #     for feature in self.features_name:
-        #         fig_interaction = explainer.plot_interaction(col=feature)
-        #         plots_explainer_html += add_plot_to_html(fig_interaction)
-        # except Exception as e:
-        #     LOG.error(f"Error generating plot shap interaction: {e}")
-
+        # --- 2) SHAP permutation importance ---
         try:
-            for feature in self.features_name:
-                fig_interactions_importance = \
-                    explainer.plot_interactions_importance(
-                        col=feature)
-                plots_explainer_html += add_plot_to_html(
-                    fig_interactions_importance)
-                plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["shap_perm"] = explainer.plot_importances_permutation(
+                kind="permutation"
+            )
         except Exception as e:
-            LOG.error(f"Error generating plot shap summary: {e}")
+            LOG.error(f"Error generating SHAP permutation importance: {e}")
 
-        # Regression specific plots
-        try:
-            fig_pred_actual = explainer.plot_predicted_vs_actual()
-            plots_explainer_html += add_plot_to_html(fig_pred_actual)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot prediction vs actual: {e}")
+        # Pre-filter features so we never call PDP or residual-vs-feature on missing cols
+        valid_feats = []
+        for feat in self.features_name:
+            if feat in explainer.X.columns or feat in explainer.onehot_cols:
+                valid_feats.append(feat)
+            else:
+                LOG.warning(f"Skipping feature {feat!r}: not found in explainer data")
 
-        try:
-            fig_residuals = explainer.plot_residuals()
-            plots_explainer_html += add_plot_to_html(fig_residuals)
-            plots_explainer_html += add_hr_to_html()
-        except Exception as e:
-            LOG.error(f"Error generating plot residuals: {e}")
+        # --- 3) Partial Dependence Plots (PDPs) per feature ---
+        for feature in valid_feats:
+            try:
+                fig_pdp = explainer.plot_pdp(feature)
+                self.explainer_plots[f"pdp__{feature}"] = fig_pdp
+            except AssertionError as ae:
+                LOG.warning(f"PDP AssertionError for {feature!r}: {ae}")
+            except Exception as e:
+                LOG.error(f"Error generating PDP for {feature}: {e}")
 
+        # --- 4) Predicted vs Actual plot ---
         try:
-            for feature in self.features_name:
-                fig_residuals_vs_feature = \
-                    explainer.plot_residuals_vs_feature(feature)
-                plots_explainer_html += add_plot_to_html(
-                    fig_residuals_vs_feature)
-                plots_explainer_html += add_hr_to_html()
+            self.explainer_plots["predicted_vs_actual"] = explainer.plot_predicted_vs_actual()
+        except Exception as e:
+            LOG.error(f"Error generating Predicted vs Actual plot: {e}")
+
+        # --- 5) Global residuals distribution ---
+        try:
+            self.explainer_plots["residuals"] = explainer.plot_residuals()
         except Exception as e:
-            LOG.error(f"Error generating plot residuals vs feature: {e}")
+            LOG.error(f"Error generating Residuals plot: {e}")
 
-        self.plots_explainer_html = plots_explainer_html
+        # --- 6) Residuals vs each feature ---
+        for feature in valid_feats:
+            try:
+                fig_res_vs_feat = explainer.plot_residuals_vs_feature(feature)
+                self.explainer_plots[f"residuals_vs_feature__{feature}"] = fig_res_vs_feat
+            except AssertionError as ae:
+                LOG.warning(f"Residuals-vs-feature AssertionError for {feature!r}: {ae}")
+            except Exception as e:
+                LOG.error(f"Error generating Residuals vs {feature}: {e}")
--- a/pycaret_train.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/pycaret_train.py	Fri Jul 25 19:02:32 2025 +0000
@@ -12,68 +12,123 @@
     parser = argparse.ArgumentParser()
     parser.add_argument("--input_file", help="Path to the input file")
     parser.add_argument("--target_col", help="Column number of the target")
-    parser.add_argument("--output_dir",
-                        help="Path to the output directory")
-    parser.add_argument("--model_type",
-                        choices=["classification", "regression"],
-                        help="Type of the model")
-    parser.add_argument("--train_size", type=float,
-                        default=None,
-                        help="Train size for PyCaret setup")
-    parser.add_argument("--normalize", action="store_true",
-                        default=None,
-                        help="Normalize data for PyCaret setup")
-    parser.add_argument("--feature_selection", action="store_true",
-                        default=None,
-                        help="Perform feature selection for PyCaret setup")
-    parser.add_argument("--cross_validation", action="store_true",
-                        default=None,
-                        help="Perform cross-validation for PyCaret setup")
-    parser.add_argument("--no_cross_validation", action="store_true",
-                        default=None,
-                        help="Don't perform cross-validation for PyCaret setup")
-    parser.add_argument("--cross_validation_folds", type=int,
-                        default=None,
-                        help="Number of cross-validation folds \
-                          for PyCaret setup")
-    parser.add_argument("--remove_outliers", action="store_true",
-                        default=None,
-                        help="Remove outliers for PyCaret setup")
-    parser.add_argument("--remove_multicollinearity", action="store_true",
-                        default=None,
-                        help="Remove multicollinearity for PyCaret setup")
-    parser.add_argument("--polynomial_features", action="store_true",
-                        default=None,
-                        help="Generate polynomial features for PyCaret setup")
-    parser.add_argument("--feature_interaction", action="store_true",
-                        default=None,
-                        help="Generate feature interactions for PyCaret setup")
-    parser.add_argument("--feature_ratio", action="store_true",
-                        default=None,
-                        help="Generate feature ratios for PyCaret setup")
-    parser.add_argument("--fix_imbalance", action="store_true",
-                        default=None,
-                        help="Fix class imbalance for PyCaret setup")
-    parser.add_argument("--models", nargs='+',
-                        default=None,
-                        help="Selected models for training")
-    parser.add_argument("--random_seed", type=int,
-                        default=42,
-                        help="Random seed for PyCaret setup")
-    parser.add_argument("--test_file", type=str, default=None,
-                        help="Path to the test data file")
+    parser.add_argument("--output_dir", help="Path to the output directory")
+    parser.add_argument(
+        "--model_type",
+        choices=["classification", "regression"],
+        help="Type of the model",
+    )
+    parser.add_argument(
+        "--train_size",
+        type=float,
+        default=None,
+        help="Train size for PyCaret setup",
+    )
+    parser.add_argument(
+        "--normalize",
+        action="store_true",
+        default=None,
+        help="Normalize data for PyCaret setup",
+    )
+    parser.add_argument(
+        "--feature_selection",
+        action="store_true",
+        default=None,
+        help="Perform feature selection for PyCaret setup",
+    )
+    parser.add_argument(
+        "--cross_validation",
+        action="store_true",
+        default=None,
+        help="Enable cross-validation for PyCaret setup",
+    )
+    parser.add_argument(
+        "--no_cross_validation",
+        action="store_true",
+        default=None,
+        help="Disable cross-validation for PyCaret setup",
+    )
+    parser.add_argument(
+        "--cross_validation_folds",
+        type=int,
+        default=None,
+        help="Number of cross-validation folds for PyCaret setup",
+    )
+    parser.add_argument(
+        "--remove_outliers",
+        action="store_true",
+        default=None,
+        help="Remove outliers for PyCaret setup",
+    )
+    parser.add_argument(
+        "--remove_multicollinearity",
+        action="store_true",
+        default=None,
+        help="Remove multicollinearity for PyCaret setup",
+    )
+    parser.add_argument(
+        "--polynomial_features",
+        action="store_true",
+        default=None,
+        help="Generate polynomial features for PyCaret setup",
+    )
+    parser.add_argument(
+        "--feature_interaction",
+        action="store_true",
+        default=None,
+        help="Generate feature interactions for PyCaret setup",
+    )
+    parser.add_argument(
+        "--feature_ratio",
+        action="store_true",
+        default=None,
+        help="Generate feature ratios for PyCaret setup",
+    )
+    parser.add_argument(
+        "--fix_imbalance",
+        action="store_true",
+        default=None,
+        help="Fix class imbalance for PyCaret setup",
+    )
+    parser.add_argument(
+        "--models",
+        nargs="+",
+        default=None,
+        help="Selected models for training",
+    )
+    parser.add_argument(
+        "--tune_model",
+        action="store_true",
+        default=False,
+        help="Tune the best model hyperparameters after training",
+    )
+    parser.add_argument(
+        "--random_seed",
+        type=int,
+        default=42,
+        help="Random seed for PyCaret setup",
+    )
+    parser.add_argument(
+        "--test_file",
+        type=str,
+        default=None,
+        help="Path to the test data file",
+    )
 
     args = parser.parse_args()
 
-    cross_validation = True
+    # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
     if args.no_cross_validation:
-        cross_validation = False
+        args.cross_validation = False
+    # If --cross_validation was passed, args.cross_validation is True
+    # If neither was passed, args.cross_validation remains None
 
+    # Build the model_kwargs dict from CLI args
     model_kwargs = {
         "train_size": args.train_size,
         "normalize": args.normalize,
         "feature_selection": args.feature_selection,
-        "cross_validation": cross_validation,
+        "cross_validation": args.cross_validation,
         "cross_validation_folds": args.cross_validation_folds,
         "remove_outliers": args.remove_outliers,
         "remove_multicollinearity": args.remove_multicollinearity,
@@ -81,17 +136,19 @@
         "feature_interaction": args.feature_interaction,
         "feature_ratio": args.feature_ratio,
         "fix_imbalance": args.fix_imbalance,
+        "tune_model": args.tune_model,
     }
     LOG.info(f"Model kwargs: {model_kwargs}")
 
-    # Remove None values from model_kwargs
-
-    LOG.info(f"Model kwargs 2: {model_kwargs}")
+    # If the XML passed a comma-separated string in a single list element, split it out
     if args.models:
         model_kwargs["models"] = args.models[0].split(",")
 
+    # Drop None entries so PyCaret uses its default values
     model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None}
+    LOG.info(f"Model kwargs 2: {model_kwargs}")
 
+    # Instantiate the appropriate trainer
     if args.model_type == "classification":
         trainer = ClassificationModelTrainer(
             args.input_file,
@@ -100,10 +157,11 @@
             args.model_type,
             args.random_seed,
             args.test_file,
-            **model_kwargs)
+            **model_kwargs,
+        )
     elif args.model_type == "regression":
-        if "fix_imbalance" in model_kwargs:
-            del model_kwargs["fix_imbalance"]
+        # regression doesn't support fix_imbalance
+        model_kwargs.pop("fix_imbalance", None)
         trainer = RegressionModelTrainer(
             args.input_file,
             args.target_col,
@@ -111,11 +169,12 @@
             args.model_type,
             args.random_seed,
             args.test_file,
-            **model_kwargs)
+            **model_kwargs,
+        )
     else:
-        LOG.error("Invalid model type. Please choose \
-                  'classification' or 'regression'.")
+        LOG.error("Invalid model type. Please choose 'classification' or 'regression'.")
         return
+
     trainer.run()
 
 
--- a/utils.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/utils.py	Fri Jul 25 19:02:32 2025 +0000
@@ -1,5 +1,6 @@
 import base64
 import logging
+from typing import Optional
 
 import numpy as np
 
@@ -7,7 +8,7 @@
 LOG = logging.getLogger(__name__)
 
 
-def get_html_template():
+def get_html_template() -> str:
     return """
     <html>
     <head>
@@ -20,13 +21,16 @@
               padding: 20px;
               background-color: #f4f4f4;
           }
+          /* allow horizontal scrolling if content overflows */
           .container {
               max-width: 800px;
               margin: auto;
               background: white;
               padding: 20px;
               box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
+              overflow-x: auto;
           }
+
           h1 {
               text-align: center;
               color: #333;
@@ -36,6 +40,14 @@
               color: #4CAF50;
               padding-bottom: 5px;
           }
+
+          /* wrapper for tables to allow individual horizontal scroll */
+          .table-wrapper {
+              overflow-x: auto;
+              margin: 1rem 0;
+          }
+
+          /* revert table styling to full borders */
           table {
               width: 100%;
               border-collapse: collapse;
@@ -52,6 +64,7 @@
               background-color: #4CAF50;
               color: white;
           }
+
           .plot {
               text-align: center;
               margin: 20px 0;
@@ -60,106 +73,202 @@
               max-width: 100%;
               height: auto;
           }
+
           .tabs {
               display: flex;
-              margin-bottom: 20px;
-              cursor: pointer;
-              justify-content: space-around;
+              align-items: center;
+              border-bottom: 2px solid #ccc;
+              margin-bottom: 1rem;
           }
           .tab {
-              padding: 10px;
-              background-color: #4CAF50;
-              color: white;
-              border-radius: 5px 5px 0 0;
-              flex-grow: 1;
-              text-align: center;
-              margin: 0 5px;
+              padding: 10px 20px;
+              cursor: pointer;
+              border: 1px solid #ccc;
+              border-bottom: none;
+              background: #f9f9f9;
+              margin-right: 5px;
+              border-top-left-radius: 8px;
+              border-top-right-radius: 8px;
           }
-          .tab.active-tab {
-              background-color: #333;
+          .tab.active {
+              background: white;
+              font-weight: bold;
           }
+
           .tab-content {
               display: none;
               padding: 20px;
-              border: 1px solid #ddd;
+              border: 1px solid #ccc;
               border-top: none;
-              background-color: white;
+              background: white;
           }
-          .tab-content.active-content {
+          .tab-content.active {
               display: block;
           }
-      </style>
+
+          .help-btn {
+              margin-left: auto;
+              padding: 6px 12px;
+              font-size: 0.9rem;
+              border: 1px solid #4CAF50;
+              border-radius: 4px;
+              background: #4CAF50;
+              color: white;
+              cursor: pointer;
+          }
+
+          /* sortable table header arrows */
+          table.sortable th {
+              position: relative;
+              padding-right: 20px; /* room for the arrow */
+              cursor: pointer;
+          }
+          table.sortable th::after {
+              content: '↕';
+              position: absolute;
+              right: 8px;
+              opacity: 0.4;
+              transition: opacity 0.2s;
+          }
+          table.sortable th:hover::after {
+              opacity: 0.7;
+          }
+          table.sortable th.sorted-asc::after {
+              content: '↑';
+              opacity: 1;
+          }
+          table.sortable th.sorted-desc::after {
+              content: '↓';
+              opacity: 1;
+          }
+        </style>
     </head>
     <body>
     <div class="container">
     """
 
 
-def get_html_closing():
+def get_html_closing() -> str:
     return """
-        </div>
-        <script>
-            function openTab(evt, tabName) {{
-                var i, tabcontent, tablinks;
-                tabcontent = document.getElementsByClassName("tab-content");
-                for (i = 0; i < tabcontent.length; i++) {{
-                    tabcontent[i].style.display = "none";
-                }}
-                tablinks = document.getElementsByClassName("tab");
-                for (i = 0; i < tablinks.length; i++) {{
-                    tablinks[i].className =
-                        tablinks[i].className.replace(" active-tab", "");
-                }}
-                document.getElementById(tabName).style.display = "block";
-                evt.currentTarget.className += " active-tab";
-            }}
-            document.addEventListener("DOMContentLoaded", function() {{
-                document.querySelector(".tab").click();
-            }});
-        </script>
+    </div>
+    <script>
+    document.addEventListener('DOMContentLoaded', () => {
+      document.querySelectorAll('table.sortable').forEach(table => {
+        const getCellValue = (row, idx) =>
+          row.children[idx].innerText.trim() || '';
+
+        const comparer = (idx, asc) => (a, b) => {
+          const v1 = getCellValue(asc ? a : b, idx);
+          const v2 = getCellValue(asc ? b : a, idx);
+          const n1 = parseFloat(v1), n2 = parseFloat(v2);
+          if (!isNaN(n1) && !isNaN(n2)) return n1 - n2;
+          return v1.localeCompare(v2);
+        };
+
+        table.querySelectorAll('th').forEach((th, idx) => {
+          let asc = true;
+          th.addEventListener('click', () => {
+            // sort rows
+            const tbody = table.tBodies[0];
+            Array.from(tbody.rows)
+              .sort(comparer(idx, asc))
+              .forEach(row => tbody.appendChild(row));
+            // update arrow classes
+            table.querySelectorAll('th').forEach(h => {
+              h.classList.remove('sorted-asc','sorted-desc');
+            });
+            th.classList.add(asc ? 'sorted-asc' : 'sorted-desc');
+            asc = !asc;
+          });
+        });
+      });
+    });
+    </script>
     </body>
     </html>
     """
 
 
-def customize_figure_layout(fig, margin_dict=None):
+def build_tabbed_html(
+    summary_html: str,
+    test_html: str,
+    feature_html: str,
+    explainer_html: Optional[str] = None,
+) -> str:
+    """
+    Render the tabbed sections and an always-visible Help button.
     """
-    Update the layout of a Plotly figure to reduce margins.
+    # CSS
+    css = get_html_template().split("<body>")[1].rsplit("</style>", 1)[0] + "</style>"
+
+    # Tabs header
+    tabs = [
+        '<div class="tabs">',
+        '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary & Config</div>',
+        '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>',
+        '<div class="tab" onclick="showTab(\'feature\')">Feature Importance</div>',
+    ]
+    if explainer_html:
+        tabs.append(
+            '<div class="tab" onclick="showTab(\'explainer\')">Explainer Plots</div>'
+        )
+    tabs.append('<button id="openMetricsHelp" class="help-btn">Help</button>')
+    tabs.append("</div>")
+    tabs_section = "\n".join(tabs)
 
-    Parameters:
-        fig (plotly.graph_objects.Figure): The Plotly figure to customize.
-        margin_dict (dict, optional): A dictionary specifying margin sizes.
-            Example: {'l': 10, 'r': 10, 't': 10, 'b': 10}
+    # Content
+    contents = [
+        f'<div id="summary" class="tab-content active">{summary_html}</div>',
+        f'<div id="test" class="tab-content">{test_html}</div>',
+        f'<div id="feature" class="tab-content">{feature_html}</div>',
+    ]
+    if explainer_html:
+        contents.append(
+            f'<div id="explainer" class="tab-content">{explainer_html}</div>'
+        )
+    content_section = "\n".join(contents)
 
-    Returns:
-        plotly.graph_objects.Figure: The updated Plotly figure.
-    """
+    # JS
+    js = """
+<script>
+function showTab(id) {
+  document.querySelectorAll('.tab-content').forEach(el=>el.classList.remove('active'));
+  document.querySelectorAll('.tab').forEach(el=>el.classList.remove('active'));
+  document.getElementById(id).classList.add('active');
+  document.querySelector(`.tab[onclick*="${id}"]`).classList.add('active');
+}
+</script>
+"""
+
+    return css + "\n" + tabs_section + "\n" + content_section + "\n" + js
+
+
+def customize_figure_layout(fig, margin_dict=None):
     if margin_dict is None:
-        # Set default smaller margins
-        margin_dict = {'l': 40, 'r': 40, 't': 40, 'b': 40}
-
+        margin_dict = {"l": 40, "r": 40, "t": 40, "b": 40}
     fig.update_layout(margin=margin_dict)
     return fig
 
 
-def add_plot_to_html(fig, include_plotlyjs=True):
-    custom_margin = {'l': 40, 'r': 40, 't': 60, 'b': 60}
+def add_plot_to_html(fig, include_plotlyjs=True) -> str:
+    custom_margin = {"l": 40, "r": 40, "t": 60, "b": 60}
     fig = customize_figure_layout(fig, margin_dict=custom_margin)
-    return fig.to_html(full_html=False,
-                       default_height=350,
-                       include_plotlyjs="cdn" if include_plotlyjs else False)
+    return fig.to_html(
+        full_html=False,
+        default_height=350,
+        include_plotlyjs="cdn" if include_plotlyjs else False,
+    )
 
 
-def add_hr_to_html():
+def add_hr_to_html() -> str:
     return "<hr>"
 
 
-def encode_image_to_base64(image_path):
-    """Convert an image file to a base64 encoded string."""
+def encode_image_to_base64(image_path: str) -> str:
     with open(image_path, "rb") as img_file:
         return base64.b64encode(img_file.read()).decode("utf-8")
 
 
 def predict_proba(self, X):
     pred = self.predict(X)
-    return np.array([1 - pred, pred]).T
+    return np.vstack((1 - pred, pred)).T