diff base_model_trainer.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents f4cb41f458fd
children
line wrap: on
line diff
--- a/base_model_trainer.py	Wed Jul 09 01:13:01 2025 +0000
+++ b/base_model_trainer.py	Fri Jul 25 19:02:32 2025 +0000
@@ -1,7 +1,7 @@
 import base64
 import logging
-import os
 import tempfile
+from pathlib import Path
 
 import h5py
 import joblib
@@ -10,7 +10,14 @@
 from feature_help_modal import get_feature_metrics_help_modal
 from feature_importance import FeatureImportanceAnalyzer
 from sklearn.metrics import average_precision_score
-from utils import get_html_closing, get_html_template
+from utils import (
+    add_hr_to_html,
+    add_plot_to_html,
+    build_tabbed_html,
+    encode_image_to_base64,
+    get_html_closing,
+    get_html_template,
+)
 
 logging.basicConfig(level=logging.DEBUG)
 LOG = logging.getLogger(__name__)
@@ -27,7 +34,7 @@
         test_file=None,
         **kwargs,
     ):
-        self.exp = None  # This will be set in the subclass
+        self.exp = None
         self.input_file = input_file
         self.target_col = target_col
         self.output_dir = output_dir
@@ -39,10 +46,11 @@
         self.results = None
         self.features_name = None
         self.plots = {}
-        self.expaliner = None
+        self.explainer_plots = {}
         self.plots_explainer_html = None
         self.trees = []
-        for key, value in kwargs.items():
+        self.user_kwargs = kwargs.copy()
+        for key, value in self.user_kwargs.items():
             setattr(self, key, value)
         self.setup_params = {}
         self.test_file = test_file
@@ -57,43 +65,38 @@
         LOG.info(f"Loading data from {self.input_file}")
         self.data = pd.read_csv(self.input_file, sep=None, engine="python")
         self.data.columns = self.data.columns.str.replace(".", "_")
-
-        # Remove prediction_label if present
         if "prediction_label" in self.data.columns:
             self.data = self.data.drop(columns=["prediction_label"])
 
         numeric_cols = self.data.select_dtypes(include=["number"]).columns
         non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns
-
         self.data[numeric_cols] = self.data[numeric_cols].apply(
             pd.to_numeric, errors="coerce"
         )
-
         if len(non_numeric_cols) > 0:
             LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
 
         names = self.data.columns.to_list()
         target_index = int(self.target_col) - 1
         self.target = names[target_index]
-        self.features_name = [name for i, name in enumerate(names) if i != target_index]
-        if hasattr(self, "missing_value_strategy"):
-            if self.missing_value_strategy == "mean":
+        self.features_name = [n for i, n in enumerate(names) if i != target_index]
+
+        if getattr(self, "missing_value_strategy", None):
+            strat = self.missing_value_strategy
+            if strat == "mean":
                 self.data = self.data.fillna(self.data.mean(numeric_only=True))
-            elif self.missing_value_strategy == "median":
+            elif strat == "median":
                 self.data = self.data.fillna(self.data.median(numeric_only=True))
-            elif self.missing_value_strategy == "drop":
+            elif strat == "drop":
                 self.data = self.data.dropna()
         else:
-            # Default strategy if not specified
             self.data = self.data.fillna(self.data.median(numeric_only=True))
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
-            self.test_data = pd.read_csv(self.test_file, sep=None, engine="python")
-            self.test_data = self.test_data[numeric_cols].apply(
-                pd.to_numeric, errors="coerce"
-            )
-            self.test_data.columns = self.test_data.columns.str.replace(".", "_")
+            df_test = pd.read_csv(self.test_file, sep=None, engine="python")
+            df_test.columns = df_test.columns.str.replace(".", "_")
+            self.test_data = df_test
 
     def setup_pycaret(self):
         LOG.info("Initializing PyCaret")
@@ -105,59 +108,26 @@
             "system_log": False,
             "index": False,
         }
-
         if self.test_data is not None:
             self.setup_params["test_data"] = self.test_data
