changeset 6:a32ff7201629 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 06c0da44ac93256dfb616a6b40276b5485a71e8e
author goeckslab
date Wed, 02 Jul 2025 19:00:03 +0000 (2 days ago)
parents c846405830eb
children
files base_model_trainer.py feature_help_modal.py feature_importance.py pycaret_predict.xml utils.py
diffstat 5 files changed, 519 insertions(+), 330 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Sat Jun 21 15:07:04 2025 +0000
+++ b/base_model_trainer.py	Wed Jul 02 19:00:03 2025 +0000
@@ -7,6 +7,7 @@
 import joblib
 import numpy as np
 import pandas as pd
+from feature_help_modal import get_feature_metrics_help_modal
 from feature_importance import FeatureImportanceAnalyzer
 from sklearn.metrics import average_precision_score
 from utils import get_html_closing, get_html_template
@@ -16,16 +17,16 @@
 
 
 class BaseModelTrainer:
-
     def __init__(
-            self,
-            input_file,
-            target_col,
-            output_dir,
-            task_type,
-            random_seed,
-            test_file=None,
-            **kwargs):
+        self,
+        input_file,
+        target_col,
+        output_dir,
+        task_type,
+        random_seed,
+        test_file=None,
+        **kwargs,
+    ):
         self.exp = None  # This will be set in the subclass
         self.input_file = input_file
         self.target_col = target_col
@@ -47,18 +48,26 @@
         self.test_file = test_file
         self.test_data = None
 
+        if not self.output_dir:
+            raise ValueError("output_dir must be specified and not None")
+
         LOG.info(f"Model kwargs: {self.__dict__}")
 
     def load_data(self):
         LOG.info(f"Loading data from {self.input_file}")
-        self.data = pd.read_csv(self.input_file, sep=None, engine='python')
-        self.data.columns = self.data.columns.str.replace('.', '_')
+        self.data = pd.read_csv(self.input_file, sep=None, engine="python")
+        self.data.columns = self.data.columns.str.replace(".", "_")
 
-        numeric_cols = self.data.select_dtypes(include=['number']).columns
-        non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns
+        # Remove prediction_label if present
+        if "prediction_label" in self.data.columns:
+            self.data = self.data.drop(columns=["prediction_label"])
+
+        numeric_cols = self.data.select_dtypes(include=["number"]).columns
+        non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns
 
         self.data[numeric_cols] = self.data[numeric_cols].apply(
-            pd.to_numeric, errors='coerce')
+            pd.to_numeric, errors="coerce"
+        )
 
         if len(non_numeric_cols) > 0:
             LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
@@ -66,17 +75,13 @@
         names = self.data.columns.to_list()
         target_index = int(self.target_col) - 1
         self.target = names[target_index]
-        self.features_name = [name
-                              for i, name in enumerate(names)
-                              if i != target_index]
-        if hasattr(self, 'missing_value_strategy'):
-            if self.missing_value_strategy == 'mean':
-                self.data = self.data.fillna(
-                    self.data.mean(numeric_only=True))
-            elif self.missing_value_strategy == 'median':
-                self.data = self.data.fillna(
-                    self.data.median(numeric_only=True))
-            elif self.missing_value_strategy == 'drop':
+        self.features_name = [name for i, name in enumerate(names) if i != target_index]
+        if hasattr(self, "missing_value_strategy"):
+            if self.missing_value_strategy == "mean":
+                self.data = self.data.fillna(self.data.mean(numeric_only=True))
+            elif self.missing_value_strategy == "median":
+                self.data = self.data.fillna(self.data.median(numeric_only=True))
+            elif self.missing_value_strategy == "drop":
                 self.data = self.data.dropna()
         else:
             # Default strategy if not specified
@@ -84,287 +89,322 @@
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
-            self.test_data = pd.read_csv(
-                self.test_file, sep=None, engine='python')
+            self.test_data = pd.read_csv(self.test_file, sep=None, engine="python")
             self.test_data = self.test_data[numeric_cols].apply(
-                pd.to_numeric, errors='coerce')
-            self.test_data.columns = self.test_data.columns.str.replace(
-                '.', '_'
+                pd.to_numeric, errors="coerce"
             )
+            self.test_data.columns = self.test_data.columns.str.replace(".", "_")
 
     def setup_pycaret(self):
         LOG.info("Initializing PyCaret")
         self.setup_params = {
-            'target': self.target,
-            'session_id': self.random_seed,
-            'html': True,
-            'log_experiment': False,
-            'system_log': False,
-            'index': False,
+            "target": self.target,
+            "session_id": self.random_seed,
+            "html": True,
+            "log_experiment": False,
+            "system_log": False,
+            "index": False,
         }
 
         if self.test_data is not None:
-            self.setup_params['test_data'] = self.test_data
+            self.setup_params["test_data"] = self.test_data
 
-        if hasattr(self, 'train_size') and self.train_size is not None \
-                and self.test_data is None:
-            self.setup_params['train_size'] = self.train_size
-
-        if hasattr(self, 'normalize') and self.normalize is not None:
-            self.setup_params['normalize'] = self.normalize
+        if (
+            hasattr(self, "train_size")
+            and self.train_size is not None
+            and self.test_data is None
+        ):
+            self.setup_params["train_size"] = self.train_size
 
-        if hasattr(self, 'feature_selection') and \
-                self.feature_selection is not None:
-            self.setup_params['feature_selection'] = self.feature_selection
+        if hasattr(self, "normalize") and self.normalize is not None:
+            self.setup_params["normalize"] = self.normalize
+
+        if hasattr(self, "feature_selection") and self.feature_selection is not None:
+            self.setup_params["feature_selection"] = self.feature_selection
 
