Mercurial > repos > goeckslab > pycaret_predict
changeset 10:e2a6fed32d54 draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 47a5977e074223e92e216efa42969a4056516707
author | goeckslab |
---|---|
date | Fri, 01 Aug 2025 14:02:26 +0000 |
parents | c6c1f8777aae |
children | |
files | base_model_trainer.py dashboard.py pycaret_macros.xml pycaret_train.py test-data/expected_comparison_result_classification.html test-data/expected_comparison_result_classification_customized.html test-data/expected_comparison_result_regression.html utils.py |
diffstat | 8 files changed, 216 insertions(+), 74 deletions(-) [+] |
line wrap: on
line diff
--- a/base_model_trainer.py Thu Jul 31 15:41:24 2025 +0000 +++ b/base_model_trainer.py Fri Aug 01 14:02:26 2025 +0000 @@ -44,6 +44,7 @@ self.target = None self.best_model = None self.results = None + self.tuning_results = None self.features_name = None self.plots = {} self.explainer_plots = {} @@ -57,44 +58,98 @@ self.test_data = None if not self.output_dir: - raise ValueError("output_dir must be specified and not None") + raise ValueError( + "output_dir must be specified and not None" + ) + + # Warn about irrelevant kwargs for the task type + if self.task_type == "regression" and ( + "probability_threshold" in self.user_kwargs + ): + LOG.warning( + "probability_threshold is ignored for regression tasks." + ) LOG.info(f"Model kwargs: {self.__dict__}") def load_data(self): LOG.info(f"Loading data from {self.input_file}") - self.data = pd.read_csv(self.input_file, sep=None, engine="python") + self.data = pd.read_csv( + self.input_file, sep=None, engine="python" + ) self.data.columns = self.data.columns.str.replace(".", "_") - if "prediction_label" in self.data.columns: + + names = self.data.columns.to_list() + LOG.info(f"Original dataset columns: {names}") + + target_index = int(self.target_col) - 1 + num_cols = len(names) + if target_index < 0 or target_index >= num_cols: + raise ValueError( + f"Target column number {self.target_col} is invalid. " + f"Please select a number between 1 and {num_cols}." + ) + + self.target = names[target_index] + + # Conditional drop: only if 'prediction_label' exists and is not + # the target + if "prediction_label" in self.data.columns and ( + self.data.columns[target_index] != "prediction_label" + ): + LOG.info( + "Dropping 'prediction_label' column as it's not the target." + ) self.data = self.data.drop(columns=["prediction_label"]) + else: + if self.target == "prediction_label": + LOG.warning( + "Using 'prediction_label' as target column. " + "This may not be intended if it's a previous prediction." + ) - numeric_cols = self.data.select_dtypes(include=["number"]).columns - non_numeric_cols = self.data.select_dtypes(exclude=["number"]).columns + numeric_cols = self.data.select_dtypes( + include=["number"] + ).columns + non_numeric_cols = self.data.select_dtypes( + exclude=["number"] + ).columns self.data[numeric_cols] = self.data[numeric_cols].apply( pd.to_numeric, errors="coerce" ) if len(non_numeric_cols) > 0: - LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}") + LOG.info( + f"Non-numeric columns found: {non_numeric_cols.tolist()}" + ) + # Update names after possible drop names = self.data.columns.to_list() - target_index = int(self.target_col) - 1 - self.target = names[target_index] - self.features_name = [n for i, n in enumerate(names) if i != target_index] + LOG.info(f"Dataset columns after processing: {names}") + + self.features_name = [n for n in names if n != self.target] if getattr(self, "missing_value_strategy", None): strat = self.missing_value_strategy if strat == "mean": - self.data = self.data.fillna(self.data.mean(numeric_only=True)) + self.data = self.data.fillna( + self.data.mean(numeric_only=True) + ) elif strat == "median": - self.data = self.data.fillna(self.data.median(numeric_only=True)) + self.data = self.data.fillna( + self.data.median(numeric_only=True) + ) elif strat == "drop": self.data = self.data.dropna() else: - self.data = self.data.fillna(self.data.median(numeric_only=True)) + self.data = self.data.fillna( + self.data.median(numeric_only=True) + ) if self.test_file: LOG.info(f"Loading test data from {self.test_file}") - df_test = pd.read_csv(self.test_file, sep=None, engine="python") + df_test = pd.read_csv( + self.test_file, sep=None, engine="python" + ) df_test.columns = df_test.columns.str.replace(".", "_") self.test_data = df_test @@ -137,7 +192,9 @@ self.exp = RegressionExperiment() else: - raise ValueError("task_type must be 'classification' or 'regression'") + raise ValueError( + "task_type must be 'classification' or 'regression'" + ) self.exp.setup(self.data, **self.setup_params) self.setup_params.update(self.user_kwargs) @@ -171,20 +228,26 @@ if getattr(self, "tune_model", False): LOG.info("Tuning hyperparameters of the best model") self.best_model = self.exp.tune_model(self.best_model) - self.results = self.exp.pull() + self.tuning_results = self.exp.pull() if self.task_type == "classification": self.results.rename(columns={"AUC": "ROC-AUC"}, inplace=True) prob_thresh = getattr(self, "probability_threshold", None) - if self.task_type == "classification" and prob_thresh is not None: - _ = self.exp.predict_model(self.best_model, probability_threshold=prob_thresh) + if self.task_type == "classification" and ( + prob_thresh is not None + ): + _ = self.exp.predict_model( + self.best_model, probability_threshold=prob_thresh + ) else: _ = self.exp.predict_model(self.best_model) self.test_result_df = self.exp.pull() if self.task_type == "classification": - self.test_result_df.rename(columns={"AUC": "ROC-AUC"}, inplace=True) + self.test_result_df.rename( + columns={"AUC": "ROC-AUC"}, inplace=True + ) def save_model(self): hdf5_path = Path(self.output_dir) / "pycaret_model.h5" @@ -198,7 +261,7 @@ def generate_plots(self): LOG.info("Generating PyCaret diagnostic pltos") - # choose the right plots based on task + # choose the right plots based on task type if self.task_type == "classification": plot_names = [ "learning", @@ -214,10 +277,13 @@ "roc_auc", ] else: - plot_names = ["residuals", "vc", "parameter", "error", "learning"] + plot_names = ["residuals", "vc", "parameter", "error", + "learning"] for name in plot_names: try: - ax = self.exp.plot_model(self.best_model, plot=name, save=False) + ax = self.exp.plot_model( + self.best_model, plot=name, save=False + ) out_path = Path(self.output_dir) / f"plot_{name}.png" fig = ax.get_figure() fig.savefig(out_path, bbox_inches="tight") @@ -239,18 +305,23 @@ best_model_name = type(self.best_model).__name__ LOG.info(f"Best model determined as: {best_model_name}") - # 2) Compute training sample count + # 2) Compute training sample count try: n_train = self.exp.X_train.shape[0] except Exception: - n_train = getattr(self.exp, "X_train_transformed", pd.DataFrame()).shape[0] + n_train = getattr( + self.exp, "X_train_transformed", pd.DataFrame() + ).shape[0] total_rows = self.data.shape[0] # 3) Build setup parameters table all_params = self.setup_params.copy() - if self.task_type == "classification" and hasattr(self, "probability_threshold"): - all_params["probability_threshold"] = self.probability_threshold - + if self.task_type == "classification" and ( + hasattr(self, "probability_threshold") + ): + all_params["probability_threshold"] = ( + self.probability_threshold + ) display_keys = [ "Target", "Session ID", @@ -290,9 +361,11 @@ elif key == "Cross Validation Folds": dv = v if v is not None else "None" elif key == "Models": - dv = ", ".join(map(str, v)) if isinstance(v, (list, tuple)) else "None" + dv = ", ".join(map(str, v)) if isinstance( + v, (list, tuple) + ) else "None" elif key == "Probability Threshold": - dv = v if v is not None else "None" + dv = f"{v:.2f}" if v is not None else "0.5" else: dv = v if v is not None else "None" setup_rows.append([key, dv]) @@ -300,19 +373,29 @@ setup_rows.append(["best_model_metric", self.exp._fold_metric]) df_setup = pd.DataFrame(setup_rows, columns=["Parameter", "Value"]) - df_setup.to_csv(Path(self.output_dir) / "setup_params.csv", index=False) + df_setup.to_csv( + Path(self.output_dir) / "setup_params.csv", index=False + ) # 4) Persist CSVs self.results.to_csv( - Path(self.output_dir) / "comparison_results.csv", index=False + Path(self.output_dir) / "comparison_results.csv", + index=False ) self.test_result_df.to_csv( Path(self.output_dir) / "test_results.csv", index=False ) pd.DataFrame( - self.best_model.get_params().items(), columns=["Parameter", "Value"] + self.best_model.get_params().items(), + columns=["Parameter", "Value"] ).to_csv(Path(self.output_dir) / "best_model.csv", index=False) + if self.tuning_results is not None: + self.tuning_results.to_csv( + Path(self.output_dir) / "tuning_results.csv", + index=False + ) + # 5) Header header = f"<h2>Best Model: {best_model_name}</h2>" @@ -334,14 +417,31 @@ "residuals": "Residuals Distribution", "error": "Prediction Error Distribution", } - val_df.drop(columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True) + val_df.drop( + columns=["TT (Ec)", "TT (Sec)"], errors="ignore", inplace=True + ) summary_html = ( header + "<h2>Train & Validation Summary</h2>" + '<div class="table-wrapper">' + val_df.to_html(index=False, classes="table sortable") + "</div>" - + "<h2>Setup Parameters</h2>" + ) + + if self.tuning_results is not None: + tuning_df = self.tuning_results.copy() + tuning_df.drop( + columns=["TT (Sec)"], errors="ignore", inplace=True + ) + summary_html += ( + f"<h2>{best_model_name}: Tuning Summary</h2>" + + '<div class="table-wrapper">' + + tuning_df.to_html(index=False, classes="table sortable") + + "</div>" + ) + + summary_html += ( + "<h2>Setup Parameters</h2>" + '<div class="table-wrapper">' + df_setup.to_html(index=False, classes="table sortable") + "</div>" @@ -349,7 +449,8 @@ + "<h2>Best Model Hyperparameters</h2>" + '<div class="table-wrapper">' + pd.DataFrame( - self.best_model.get_params().items(), columns=["Parameter", "Value"] + self.best_model.get_params().items(), + columns=["Parameter", "Value"] ).to_html(index=False, classes="table sortable") + "</div>" ) @@ -373,12 +474,15 @@ if name in self.plots: summary_html += "<hr>" b64 = encode_image_to_base64(self.plots[name]) - title = plot_title_map.get(name, name.replace("_", " ").title()) + title = plot_title_map.get( + name, name.replace("_", " ").title() + ) summary_html += ( '<div class="plot">' f"<h2>{title}</h2>" f'<img src="data:image/png;base64,{b64}" ' - 'style="max-width:90%;max-height:600px;border:1px solid #ddd;"/>' + 'style="max-width:90%;max-height:600px;' + 'border:1px solid #ddd;"/>' "</div>" ) @@ -386,7 +490,9 @@ test_html = ( header + '<div class="table-wrapper">' - + self.test_result_df.to_html(index=False, classes="table sortable") + + self.test_result_df.to_html( + index=False, classes="table sortable" + ) + "</div>" ) if self.task_type == "regression": @@ -397,18 +503,25 @@ .rename("True") ) y_pred = pd.Series( - self.best_model.predict(self.exp.X_test_transformed) + self.best_model.predict( + self.exp.X_test_transformed + ) ).rename("Predicted") df_tp = pd.concat([y_true, y_pred], axis=1) test_html += "<h2>True vs Predicted Values</h2>" test_html += ( - '<div class="table-wrapper" style="max-height:400px; overflow-y:auto;">' - + df_tp.head(50).to_html(index=False, classes="table sortable") + '<div class="table-wrapper" ' + 'style="max-height:400px; overflow-y:auto;">' + + df_tp.head(50).to_html( + index=False, classes="table sortable" + ) + "</div>" + add_hr_to_html() ) except Exception as e: - LOG.warning(f"Could not generate True vs Predicted table: {e}") + LOG.warning( + f"Could not generate True vs Predicted table: {e}" + ) # 5a) Explainer-substituted plots in order if self.task_type == "regression": @@ -426,38 +539,53 @@ fig_or_fn = self.explainer_plots.pop(key, None) if fig_or_fn is not None: fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn - title = plot_title_map.get(key, key.replace("_", " ").title()) + title = plot_title_map.get( + key, key.replace("_", " ").title() + ) test_html += ( - f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() + f"<h2>{title}</h2>" + add_plot_to_html(fig) + + add_hr_to_html() ) # 5b) Remaining PyCaret test plots for name, path in self.plots.items(): - # classification: include only the small extras, before skipping anything - if self.task_type == "classification" and name in { - "threshold", - "pr_auc", - "class_report", - }: - title = plot_title_map.get(name, name.replace("_", " ").title()) + # classification: include only the small extras, before + # skipping anything + if self.task_type == "classification" and ( + name in { + "threshold", + "pr_auc", + "class_report", + } + ): + title = plot_title_map.get( + name, name.replace("_", " ").title() + ) b64 = encode_image_to_base64(path) test_html += ( f"<h2>{title}</h2>" "<div class='plot'>" f"<img src='data:image/png;base64,{b64}' " - "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" + "style='max-width:90%;max-height:600px;" + "border:1px solid #ddd;'/>" "</div>" + add_hr_to_html() ) continue - # regression: explicitly include the 'error' plot, before skipping - if self.task_type == "regression" and name == "error": - title = plot_title_map.get("error", "Prediction Error Distribution") + # regression: explicitly include the 'error' plot, + # before skipping + if self.task_type == "regression" and ( + name == "error" + ): + title = plot_title_map.get( + "error", "Prediction Error Distribution" + ) b64 = encode_image_to_base64(path) test_html += ( f"<h2>{title}</h2>" "<div class='plot'>" f"<img src='data:image/png;base64,{b64}' " - "style='max-width:90%;max-height:600px;border:1px solid #ddd;'/>" + "style='max-width:90%;max-height:600px;" + "border:1px solid #ddd;'/>" "</div>" + add_hr_to_html() ) continue @@ -491,11 +619,14 @@ else "Permutation Feature Importance" ) feature_html += ( - f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() + f"<h2>{title}</h2>" + add_plot_to_html(fig) + + add_hr_to_html() ) # 6c) PDPs last - pdp_keys = sorted(k for k in self.explainer_plots if k.startswith("pdp__")) + pdp_keys = sorted( + k for k in self.explainer_plots if k.startswith("pdp__") + ) for k in pdp_keys: fig_or_fn = self.explainer_plots[k] fig = fig_or_fn() if callable(fig_or_fn) else fig_or_fn @@ -503,7 +634,8 @@ feature = k.split("__", 1)[1] title = f"Partial Dependence for {feature}" feature_html += ( - f"<h2>{title}</h2>" + add_plot_to_html(fig) + add_hr_to_html() + f"<h2>{title}</h2>" + add_plot_to_html(fig) + + add_hr_to_html() ) # 7) Assemble final HTML (three tabs) html = get_html_template() @@ -516,7 +648,10 @@ (Path(self.output_dir) / "comparison_result.html").write_text( html, encoding="utf-8" ) - LOG.info(f"HTML report generated at: {self.output_dir}/comparison_result.html") + LOG.info( + f"HTML report generated at: " + f"{self.output_dir}/comparison_result.html" + ) def save_dashboard(self): raise NotImplementedError("Subclasses should implement this method") @@ -525,7 +660,9 @@ raise NotImplementedError("Subclasses should implement this method") def generate_tree_plots(self): - from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor + from sklearn.ensemble import ( + RandomForestClassifier, RandomForestRegressor + ) from xgboost import XGBClassifier, XGBRegressor from explainerdashboard.explainers import RandomForestExplainer @@ -533,7 +670,9 @@ X_test = self.exp.X_test_transformed.copy() y_test = self.exp.y_test_transformed - if isinstance(self.best_model, (RandomForestClassifier, RandomForestRegressor)): + if isinstance( + self.best_model, (RandomForestClassifier, RandomForestRegressor) + ): n_trees = self.best_model.n_estimators elif isinstance(self.best_model, (XGBClassifier, XGBRegressor)): n_trees = len(self.best_model.get_booster().get_dump())
--- a/dashboard.py Thu Jul 31 15:41:24 2025 +0000 +++ b/dashboard.py Fri Aug 01 14:02:26 2025 +0000 @@ -76,7 +76,6 @@ col.replace(".", "__").replace("{", "__").replace("}", "__") for col in X_test_df.columns ] - explainer = ClassifierExplainer( estimator, X_test_df, exp.y_test_transformed, labels=labels_, **kwargs ) @@ -153,7 +152,11 @@ estimator, X_test_df, exp.y_test_transformed, **kwargs ) return ExplainerDashboard( - explainer, mode=display_format, contributions=False, - whatif=False, shap_interaction=False, decision_trees=False, - **dashboard_kwargs + explainer, + mode=display_format, + contributions=False, + whatif=False, + shap_interaction=False, + decision_trees=False, + **dashboard_kwargs, )
--- a/pycaret_macros.xml Thu Jul 31 15:41:24 2025 +0000 +++ b/pycaret_macros.xml Fri Aug 01 14:02:26 2025 +0000 @@ -23,4 +23,4 @@ </citations> </xml> -</macros> \ No newline at end of file +</macros>
--- a/pycaret_train.py Thu Jul 31 15:41:24 2025 +0000 +++ b/pycaret_train.py Fri Aug 01 14:02:26 2025 +0000 @@ -126,7 +126,7 @@ # Normalize cross-validation flags: --no_cross_validation overrides --cross_validation if args.no_cross_validation: args.cross_validation = False - # If --cross_validation was passed, args.cross_validation is True + # If --cross_validation was passed, args.cross_validation is True # If neither was passed, args.cross_validation remains None # Build the model_kwargs dict from CLI args
--- a/test-data/expected_comparison_result_classification.html Thu Jul 31 15:41:24 2025 +0000 +++ b/test-data/expected_comparison_result_classification.html Fri Aug 01 14:02:26 2025 +0000 @@ -86,7 +86,7 @@ <h1>PyCaret Model Training Report</h1> <div class="tabs"> <div class="tab" onclick="openTab(event, 'summary')"> - Setup & Best Model</div> + Setup and Best Model</div> <div class="tab" onclick="openTab(event, 'plots')"> Best Model Plots</div> <div class="tab" onclick="openTab(event, 'feature')"> @@ -603,4 +603,4 @@ </body> </html> - \ No newline at end of file +
--- a/test-data/expected_comparison_result_classification_customized.html Thu Jul 31 15:41:24 2025 +0000 +++ b/test-data/expected_comparison_result_classification_customized.html Fri Aug 01 14:02:26 2025 +0000 @@ -86,7 +86,7 @@ <h1>PyCaret Model Training Report</h1> <div class="tabs"> <div class="tab" onclick="openTab(event, 'summary')"> - Setup & Best Model</div> + Setup and Best Model</div> <div class="tab" onclick="openTab(event, 'plots')"> Best Model Plots</div> <div class="tab" onclick="openTab(event, 'feature')"> @@ -617,4 +617,4 @@ </body> </html> - \ No newline at end of file +
--- a/test-data/expected_comparison_result_regression.html Thu Jul 31 15:41:24 2025 +0000 +++ b/test-data/expected_comparison_result_regression.html Fri Aug 01 14:02:26 2025 +0000 @@ -86,7 +86,7 @@ <h1>PyCaret Model Training Report</h1> <div class="tabs"> <div class="tab" onclick="openTab(event, 'summary')"> - Setup & Best Model</div> + Setup and Best Model</div> <div class="tab" onclick="openTab(event, 'plots')"> Best Model Plots</div> <div class="tab" onclick="openTab(event, 'feature')"> @@ -588,4 +588,4 @@ </body> </html> - \ No newline at end of file +
--- a/utils.py Thu Jul 31 15:41:24 2025 +0000 +++ b/utils.py Fri Aug 01 14:02:26 2025 +0000 @@ -204,7 +204,7 @@ # Tabs header tabs = [ '<div class="tabs">', - '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary & Config</div>', + '<div class="tab active" onclick="showTab(\'summary\')">Validation Summary and Config</div>', '<div class="tab" onclick="showTab(\'test\')">Test Summary</div>', '<div class="tab" onclick="showTab(\'feature\')">Feature Importance</div>', ]