-
-        if (
-            hasattr(self, "train_size")
-            and self.train_size is not None
-            and self.test_data is None
-        ):
-            self.setup_params["train_size"] = self.train_size
-
-        if hasattr(self, "normalize") and self.normalize is not None:
-            self.setup_params["normalize"] = self.normalize
-
-        if hasattr(self, "feature_selection") and self.feature_selection is not None:
-            self.setup_params["feature_selection"] = self.feature_selection
-
-        if (
-            hasattr(self, "cross_validation")
-            and self.cross_validation is not None
-            and self.cross_validation is False
-        ):
-            logging.info(
-                "cross_validation is set to False. This will disable cross-validation."
-            )
-
-        if hasattr(self, "cross_validation") and self.cross_validation:
-            if hasattr(self, "cross_validation_folds"):
-                self.setup_params["fold"] = self.cross_validation_folds
-
-        if hasattr(self, "remove_outliers") and self.remove_outliers is not None:
-            self.setup_params["remove_outliers"] = self.remove_outliers
-
-        if (
-            hasattr(self, "remove_multicollinearity")
-            and self.remove_multicollinearity is not None
-        ):
-            self.setup_params["remove_multicollinearity"] = (
-                self.remove_multicollinearity
-            )
-
-        if (
-            hasattr(self, "polynomial_features")
-            and self.polynomial_features is not None
-        ):
-            self.setup_params["polynomial_features"] = self.polynomial_features
-
-        if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None:
-            self.setup_params["fix_imbalance"] = self.fix_imbalance
-
+        for attr in [
+            "train_size",
+            "normalize",
+            "feature_selection",
+            "remove_outliers",
+            "remove_multicollinearity",
+            "polynomial_features",
+            "feature_interaction",
+            "feature_ratio",
+            "fix_imbalance",
+        ]:
+            val = getattr(self, attr, None)
+            if val is not None:
+                self.setup_params[attr] = val
+        if getattr(self, "cross_validation_folds", None) is not None:
+            self.setup_params["fold"] = self.cross_validation_folds
         LOG.info(self.setup_params)
 
-        # Solution: instantiate the correct PyCaret experiment based on task_type
         if self.task_type == "classification":
             from pycaret.classification import ClassificationExperiment
 
@@ -170,246 +140,371 @@
             raise ValueError("task_type must be 'classification' or 'regression'")
 
         self.exp.setup(self.data, **self.setup_params)
+        self.setup_params.update(self.user_kwargs)
 
     def train_model(self):
         LOG.info("Training and selecting the best model")
         if self.task_type == "classification":
-            average_displayed = "Weighted"
             self.exp.add_metric(
-                id=f"PR-AUC-{average_displayed}",
-                name=f"PR-AUC-{average_displayed}",
+                id="PR-AUC-Weighted",
+                name="PR-AUC-Weighted",
                 target="pred_proba",
                 score_func=average_precision_score,
                 average="weighted",
             )
+        # Build arguments for compare_models()
+        compare_kwargs = {}
+        if getattr(self, "models", None):
+            compare_kwargs["include"] = self.models
 
-        if hasattr(self, "models") and self.models is not None:
-            self.best_model = self.exp.compare_models(include=self.models, cross_validation=self.cross_validation)
-        else:
-            self.best_model = self.exp.compare_models(cross_validation=self.cross_validation)
+        # Respect explicit cross-validation flag
+        if getattr(self, "cross_validation", None) is not None:
+            compare_kwargs["cross_validation"] = self.cross_validation
+
+        # Respect explicit fold count
+        if getattr(self, "cross_validation_folds", None) is not None:
+            compare_kwargs["fold"] = self.cross_validation_folds
+
+        LOG.info(f"compare_models kwargs: {compare_kwargs}")
+        self.best_model = self.exp.compare_models(**compare_kwargs)
         self.results = self.exp.pull()