-        if hasattr(self, 'cross_validation') and \
-                self.cross_validation is not None \
-                and self.cross_validation is False:
-            self.setup_params['cross_validation'] = self.cross_validation
+        if (
+            hasattr(self, "cross_validation")
+            and self.cross_validation is not None
+            and self.cross_validation is False
+        ):
+            self.setup_params["cross_validation"] = self.cross_validation
 
-        if hasattr(self, 'cross_validation') and \
-                self.cross_validation is not None:
-            if hasattr(self, 'cross_validation_folds'):
-                self.setup_params['fold'] = self.cross_validation_folds
+        if hasattr(self, "cross_validation") and self.cross_validation is not None:
+            if hasattr(self, "cross_validation_folds"):
+                self.setup_params["fold"] = self.cross_validation_folds
 
-        if hasattr(self, 'remove_outliers') and \
-                self.remove_outliers is not None:
-            self.setup_params['remove_outliers'] = self.remove_outliers
+        if hasattr(self, "remove_outliers") and self.remove_outliers is not None:
+            self.setup_params["remove_outliers"] = self.remove_outliers
 
-        if hasattr(self, 'remove_multicollinearity') and \
-                self.remove_multicollinearity is not None:
-            self.setup_params['remove_multicollinearity'] = \
+        if (
+            hasattr(self, "remove_multicollinearity")
+            and self.remove_multicollinearity is not None
+        ):
+            self.setup_params["remove_multicollinearity"] = (
                 self.remove_multicollinearity
+            )
 
-        if hasattr(self, 'polynomial_features') and \
-                self.polynomial_features is not None:
-            self.setup_params['polynomial_features'] = self.polynomial_features
+        if (
+            hasattr(self, "polynomial_features")
+            and self.polynomial_features is not None
+        ):
+            self.setup_params["polynomial_features"] = self.polynomial_features
 
-        if hasattr(self, 'fix_imbalance') and \
-                self.fix_imbalance is not None:
-            self.setup_params['fix_imbalance'] = self.fix_imbalance
+        if hasattr(self, "fix_imbalance") and self.fix_imbalance is not None:
+            self.setup_params["fix_imbalance"] = self.fix_imbalance
 
         LOG.info(self.setup_params)
+
+        # Solution: instantiate the correct PyCaret experiment based on task_type
+        if self.task_type == "classification":
+            from pycaret.classification import ClassificationExperiment
+
+            self.exp = ClassificationExperiment()
+        elif self.task_type == "regression":
+            from pycaret.regression import RegressionExperiment
+
+            self.exp = RegressionExperiment()
+        else:
+            raise ValueError("task_type must be 'classification' or 'regression'")
+
         self.exp.setup(self.data, **self.setup_params)
 
     def train_model(self):
         LOG.info("Training and selecting the best model")
         if self.task_type == "classification":
             average_displayed = "Weighted"
-            self.exp.add_metric(id=f'PR-AUC-{average_displayed}',
-                                name=f'PR-AUC-{average_displayed}',
-                                target='pred_proba',
-                                score_func=average_precision_score,
-                                average='weighted'
-                                )
+            self.exp.add_metric(
+                id=f"PR-AUC-{average_displayed}",
+                name=f"PR-AUC-{average_displayed}",
+                target="pred_proba",
+                score_func=average_precision_score,
+                average="weighted",
+            )
 
-        if hasattr(self, 'models') and self.models is not None:
-            self.best_model = self.exp.compare_models(
-                include=self.models)
+        if hasattr(self, "models") and self.models is not None:
+            self.best_model = self.exp.compare_models(include=self.models)
         else:
             self.best_model = self.exp.compare_models()
         self.results = self.exp.pull()
         if self.task_type == "classification":
-            self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True)
+            self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
         _ = self.exp.predict_model(self.best_model)
         self.test_result_df = self.exp.pull()
         if self.task_type == "classification":
-            self.test_result_df.rename(
-                columns={'AUC': 'ROC-AUC'}, inplace=True)
+            self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
     def save_model(self):
         hdf5_model_path = "pycaret_model.h5"
-        with h5py.File(hdf5_model_path, 'w') as f:
+        with h5py.File(hdf5_model_path, "w") as f:
             with tempfile.NamedTemporaryFile(delete=False) as temp_file:
                 joblib.dump(self.best_model, temp_file.name)
                 temp_file.seek(0)
                 model_bytes = temp_file.read()
-            f.create_dataset('model', data=np.void(model_bytes))
+            f.create_dataset("model", data=np.void(model_bytes))
 
     def generate_plots(self):
         raise NotImplementedError("Subclasses should implement this method")
 
     def encode_image_to_base64(self, img_path):
-        with open(img_path, 'rb') as img_file:
-            return base64.b64encode(img_file.read()).decode('utf-8')
+        with open(img_path, "rb") as img_file:
+            return base64.b64encode(img_file.read()).decode("utf-8")
 
     def save_html_report(self):
         LOG.info("Saving HTML report")
 
+        if not self.output_dir:
+            raise ValueError("output_dir must be specified and not None")
+
         model_name = type(self.best_model).__name__
-        excluded_params = ['html', 'log_experiment', 'system_log', 'test_data']
+        excluded_params = ["html", "log_experiment", "system_log", "test_data"]
         filtered_setup_params = {
-            k: v
-            for k, v in self.setup_params.items() if k not in excluded_params
+            k: v for k, v in self.setup_params.items() if k not in excluded_params
         }
         setup_params_table = pd.DataFrame(
-            list(filtered_setup_params.items()), columns=['Parameter', 'Value']
+            list(filtered_setup_params.items()), columns=["Parameter", "Value"]
         )
 
         best_model_params = pd.DataFrame(
-            self.best_model.get_params().items(),
-            columns=['Parameter', 'Value']
+            self.best_model.get_params().items(), columns=["Parameter", "Value"]
         )
         best_model_params.to_csv(
             os.path.join(self.output_dir, "best_model.csv"), index=False
         )
