changeset 10:e2a6fed32d54 draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 47a5977e074223e92e216efa42969a4056516707
author goeckslab
date Fri, 01 Aug 2025 14:02:26 +0000
parents c6c1f8777aae
children
files base_model_trainer.py dashboard.py pycaret_macros.xml pycaret_train.py test-data/expected_comparison_result_classification.html test-data/expected_comparison_result_classification_customized.html test-data/expected_comparison_result_regression.html utils.py
diffstat 8 files changed, 216 insertions(+), 74 deletions(-) [+]
line wrap: on
line diff
--- a/base_model_trainer.py	Thu Jul 31 15:41:24 2025 +0000
+++ b/base_model_trainer.py	Fri Aug 01 14:02:26 2025 +0000
@@ -44,6 +44,7 @@
         self.target = None
         self.best_model = None
         self.results = None
+        self.tuning_results = None
         self.features_name = None
         self.plots = {}
         self.explainer_plots = {}
@@ -57,44 +58,98 @@
         self.test_data = None
 
         if not self.output_dir:
-            raise ValueError("output_dir must be specified and not None")
+            raise ValueError(
+                "output_dir must be specified and not None"
+            )
+
+        # Warn about irrelevant kwargs for the task type
+        if self.task_type == "regression" and (
+            "probability_threshold" in self.user_kwargs
+        ):
+            LOG.warning(
+                "probability_threshold is ignored for regression tasks."
+            )
 
         LOG.info(f"Model kwargs: {self.__dict__}")
 
     def load_data(self):
         LOG.info(f"Loading data from {self.input_file}")
-        self.data = pd.read_csv(self.input_file, sep=None, engine="python")
+        self.data = pd.read_csv(
+            self.input_file, sep=None, engine="python"
+        )
         self.data.columns = self.data.columns.str.replace(".", "_")
-        if "prediction_label" in self.data.columns:
+
+        names = self.data.columns.to_list()
+        LOG.info(f"Original dataset columns: {names}")
+
+        target_index = int(self.target_col) - 1
+        num_cols = len(names)
+        if target_index < 0 or target_index >= num_cols:
+            raise ValueError(
+                f"Target column number {self.target_col} is invalid. "
+                f"Please select a number between 1 and {num_cols}."
+            )
+
+        self.target = names[target_index]
+
+        # Conditional drop: only if 'prediction_label' exists and is not
+        # the target
+        if "prediction_label" in self.data.columns and (
+            self.data.columns[target_index] != "prediction_label"
+        ):
+            LOG.info(
+                "Dropping 'prediction_label' column as it's not the target."
+            )
             self.data = self.data.drop(columns=["prediction_label"])
+        else:
+            if self.target == "prediction_label":
+                LOG.warning(
+                    "Using 'prediction_label' as target column. "
+                    "This may not be intended if it's a previous prediction."
+                )
 
-        numeric_cols = self.data.select_dtypes(include=["number"]).columns
-        non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns
+        numeric_cols = self.data.select_dtypes(
+            include=["number"]
+        ).columns
+        non_numeric_cols = self.data.select_dtypes(
+            exclude=["number"]
+        ).columns
         self.data[numeric_cols] = self.data[numeric_cols].apply(
             pd.to_numeric, errors="coerce"
         )
         if len(non_numeric_cols) > 0:
-            LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
+            LOG.info(
+                f"Non-numeric columns found: {non_numeric_cols.tolist()}"
+            )
 
+        # Update names after possible drop
         names = self.data.columns.to_list()
-        target_index = int(self.target_col) - 1
-        self.target = names[target_index]
-        self.features_name = [n for i, n in enumerate(names) if i != target_index]
+        LOG.info(f"Dataset columns after processing: {names}")
+
+        self.features_name = [n for n in names if n != self.target]
 
         if getattr(self, "missing_value_strategy", None):
             strat = self.missing_value_strategy
             if strat == "mean":
-                self.data = self.data.fillna(self.data.mean(numeric_only=True))
+                self.data = self.data.fillna(
+                    self.data.mean(numeric_only=True)
+                )
             elif strat == "median":
-                self.data = self.data.fillna(self.data.median(numeric_only=True))
+                self.data = self.data.fillna(
+                    self.data.median(numeric_only=True)
+                )
             elif strat == "drop":
                 self.data = self.data.dropna()
         else:
