Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_regression.py @ 8:1aed7d47c5ec draft default tip
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author | goeckslab |
---|---|
date | Fri, 25 Jul 2025 19:02:32 +0000 |
parents | ccd798db5abb |
children |
comparison
equal
deleted
inserted
replaced
7:f4cb41f458fd | 8:1aed7d47c5ec |
---|---|
1 import logging | 1 import logging |
2 | 2 |
3 from base_model_trainer import BaseModelTrainer | 3 from base_model_trainer import BaseModelTrainer |
4 from dashboard import generate_regression_explainer_dashboard | 4 from dashboard import generate_regression_explainer_dashboard |
5 from pycaret.regression import RegressionExperiment | 5 from pycaret.regression import RegressionExperiment |
6 from utils import add_hr_to_html, add_plot_to_html | |
7 | 6 |
8 LOG = logging.getLogger(__name__) | 7 LOG = logging.getLogger(__name__) |
9 | 8 |
10 | 9 |
11 class RegressionModelTrainer(BaseModelTrainer): | 10 class RegressionModelTrainer(BaseModelTrainer): |
12 def __init__( | 11 def __init__( |
13 self, | 12 self, |
14 input_file, | 13 input_file, |
15 target_col, | 14 target_col, |
16 output_dir, | 15 output_dir, |
17 task_type, | 16 task_type, |
18 random_seed, | 17 random_seed, |
19 test_file=None, | 18 test_file=None, |
20 **kwargs): | 19 **kwargs, |
20 ): | |
21 super().__init__( | 21 super().__init__( |
22 input_file, | 22 input_file, |
23 target_col, | 23 target_col, |
24 output_dir, | 24 output_dir, |
25 task_type, | 25 task_type, |
26 random_seed, | 26 random_seed, |
27 test_file, | 27 test_file, |
28 **kwargs) | 28 **kwargs, |
29 ) | |
30 # The BaseModelTrainer.setup_pycaret will set self.exp appropriately | |
31 # But we reassign here for clarity | |
29 self.exp = RegressionExperiment() | 32 self.exp = RegressionExperiment() |
30 | 33 |
31 def save_dashboard(self): | 34 def save_dashboard(self): |
32 LOG.info("Saving explainer dashboard") | 35 LOG.info("Saving explainer dashboard") |
33 dashboard = generate_regression_explainer_dashboard(self.exp, | 36 dashboard = generate_regression_explainer_dashboard(self.exp, self.best_model) |
34 self.best_model) | |
35 dashboard.save_html("dashboard.html") | 37 dashboard.save_html("dashboard.html") |
36 | 38 |
37 def generate_plots(self): | 39 def generate_plots(self): |
38 LOG.info("Generating and saving plots") | 40 LOG.info("Generating and saving plots") |
39 plots = ['residuals', 'error', 'cooks', | 41 plots = [ |
40 'learning', 'vc', 'manifold', | 42 "residuals", |
41 'rfe', 'feature', 'feature_all'] | 43 "error", |
44 "cooks", | |
45 "learning", | |
46 "vc", | |
47 "manifold", | |
48 "rfe", | |
49 "feature", | |
50 "feature_all", | |
51 ] | |
42 for plot_name in plots: | 52 for plot_name in plots: |
43 try: | 53 try: |
44 plot_path = self.exp.plot_model(self.best_model, | 54 plot_path = self.exp.plot_model( |
45 plot=plot_name, save=True) | 55 self.best_model, plot=plot_name, save=True |
56 ) | |
46 self.plots[plot_name] = plot_path | 57 self.plots[plot_name] = plot_path |
47 except Exception as e: | 58 except Exception as e: |
48 LOG.error(f"Error generating plot {plot_name}: {e}") | 59 LOG.error(f"Error generating plot {plot_name}: {e}") |
49 continue | 60 continue |
50 | 61 |
56 X_test = self.exp.X_test_transformed.copy() | 67 X_test = self.exp.X_test_transformed.copy() |
57 y_test = self.exp.y_test_transformed | 68 y_test = self.exp.y_test_transformed |
58 | 69 |
59 try: | 70 try: |
60 explainer = RegressionExplainer(self.best_model, X_test, y_test) | 71 explainer = RegressionExplainer(self.best_model, X_test, y_test) |
61 self.expaliner = explainer | |
62 plots_explainer_html = "" | |
63 except Exception as e: | 72 except Exception as e: |
64 LOG.error(f"Error creating explainer: {e}") | 73 LOG.error(f"Error creating explainer: {e}") |
65 self.plots_explainer_html = None | |
66 return | 74 return |
67 | 75 |
76 # --- 1) SHAP mean impact (average absolute SHAP values) --- | |
68 try: | 77 try: |
69 fig_importance = explainer.plot_importances() | 78 self.explainer_plots["shap_mean"] = explainer.plot_importances() |
70 plots_explainer_html += add_plot_to_html(fig_importance) | |
71 plots_explainer_html += add_hr_to_html() | |
72 except Exception as e: | 79 except Exception as e: |
73 LOG.error(f"Error generating plot importance: {e}") | 80 LOG.error(f"Error generating SHAP mean importance: {e}") |
74 | 81 |
82 # --- 2) SHAP permutation importance --- | |
75 try: | 83 try: |
76 fig_importance_permutation = \ | 84 self.explainer_plots["shap_perm"] = explainer.plot_importances_permutation( |
77 explainer.plot_importances_permutation( | 85 kind="permutation" |
78 kind="permutation") | 86 ) |
79 plots_explainer_html += add_plot_to_html( | |
80 fig_importance_permutation) | |
81 plots_explainer_html += add_hr_to_html() | |
82 except Exception as e: | 87 except Exception as e: |
83 LOG.error(f"Error generating plot importance permutation: {e}") | 88 LOG.error(f"Error generating SHAP permutation importance: {e}") |
84 | 89 |
90 # Pre-filter features so we never call PDP or residual-vs-feature on missing cols | |
91 valid_feats = [] | |
92 for feat in self.features_name: | |
93 if feat in explainer.X.columns or feat in explainer.onehot_cols: | |
94 valid_feats.append(feat) | |
95 else: | |
96 LOG.warning(f"Skipping feature {feat!r}: not found in explainer data") | |
97 | |
98 # --- 3) Partial Dependence Plots (PDPs) per feature --- | |
99 for feature in valid_feats: | |
100 try: | |
101 fig_pdp = explainer.plot_pdp(feature) | |
102 self.explainer_plots[f"pdp__{feature}"] = fig_pdp | |
103 except AssertionError as ae: | |
104 LOG.warning(f"PDP AssertionError for {feature!r}: {ae}") | |
105 except Exception as e: | |
106 LOG.error(f"Error generating PDP for {feature}: {e}") | |
107 | |
108 # --- 4) Predicted vs Actual plot --- | |
85 try: | 109 try: |
86 for feature in self.features_name: | 110 self.explainer_plots["predicted_vs_actual"] = explainer.plot_predicted_vs_actual() |
87 fig_shap = explainer.plot_pdp(feature) | |
88 plots_explainer_html += add_plot_to_html(fig_shap) | |
89 plots_explainer_html += add_hr_to_html() | |
90 except Exception as e: | 111 except Exception as e: |
91 LOG.error(f"Error generating plot shap dependence: {e}") | 112 LOG.error(f"Error generating Predicted vs Actual plot: {e}") |
92 | 113 |
93 # try: | 114 # --- 5) Global residuals distribution --- |
94 # for feature in self.features_name: | 115 try: |
95 # fig_interaction = explainer.plot_interaction(col=feature) | 116 self.explainer_plots["residuals"] = explainer.plot_residuals() |
96 # plots_explainer_html += add_plot_to_html(fig_interaction) | 117 except Exception as e: |
97 # except Exception as e: | 118 LOG.error(f"Error generating Residuals plot: {e}") |
98 # LOG.error(f"Error generating plot shap interaction: {e}") | |
99 | 119 |
100 try: | 120 # --- 6) Residuals vs each feature --- |
101 for feature in self.features_name: | 121 for feature in valid_feats: |
102 fig_interactions_importance = \ | 122 try: |
103 explainer.plot_interactions_importance( | 123 fig_res_vs_feat = explainer.plot_residuals_vs_feature(feature) |
104 col=feature) | 124 self.explainer_plots[f"residuals_vs_feature__{feature}"] = fig_res_vs_feat |
105 plots_explainer_html += add_plot_to_html( | 125 except AssertionError as ae: |
106 fig_interactions_importance) | 126 LOG.warning(f"Residuals-vs-feature AssertionError for {feature!r}: {ae}") |
107 plots_explainer_html += add_hr_to_html() | 127 except Exception as e: |
108 except Exception as e: | 128 LOG.error(f"Error generating Residuals vs {feature}: {e}") |
109 LOG.error(f"Error generating plot shap summary: {e}") | |
110 | |
111 # Regression specific plots | |
112 try: | |
113 fig_pred_actual = explainer.plot_predicted_vs_actual() | |
114 plots_explainer_html += add_plot_to_html(fig_pred_actual) | |
115 plots_explainer_html += add_hr_to_html() | |
116 except Exception as e: | |
117 LOG.error(f"Error generating plot prediction vs actual: {e}") | |
118 | |
119 try: | |
120 fig_residuals = explainer.plot_residuals() | |
121 plots_explainer_html += add_plot_to_html(fig_residuals) | |
122 plots_explainer_html += add_hr_to_html() | |
123 except Exception as e: | |
124 LOG.error(f"Error generating plot residuals: {e}") | |
125 | |
126 try: | |
127 for feature in self.features_name: | |
128 fig_residuals_vs_feature = \ | |
129 explainer.plot_residuals_vs_feature(feature) | |
130 plots_explainer_html += add_plot_to_html( | |
131 fig_residuals_vs_feature) | |
132 plots_explainer_html += add_hr_to_html() | |
133 except Exception as e: | |
134 LOG.error(f"Error generating plot residuals vs feature: {e}") | |
135 | |
136 self.plots_explainer_html = plots_explainer_html |