-        self.results.to_csv(
-            os.path.join(self.output_dir, "comparison_results.csv")
-        )
-        self.test_result_df.to_csv(
-            os.path.join(self.output_dir, "test_results.csv")
-        )
+        self.results.to_csv(os.path.join(self.output_dir, "comparison_results.csv"))
+        self.test_result_df.to_csv(os.path.join(self.output_dir, "test_results.csv"))
 
         plots_html = ""
         length = len(self.plots)
         for i, (plot_name, plot_path) in enumerate(self.plots.items()):
             encoded_image = self.encode_image_to_base64(plot_path)
-            plots_html += f"""
-            <div class="plot">
-                <h3>{plot_name.capitalize()}</h3>
-                <img src="data:image/png;base64,{encoded_image}"
-                    alt="{plot_name}">
-            </div>
-            """
+            plots_html += (
+                f'<div class="plot">'
+                f"<h3>{plot_name.capitalize()}</h3>"
+                f'<img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">'
+                f"</div>"
+            )
             if i < length - 1:
                 plots_html += "<hr>"
 
         tree_plots = ""
         for i, tree in enumerate(self.trees):
             if tree:
-                tree_plots += f"""
-                <div class="plot">
-                    <h3>Tree {i+1}</h3>
-                    <img src="data:image/png;base64,
-                    {tree}"
-                    alt="tree {i+1}">
-                </div>
-                """
+                tree_plots += (
+                    f'<div class="plot">'
+                    f"<h3>Tree {i + 1}</h3>"
+                    f'<img src="data:image/png;base64,{tree}" alt="tree {i + 1}">'
+                    f"</div>"
+                )
 
         analyzer = FeatureImportanceAnalyzer(
             data=self.data,
             target_col=self.target_col,
             task_type=self.task_type,
             output_dir=self.output_dir,
+            exp=self.exp,
+            best_model=self.best_model,
         )
         feature_importance_html = analyzer.run()
 
-        html_content = f"""
-        {get_html_template()}
-            <h1>PyCaret Model Training Report</h1>
-            <div class="tabs">
-                <div class="tab" onclick="openTab(event, 'summary')">
-                Setup & Best Model</div>
-                <div class="tab" onclick="openTab(event, 'plots')">
-                Best Model Plots</div>
-                <div class="tab" onclick="openTab(event, 'feature')">
-                Feature Importance</div>
-        """
-        if self.plots_explainer_html:
-            html_content += """
-                <div class="tab" onclick="openTab(event, 'explainer')">
-                Explainer Plots</div>
-            """
-        html_content += f"""
-            </div>
-            <div id="summary" class="tab-content">
-                <h2>Setup Parameters</h2>
-                {setup_params_table.to_html(
-                    index=False,
-                    header=True,
-                    classes='table sortable'
-                )}
-                <h5>If you want to know all the experiment setup parameters,
-                  please check the PyCaret documentation for
-                  the classification/regression <code>exp</code> function.</h5>
-                <h2>Best Model: {model_name}</h2>
-                {best_model_params.to_html(
-                    index=False,
-                    header=True,
-                    classes='table sortable'
-                )}
-                <h2>Comparison Results on the Cross-Validation Set</h2>
-                {self.results.to_html(index=False, classes='table sortable')}
-                <h2>Results on the Test Set for the best model</h2>
-                {self.test_result_df.to_html(
-                    index=False,
-                    classes='table sortable'
-                )}
-            </div>
-            <div id="plots" class="tab-content">
-                <h2>Best Model Plots on the testing set</h2>
-                {plots_html}
-            </div>
-            <div id="feature" class="tab-content">
-                {feature_importance_html}
-            </div>
-        """
+        # --- Feature Metrics Help Button ---
+        feature_metrics_button_html = (
+            '<button class="help-modal-btn" id="openFeatureMetricsHelp" style="margin-bottom:12px;">'
+            "Help: Metrics Guide"
+            "</button>"
+            "<style>"
+            ".help-modal-btn {"
+            "background-color: #17623b;"
+            "color: #fff;"
+            "border: none;"
+            "border-radius: 24px;"
+            "padding: 10px 28px;"
+            "font-size: 1.1rem;"
+            "font-weight: bold;"
+            "letter-spacing: 0.03em;"
+            "cursor: pointer;"
+            "transition: background 0.2s, box-shadow 0.2s;"
+            "box-shadow: 0 2px 8px rgba(23,98,59,0.07);"
+            "}"
+            ".help-modal-btn:hover, .help-modal-btn:focus {"
+            "background-color: #21895e;"
+            "outline: none;"
+            "box-shadow: 0 4px 16px rgba(23,98,59,0.14);"
+            "}"
+            "</style>"
+        )
+
+        html_content = (
+            f"{get_html_template()}"
+            "<h1>Tabular Learner Model Report</h1>"
+            f"{feature_metrics_button_html}"
+            '<div class="tabs">'
+            '<div class="tab" onclick="openTab(event, \'summary\')">'
+            "Validation Result Summary & Config</div>"
+            '<div class="tab" onclick="openTab(event, \'plots\')">'
+            "Test Results</div>"
+            '<div class="tab" onclick="openTab(event, \'feature\')">'
+            "Feature Importance</div>"
+        )
         if self.plots_explainer_html:
-            html_content += f"""
-            <div id="explainer" class="tab-content">
-                {self.plots_explainer_html}
-                {tree_plots}
-            </div>
-            """
-        html_content += """
-        <script>
-        document.addEventListener("DOMContentLoaded", function() {
-            var tables = document.querySelectorAll("table.sortable");
-            tables.forEach(function(table) {
-                var headers = table.querySelectorAll("th");
-                headers.forEach(function(header, index) {
-                    header.style.cursor = "pointer";
-                    // Add initial arrow (up) to indicate sortability
-                    header.innerHTML += '<span class="sort-arrow"> ↑</span>';
-                    header.addEventListener("click", function() {
-                        var direction = this.getAttribute(
-                            "data-sort-direction"
-                        ) || "asc";
-                        // Reset arrows in all headers of this table
-                        headers.forEach(function(h) {
-                            var arrow = h.querySelector(".sort-arrow");
-                            if (arrow) arrow.textContent = " ↑";
-                        });
-                        // Set arrow for clicked header
-                        var arrow = this.querySelector(".sort-arrow");
-                        arrow.textContent = direction === "asc" ? " ↓" : " ↑";
-                        sortTable(table, index, direction);
-                        this.setAttribute("data-sort-direction",
-                        direction === "asc" ? "desc" : "asc");
-                    });
-                });
-            });
-        });
-
-        function sortTable(table, colNum, direction) {
-            var tb = table.tBodies[0];
-            var tr = Array.prototype.slice.call(tb.rows, 0);
-            var multiplier = direction === "asc" ? 1 : -1;
-            tr = tr.sort(function(a, b) {
-                var aText = a.cells[colNum].textContent.trim();
-                var bText = b.cells[colNum].textContent.trim();
-                // Remove arrow from text comparison
-                aText = aText.replace(/[↑↓]/g, '').trim();
-                bText = bText.replace(/[↑↓]/g, '').trim();
-                if (!isNaN(aText) && !isNaN(bText)) {
-                    return multiplier * (
-                        parseFloat(aText) - parseFloat(bText)
-                    );
-                } else {
-                    return multiplier * aText.localeCompare(bText);
-                }
-            });
-            for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);
-        }
-        </script>
-        """
-        html_content += f"""
-        {get_html_closing()}
-        """
+            html_content += (
+                '<div class="tab" onclick="openTab(event, \'explainer\')">'
+                "Explainer Plots</div>"
+            )
+        html_content += (
+            "</div>"
+            '<div id="summary" class="tab-content">'
+            "<h2>Model Metrics from Cross-Validation Set</h2>"
+            f"<h2>Best Model: {model_name}</h2>"
+            "<h5>The best model is selected by: Accuracy (Classification)"
+            " or R2 (Regression).</h5>"
+            f"{self.results.to_html(index=False, classes='table sortable')}"
+            "<h2>Best Model's Hyperparameters</h2>"
+            f"{best_model_params.to_html(index=False, header=True, classes='table sortable')}"
+            "<h2>Setup Parameters</h2>"
+            f"{setup_params_table.to_html(index=False, header=True, classes='table sortable')}"
+            "<h5>If you want to know all the experiment setup parameters,"
+            " please check the PyCaret documentation for"
+            " the classification/regression <code>exp</code> function.</h5>"
+            "</div>"
+            '<div id="plots" class="tab-content">'
+            f"<h2>Best Model: {model_name}</h2>"
+            "<h5>The best model is selected by: Accuracy (Classification)"
+            " or R2 (Regression).</h5>"
+            "<h2>Test Metrics</h2>"
+            f"{self.test_result_df.to_html(index=False)}"
+            "<h2>Test Results</h2>"
+            f"{plots_html}"
+            "</div>"
+            '<div id="feature" class="tab-content">'
+            f"{feature_importance_html}"
+            "</div>"
+        )
+        if self.plots_explainer_html:
+            html_content += (
+                '<div id="explainer" class="tab-content">'
+                f"{self.plots_explainer_html}"
+                f"{tree_plots}"
+                "</div>"
+            )
+        html_content += (
+            "<script>"
+            "document.addEventListener(\"DOMContentLoaded\", function() {"
+            "var tables = document.querySelectorAll(\"table.sortable\");"
+            "tables.forEach(function(table) {"
+            "var headers = table.querySelectorAll(\"th\");"
+            "headers.forEach(function(header, index) {"
+            "header.style.cursor = \"pointer\";"
+            "// Add initial arrow (up) to indicate sortability, use Unicode ↑ (U+2191)"
+            "header.innerHTML += '<span class=\"sort-arrow\"> ↑</span>';"
+            "header.addEventListener(\"click\", function() {"
+            "var direction = this.getAttribute("
+            "\"data-sort-direction\""
+            ") || \"asc\";"
+            "// Reset arrows in all headers of this table"
+            "headers.forEach(function(h) {"
+            "var arrow = h.querySelector(\".sort-arrow\");"
+            "if (arrow) arrow.textContent = \" ↑\";"
+            "});"
+            "// Set arrow for clicked header"
+            "var arrow = this.querySelector(\".sort-arrow\");"
+            "arrow.textContent = direction === \"asc\" ? \" ↓\" : \" ↑\";"
+            "sortTable(table, index, direction);"
+            "this.setAttribute(\"data-sort-direction\","
+            "direction === \"asc\" ? \"desc\" : \"asc\");"
+            "});"
+            "});"
+            "});"
+            "});"
+            "function sortTable(table, colNum, direction) {"
+            "var tb = table.tBodies[0];"
+            "var tr = Array.prototype.slice.call(tb.rows, 0);"
+            "var multiplier = direction === \"asc\" ? 1 : -1;"
+            "tr = tr.sort(function(a, b) {"
+            "var aText = a.cells[colNum].textContent.trim();"
+            "var bText = b.cells[colNum].textContent.trim();"
+            "// Remove arrow from text comparison"
+            "aText = aText.replace(/[↑↓]/g, '').trim();"
+            "bText = bText.replace(/[↑↓]/g, '').trim();"
+            "if (!isNaN(aText) && !isNaN(bText)) {"
+            "return multiplier * ("
+            "parseFloat(aText) - parseFloat(bText)"
+            ");"
+            "} else {"
+            "return multiplier * aText.localeCompare(bText);"
+            "}"
+            "});"
+            "for (var i = 0; i < tr.length; ++i) tb.appendChild(tr[i]);"
+            "}"
+            "</script>"
+        )
+        # --- Add the Feature Metrics Help Modal ---
+        html_content += get_feature_metrics_help_modal()
+        html_content += f"{get_html_closing()}"
         with open(
             os.path.join(self.output_dir, "comparison_result.html"),
-            "w"
+            "w",
+            encoding="utf-8",
         ) as file:
             file.write(html_content)
 