+        if getattr(self, "tune_model", False):
+            LOG.info("Tuning hyperparameters of the best model")
+            self.best_model = self.exp.tune_model(self.best_model)
+            self.results = self.exp.pull()
 
         if self.task_type == "classification":
             self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
-
         _ = self.exp.predict_model(self.best_model)
         self.test_result_df = self.exp.pull()
         if self.task_type == "classification":
             self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
     def save_model(self):
-        hdf5_model_path = "pycaret_model.h5"
-        with h5py.File(hdf5_model_path, "w") as f:
-            with tempfile.NamedTemporaryFile(delete=False) as temp_file:
-                joblib.dump(self.best_model, temp_file.name)
-                temp_file.seek(0)
-                model_bytes = temp_file.read()
+        hdf5_path = Path(self.output_dir) / "pycaret_model.h5"
+        with h5py.File(hdf5_path, "w") as f:
+            with tempfile.NamedTemporaryFile(delete=False) as tmp:
+                joblib.dump(self.best_model, tmp.name)
+                tmp.seek(0)
+                model_bytes = tmp.read()
             f.create_dataset("model", data=np.void(model_bytes))
 
     def generate_plots(self):
-        raise NotImplementedError("Subclasses should implement this method")
+        LOG.info("Generating PyCaret diagnostic pltos")
 
-    def encode_image_to_base64(self, img_path):
+        # choose the right plots based on task
+        if self.task_type == "classification":
+            plot_names = [
+                "learning",
+                "vc",
+                "calibration",
+                "dimension",
+                "manifold",
+                "rfe",
+                "threshold",
+                "percentage_above_below",
+                "class_report",
+                "pr_auc",
+                "roc_auc",
+            ]
+        else:
+            plot_names = ["residuals", "vc", "parameter", "error", "learning"]
+        for name in plot_names:
+            try:
+                ax = self.exp.plot_model(self.best_model, plot=name, save=False)
+                out_path = Path(self.output_dir) / f"plot_{name}.png"
+                fig = ax.get_figure()
+                fig.savefig(out_path, bbox_inches="tight")
+                self.plots[name] = str(out_path)
+            except Exception as e:
+                LOG.warning(f"Could not generate {name} plot: {e}")
+
+    def encode_image_to_base64(self, img_path: str) -> str:
         with open(img_path, "rb") as img_file:
             return base64.b64encode(img_file.read()).decode("utf-8")
 
     def save_html_report(self):
         LOG.info("Saving HTML report")
 
-        if not self.output_dir:
-            raise ValueError("output_dir must be specified and not None")
+        # 1) Determine best model name
+        try:
+            best_model_name = str(self.results.iloc[0]["Model"])
+        except Exception:
+            best_model_name = type(self.best_model).__name__
+        LOG.info(f"Best model determined as: {best_model_name}")
+
+        # 2) Compute training sample count
+        try:
+            n_train = self.exp.X_train.shape[0]
+        except Exception:
+            n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0]
+        total_rows = self.data.shape[0]
 