-            self.data = self.data.fillna(self.data.median(numeric_only=True))
+            self.data = self.data.fillna(
+                self.data.median(numeric_only=True)
+            )
 
         if self.test_file:
             LOG.info(f"Loading test data from {self.test_file}")
-            df_test = pd.read_csv(self.test_file, sep=None, engine="python")
+            df_test = pd.read_csv(
+                self.test_file, sep=None, engine="python"
+            )
             df_test.columns = df_test.columns.str.replace(".", "_")
             self.test_data = df_test
 
@@ -137,7 +192,9 @@
 
             self.exp = RegressionExperiment()
         else:
-            raise ValueError("task_type must be 'classification' or 'regression'")
+            raise ValueError(
+                "task_type must be 'classification' or 'regression'"
+            )
 
         self.exp.setup(self.data, **self.setup_params)
         self.setup_params.update(self.user_kwargs)
@@ -171,20 +228,26 @@
         if getattr(self, "tune_model", False):
             LOG.info("Tuning hyperparameters of the best model")
             self.best_model = self.exp.tune_model(self.best_model)
-            self.results = self.exp.pull()
+            self.tuning_results = self.exp.pull()
 
         if self.task_type == "classification":
             self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
 
         prob_thresh = getattr(self, "probability_threshold", None)
-        if self.task_type == "classification" and prob_thresh is not None:
-            _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh)
+        if self.task_type == "classification" and (
+            prob_thresh is not None
+        ):
+            _ = self.exp.predict_model(
+                self.best_model, probability_threshold=prob_thresh
+            )
         else:
             _ = self.exp.predict_model(self.best_model)
 
         self.test_result_df = self.exp.pull()
         if self.task_type == "classification":
-            self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True)
+            self.test_result_df.rename(
+                columns={"AUC": "ROC-AUC"}, inplace=True
+            )
 
     def save_model(self):
         hdf5_path = Path(self.output_dir) / "pycaret_model.h5"
@@ -198,7 +261,7 @@
     def generate_plots(self):
         LOG.info("Generating PyCaret diagnostic pltos")
 
-        # choose the right plots based on task
+        # choose the right plots based on task type
         if self.task_type == "classification":
             plot_names = [
                 "learning",
@@ -214,10 +277,13 @@
                 "roc_auc",
             ]
         else:
-            plot_names = ["residuals", "vc", "parameter", "error", "learning"]
+            plot_names = ["residuals", "vc", "parameter", "error",
+                          "learning"]
         for name in plot_names:
             try:
-                ax = self.exp.plot_model(self.best_model, plot=name, save=False)
+                ax = self.exp.plot_model(
+                    self.best_model, plot=name, save=False
+                )
                 out_path = Path(self.output_dir) / f"plot_{name}.png"
                 fig = ax.get_figure()
                 fig.savefig(out_path, bbox_inches="tight")
@@ -239,18 +305,23 @@
             best_model_name = type(self.best_model).__name__
         LOG.info(f"Best model determined as: {best_model_name}")
 
-    # 2) Compute training sample count
+        # 2) Compute training sample count
         try:
             n_train = self.exp.X_train.shape[0]
         except Exception:
-            n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0]
+            n_train = getattr(
+                self.exp, "X_train_transformed", pd.DataFrame()
+            ).shape[0]
         total_rows = self.data.shape[0]
 
         # 3) Build setup parameters table
         all_params = self.setup_params.copy()
-        if self.task_type == "classification" and hasattr(self, "probability_threshold"):
-            all_params["probability_threshold"] = self.probability_threshold
-
+        if self.task_type == "classification" and (
+            hasattr(self, "probability_threshold")
+        ):
+            all_params["probability_threshold"] = (
+                self.probability_threshold
+            )
         display_keys = [
             "Target",
             "Session ID",
@@ -290,9 +361,11 @@
             elif key == "Cross Validation Folds":
                 dv = v if v is not None else "None"
             elif key == "Models":
-                dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None"
+                dv = ", ".join(map(str, v)) if isinstance(
+                    v, (list, tuple)
+                ) else "None"
             elif key == "Probability Threshold":
-                dv = v if v is not None else "None"
+                dv = f"{v:.2f}" if v is not None else "0.5"
             else:
                 dv = v if v is not None else "None"
             setup_rows.append([key, dv])
@@ -300,19 +373,29 @@
             setup_rows.append(["best_model_metric", self.exp._fold_metric])
 
         df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"])
-        df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False)
+        df_setup.to_csv(
+            Path(self.output_dir) / "setup_params.csv", index=False
+        )
 
         # 4) Persist CSVs
         self.results.to_csv(
-            Path(self.output_dir) / "comparison_results.csv", index=False
+            Path(self.output_dir) / "comparison_results.csv",
+            index=False
         )
         self.test_result_df.to_csv(
             Path(self.output_dir) / "test_results.csv", index=False
         )
         pd.DataFrame(
-            self.best_model.get_params().items(), columns=["Parameter", "Value"]
+            self.best_model.get_params().items(),
+            columns=["Parameter", "Value"]
         ).to_csv(Path(self.output_dir) / "best_model.csv", index=False)
 