@@ -374,10 +414,8 @@
     def generate_plots_explainer(self):
         raise NotImplementedError("Subclasses should implement this method")
 
-    # not working now
     def generate_tree_plots(self):
-        from sklearn.ensemble import RandomForestClassifier, \
-            RandomForestRegressor
+        from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
         from xgboost import XGBClassifier, XGBRegressor
         from explainerdashboard.explainers import RandomForestExplainer
 
@@ -385,21 +423,25 @@
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
 
-        is_rf = isinstance(self.best_model, RandomForestClassifier) or \
-            isinstance(self.best_model, RandomForestRegressor)
+        is_rf = isinstance(
+            self.best_model, (RandomForestClassifier, RandomForestRegressor)
+        )
+        is_xgb = isinstance(self.best_model, (XGBClassifier, XGBRegressor))
 
-        is_xgb = isinstance(self.best_model, XGBClassifier) or \
-            isinstance(self.best_model, XGBRegressor)
+        num_trees = None
+        if is_rf:
+            num_trees = self.best_model.n_estimators
+        elif is_xgb:
+            num_trees = len(self.best_model.get_booster().get_dump())
+        else:
+            LOG.warning("Tree plots not supported for this model type.")
+            return
 
         try:
-            if is_rf:
-                num_trees = self.best_model.n_estimators
-            if is_xgb:
-                num_trees = len(self.best_model.get_booster().get_dump())
             explainer = RandomForestExplainer(self.best_model, X_test, y_test)
             for i in range(num_trees):
                 fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
