Mercurial > repos > goeckslab > pycaret_predict
comparison feature_importance.py @ 8:1aed7d47c5ec draft
planemo upload for repository https://github.com/goeckslab/gleam commit 8112548ac44b7a4769093d76c722c8fcdeaaef54
| author | goeckslab |
|---|---|
| date | Fri, 25 Jul 2025 19:02:32 +0000 |
| parents | f4cb41f458fd |
| children |
comparison
equal
deleted
inserted
replaced
| 7:f4cb41f458fd | 8:1aed7d47c5ec |
|---|---|
| 12 LOG = logging.getLogger(__name__) | 12 LOG = logging.getLogger(__name__) |
| 13 | 13 |
| 14 | 14 |
| 15 class FeatureImportanceAnalyzer: | 15 class FeatureImportanceAnalyzer: |
| 16 def __init__( | 16 def __init__( |
| 17 self, | 17 self, |
| 18 task_type, | 18 task_type, |
| 19 output_dir, | 19 output_dir, |
| 20 data_path=None, | 20 data_path=None, |
| 21 data=None, | 21 data=None, |
| 22 target_col=None, | 22 target_col=None, |
| 23 exp=None, | 23 exp=None, |
| 24 best_model=None): | 24 best_model=None, |
| 25 ): | |
| 25 | 26 |
| 26 self.task_type = task_type | 27 self.task_type = task_type |
| 27 self.output_dir = output_dir | 28 self.output_dir = output_dir |
| 28 self.exp = exp | 29 self.exp = exp |
| 29 self.best_model = best_model | 30 self.best_model = best_model |
| 41 self.target_col = target_col | 42 self.target_col = target_col |
| 42 self.data = pd.read_csv(data_path, sep=None, engine='python') | 43 self.data = pd.read_csv(data_path, sep=None, engine='python') |
| 43 self.data.columns = self.data.columns.str.replace('.', '_') | 44 self.data.columns = self.data.columns.str.replace('.', '_') |
| 44 self.data = self.data.fillna(self.data.median(numeric_only=True)) | 45 self.data = self.data.fillna(self.data.median(numeric_only=True)) |
| 45 self.target = self.data.columns[int(target_col) - 1] | 46 self.target = self.data.columns[int(target_col) - 1] |
| 46 self.exp = ClassificationExperiment() if task_type == 'classification' else RegressionExperiment() | 47 self.exp = ( |
| 48 ClassificationExperiment() | |
| 49 if task_type == "classification" | |
| 50 else RegressionExperiment() | |
| 51 ) | |
| 47 | 52 |
| 48 self.plots = {} | 53 self.plots = {} |
| 49 | 54 |
| 50 def setup_pycaret(self): | 55 def setup_pycaret(self): |
| 51 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: | 56 if self.exp is not None and hasattr(self.exp, 'is_setup') and self.exp.is_setup: |
| 55 setup_params = { | 60 setup_params = { |
| 56 'target': self.target, | 61 'target': self.target, |
| 57 'session_id': 123, | 62 'session_id': 123, |
| 58 'html': True, | 63 'html': True, |
| 59 'log_experiment': False, | 64 'log_experiment': False, |
| 60 'system_log': False | 65 'system_log': False, |
| 61 } | 66 } |
| 62 self.exp.setup(self.data, **setup_params) | 67 self.exp.setup(self.data, **setup_params) |
| 63 | 68 |
| 64 def save_tree_importance(self): | 69 def save_tree_importance(self): |
| 65 model = self.best_model or self.exp.get_config('best_model') | 70 model = self.best_model or self.exp.get_config('best_model') |
| 68 # Try feature_importances_ or coef_ if available | 73 # Try feature_importances_ or coef_ if available |
| 69 importances = None | 74 importances = None |
| 70 model_type = model.__class__.__name__ | 75 model_type = model.__class__.__name__ |
| 71 self.tree_model_name = model_type # Store the model name for reporting | 76 self.tree_model_name = model_type # Store the model name for reporting |
| 72 | 77 |
| 73 if hasattr(model, "feature_importances_"): | 78 if hasattr(model, 'feature_importances_'): |
| 74 importances = model.feature_importances_ | 79 importances = model.feature_importances_ |
| 75 elif hasattr(model, "coef_"): | 80 elif hasattr(model, 'coef_'): |
| 76 # For linear models, flatten coef_ and take abs (importance as magnitude) | 81 # For linear models, flatten coef_ and take abs (importance as magnitude) |
| 77 importances = abs(model.coef_).flatten() | 82 importances = abs(model.coef_).flatten() |
| 78 else: | 83 else: |
| 79 # Neither attribute exists; skip the plot | 84 # Neither attribute exists; skip the plot |
| 80 LOG.warning(f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot.") | 85 LOG.warning( |
| 86 f"Model {model_type} does not have feature_importances_ or coef_ attribute. Skipping feature importance plot." | |
| 87 ) | |
| 81 self.tree_model_name = None # No plot generated | 88 self.tree_model_name = None # No plot generated |
| 82 return | 89 return |
| 83 | 90 |
| 84 # Defensive: handle mismatch in number of features | 91 # Defensive: handle mismatch in number of features |
| 85 if len(importances) != len(processed_features): | 92 if len(importances) != len(processed_features): |
| 87 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." | 94 f"Number of importances ({len(importances)}) does not match number of features ({len(processed_features)}). Skipping plot." |
| 88 ) | 95 ) |
| 89 self.tree_model_name = None | 96 self.tree_model_name = None |
| 90 return | 97 return |
| 91 | 98 |
| 92 feature_importances = pd.DataFrame({ | 99 feature_importances = pd.DataFrame( |
| 93 'Feature': processed_features, | 100 {'Feature': processed_features, 'Importance': importances} |
| 94 'Importance': importances | 101 ).sort_values(by='Importance', ascending=False) |
| 95 }).sort_values(by='Importance', ascending=False) | |
| 96 plt.figure(figsize=(10, 6)) | 102 plt.figure(figsize=(10, 6)) |
| 97 plt.barh( | 103 plt.barh(feature_importances['Feature'], feature_importances['Importance']) |
| 98 feature_importances['Feature'], | |
| 99 feature_importances['Importance']) | |
| 100 plt.xlabel('Importance') | 104 plt.xlabel('Importance') |
| 101 plt.title(f'Feature Importance ({model_type})') | 105 plt.title(f'Feature Importance ({model_type})') |
| 102 plot_path = os.path.join( | 106 plot_path = os.path.join(self.output_dir, 'tree_importance.png') |
| 103 self.output_dir, | |
| 104 'tree_importance.png') | |
| 105 plt.savefig(plot_path) | 107 plt.savefig(plot_path) |
| 106 plt.close() | 108 plt.close() |
| 107 self.plots['tree_importance'] = plot_path | 109 self.plots['tree_importance'] = plot_path |
| 108 | 110 |
| 109 def save_shap_values(self): | 111 def save_shap_values(self): |
| 110 model = self.best_model or self.exp.get_config('best_model') | 112 |
| 111 X_transformed = self.exp.get_config('X_transformed') | 113 model = self.best_model or self.exp.get_config("best_model") |
| 112 tree_classes = ( | 114 |
| 113 "LGBM", "XGB", "CatBoost", "RandomForest", "DecisionTree", "ExtraTrees", "HistGradientBoosting" | 115 X_data = None |
| 114 ) | 116 for key in ("X_test_transformed", "X_train_transformed"): |
| 115 model_class_name = model.__class__.__name__ | 117 try: |
| 116 self.shap_model_name = model_class_name | 118 X_data = self.exp.get_config(key) |
| 117 | 119 break |
| 118 # Ensure feature alignment | 120 except KeyError: |
| 119 if hasattr(model, "feature_name_"): | 121 continue |
| 120 used_features = model.feature_name_ | 122 if X_data is None: |
| 121 elif hasattr(model, "booster_") and hasattr(model.booster_, "feature_name"): | 123 raise RuntimeError( |
| 124 "Could not find 'X_test_transformed' or 'X_train_transformed' in the experiment. " | |
| 125 "Make sure PyCaret setup/compare_models was run with feature_selection=True." | |
| 126 ) | |
| 127 | |
| 128 try: | |
| 122 used_features = model.booster_.feature_name() | 129 used_features = model.booster_.feature_name() |
| 123 elif hasattr(model, "feature_names_in_"): | 130 except Exception: |
| 124 # scikit‐learn's standard attribute for the names of features used during fit | 131 used_features = getattr(model, "feature_names_in_", X_data.columns.tolist()) |
| 125 used_features = list(model.feature_names_in_) | 132 X_data = X_data[used_features] |
| 133 | |
| 134 max_bg = min(len(X_data), 100) | |
| 135 bg = X_data.sample(max_bg, random_state=42) | |
| 136 | |
| 137 predict_fn = model.predict_proba if hasattr(model, "predict_proba") else model.predict | |
| 138 | |
| 139 explainer = shap.Explainer(predict_fn, bg) | |
| 140 self.shap_model_name = explainer.__class__.__name__ | |
| 141 | |
| 142 shap_values = explainer(X_data) | |
| 143 | |
| 144 output_names = getattr(shap_values, "output_names", None) | |
| 145 if output_names is None and hasattr(model, "classes_"): | |
| 146 output_names = list(model.classes_) | |
| 147 if output_names is None: | |
| 148 n_out = shap_values.values.shape[-1] | |
| 149 output_names = list(map(str, range(n_out))) | |
| 150 | |
| 151 values = shap_values.values | |
| 152 if values.ndim == 3: | |
| 153 for j, name in enumerate(output_names): | |
| 154 safe = name.replace(" ", "_").replace("/", "_") | |
| 155 out_path = os.path.join(self.output_dir, f"shap_summary_{safe}.png") | |
| 156 plt.figure() | |
| 157 shap.plots.beeswarm(shap_values[..., j], show=False) | |
| 158 plt.title(f"SHAP for {model.__class__.__name__} ⇒ {name}") | |
| 159 plt.savefig(out_path) | |
| 160 plt.close() | |
| 161 self.plots[f"shap_summary_{safe}"] = out_path | |
| 126 else: | 162 else: |
| 127 used_features = X_transformed.columns | 163 plt.figure() |
| 128 | 164 shap.plots.beeswarm(shap_values, show=False) |
| 129 if any(tc in model_class_name for tc in tree_classes): | 165 plt.title(f"SHAP Summary for {model.__class__.__name__}") |
| 130 explainer = shap.TreeExplainer(model) | 166 out_path = os.path.join(self.output_dir, "shap_summary.png") |
| 131 X_shap = X_transformed[used_features] | 167 plt.savefig(out_path) |
| 132 shap_values = explainer.shap_values(X_shap) | 168 plt.close() |
| 133 plot_X = X_shap | 169 self.plots["shap_summary"] = out_path |
| 134 plot_title = f"SHAP Summary for {model_class_name} (TreeExplainer)" | |
| 135 else: | |
| 136 logging.warning(f"len(X_transformed) = {len(X_transformed)}") | |
| 137 max_samples = 100 | |
| 138 n_samples = min(max_samples, len(X_transformed)) | |
| 139 sampled_X = X_transformed[used_features].sample( | |
| 140 n=n_samples, | |
| 141 replace=False, | |
| 142 random_state=42 | |
| 143 ) | |
| 144 explainer = shap.KernelExplainer(model.predict, sampled_X) | |
| 145 shap_values = explainer.shap_values(sampled_X) | |
| 146 plot_X = sampled_X | |
| 147 plot_title = f"SHAP Summary for {model_class_name} (KernelExplainer)" | |
| 148 | |
| 149 shap.summary_plot(shap_values, plot_X, show=False) | |
| 150 plt.title(plot_title) | |
| 151 plot_path = os.path.join(self.output_dir, "shap_summary.png") | |
| 152 plt.savefig(plot_path) | |
| 153 plt.close() | |
| 154 self.plots["shap_summary"] = plot_path | |
| 155 | 170 |
| 156 def generate_html_report(self): | 171 def generate_html_report(self): |
| 157 LOG.info("Generating HTML report") | 172 LOG.info("Generating HTML report") |
| 158 | 173 |
| 159 plots_html = "" | 174 plots_html = "" |
| 160 for plot_name, plot_path in self.plots.items(): | 175 for plot_name, plot_path in self.plots.items(): |
| 161 # Special handling for tree importance: skip if no model name (not generated) | 176 # Special handling for tree importance: skip if no model name (not generated) |
| 162 if plot_name == 'tree_importance' and not getattr(self, 'tree_model_name', None): | 177 if plot_name == 'tree_importance' and not getattr( |
| 178 self, 'tree_model_name', None | |
| 179 ): | |
| 163 continue | 180 continue |
| 164 encoded_image = self.encode_image_to_base64(plot_path) | 181 encoded_image = self.encode_image_to_base64(plot_path) |
| 165 if plot_name == 'tree_importance' and getattr(self, 'tree_model_name', None): | 182 if plot_name == 'tree_importance' and getattr( |
| 166 section_title = f"Feature importance analysis from a trained {self.tree_model_name}" | 183 self, 'tree_model_name', None |
| 184 ): | |
| 185 section_title = ( | |
| 186 f"Feature importance analysis from a trained {self.tree_model_name}" | |
| 187 ) | |
| 167 elif plot_name == 'shap_summary': | 188 elif plot_name == 'shap_summary': |
| 168 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" | 189 section_title = f"SHAP Summary from a trained {getattr(self, 'shap_model_name', 'model')}" |
| 169 else: | 190 else: |
| 170 section_title = plot_name | 191 section_title = plot_name |
| 171 plots_html += f""" | 192 plots_html += f""" |
| 174 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> | 195 <img src="data:image/png;base64,{encoded_image}" alt="{plot_name}"> |
| 175 </div> | 196 </div> |
| 176 """ | 197 """ |
| 177 | 198 |
| 178 html_content = f""" | 199 html_content = f""" |
| 179 <h1>PyCaret Feature Importance Report</h1> | |
| 180 {plots_html} | 200 {plots_html} |
| 181 """ | 201 """ |
| 182 | 202 |
| 183 return html_content | 203 return html_content |
| 184 | 204 |
| 185 def encode_image_to_base64(self, img_path): | 205 def encode_image_to_base64(self, img_path): |
| 186 with open(img_path, 'rb') as img_file: | 206 with open(img_path, 'rb') as img_file: |
| 187 return base64.b64encode(img_file.read()).decode('utf-8') | 207 return base64.b64encode(img_file.read()).decode('utf-8') |
| 188 | 208 |
| 189 def run(self): | 209 def run(self): |
| 190 if self.exp is None or not hasattr(self.exp, 'is_setup') or not self.exp.is_setup: | 210 if ( |
| 211 self.exp is None | |
| 212 or not hasattr(self.exp, 'is_setup') | |
| 213 or not self.exp.is_setup | |
| 214 ): | |
| 191 self.setup_pycaret() | 215 self.setup_pycaret() |
| 192 self.save_tree_importance() | 216 self.save_tree_importance() |
| 193 self.save_shap_values() | 217 self.save_shap_values() |
| 194 html_content = self.generate_html_report() | 218 html_content = self.generate_html_report() |
| 195 return html_content | 219 return html_content |
