comparison pycaret_regression.py @ 8:1aed7d47c5ec draft default tip

planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
author goeckslab
date Fri, 25 Jul 2025 19:02:32 +0000
parents ccd798db5abb
children
comparison
equal deleted inserted replaced
7:f4cb41f458fd 8:1aed7d47c5ec
1 import logging 1 import logging
2 2
3 from base_model_trainer import BaseModelTrainer 3 from base_model_trainer import BaseModelTrainer
4 from dashboard import generate_regression_explainer_dashboard 4 from dashboard import generate_regression_explainer_dashboard
5 from pycaret.regression import RegressionExperiment 5 from pycaret.regression import RegressionExperiment
6 from utils import add_hr_to_html, add_plot_to_html
7 6
8 LOG = logging.getLogger(__name__) 7 LOG = logging.getLogger(__name__)
9 8
10 9
11 class RegressionModelTrainer(BaseModelTrainer): 10 class RegressionModelTrainer(BaseModelTrainer):
12 def __init__( 11 def __init__(
13 self, 12 self,
14 input_file, 13 input_file,
15 target_col, 14 target_col,
16 output_dir, 15 output_dir,
17 task_type, 16 task_type,
18 random_seed, 17 random_seed,
19 test_file=None, 18 test_file=None,
20 **kwargs): 19 **kwargs,
20 ):
21 super().__init__( 21 super().__init__(
22 input_file, 22 input_file,
23 target_col, 23 target_col,
24 output_dir, 24 output_dir,
25 task_type, 25 task_type,
26 random_seed, 26 random_seed,
27 test_file, 27 test_file,
28 **kwargs) 28 **kwargs,
29 )
30 # The BaseModelTrainer.setup_pycaret will set self.exp appropriately
31 # But we reassign here for clarity
29 self.exp = RegressionExperiment() 32 self.exp = RegressionExperiment()
30 33
31 def save_dashboard(self): 34 def save_dashboard(self):
32 LOG.info("Saving explainer dashboard") 35 LOG.info("Saving explainer dashboard")
33 dashboard = generate_regression_explainer_dashboard(self.exp, 36 dashboard = generate_regression_explainer_dashboard(self.exp, self.best_model)
34 self.best_model)
35 dashboard.save_html("dashboard.html") 37 dashboard.save_html("dashboard.html")
36 38
37 def generate_plots(self): 39 def generate_plots(self):
38 LOG.info("Generating and saving plots") 40 LOG.info("Generating and saving plots")
39 plots = ['residuals', 'error', 'cooks', 41 plots = [
40 'learning', 'vc', 'manifold', 42 "residuals",
41 'rfe', 'feature', 'feature_all'] 43 "error",
44 "cooks",
45 "learning",
46 "vc",
47 "manifold",
48 "rfe",
49 "feature",
50 "feature_all",
51 ]
42 for plot_name in plots: 52 for plot_name in plots:
43 try: 53 try:
44 plot_path = self.exp.plot_model(self.best_model, 54 plot_path = self.exp.plot_model(
45 plot=plot_name, save=True) 55 self.best_model, plot=plot_name, save=True
56 )
46 self.plots[plot_name] = plot_path 57 self.plots[plot_name] = plot_path
47 except Exception as e: 58 except Exception as e:
48 LOG.error(f"Error generating plot {plot_name}: {e}") 59 LOG.error(f"Error generating plot {plot_name}: {e}")
49 continue 60 continue
50 61
56 X_test = self.exp.X_test_transformed.copy() 67 X_test = self.exp.X_test_transformed.copy()
57 y_test = self.exp.y_test_transformed 68 y_test = self.exp.y_test_transformed
58 69
59 try: 70 try:
60 explainer = RegressionExplainer(self.best_model, X_test, y_test) 71 explainer = RegressionExplainer(self.best_model, X_test, y_test)
61 self.expaliner = explainer
62 plots_explainer_html = ""
63 except Exception as e: 72 except Exception as e:
64 LOG.error(f"Error creating explainer: {e}") 73 LOG.error(f"Error creating explainer: {e}")
65 self.plots_explainer_html = None
66 return 74 return
67 75
76 # --- 1) SHAP mean impact (average absolute SHAP values) ---
68 try: 77 try:
69 fig_importance = explainer.plot_importances() 78 self.explainer_plots["shap_mean"] = explainer.plot_importances()
70 plots_explainer_html += add_plot_to_html(fig_importance)
71 plots_explainer_html += add_hr_to_html()
72 except Exception as e: 79 except Exception as e:
73 LOG.error(f"Error generating plot importance: {e}") 80 LOG.error(f"Error generating SHAP mean importance: {e}")
74 81
82 # --- 2) SHAP permutation importance ---
75 try: 83 try:
76 fig_importance_permutation = \ 84 self.explainer_plots["shap_perm"] = explainer.plot_importances_permutation(
77 explainer.plot_importances_permutation( 85 kind="permutation"
78 kind="permutation") 86 )
79 plots_explainer_html += add_plot_to_html(
80 fig_importance_permutation)
81 plots_explainer_html += add_hr_to_html()
82 except Exception as e: 87 except Exception as e:
83 LOG.error(f"Error generating plot importance permutation: {e}") 88 LOG.error(f"Error generating SHAP permutation importance: {e}")
84 89
90 # Pre-filter features so we never call PDP or residual-vs-feature on missing cols
91 valid_feats = []
92 for feat in self.features_name:
93 if feat in explainer.X.columns or feat in explainer.onehot_cols:
94 valid_feats.append(feat)
95 else:
96 LOG.warning(f"Skipping feature {feat!r}: not found in explainer data")
97
98 # --- 3) Partial Dependence Plots (PDPs) per feature ---
99 for feature in valid_feats:
100 try:
101 fig_pdp = explainer.plot_pdp(feature)
102 self.explainer_plots[f"pdp__{feature}"] = fig_pdp
103 except AssertionError as ae:
104 LOG.warning(f"PDP AssertionError for {feature!r}: {ae}")
105 except Exception as e:
106 LOG.error(f"Error generating PDP for {feature}: {e}")
107
108 # --- 4) Predicted vs Actual plot ---
85 try: 109 try:
86 for feature in self.features_name: 110 self.explainer_plots["predicted_vs_actual"] = explainer.plot_predicted_vs_actual()
87 fig_shap = explainer.plot_pdp(feature)
88 plots_explainer_html += add_plot_to_html(fig_shap)
89 plots_explainer_html += add_hr_to_html()
90 except Exception as e: 111 except Exception as e:
91 LOG.error(f"Error generating plot shap dependence: {e}") 112 LOG.error(f"Error generating Predicted vs Actual plot: {e}")
92 113
93 # try: 114 # --- 5) Global residuals distribution ---
94 # for feature in self.features_name: 115 try:
95 # fig_interaction = explainer.plot_interaction(col=feature) 116 self.explainer_plots["residuals"] = explainer.plot_residuals()
96 # plots_explainer_html += add_plot_to_html(fig_interaction) 117 except Exception as e:
97 # except Exception as e: 118 LOG.error(f"Error generating Residuals plot: {e}")
98 # LOG.error(f"Error generating plot shap interaction: {e}")
99 119
100 try: 120 # --- 6) Residuals vs each feature ---
101 for feature in self.features_name: 121 for feature in valid_feats:
102 fig_interactions_importance = \ 122 try:
103 explainer.plot_interactions_importance( 123 fig_res_vs_feat = explainer.plot_residuals_vs_feature(feature)
104 col=feature) 124 self.explainer_plots[f"residuals_vs_feature__{feature}"] = fig_res_vs_feat
105 plots_explainer_html += add_plot_to_html( 125 except AssertionError as ae:
106 fig_interactions_importance) 126 LOG.warning(f"Residuals-vs-feature AssertionError for {feature!r}: {ae}")
107 plots_explainer_html += add_hr_to_html() 127 except Exception as e:
108 except Exception as e: 128 LOG.error(f"Error generating Residuals vs {feature}: {e}")
109 LOG.error(f"Error generating plot shap summary: {e}")
110
111 # Regression specific plots
112 try:
113 fig_pred_actual = explainer.plot_predicted_vs_actual()
114 plots_explainer_html += add_plot_to_html(fig_pred_actual)
115 plots_explainer_html += add_hr_to_html()
116 except Exception as e:
117 LOG.error(f"Error generating plot prediction vs actual: {e}")
118
119 try:
120 fig_residuals = explainer.plot_residuals()
121 plots_explainer_html += add_plot_to_html(fig_residuals)
122 plots_explainer_html += add_hr_to_html()
123 except Exception as e:
124 LOG.error(f"Error generating plot residuals: {e}")
125
126 try:
127 for feature in self.features_name:
128 fig_residuals_vs_feature = \
129 explainer.plot_residuals_vs_feature(feature)
130 plots_explainer_html += add_plot_to_html(
131 fig_residuals_vs_feature)
132 plots_explainer_html += add_hr_to_html()
133 except Exception as e:
134 LOG.error(f"Error generating plot residuals vs feature: {e}")
135
136 self.plots_explainer_html = plots_explainer_html