-                LOG.info(f"Tree {i+1}")
+                LOG.info(f"Tree {i + 1}")
                 LOG.info(fig)
                 self.trees.append(fig)
         except Exception as e:
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/feature_help_modal.py	Wed Jul 02 19:00:03 2025 +0000
@@ -0,0 +1,120 @@
+def get_feature_metrics_help_modal() -> str:
+    modal_html = """
+<div id="featureMetricsHelpModal" class="modal">
+  <div class="modal-content">
+    <span class="close-feature-metrics">&times;</span>
+    <h2>Help Guide: Common Model Metrics</h2>
+    <div class="metrics-guide" style="max-height:65vh;overflow-y:auto;font-size:1.04em;">
+      <h3>1) General Metrics</h3>
+      <h4>Classification</h4>
+      <p><strong>Accuracy:</strong> The proportion of correct predictions among all predictions. It is calculated as (TP + TN) / (TP + TN + FP + FN). While intuitive, Accuracy can be misleading for imbalanced datasets where one class dominates. For example, in a dataset with 95% negative cases, a model predicting all negatives achieves 95% Accuracy but fails to identify positives.</p>
+      <p><strong>AUC (Area Under the Curve):</strong> Specifically, the Area Under the Receiver Operating Characteristic Curve (ROC-AUC) measures a model’s ability to distinguish between classes. It ranges from 0 to 1, where 1 indicates perfect separation and 0.5 suggests random guessing. ROC-AUC is robust for binary and multiclass problems but may be less informative for highly imbalanced datasets.</p>
+      <h4>Regression</h4>
+      <p><strong>R2 (Coefficient of Determination):</strong> Measures the proportion of variance in the dependent variable explained by the independent variables. It ranges from 0 to 1, with 1 indicating perfect prediction and 0 indicating no explanatory power. Negative values are possible if the model performs worse than a mean-based baseline. R2 is widely used but sensitive to outliers.</p>
+      <p><strong>RMSE (Root Mean Squared Error):</strong> The square root of the average squared differences between predicted and actual values. It penalizes larger errors more heavily and is expressed in the same units as the target variable, making it interpretable. Lower RMSE indicates better model performance.</p>
+      <p><strong>MAE (Mean Absolute Error):</strong> The average of absolute differences between predicted and actual values. It is less sensitive to outliers than RMSE and provides a straightforward measure of average error magnitude. Lower MAE is better.</p>
+
+      <h3>2) Precision, Recall & Specificity</h3>
+      <h4>Classification</h4>
+      <p><strong>Precision:</strong> The proportion of positive predictions that are correct, calculated as TP / (TP + FP). High Precision is crucial when false positives are costly, such as in spam email detection, where misclassifying legitimate emails as spam disrupts user experience.</p>
+      <p><strong>Recall (Sensitivity):</strong> The proportion of actual positives correctly predicted, calculated as TP / (TP + FN). High Recall is vital when missing positives is risky, such as in disease diagnosis, where failing to identify a sick patient could have severe consequences.</p>
+      <p><strong>Specificity:</strong> The true negative rate, calculated as TN / (TN + FP). It measures how well a model identifies negatives, making it valuable in medical testing to minimize false alarms (e.g., incorrectly diagnosing healthy patients as sick).</p>
+
+      <h3>3) Macro, Micro, and Weighted Averages</h3>
+      <h4>Classification</h4>
+      <p><strong>Macro Precision / Recall / F1:</strong> Computes the metric for each class independently and averages them, treating all classes equally. This is ideal for balanced datasets or when all classes are equally important, such as in multiclass image classification with similar class frequencies.</p>
+      <p><strong>Micro Precision / Recall / F1:</strong> Aggregates true positives (TP), false positives (FP), and false negatives (FN) across all classes before computing the metric. It provides a global perspective and is suitable for imbalanced datasets or multilabel problems, as it accounts for class frequency.</p>
+      <p><strong>Weighted Precision / Recall / F1:</strong> Averages the metric across classes, weighted by the number of true instances per class. This balances the importance of classes based on their frequency, making it useful for imbalanced datasets where larger classes should have more influence but smaller classes are still considered.</p>
+
+      <h3>4) Average Precision (PR-AUC Variants)</h3>
+      <h4>Classification</h4>
+      <p><strong>Average Precision:</strong> The Area Under the Precision-Recall Curve (PR-AUC) summarizes the trade-off between Precision and Recall. It is particularly useful for imbalanced datasets, where ROC-AUC may overestimate performance. Average Precision is computed by averaging Precision values at different Recall thresholds, providing a robust measure for ranking tasks or rare class detection.</p>
+
+      <h3>5) ROC-AUC Variants</h3>
+      <h4>Classification</h4>
+      <p><strong>ROC-AUC:</strong> The Area Under the Receiver Operating Characteristic Curve plots the true positive rate (Recall) against the false positive rate (1 - Specificity) at various thresholds. It quantifies the model’s ability to separate classes, with higher values indicating better performance.</p>
+      <p><strong>Macro ROC-AUC:</strong> Averages the ROC-AUC scores across all classes, treating each class equally. This is suitable for balanced multiclass problems where all classes are of equal importance.</p>
+      <p><strong>Micro ROC-AUC:</strong> Computes a single ROC-AUC by aggregating predictions and true labels across all classes. It is effective for multiclass or multilabel problems with class imbalance, as it accounts for the overall prediction distribution.</p>
+
+      <h3>6) Confusion Matrix Stats (Per Class)</h3>
+      <h4>Classification</h4>
+      <p><strong>True Positives (TP):</strong> The number of correct positive predictions for a given class.</p>
+      <p><strong>True Negatives (TN):</strong> The number of correct negative predictions for a given class.</p>
+      <p><strong>False Positives (FP):</strong> The number of incorrect positive predictions for a given class (false alarms).</p>
+      <p><strong>False Negatives (FN):</strong> The number of incorrect negative predictions for a given class (missed detections). These stats are visualized in PyCaret’s confusion matrix plots, aiding class-wise performance analysis.</p>
+
+      <h3>7) Other Useful Metrics</h3>
+      <h4>Classification</h4>
+      <p><strong>Cohen’s Kappa:</strong> Measures the agreement between predicted and actual labels, adjusted for chance. It ranges from -1 to 1, where 1 indicates perfect agreement, 0 indicates chance-level agreement, and negative values suggest worse-than-chance performance. Kappa is useful for multiclass problems with imbalanced labels.</p>
+      <p><strong>Matthews Correlation Coefficient (MCC):</strong> A balanced measure that considers TP, TN, FP, and FN, calculated as (TP * TN - FP * FN) / sqrt((TP + FP)(TP + FN)(TN + FP)(TN + FN)). It ranges from -1 to 1, with 1 being perfect prediction. MCC is particularly effective for imbalanced datasets due to its symmetry across classes.</p>
+      <h4>Regression</h4>
+      <p><strong>MSE (Mean Squared Error):</strong> The average of squared differences between predicted and actual values. It amplifies larger errors, making it sensitive to outliers. Lower MSE indicates better performance.</p>
+      <p><strong>MAPE (Mean Absolute Percentage Error):</strong> The average of absolute percentage differences between predicted and actual values, calculated as (1/n) * Σ(|actual - predicted| / |actual|) * 100. It is useful when relative errors are important but can be unstable if actual values are near zero.</p>
+    </div>
+  </div>
+</div>
+"""
+    modal_css = """
+<style>
+/* Modal Background & Content */
+#featureMetricsHelpModal.modal {
+  display: none;
+  position: fixed;
+  z-index: 9999;
+  left: 0; top: 0;
+  width: 100%; height: 100%;
+  overflow: auto;
+  background-color: rgba(0,0,0,0.45);
+}
+#featureMetricsHelpModal .modal-content {
+  background-color: #fefefe;
+  margin: 5% auto;
+  padding: 24px 28px 20px 28px;
+  border: 1.5px solid #17623b;
+  width: 90%;
+  max-width: 800px;
+  border-radius: 18px;
+  box-shadow: 0 8px 32px rgba(23,98,59,0.20);
+}
+#featureMetricsHelpModal .close-feature-metrics {
+  color: #17623b;
+  float: right;
+  font-size: 28px;
+  font-weight: bold;
+  cursor: pointer;
+  transition: color 0.2s;
+}
+#featureMetricsHelpModal .close-feature-metrics:hover {
+  color: #21895e;
+}
+.metrics-guide h3 { margin-top: 20px; }
+.metrics-guide h4 { margin-top: 12px; color: #17623b; }
+.metrics-guide p { margin: 5px 0 10px 0; }
+.metrics-guide ul { margin: 10px 0 10px 24px; }
+</style>
+"""
+    modal_js = """
+<script>
+document.addEventListener("DOMContentLoaded", function() {
+  var modal = document.getElementById("featureMetricsHelpModal");
+  var openBtn = document.getElementById("openFeatureMetricsHelp");
+  var span = document.getElementsByClassName("close-feature-metrics")[0];
+  if (openBtn && modal) {
+    openBtn.onclick = function() {
+      modal.style.display = "block";
+    };
+  }
+  if (span && modal) {
+    span.onclick = function() {
+      modal.style.display = "none";
+    };
+  }
+  window.onclick = function(event) {
+    if (event.target == modal) {
+      modal.style.display = "none";
+    }
+  }
+});
+</script>
+"""
+    return modal_css + modal_html + modal_js
--- a/feature_importance.py	Sat Jun 21 15:07:04 2025 +0000
+++ b/feature_importance.py	Wed Jul 02 19:00:03 2025 +0000
@@ -4,6 +4,7 @@
 
 import matplotlib.pyplot as plt
 import pandas as pd