+        if self.tuning_results is not None:
+            self.tuning_results.to_csv(
+                Path(self.output_dir) / "tuning_results.csv",
+                index=False
+            )
+
         # 5) Header
         header = f"<h2>Best Model: {best_model_name}</h2>"
 
@@ -334,14 +417,31 @@
             "residuals": "Residuals Distribution",
             "error": "Prediction Error Distribution",
         }
-        val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True)
+        val_df.drop(
+            columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True
+        )
         summary_html = (
             header
             + "<h2>Train & Validation Summary</h2>"
             + '<div class="table-wrapper">'
             + val_df.to_html(index=False, classes="table sortable")
             + "</div>"
-            + "<h2>Setup Parameters</h2>"
+        )
+
+        if self.tuning_results is not None:
+            tuning_df = self.tuning_results.copy()
+            tuning_df.drop(
+                columns=["TT (Sec)"], errors="ignore", inplace=True
+            )
+            summary_html += (
+                f"<h2>{best_model_name}: Tuning Summary</h2>"
+                + '<div class="table-wrapper">'
+                + tuning_df.to_html(index=False, classes="table sortable")
+                + "</div>"
+            )
+
+        summary_html += (
+            "<h2>Setup Parameters</h2>"
             + '<div class="table-wrapper">'
             + df_setup.to_html(index=False, classes="table sortable")
             + "</div>"
@@ -349,7 +449,8 @@
             + "<h2>Best Model Hyperparameters</h2>"
             + '<div class="table-wrapper">'
             + pd.DataFrame(
-                self.best_model.get_params().items(), columns=["Parameter", "Value"]
+                self.best_model.get_params().items(),
+                columns=["Parameter", "Value"]
             ).to_html(index=False, classes="table sortable")
             + "</div>"
         )
@@ -373,12 +474,15 @@
             if name in self.plots:
                 summary_html += "<hr>"
                 b64 = encode_image_to_base64(self.plots[name])
-                title = plot_title_map.get(name, name.replace("_", " ").title())
+                title = plot_title_map.get(
+                    name, name.replace("_", " ").title()
+                )
                 summary_html += (
                     '<div class="plot">'
                     f"<h2>{title}</h2>"
                     f'<img src="data:image/png;base64,{b64}" '
-                    'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>'
+                    'style="max-width:90%;max-height:600px;'
+                    'border:1px solid #ddd;"/>'
                     "</div>"
                 )
 
@@ -386,7 +490,9 @@
         test_html = (
             header
             + '<div class="table-wrapper">'
-            + self.test_result_df.to_html(index=False, classes="table sortable")
+            + self.test_result_df.to_html(
+                index=False, classes="table sortable"
+            )
             + "</div>"
         )
         if self.task_type == "regression":
@@ -397,18 +503,25 @@
                     .rename("True")
                 )
                 y_pred = pd.Series(
-                    self.best_model.predict(self.exp.X_test_transformed)
+                    self.best_model.predict(
+                        self.exp.X_test_transformed
+                    )
                 ).rename("Predicted")
                 df_tp = pd.concat([y_true, y_pred], axis=1)
                 test_html += "<h2>True vs Predicted Values</h2>"
                 test_html += (
-                    '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">'
-                    + df_tp.head(50).to_html(index=False, classes="table sortable")
+                    '<div class="table-wrapper" '
+                    'style="max-height:400px; overflow-y:auto;">'
+                    + df_tp.head(50).to_html(
+                        index=False, classes="table sortable"
+                    )
                     + "</div>"
                     + add_hr_to_html()
                 )
             except Exception as e:
-                LOG.warning(f"Could not generate True vs Predicted table: {e}")
+                LOG.warning(
+                    f"Could not generate True vs Predicted table: {e}"
+                )
 
         # 5a) Explainer-substituted plots in order
         if self.task_type == "regression":