-        model_name = type(self.best_model).__name__
-        excluded_params = ["html", "log_experiment", "system_log", "test_data"]
-        filtered_setup_params = {
-            k: v for k, v in self.setup_params.items() if k not in excluded_params
+        # 3) Build setup parameters table
+        all_params = self.setup_params
+        display_keys = [
+            "Target",
+            "Session ID",
+            "Train Size",
+            "Normalize",
+            "Feature Selection",
+            "Cross Validation",
+            "Cross Validation Folds",
+            "Remove Outliers",
+            "Remove Multicollinearity",
+            "Polynomial Features",
+            "Fix Imbalance",
+            "Models",
+        ]
+        setup_rows = []
+        for key in display_keys:
+            pk = key.lower().replace(" ", "_")
+            v = all_params.get(pk)
+            if key == "Train Size":
+                frac = (
+                    float(v)
+                    if v is not None
+                    else (n_train / total_rows if total_rows else 0)
+                )
+                dv = f"{frac:.2f} ({n_train} rows)"
+            elif key in {
+                "Normalize",
+                "Feature Selection",
+                "Cross Validation",
+                "Remove Outliers",
+                "Remove Multicollinearity",
+                "Polynomial Features",
+                "Fix Imbalance",
+            }:
+                dv = bool(v)
+            elif key == "Cross Validation Folds":
+                dv = v if v is not None else "None"
+            elif key == "Models":
+                dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
+            else:
+                dv = v if v is not None else "None"
+            setup_rows.append([key, dv])
+        if hasattr(self.exp, "_fold_metric"):
+            setup_rows.append(["best_model_metric", self.exp._fold_metric])
+
+        df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
+        df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False)
+
+        # 4) Persist CSVs
+        self.results.to_csv(
+            Path(self.output_dir) / "comparison_results.csv", index=False
+        )
+        self.test_result_df.to_csv(
+            Path(self.output_dir) / "test_results.csv", index=False
+        )
+        pd.DataFrame(
+            self.best_model.get_params().items(), columns=["Parameter", "Value"]
+        ).to_csv(Path(self.output_dir) / "best_model.csv", index=False)
+
+        # 5) Header
+        header = f"<h2>Best Model: {best_model_name}</h2>"
+
+        # — Validation Summary & Configuration —
+        val_df = self.results.copy()
+        # mapping raw plot keys to user-friendly titles
+        plot_title_map = {
+            "learning": "Learning Curve",
+            "vc": "Validation Curve",
+            "calibration": "Calibration Curve",
+            "dimension": "Dimensionality Reduction",
+            "manifold": "Manifold Learning",
+            "rfe": "Recursive Feature Elimination",
+            "threshold": "Threshold Plot",
+            "percentage_above_below": "Percentage Above vs. Below Cutoff",
+            "class_report": "Classification Report",
+            "pr_auc": "Precision-Recall AUC",
+            "roc_auc": "Receiver Operating Characteristic AUC",
+            "residuals": "Residuals Distribution",
+            "error": "Prediction Error Distribution",
         }
-        setup_params_table = pd.DataFrame(
-            list(filtered_setup_params.items()), columns=["Parameter", "Value"]
+        val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True)
+        summary_html = (
+            header
+            + "<h2>Train & Validation Summary</h2>"
+            + '<div class="table-wrapper">'
+            + val_df.to_html(index=False, classes="table sortable")
+            + "</div>"
+            + "<h2>Setup Parameters</h2>"
+            + '<div class="table-wrapper">'
+            + df_setup.to_html(index=False, classes="table sortable")
+            + "</div>"
+            # — Hyperparameters
+            + "<h2>Best Model Hyperparameters</h2>"
+            + '<div class="table-wrapper">'
+            + pd.DataFrame(
+                self.best_model.get_params().items(), columns=["Parameter", "Value"]
+            ).to_html(index=False, classes="table sortable")
+            + "</div>"
         )
 
-        best_model_params = pd.DataFrame(
-            self.best_model.get_params().items(), columns=["Parameter", "Value"]
-        )
-        best_model_params.to_csv(
-            os.path.join(self.output_dir, "best_model.csv"), index=False
-        )
-        self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
-        self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv"))
+        # choose summary plots based on task type
+        if self.task_type == "classification":
+            summary_plots = [
+                "learning",
+                "vc",
+                "calibration",
+                "dimension",
+                "manifold",
+                "rfe",
+                "threshold",
+                "percentage_above_below",
+            ]
+        else:
+            summary_plots = ["learning", "vc", "parameter", "residuals"]
 