+import shap
 from pycaret.classification import ClassificationExperiment
 from pycaret.regression import RegressionExperiment
 
@@ -18,25 +19,38 @@
             output_dir,
             data_path=None,
             data=None,
-            target_col=None):
+            target_col=None,
+            exp=None,
+            best_model=None):
 
-        if data is not None:
-            self.data = data
-            LOG.info("Data loaded from memory")
+        self.task_type = task_type
+        self.output_dir = output_dir
+        self.exp = exp
+        self.best_model = best_model
+
+        if exp is not None:
+            # Assume all configs (data, target) are in exp
+            self.data = exp.dataset.copy()
+            self.target = exp.target_param
+            LOG.info("Using provided experiment object")
         else:
-            self.target_col = target_col
-            self.data = pd.read_csv(data_path, sep=None, engine='python')
-            self.data.columns = self.data.columns.str.replace('.', '_')
-            self.data = self.data.fillna(self.data.median(numeric_only=True))
-        self.task_type = task_type
-        self.target = self.data.columns[int(target_col) - 1]
-        self.exp = ClassificationExperiment() \
-            if task_type == 'classification' \
-            else RegressionExperiment()
+            if data is not None:
+                self.data = data
+                LOG.info("Data loaded from memory")
+            else:
+                self.target_col = target_col
+                self.data = pd.read_csv(data_path, sep=None, engine='python')
+                self.data.columns = self.data.columns.str.replace('.', '_')
+                self.data = self.data.fillna(self.data.median(numeric_only=True))
+            self.target = self.data.columns[int(target_col) - 1]
+            self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment()
+
         self.plots = {}
-        self.output_dir = output_dir
 
     def setup_pycaret(self):
+        if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup:
+            LOG.info("Experiment already set up. Skipping PyCaret setup.")
+            return
         LOG.info("Initializing PyCaret")
         setup_params = {
             'target': self.target,
@@ -45,25 +59,36 @@
             'log_experiment': False,
             'system_log': False
         }
-        LOG.info(self.task_type)
-        LOG.info(self.exp)
         self.exp.setup(self.data, **setup_params)
 
-    # def save_coefficients(self):
-    #     model = self.exp.create_model('lr')
-    #     coef_df = pd.DataFrame({
-    #         'Feature': self.data.columns.drop(self.target),
-    #         'Coefficient': model.coef_[0]
-    #     })
-    #     coef_html = coef_df.to_html(index=False)
-    #     return coef_html
+    def save_tree_importance(self):
+        model = self.best_model or self.exp.get_config('best_model')
+        processed_features = self.exp.get_config('X_transformed').columns
+
+        # Try feature_importances_ or coef_ if available
+        importances = None
+        model_type = model.__class__.__name__
+        self.tree_model_name = model_type  # Store the model name for reporting
 
-    def save_tree_importance(self):
-        model = self.exp.create_model('rf')
-        importances = model.feature_importances_
-        processed_features = self.exp.get_config('X_transformed').columns
-        LOG.debug(f"Feature importances: {importances}")
-        LOG.debug(f"Features: {processed_features}")
+        if hasattr(model, "feature_importances_"):
+            importances = model.feature_importances_
+        elif hasattr(model, "coef_"):
+            # For linear models, flatten coef_ and take abs (importance as magnitude)
+            importances = abs(model.coef_).flatten()
+        else:
+            # Neither attribute exists; skip the plot
+            LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.")
+            self.tree_model_name = None  # No plot generated
+            return
+
+        # Defensive: handle mismatch in number of features
+        if len(importances) != len(processed_features):
+            LOG.warning(
+                f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot."
+            )
+            self.tree_model_name = None
+            return
+
         feature_importances = pd.DataFrame({
             'Feature': processed_features,
             'Importance': importances
@@ -73,7 +98,7 @@
             feature_importances['Feature'],
             feature_importances['Importance'])
         plt.xlabel('Importance')
-        plt.title('Feature Importance (Random Forest)')
+        plt.title(f'Feature Importance ({model_type})')
         plot_path = os.path.join(
             self.output_dir,
             'tree_importance.png')
@@ -82,53 +107,64 @@
         self.plots['tree_importance'] = plot_path
 
     def save_shap_values(self):
-        model = self.exp.create_model('lightgbm')
-        import shap
-        explainer = shap.Explainer(model)
-        shap_values = explainer.shap_values(
-            self.exp.get_config('X_transformed'))
-        shap.summary_plot(shap_values,
-                          self.exp.get_config('X_transformed'), show=False)
-        plt.title('Shap (LightGBM)')
-        plot_path = os.path.join(
-            self.output_dir, 'shap_summary.png')
+        model = self.best_model or self.exp.get_config('best_model')
+        X_transformed = self.exp.get_config('X_transformed')
+        tree_classes = (
+            "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting"
+        )
+        model_class_name = model.__class__.__name__
+        self.shap_model_name = model_class_name
+
+        # Ensure feature alignment
+        if hasattr(model, "feature_name_"):
+            used_features = model.feature_name_
+        elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"):
+            used_features = model.booster_.feature_name()
+        else:
+            used_features = X_transformed.columns
+
+        if any(tc in model_class_name for tc in tree_classes):
+            explainer = shap.TreeExplainer(model)
+            X_shap = X_transformed[used_features]
+            shap_values = explainer.shap_values(X_shap)
+            plot_X = X_shap
+            plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)"
+        else:
+            sampled_X = X_transformed[used_features].sample(100, random_state=42)
+            explainer = shap.KernelExplainer(model.predict, sampled_X)
+            shap_values = explainer.shap_values(sampled_X)
+            plot_X = sampled_X
+            plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)"
+
+        shap.summary_plot(shap_values, plot_X, show=False)
+        plt.title(plot_title)
+        plot_path = os.path.join(self.output_dir, "shap_summary.png")
         plt.savefig(plot_path)
         plt.close()