@@ -426,38 +539,53 @@
             fig_or_fn = self.explainer_plots.pop(key, None)
             if fig_or_fn is not None:
                 fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
-                title = plot_title_map.get(key, key.replace("_", " ").title())
+                title = plot_title_map.get(
+                    key, key.replace("_", " ").title()
+                )
                 test_html += (
-                    f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
+                    f"<h2>{title}</h2>" + add_plot_to_html(fig)
+                    + add_hr_to_html()
                 )
         # 5b) Remaining PyCaret test plots
         for name, path in self.plots.items():
-            # classification: include only the small extras, before skipping anything
-            if self.task_type == "classification" and name in {
-                "threshold",
-                "pr_auc",
-                "class_report",
-            }:
-                title = plot_title_map.get(name, name.replace("_", " ").title())
+            # classification: include only the small extras, before
+            # skipping anything
+            if self.task_type == "classification" and (
+                name in {
+                    "threshold",
+                    "pr_auc",
+                    "class_report",
+                }
+            ):
+                title = plot_title_map.get(
+                    name, name.replace("_", " ").title()
+                )
                 b64 = encode_image_to_base64(path)
                 test_html += (
                     f"<h2>{title}</h2>"
                     "<div class='plot'>"
                     f"<img src='data:image/png;base64,{b64}' "
-                    "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
+                    "style='max-width:90%;max-height:600px;"
+                    "border:1px solid #ddd;'/>"
                     "</div>" + add_hr_to_html()
                 )
                 continue
 
-            # regression: explicitly include the 'error' plot, before skipping
-            if self.task_type == "regression" and name == "error":
-                title = plot_title_map.get("error", "Prediction Error Distribution")
+            # regression: explicitly include the 'error' plot,
+            # before skipping
+            if self.task_type == "regression" and (
+                name == "error"
+            ):
+                title = plot_title_map.get(
+                    "error", "Prediction Error Distribution"
+                )
                 b64 = encode_image_to_base64(path)
                 test_html += (
                     f"<h2>{title}</h2>"
                     "<div class='plot'>"
                     f"<img src='data:image/png;base64,{b64}' "
-                    "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>"
+                    "style='max-width:90%;max-height:600px;"
+                    "border:1px solid #ddd;'/>"
                     "</div>" + add_hr_to_html()
                 )
                 continue
@@ -491,11 +619,14 @@
                     else "Permutation Feature Importance"
                 )
                 feature_html += (
-                    f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
+                    f"<h2>{title}</h2>" + add_plot_to_html(fig)
+                    + add_hr_to_html()
                 )
 
         # 6c) PDPs last
-        pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__"))
+        pdp_keys = sorted(
+            k for k in self.explainer_plots if k.startswith("pdp__")
+        )
         for k in pdp_keys:
             fig_or_fn = self.explainer_plots[k]
             fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn
@@ -503,7 +634,8 @@
             feature = k.split("__", 1)[1]
             title = f"Partial Dependence for {feature}"
             feature_html += (
-                f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html()
+                f"<h2>{title}</h2>" + add_plot_to_html(fig)
+                + add_hr_to_html()
             )
         # 7) Assemble final HTML (three tabs)
         html = get_html_template()
@@ -516,7 +648,10 @@
         (Path(self.output_dir) / "comparison_result.html").write_text(
             html, encoding="utf-8"
         )
-        LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html")
+        LOG.info(
+            f"HTML report generated at: "
+            f"{self.output_dir}/comparison_result.html"
+        )
 
     def save_dashboard(self):
         raise NotImplementedError("Subclasses should implement this method")
@@ -525,7 +660,9 @@
         raise NotImplementedError("Subclasses should implement this method")
 
     def generate_tree_plots(self):
-        from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
+        from sklearn.ensemble import (
+            RandomForestClassifier, RandomForestRegressor
+        )
         from xgboost import XGBClassifier, XGBRegressor
         from explainerdashboard.explainers import RandomForestExplainer
 
@@ -533,7 +670,9 @@
         X_test = self.exp.X_test_transformed.copy()
         y_test = self.exp.y_test_transformed
 
-        if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)):
+        if isinstance(
+            self.best_model, (RandomForestClassifier, RandomForestRegressor)
+        ):
             n_trees = self.best_model.n_estimators
         elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)):
             n_trees = len(self.best_model.get_booster().get_dump())
