view base_model_trainer.py @ 16:4fee4504646e draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 2b826699ef9518d4610f5cfb6468ce719ec8039d
author goeckslab
date Fri, 28 Nov 2025 22:28:26 +0000
parents 7d78a6afc958
children
line wrap: on
line source

import base64
import logging
import tempfile
from pathlib import Path

import h5py
import joblib
import numpy as np
import pandas as pd
from feature_help_modal import get_feature_metrics_help_modal
from feature_importance import FeatureImportanceAnalyzer
from sklearn.metrics import average_precision_score
from utils import (
    add_hr_to_html,
    add_plot_to_html,
    build_tabbed_html,
    encode_image_to_base64,
    get_html_closing,
    get_html_template,
)

logging.basicConfig(level=logging.DEBUG)
LOG = logging.getLogger(__name__)


class BaseModelTrainer:
    def __init__(
        self,
        input_file,
        target_col,
        output_dir,
        task_type,
        random_seed,
        test_file=None,
        **kwargs,
    ):
        self.exp = None
        self.input_file = input_file
        self.target_col = target_col
        self.output_dir = output_dir
        self.task_type = task_type
        self.random_seed = random_seed
        self.data = None
        self.target = None
        self.best_model = None
        self.results = None
        self.tuning_results = None
        self.features_name = None
        self.plot_feature_names = None
        self.plots = {}
        self.explainer_plots = {}
        self.plots_explainer_html = None
        self.trees = []
        self.user_kwargs = kwargs.copy()
        for key, value in self.user_kwargs.items():
            setattr(self, key, value)
        if not hasattr(self, "plot_feature_limit"):
            self.plot_feature_limit = 30
        self._shap_row_cap = None
        if getattr(self, "polynomial_features", False):
            # Keep feature importance responsive by trimming plots/SHAP rows
            try:
                limit_val = int(self.plot_feature_limit)
            except (TypeError, ValueError):
                limit_val = 30
            self.plot_feature_limit = min(limit_val, 15)
            self._shap_row_cap = 200
            LOG.info(
                "Polynomial features enabled; limiting feature plots to %s and SHAP rows to %s",
                self.plot_feature_limit,
                self._shap_row_cap,
            )
        self.imputed_training_data = None
        self._best_model_metric_used = None
        self.setup_params = {}
        self.test_file = test_file
        self.test_data = None

        if not self.output_dir:
            raise ValueError(
                "output_dir must be specified and not None"
            )

        # Warn about irrelevant kwargs for the task type
        if self.task_type == "regression" and (
            "probability_threshold" in self.user_kwargs
        ):
            LOG.warning(
                "probability_threshold is ignored for regression tasks."
            )

        LOG.info(f"Model kwargs: {self.__dict__}")

    def load_data(self):
        LOG.info(f"Loading data from {self.input_file}")
        self.data = pd.read_csv(
            self.input_file, sep=None, engine="python"
        )
        self.data.columns = self.data.columns.str.replace(".", "_")

        names = self.data.columns.to_list()
        LOG.info(f"Original dataset columns: {names}")

        target_index = int(self.target_col) - 1
        num_cols = len(names)
        if target_index < 0 or target_index >= num_cols:
            raise ValueError(
                f"Target column number {self.target_col} is invalid. "
                f"Please select a number between 1 and {num_cols}."
            )

        self.target = names[target_index]

        # Conditional drop: only if 'prediction_label' exists and is not
        # the target
        if "prediction_label" in self.data.columns and (
            self.data.columns[target_index] != "prediction_label"
        ):
            LOG.info(
                "Dropping 'prediction_label' column as it's not the target."
            )
            self.data = self.data.drop(columns=["prediction_label"])
        else:
            if self.target == "prediction_label":
                LOG.warning(
                    "Using 'prediction_label' as target column. "
                    "This may not be intended if it's a previous prediction."
                )

        numeric_cols = self.data.select_dtypes(
            include=["number"]
        ).columns
        non_numeric_cols = self.data.select_dtypes(
            exclude=["number"]
        ).columns
        self.data[numeric_cols] = self.data[numeric_cols].apply(
            pd.to_numeric, errors="coerce"
        )
        if len(non_numeric_cols) > 0:
            LOG.info(
                f"Non-numeric columns found: {non_numeric_cols.tolist()}"
            )

        # Update names after possible drop
        names = self.data.columns.to_list()
        LOG.info(f"Dataset columns after processing: {names}")

        self.features_name = [n for n in names if n != self.target]
        self.plot_feature_names = self._select_plot_features(self.features_name)

        if self.test_file:
            LOG.info(f"Loading test data from {self.test_file}")
            df_test = pd.read_csv(
                self.test_file, sep=None, engine="python"
            )
            df_test.columns = df_test.columns.str.replace(".", "_")
            self.test_data = df_test

    def _select_plot_features(self, all_features):
        limit = getattr(self, "plot_feature_limit", 30)
        if not isinstance(limit, int) or limit <= 0:
            LOG.info(
                "Feature plotting limit disabled (plot_feature_limit=%s).", limit
            )
            return all_features
        if len(all_features) <= limit:
            LOG.info(
                "Feature plotting limit not needed (%s features <= limit %s).",
                len(all_features),
                limit,
            )
            return all_features
        df = self.data[all_features].copy()
        numeric_cols = df.select_dtypes(include=["number"]).columns
        ranked = []
        if len(numeric_cols) > 0:
            variances = (
                df[numeric_cols]
                .var()
                .fillna(0)
                .abs()
                .sort_values(ascending=False)
            )
            ranked = variances.index.tolist()
        selected = []
        for col in ranked:
            if len(selected) >= limit:
                break
            selected.append(col)
        if len(selected) < limit:
            for col in all_features:
                if col in selected:
                    continue
                selected.append(col)
                if len(selected) >= limit:
                    break
        LOG.info(
            "Limiting feature-level plots to %s of %s available features (limit=%s).",
            len(selected),
            len(all_features),
            limit,
        )
        return selected

    def setup_pycaret(self):
        LOG.info("Initializing PyCaret")
        self.setup_params = {
            "target": self.target,
            "session_id": self.random_seed,
            "html": True,
            "log_experiment": False,
            "system_log": False,
            "index": False,
        }
        if self.test_data is not None:
            self.setup_params["test_data"] = self.test_data
        for attr in [
            "train_size",
            "normalize",
            "feature_selection",
            "remove_outliers",
            "remove_multicollinearity",
            "polynomial_features",
            "feature_interaction",
            "feature_ratio",
            "fix_imbalance",
            "n_jobs",
        ]:
            val = getattr(self, attr, None)
            if val is not None:
                self.setup_params[attr] = val
        if getattr(self, "cross_validation_folds", None) is not None:
            self.setup_params["fold"] = self.cross_validation_folds
        LOG.info(self.setup_params)

        if self.task_type == "classification":
            from pycaret.classification import ClassificationExperiment

            self.exp = ClassificationExperiment()
        elif self.task_type == "regression":
            from pycaret.regression import RegressionExperiment

            self.exp = RegressionExperiment()
        else:
            raise ValueError(
                "task_type must be 'classification' or 'regression'"
            )

        self.exp.setup(self.data, **self.setup_params)
        self._capture_imputed_training_data()
        self.setup_params.update(self.user_kwargs)

    def _capture_imputed_training_data(self):
        """
        Cache the dataset as transformed/imputed by PyCaret so downstream
        components (e.g., feature importance) can operate on the exact data
        used for training.
        """
        if self.exp is None:
            return
        try:
            X_processed = self.exp.get_config("X_transformed").copy()
            y_processed = self.exp.get_config("y")
            if isinstance(y_processed, pd.Series):
                y_series = y_processed.reset_index(drop=True)
            else:
                y_series = pd.Series(y_processed)
            y_series.name = self.target
            X_processed = X_processed.reset_index(drop=True)
            self.imputed_training_data = pd.concat(
                [X_processed, y_series], axis=1
            )
            LOG.info(
                "Captured imputed training dataset from PyCaret "
                "(%s rows, %s features).",
                self.imputed_training_data.shape[0],
                self.imputed_training_data.shape[1] - 1,
            )
        except Exception as exc:
            LOG.warning(
                "Unable to capture processed training data from PyCaret: %s",
                exc,
            )
            self.imputed_training_data = None

    def train_model(self):
        LOG.info("Training and selecting the best model")
        if self.task_type == "classification":
            self.exp.add_metric(
                id="PR-AUC-Weighted",
                name="PR-AUC-Weighted",
                target="pred_proba",
                score_func=average_precision_score,
                average="weighted",
            )
        # Build arguments for compare_models()
        compare_kwargs = {}
        if getattr(self, "models", None):
            compare_kwargs["include"] = self.models

        # Respect explicit cross-validation flag
        if getattr(self, "cross_validation", None) is not None:
            compare_kwargs["cross_validation"] = self.cross_validation

        # Respect explicit fold count
        if getattr(self, "cross_validation_folds", None) is not None:
            compare_kwargs["fold"] = self.cross_validation_folds

        best_metric = getattr(self, "best_model_metric", None)
        if best_metric:
            compare_kwargs["sort"] = best_metric
            self._best_model_metric_used = best_metric
            LOG.info(f"Ranking models using metric: {best_metric}")

        LOG.info(f"compare_models kwargs: {compare_kwargs}")
        self.best_model = self.exp.compare_models(**compare_kwargs)
        if self._best_model_metric_used is None:
            self._best_model_metric_used = getattr(self.exp, "_fold_metric", None)
        self.results = self.exp.pull()
        if getattr(self, "tune_model", False):
            LOG.info("Tuning hyperparameters of the best model")
            self.best_model = self.exp.tune_model(self.best_model)
            self.tuning_results = self.exp.pull()

        if self.task_type == "classification":
            self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)

        prob_thresh = getattr(self, "probability_threshold", None)
        if self.task_type == "classification" and (
            prob_thresh is not None
        ):
            _ = self.exp.predict_model(
                self.best_model, probability_threshold=prob_thresh
            )
        else:
            _ = self.exp.predict_model(self.best_model)

        self.test_result_df = self.exp.pull()
        if self.task_type == "classification":
            self.test_result_df.rename(
                columns={"AUC": "ROC-AUC"}, inplace=True
            )

    def save_model(self):
        hdf5_path = Path(self.output_dir) / "pycaret_model.h5"
        with h5py.File(hdf5_path, "w") as f:
            with tempfile.NamedTemporaryFile(delete=False) as tmp:
                joblib.dump(self.best_model, tmp.name)
                tmp.seek(0)
                model_bytes = tmp.read()
            f.create_dataset("model", data=np.void(model_bytes))

    def generate_plots(self):
        LOG.info("Generating PyCaret diagnostic pltos")

        # choose the right plots based on task type
        if self.task_type == "classification":
            plot_names = [
                "learning",
                "vc",
                "calibration",
                "dimension",
                "manifold",
                "rfe",
                "threshold",
                "percentage_above_below",
                "class_report",
                "pr_auc",
                "roc_auc",
            ]
        else:
            plot_names = ["residuals", "vc", "parameter", "error",
                          "learning"]
        for name in plot_names:
            try:
                ax = self.exp.plot_model(
                    self.best_model, plot=name, save=False
                )
                out_path = Path(self.output_dir) / f"plot_{name}.png"
                fig = ax.get_figure()
                fig.savefig(out_path, bbox_inches="tight")
                self.plots[name] = str(out_path)
            except Exception as e:
                LOG.warning(f"Could not generate {name} plot: {e}")

    def encode_image_to_base64(self, img_path: str) -> str:
        with open(img_path, "rb") as img_file:
            return base64.b64encode(img_file.read()).decode("utf-8")

    def _resolve_plot_callable(self, key, fig_or_fn, section):
        """
        Safely execute stored plot callables so a single failure does not
        abort the entire HTML report generation.
        """
        if fig_or_fn is None:
            return None
        try:
            return fig_or_fn() if callable(fig_or_fn) else fig_or_fn
        except Exception as exc:
            extra = ""
            if isinstance(exc, ValueError) and "Input contains NaN" in str(exc):
                extra = (
                    " (model returned NaN probabilities; "
                    "consider checking data preprocessing)"
                )
            LOG.warning(
                "Skipping %s plot '%s' due to error: %s%s",
                section,
                key,
                exc,
                extra,
            )
            return None

    def save_html_report(self):
        LOG.info("Saving HTML report")

        # 1) Determine best model name
        try:
            best_model_name = str(self.results.iloc[0]["Model"])
        except Exception:
            best_model_name = type(self.best_model).__name__
        LOG.info(f"Best model determined as: {best_model_name}")

        # 2) Compute training sample count
        try:
            n_train = self.exp.X_train.shape[0]
        except Exception:
            n_train = getattr(
                self.exp, "X_train_transformed", pd.DataFrame()
            ).shape[0]
        total_rows = self.data.shape[0]

        # 3) Build setup parameters table
        all_params = self.setup_params.copy()
        if self.task_type == "classification" and (
            hasattr(self, "probability_threshold")
        ):
            all_params["probability_threshold"] = (
                self.probability_threshold
            )
        display_keys = [
            "Target",
            "Session ID",
            "Train Size",
            "Normalize",
            "Feature Selection",
            "Cross Validation",
            "Cross Validation Folds",
            "Remove Outliers",
            "Remove Multicollinearity",
            "Polynomial Features",
            "Fix Imbalance",
            "Models",
            "Probability Threshold",
        ]
        setup_rows = []
        for key in display_keys:
            pk = key.lower().replace(" ", "_")
            v = all_params.get(pk)
            if key == "Train Size":
                frac = (
                    float(v)
                    if v is not None
                    else (n_train / total_rows if total_rows else 0)
                )
                dv = f"{frac:.2f} ({n_train} rows)"
            elif key in {
                "Normalize",
                "Feature Selection",
                "Cross Validation",
                "Remove Outliers",
                "Remove Multicollinearity",
                "Polynomial Features",
                "Fix Imbalance",
            }:
                dv = bool(v)
            elif key == "Cross Validation Folds":
                dv = v if v is not None else "None"
            elif key == "Models":
                dv = ", ".join(map(str, v)) if isinstance(
                    v, (list, tuple)
                ) else "None"
            elif key == "Probability Threshold":
                dv = f"{v:.2f}" if v is not None else "0.5"
            else:
                dv = v if v is not None else "None"
            setup_rows.append([key, dv])
        metric_label = self._best_model_metric_used or getattr(
            self.exp, "_fold_metric", None
        )
        if metric_label:
            setup_rows.append(["Best Model Metric", metric_label])

        df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
        df_setup.to_csv(
            Path(self.output_dir) / "setup_params.csv", index=False
        )

        # 4) Persist CSVs
        self.results.to_csv(
            Path(self.output_dir) / "comparison_results.csv",
            index=False
        )
        self.test_result_df.to_csv(
            Path(self.output_dir) / "test_results.csv", index=False
        )
        pd.DataFrame(
            self.best_model.get_params().items(),
            columns=["Parameter", "Value"]
        ).to_csv(Path(self.output_dir) / "best_model.csv", index=False)

        if self.tuning_results is not None:
            self.tuning_results.to_csv(
                Path(self.output_dir) / "tuning_results.csv",
                index=False
            )

        # 5) Header
        header = f"<h2>Best Model: {best_model_name}</h2>"

        # — Validation Summary & Configuration —
        val_df = self.results.copy()
        # mapping raw plot keys to user-friendly titles
        plot_title_map = {
            "learning": "Learning Curve",
            "vc": "Validation Curve",
            "calibration": "Calibration Curve",
            "dimension": "Dimensionality Reduction",
            "manifold": "Manifold Learning",
            "rfe": "Recursive Feature Elimination",
            "threshold": "Threshold Plot",
            "percentage_above_below": "Percentage Above vs. Below Cutoff",
            "class_report": "Classification Report",
            "pr_auc": "Precision-Recall AUC",
            "roc_auc": "Receiver Operating Characteristic AUC",
            "residuals": "Residuals Distribution",
            "error": "Prediction Error Distribution",
        }
        val_df.drop(
            columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True
        )
        summary_html = (
            header
            + "<h2>Train & Validation Summary</h2>"
            + '<div class="table-wrapper">'
            + val_df.to_html(index=False, classes="table sortable")
            + "</div>"
        )

        if self.tuning_results is not None:
            tuning_df = self.tuning_results.copy()
            tuning_df.drop(
                columns=["TT (Sec)"], errors="ignore", inplace=True
            )
            summary_html += (
                f"<h2>{best_model_name}: Tuning Summary</h2>"
                + '<div class="table-wrapper">'
                + tuning_df.to_html(index=False, classes="table sortable")
                + "</div>"
            )

        summary_html += (
            "<h2>Setup Parameters</h2>"
            + '<div class="table-wrapper">'
            + df_setup.to_html(index=False, classes="table sortable")
            + "</div>"
            # — Hyperparameters
            + "<h2>Best Model Hyperparameters</h2>"
            + '<div class="table-wrapper">'
            + pd.DataFrame(
                self.best_model.get_params().items(),
                columns=["Parameter", "Value"]
            ).to_html(index=False, classes="table sortable")
            + "</div>"
        )

        # choose summary plots based on task type
        if self.task_type == "classification":
            summary_plots = [
                "learning",
                "vc",
                "calibration",
                "dimension",
                "manifold",
                "rfe",
                "threshold",
                "percentage_above_below",
            ]
        else:
            summary_plots = ["learning", "vc", "parameter", "residuals"]

        for name in summary_plots:
            if name in self.plots:
                summary_html += "<hr>"
                b64 = encode_image_to_base64(self.plots[name])
                title = plot_title_map.get(
                    name, name.replace("_", " ").title()
                )
                summary_html += (
                    '<div class="plot">'
                    f"<h2>{title}</h2>"
                    f'<img src="data:image/png;base64,{b64}" '
                    'style="max-width:90%;max-height:600px;'
                    'border:1px solid #ddd;"/>'
                    "</div>"
                )

        # — Test Summary —
        test_html = (
            header
            + '<div class="table-wrapper">'
            + self.test_result_df.to_html(
                index=False, classes="table sortable"
            )
            + "</div>"
        )
        if self.task_type == "regression":
            try:
                y_true = (
                    pd.Series(self.exp.y_test_transformed)
                    .reset_index(drop=True)
                    .rename("True")
                )
                y_pred = pd.Series(
                    self.best_model.predict(
                        self.exp.X_test_transformed
                    )
                ).rename("Predicted")
                df_tp = pd.concat([y_true, y_pred], axis=1)
                test_html += "<h2>True vs Predicted Values</h2>"
                test_html += (
                    '<div class="table-wrapper" '
                    'style="max-height:400px; overflow-y:auto;">'
                    + df_tp.head(50).to_html(
                        index=False, classes="table sortable"
                    )
                    + "</div>"
                    + add_hr_to_html()
                )
            except Exception as e:
                LOG.warning(
                    f"Could not generate True vs Predicted table: {e}"
                )

        # 5a) Explainer-substituted plots in order
        if self.task_type == "regression":
            test_order = ["residuals"]
        else:
            test_order = [
                "confusion_matrix",
                "roc_auc",
                "pr_auc",
                "lift_curve",
                "cumulative_precision",
            ]
        for key in test_order:
            fig_or_fn = self.explainer_plots.pop(key, None)
            if fig_or_fn is not None:
                fig = self._resolve_plot_callable(
                    key, fig_or_fn, section="test/explainer"
                )
                if fig is None:
                    continue
                title = plot_title_map.get(
                    key, key.replace("_", " ").title()
                )
                test_html += (
                    f"<h2>{title}</h2>" + add_plot_to_html(fig)
                    + add_hr_to_html()
                )
        # 5b) Remaining PyCaret test plots
        for name, path in self.plots.items():
            # classification: include only the small extras, before
            # skipping anything
            if self.task_type == "classification" and (
                name in {
                    "pr_auc",
                    "class_report",
                }
            ):
                title = plot_title_map.get(
                    name, name.replace("_", " ").title()
                )
                b64 = encode_image_to_base64(path)
                test_html += (
                    f"<h2>{title}</h2>"
                    "<div class='plot'>"
                    f"<img src='data:image/png;base64,{b64}' "
                    "style='max-width:90%;max-height:600px;"
                    "border:1px solid #ddd;'/>"
                    "</div>" + add_hr_to_html()
                )
                continue

            # regression: explicitly include the 'error' plot,
            # before skipping
            if self.task_type == "regression" and (
                name == "error"
            ):
                title = plot_title_map.get(
                    "error", "Prediction Error Distribution"
                )
                b64 = encode_image_to_base64(path)
                test_html += (
                    f"<h2>{title}</h2>"
                    "<div class='plot'>"
                    f"<img src='data:image/png;base64,{b64}' "
                    "style='max-width:90%;max-height:600px;"
                    "border:1px solid #ddd;'/>"
                    "</div>" + add_hr_to_html()
                )
                continue

            # now skip any plots already rendered via test_order
            if name in test_order:
                continue

        # — Feature Importance —
        feature_html = header

        # 6a) PyCaret’s default feature importances
        imputed_data = (
            self.imputed_training_data
            if self.imputed_training_data is not None
            else self.data
        )
        fi_analyzer = FeatureImportanceAnalyzer(
            data=imputed_data,
            target_col=self.target_col,
            task_type=self.task_type,
            output_dir=self.output_dir,
            exp=self.exp,
            best_model=self.best_model,
            max_plot_features=self.plot_feature_limit,
            processed_data=self.imputed_training_data,
            max_shap_rows=self._shap_row_cap,
        )
        fi_html = fi_analyzer.run()
        # Add a small table to show SHAP feature caps near the Best Model header.
        cap_rows = []
        if fi_analyzer.shap_total_features is not None:
            cap_rows.append(
                ("Total transformed features", fi_analyzer.shap_total_features)
            )
        if fi_analyzer.shap_used_features is not None:
            cap_rows.append(
                ("Features used in SHAP", fi_analyzer.shap_used_features)
            )
        if cap_rows:
            cap_table = (
                "<div class='table-wrapper'>"
                "<table class='table sortable'>"
                "<thead><tr><th>Feature Importance Scope</th><th>Count</th></tr></thead>"
                "<tbody>"
                + "".join(
                    f"<tr><td>{label}</td><td>{value}</td></tr>"
                    for label, value in cap_rows
                )
                + "</tbody></table></div>"
            )
            feature_html += cap_table
        feature_html += fi_html

        # 6b) Explainer SHAP importances
        for key in ["shap_mean", "shap_perm"]:
            fig_or_fn = self.explainer_plots.pop(key, None)
            if fig_or_fn is not None:
                fig = self._resolve_plot_callable(
                    key, fig_or_fn, section="feature importance"
                )
                if fig is None:
                    continue
                # give SHAP plots explicit titles
                title = (
                    "Mean Absolute SHAP Value Impact"
                    if key == "shap_mean"
                    else "Permutation Feature Importance"
                )
                feature_html += (
                    f"<h2>{title}</h2>" + add_plot_to_html(fig)
                    + add_hr_to_html()
                )

        # 6c) PDPs last
        pdp_keys = sorted(
            k for k in self.explainer_plots if k.startswith("pdp__")
        )
        for k in pdp_keys:
            fig_or_fn = self.explainer_plots[k]
            fig = self._resolve_plot_callable(
                k, fig_or_fn, section="pdp"
            )
            if fig is None:
                continue
            # extract feature name
            feature = k.split("__", 1)[1]
            title = f"Partial Dependence for {feature}"
            feature_html += (
                f"<h2>{title}</h2>" + add_plot_to_html(fig)
                + add_hr_to_html()
            )
        # 7) Assemble final HTML (three tabs)
        html = get_html_template()
        html += "<h1>Tabular Learner Model Report</h1>"
        html += build_tabbed_html(summary_html, test_html, feature_html)
        html += get_feature_metrics_help_modal()
        html += get_html_closing()

        # 8) Write out
        (Path(self.output_dir) / "comparison_result.html").write_text(
            html, encoding="utf-8"
        )
        LOG.info(
            f"HTML report generated at: "
            f"{self.output_dir}/comparison_result.html"
        )

    def save_dashboard(self):
        raise NotImplementedError("Subclasses should implement this method")

    def generate_plots_explainer(self):
        raise NotImplementedError("Subclasses should implement this method")

    def generate_tree_plots(self):
        from sklearn.ensemble import (
            RandomForestClassifier, RandomForestRegressor
        )
        from xgboost import XGBClassifier, XGBRegressor
        from explainerdashboard.explainers import RandomForestExplainer

        LOG.info("Generating tree plots")
        X_test = self.exp.X_test_transformed.copy()
        y_test = self.exp.y_test_transformed

        if isinstance(
            self.best_model, (RandomForestClassifier, RandomForestRegressor)
        ):
            n_trees = self.best_model.n_estimators
        elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)):
            n_trees = len(self.best_model.get_booster().get_dump())
        else:
            LOG.warning("Tree plots not supported for this model type.")
            return

        explainer = RandomForestExplainer(self.best_model, X_test, y_test)
        for i in range(n_trees):
            fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
            self.trees.append(fig)

    def run(self):
        self.load_data()
        self.setup_pycaret()
        self.train_model()
        self.save_model()
        self.generate_plots()
        self.generate_plots_explainer()
        self.generate_tree_plots()
        self.save_html_report()
        # self.save_dashboard()