-        self.plots['shap_summary'] = plot_path
-
-    def generate_feature_importance(self):
-        # coef_html = self.save_coefficients()
-        self.save_tree_importance()
-        self.save_shap_values()
-
-    def encode_image_to_base64(self, img_path):
-        with open(img_path, 'rb') as img_file:
-            return base64.b64encode(img_file.read()).decode('utf-8')
+        self.plots["shap_summary"] = plot_path
 
     def generate_html_report(self):
         LOG.info("Generating HTML report")
 
-        # Read and encode plot images
         plots_html = ""
         for plot_name, plot_path in self.plots.items():
+            # Special handling for tree importance: skip if no model name (not generated)
+            if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None):
+                continue
             encoded_image = self.encode_image_to_base64(plot_path)
+            if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None):
+                section_title = f"Feature importance analysis from a trained {self.tree_model_name}"
+            elif plot_name == 'shap_summary':
+                section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}"
+            else:
+                section_title = plot_name
             plots_html += f"""
             <div class="plot" id="{plot_name}">
-                <h2>{'Feature importance analysis from a'
-                    'trained Random Forest'
-                    if plot_name == 'tree_importance'
-                    else 'SHAP Summary from a trained lightgbm'}</h2>
-                <h3>{'Use gini impurity for'
-                    'calculating feature importance for classification'
-                    'and Variance Reduction for regression'
-                  if plot_name == 'tree_importance'
-                  else ''}</h3>
-                <img src="data:image/png;base64,
-                {encoded_image}" alt="{plot_name}">
+                <h2>{section_title}</h2>
+                <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}">
             </div>
             """
 
-        # Generate HTML content with tabs
         html_content = f"""
             <h1>PyCaret Feature Importance Report</h1>
             {plots_html}
@@ -136,34 +172,14 @@
 
         return html_content
 
-    def run(self):
-        LOG.info("Running feature importance analysis")
-        self.setup_pycaret()
-        self.generate_feature_importance()
-        html_content = self.generate_html_report()
-        LOG.info("Feature importance analysis completed")
-        return html_content
-
+    def encode_image_to_base64(self, img_path):
+        with open(img_path, 'rb') as img_file:
+            return base64.b64encode(img_file.read()).decode('utf-8')
 
-if __name__ == "__main__":
-    import argparse
-    parser = argparse.ArgumentParser(description="Feature Importance Analysis")
-    parser.add_argument(
-        "--data_path", type=str, help="Path to the dataset")
-    parser.add_argument(
-        "--target_col", type=int,
-        help="Index of the target column (1-based)")
-    parser.add_argument(
-        "--task_type", type=str,
-        choices=["classification", "regression"],
-        help="Task type: classification or regression")
-    parser.add_argument(
-        "--output_dir",
-        type=str,
-        help="Directory to save the outputs")
-    args = parser.parse_args()
-
-    analyzer = FeatureImportanceAnalyzer(
-        args.data_path, args.target_col,
-        args.task_type, args.output_dir)
-    analyzer.run()
+    def run(self):
+        if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup:
+            self.setup_pycaret()
+        self.save_tree_importance()
+        self.save_shap_values()
+        html_content = self.generate_html_report()
+        return html_content
--- a/pycaret_predict.xml	Sat Jun 21 15:07:04 2025 +0000
+++ b/pycaret_predict.xml	Wed Jul 02 19:00:03 2025 +0000
@@ -35,7 +35,12 @@
             <param name="model_type" value="classification" />
             <param name="target_feature" value="11" />
             <output name="prediction" file="predictions_classification.csv" />
-            <output name="report" file="evaluation_report_classification.html" compare="sim_size" />
+            <output name="report">
+                <assert_contents>
+                    <has_text text="Metrics" />
+                    <has_text text="Plots" />
+                </assert_contents>
+            </output>
         </test>
         <test expect_num_outputs="2">
             <param name="input_model" value="expected_model_regression.h5" />
@@ -43,7 +48,12 @@
             <param name="model_type" value="regression" />
             <param name="target_feature" value="1" />
             <output name="prediction" file="predictions_regression.csv" />
-            <output name="report" file="evaluation_report_regression.html" compare="sim_size" />
+            <output name="report">
+                <assert_contents>
+                    <has_text text="Metrics" />
+                    <has_text text="Plots" />
+                </assert_contents>
+            </output>
         </test>
     </tests>
     <help>
@@ -58,4 +68,4 @@
 
     </help>
     <expand macro="macro_citations" />
-</tool>
\ No newline at end of file
+</tool>
--- a/utils.py	Sat Jun 21 15:07:04 2025 +0000
+++ b/utils.py	Wed Jul 02 19:00:03 2025 +0000
@@ -11,6 +11,7 @@
     return """
     <html>
     <head>
+        <meta charset="UTF-8">
         <title>Model Training Report</title>
         <style>
           body {