-        plots_html = ""
-        length = len(self.plots)
-        for i, (plot_name, plot_path) in enumerate(self.plots.items()):
-            encoded_image = self.encode_image_to_base64(plot_path)
-            plots_html += (
-                f'<div class="plot">'
-                f"<h3>{plot_name.capitalize()}</h3>"
-                f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">'
-                f"</div>"
-            )
-            if i < length - 1:
-                plots_html += "<hr>"
-
-        tree_plots = ""
-        for i, tree in enumerate(self.trees):
-            if tree:
-                tree_plots += (
-                    f'<div class="plot">'
-                    f"<h3>Tree {i + 1}</h3>"
-                    f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">'
-                    f"</div>"
+        for name in summary_plots:
+            if name in self.plots:
+                summary_html += "<hr>"
+                b64 = encode_image_to_base64(self.plots[name])
+                title = plot_title_map.get(name, name.replace("_", " ").title())
+                summary_html += (
+                    '<div class="plot">'
+                    f"<h2>{title}</h2>"
+                    f'<img src="data:image/png;base64,{b64}" '
+                    'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>'
+                    "</div>"
                 )
 
-        analyzer = FeatureImportanceAnalyzer(
+        # — Test Summary —
+        test_html = (
+            header
+            + '<div class="table-wrapper">'
+            + self.test_result_df.to_html(index=False, classes="table sortable")
+            + "</div>"
+        )
+        if self.task_type == "regression":
+            try:
+                y_true = (
+                    pd.Series(self.exp.y_test_transformed)
+                    .reset_index(drop=True)
+                    .rename("True")
+                )
+                y_pred = pd.Series(
+                    self.best_model.predict(self.exp.X_test_transformed)
+                ).rename("Predicted")
+                df_tp = pd.concat([y_true, y_pred], axis=1)
+                test_html += "<h2>True vs Predicted Values</h2>"
+                test_html += (
+                    '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">'
+                    + df_tp.head(50).to_html(index=False, classes="table sortable")
+                    + "</div>"
+                    + add_hr_to_html()
+                )
+            except Exception as e:
+                LOG.warning(f"Could not generate True vs Predicted table: {e}")
+
+        # 5a) Explainer-substituted plots in order
+        if self.task_type == "regression":
+            test_order = ["residuals"]
+        else:
+            test_order = [
+                "confusion_matrix",
+                "roc_auc",
+                "pr_auc",
+                "lift_curve",
+                "threshold",
+                "cumulative_precision",
+            ]
+        for key in test_order:
+            fig_or_fn = self.explainer_plots.pop(key, None)
+            if fig_or_fn is not None:
+                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+                title = plot_title_map.get(key, key.replace("_", " ").title())
+                test_html += (
+                    f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
+                )
+        # 5b) Remaining PyCaret test plots
+        for name, path in self.plots.items():
+            # classification: include only the small extras, before skipping anything
+            if self.task_type == "classification" and name in {
+                "threshold",
+                "pr_auc",
+                "class_report",
+            }:
+                title = plot_title_map.get(name, name.replace("_", " ").title())
+                b64 = encode_image_to_base64(path)
+                test_html += (
+                    f"<h2>{title}</h2>"
+                    "<div class='plot'>"
+                    f"<img src='data:image/png;base64,{b64}' "
+                    "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
+                    "</div>" + add_hr_to_html()
+                )
+                continue
+
+            # regression: explicitly include the 'error' plot, before skipping
+            if self.task_type == "regression" and name == "error":
+                title = plot_title_map.get("error", "Prediction Error Distribution")
+                b64 = encode_image_to_base64(path)
+                test_html += (
+                    f"<h2>{title}</h2>"
+                    "<div class='plot'>"
+                    f"<img src='data:image/png;base64,{b64}' "
+                    "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
+                    "</div>" + add_hr_to_html()
+                )
+                continue
+
+            # now skip any plots already rendered via test_order
+            if name in test_order:
+                continue
+
+        # — Feature Importance —
+        feature_html = header
+
+        # 6a) PyCaret’s default feature importances
+        feature_html += FeatureImportanceAnalyzer(
             data=self.data,
             target_col=self.target_col,
             task_type=self.task_type,
             output_dir=self.output_dir,
             exp=self.exp,
             best_model=self.best_model,
-        )
-        feature_importance_html = analyzer.run()
+        ).run()
 
-        # --- Feature Metrics Help Button ---
-        feature_metrics_button_html = (
-            '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">'
-            "Help: Metrics Guide"
-            "</button>"
-            "<style>"
-            ".help-modal-btn {"
-            "background-color: #17623b;"
-            "color: #fff;"
-            "border: none;"
-            "border-radius: 24px;"
-            "padding: 10px 28px;"
-            "font-size: 1.1rem;"
-            "font-weight: bold;"
-            "letter-spacing: 0.03em;"
-            "cursor: pointer;"
-            "transition: background 0.2s, box-shadow 0.2s;"
-            "box-shadow: 0 2px 8px rgba(23,98,59,0.07);"
-            "}"
-            ".help-modal-btn:hover, .help-modal-btn:focus {"
-            "background-color: #21895e;"
-            "outline: none;"
-            "box-shadow: 0 4px 16px rgba(23,98,59,0.14);"
-            "}"
-            "</style>"
-        )
+        # 6b) Explainer SHAP importances
+        for key in ["shap_mean", "shap_perm"]:
+            fig_or_fn = self.explainer_plots.pop(key, None)
+            if fig_or_fn is not None:
+                fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+                # give SHAP plots explicit titles
+                title = (
+                    "Mean Absolute SHAP Value Impact"
+                    if key == "shap_mean"
+                    else "Permutation Feature Importance"
+                )
+                feature_html += (
+                    f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
+                )
 
-        html_content = (
-            f"{get_html_template()}"
-            "<h1>Tabular Learner Model Report</h1>"
-            f"{feature_metrics_button_html}"
-            '<div class="tabs">'
-            '<div class="tab" onclick="openTab(event, \'summary\')">'
-            "Validation Result Summary & Config</div>"
-            '<div class="tab" onclick="openTab(event, \'plots\')">'
-            "Test Results</div>"
-            '<div class="tab" onclick="openTab(event, \'feature\')">'
-            "Feature Importance</div>"
-        )
-        if self.plots_explainer_html:
-            html_content += (
-                '<div class="tab" onclick="openTab(event, \'explainer\')">'
-                "Explainer Plots</div>"
+        # 6c) PDPs last
+        pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__"))
+        for k in pdp_keys:
+            fig_or_fn = self.explainer_plots[k]
+            fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
+            # extract feature name
+            feature = k.split("__", 1)[1]
+            title = f"Partial Dependence for {feature}"
+            feature_html += (
+                f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
             )
-        html_content += (
-            "</div>"
-            '<div id="summary" class="tab-content">'
-            f"<h2>Model Metrics from {'Cross-Validation Set' if self.cross_validation else 'Validation set'}</h2>"
-            f"<h2>Best Model: {model_name}</h2>"
-            "<h5>The best model is selected by: Accuracy (Classification)"
-            " or R2 (Regression).</h5>"
-            f"{self.results.to_html(index=False, classes='table sortable')}"
-            "<h2>Best Model's Hyperparameters</h2>"
-            f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}"
-            "<h2>Setup Parameters</h2>"
-            f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}"
-            "<h5>If you want to know all the experiment setup parameters,"
-            " please check the PyCaret documentation for"
-            " the classification/regression <code>exp</code> function.</h5>"
-            "</div>"
-            '<div id="plots" class="tab-content">'
-            f"<h2>Best Model: {model_name}</h2>"
-            "<h5>The best model is selected by: Accuracy (Classification)"
-            " or R2 (Regression).</h5>"
-            "<h2>Test Metrics</h2>"
-            f"{self.test_result_df.to_html(index=False)}"
-            "<h2>Test Results</h2>"
-            f"{plots_html}"
-            "</div>"
-            '<div id="feature" class="tab-content">'
-            f"{feature_importance_html}"
-            "</div>"
+        # 7) Assemble final HTML (three tabs)
+        html = get_html_template()
+        html += "<h1>Tabular Learner Model Report</h1>"
+        html += build_tabbed_html(summary_html, test_html, feature_html)
+        html += get_feature_metrics_help_modal()
+        html += get_html_closing()
+
+        # 8) Write out
+        (Path(self.output_dir) / "comparison_result.html").write_text(
+            html, encoding="utf-8"
         )
-        if self.plots_explainer_html:
-            html_content += (
-                '<div id="explainer" class="tab-content">'
-                f"{self.plots_explainer_html}"
-                f"{tree_plots}"
-                "</div>"
-            )
-        html_content += (
-            "<script>"
-            "document.addEventListener(\"DOMContentLoaded\", function() {"
-            "var tables = document.querySelectorAll(\"table.sortable\");"
-            "tables.forEach(function(table) {"
-            "var headers = table.querySelectorAll(\"th\");"
-            "headers.forEach(function(header, index) {"
-            "header.style.cursor = \"pointer\";"
-            "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)"
-            "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';"
-            "header.addEventListener(\"click\", function() {"
-            "var direction = this.getAttribute("
-            "\"data-sort-direction\""
-            ") || \"asc\";"
-            "// Reset arrows in all headers of this table"
-            "headers.forEach(function(h) {"
-            "var arrow = h.querySelector(\".sort-arrow\");"
-            "if (arrow) arrow.textContent = \" ↑\";"
-            "});"
-            "// Set arrow for clicked header"
-            "var arrow = this.querySelector(\".sort-arrow\");"
-            "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";"
-            "sortTable(table, index, direction);"
-            "this.setAttribute(\"data-sort-direction\","
-            "direction === \"asc\" ? \"desc\" : \"asc\");"
-            "});"
-            "});"
-            "});"
-            "});"
-            "function sortTable(table, colNum, direction) {"
-            "var tb = table.tBodies[0];"
-            "var tr = Array.prototype.slice.call(tb.rows, 0);"
-            "var multiplier = direction === \"asc\" ? 1 : -1;"
-            "tr = tr.sort(function(a, b) {"
-            "var aText = a.cells[colNum].textContent.trim();"
-            "var bText = b.cells[colNum].textContent.trim();"
-            "// Remove arrow from text comparison"
-            "aText = aText.replace(/[↑↓]/g, '').trim();"
-            "bText = bText.replace(/[↑↓]/g, '').trim();"
-            "if (!isNaN(aText) && !isNaN(bText)) {"
-            "return multiplier * ("
-            "parseFloat(aText) - parseFloat(bText)"
-            ");"
-            "} else {"
-            "return multiplier * aText.localeCompare(bText);"
-            "}"
-            "});"
-            "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);"
-            "}"
-            "</script>"
-        )
-        # --- Add the Feature Metrics Help Modal ---
-        html_content += get_feature_metrics_help_modal()
-        html_content += f"{get_html_closing()}"
-        with open(
-            os.path.join(self.output_dir, "comparison_result.html"),
-            "w",
-            encoding="utf-8",
-        ) as file:
-            file.write(html_content)
+        LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html")
 
     def save_dashboard(self):
         raise NotImplementedError("Subclasses should implement this method")
@@ -426,29 +521,18 @@
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
 
-        is_rf = isinstance(
-            self.best_model, (RandomForestClassifier, RandomForestRegressor)
-        )
-        is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor))
-
-        num_trees = None
-        if is_rf:
-            num_trees = self.best_model.n_estimators
-        elif is_xgb:
-            num_trees = len(self.best_model.get_booster().get_dump())
+        if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)):
+            n_trees = self.best_model.n_estimators
+        elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)):
+            n_trees = len(self.best_model.get_booster().get_dump())
         else:
             LOG.warning("Tree plots not supported for this model type.")
             return
 
-        try:
-            explainer = RandomForestExplainer(self.best_model, X_test, y_test)
-            for i in range(num_trees):
-                fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
-                LOG.info(f"Tree {i + 1}")
-                LOG.info(fig)
-                self.trees.append(fig)
-        except Exception as e:
-            LOG.error(f"Error generating tree plots: {e}")
+        explainer = RandomForestExplainer(self.best_model, X_test, y_test)
+        for i in range(n_trees):
+            fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
+            self.trees.append(fig)
 
     def run(self):
         self.load_data()