--- a/dashboard.py	Thu Jul 31 15:41:24 2025 +0000
+++ b/dashboard.py	Fri Aug 01 14:02:26 2025 +0000
@@ -76,7 +76,6 @@
         col.replace(".", "__").replace("{", "__").replace("}", "__")
         for col in X_test_df.columns
     ]
-
     explainer = ClassifierExplainer(
         estimator, X_test_df, exp.y_test_transformed, labels=labels_, **kwargs
     )
@@ -153,7 +152,11 @@
         estimator, X_test_df, exp.y_test_transformed, **kwargs
     )
     return ExplainerDashboard(
-        explainer, mode=display_format, contributions=False,
-        whatif=False, shap_interaction=False, decision_trees=False,
-        **dashboard_kwargs
+        explainer,
+        mode=display_format,
+        contributions=False,
+        whatif=False,
+        shap_interaction=False,
+        decision_trees=False,
+        **dashboard_kwargs,
     )
--- a/pycaret_macros.xml	Thu Jul 31 15:41:24 2025 +0000
+++ b/pycaret_macros.xml	Fri Aug 01 14:02:26 2025 +0000
@@ -23,4 +23,4 @@
         </citations>
     </xml>
 
-</macros>
\ No newline at end of file
+</macros>
--- a/pycaret_train.py	Thu Jul 31 15:41:24 2025 +0000
+++ b/pycaret_train.py	Fri Aug 01 14:02:26 2025 +0000
@@ -126,7 +126,7 @@
     # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation
     if args.no_cross_validation:
         args.cross_validation = False
-    # If --cross_validation was passed,  args.cross_validation is True
+    # If --cross_validation was passed, args.cross_validation is True
     # If neither was passed, args.cross_validation remains None
 
     # Build the model_kwargs dict from CLI args
--- a/test-data/expected_comparison_result_classification.html	Thu Jul 31 15:41:24 2025 +0000
+++ b/test-data/expected_comparison_result_classification.html	Fri Aug 01 14:02:26 2025 +0000
@@ -86,7 +86,7 @@
             <h1>PyCaret Model Training Report</h1>
             <div class="tabs">
                 <div class="tab" onclick="openTab(event, 'summary')">
-                Setup & Best Model</div>
+                Setup and Best Model</div>
                 <div class="tab" onclick="openTab(event, 'plots')">
                 Best Model Plots</div>
                 <div class="tab" onclick="openTab(event, 'feature')">
@@ -603,4 +603,4 @@
     </body>
     </html>
     
-        
\ No newline at end of file
+        
--- a/test-data/expected_comparison_result_classification_customized.html	Thu Jul 31 15:41:24 2025 +0000
+++ b/test-data/expected_comparison_result_classification_customized.html	Fri Aug 01 14:02:26 2025 +0000
@@ -86,7 +86,7 @@
             <h1>PyCaret Model Training Report</h1>
             <div class="tabs">
                 <div class="tab" onclick="openTab(event, 'summary')">
-                Setup & Best Model</div>
+                Setup and Best Model</div>
                 <div class="tab" onclick="openTab(event, 'plots')">
                 Best Model Plots</div>
                 <div class="tab" onclick="openTab(event, 'feature')">
@@ -617,4 +617,4 @@
     </body>
     </html>
     
-        
\ No newline at end of file
+        
--- a/test-data/expected_comparison_result_regression.html	Thu Jul 31 15:41:24 2025 +0000
+++ b/test-data/expected_comparison_result_regression.html	Fri Aug 01 14:02:26 2025 +0000
@@ -86,7 +86,7 @@
             <h1>PyCaret Model Training Report</h1>
             <div class="tabs">
                 <div class="tab" onclick="openTab(event, 'summary')">
-                Setup & Best Model</div>
+                Setup and Best Model</div>
                 <div class="tab" onclick="openTab(event, 'plots')">
                 Best Model Plots</div>
                 <div class="tab" onclick="openTab(event, 'feature')">
@@ -588,4 +588,4 @@
     </body>
     </html>
     
-        
\ No newline at end of file
+        
--- a/utils.py	Thu Jul 31 15:41:24 2025 +0000
+++ b/utils.py	Fri Aug 01 14:02:26 2025 +0000
@@ -204,7 +204,7 @@
     # Tabs header
     tabs = [
         '<div class="tabs">',
-        '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary & Config</div>',
+        '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary and Config</div>',
         '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>',
         '<div class="tab" onclick="showTab(\'feature\')">Feature Importance</div>',
     ]