Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_classification.py @ 0:1f20fe57fdee draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
| author | goeckslab |
|---|---|
| date | Wed, 11 Dec 2024 04:59:43 +0000 |
| parents | |
| children | 0314dad38aaa |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:1f20fe57fdee |
|---|---|
| 1 import logging | |
| 2 | |
| 3 from base_model_trainer import BaseModelTrainer | |
| 4 | |
| 5 from dashboard import generate_classifier_explainer_dashboard | |
| 6 | |
| 7 from pycaret.classification import ClassificationExperiment | |
| 8 | |
| 9 from utils import add_hr_to_html, add_plot_to_html | |
| 10 | |
| 11 LOG = logging.getLogger(__name__) | |
| 12 | |
| 13 | |
| 14 class ClassificationModelTrainer(BaseModelTrainer): | |
| 15 def __init__( | |
| 16 self, | |
| 17 input_file, | |
| 18 target_col, | |
| 19 output_dir, | |
| 20 task_type, | |
| 21 random_seed, | |
| 22 test_file=None, | |
| 23 **kwargs): | |
| 24 super().__init__( | |
| 25 input_file, | |
| 26 target_col, | |
| 27 output_dir, | |
| 28 task_type, | |
| 29 random_seed, | |
| 30 test_file, | |
| 31 **kwargs) | |
| 32 self.exp = ClassificationExperiment() | |
| 33 | |
| 34 def save_dashboard(self): | |
| 35 LOG.info("Saving explainer dashboard") | |
| 36 dashboard = generate_classifier_explainer_dashboard(self.exp, | |
| 37 self.best_model) | |
| 38 dashboard.save_html("dashboard.html") | |
| 39 | |
| 40 def generate_plots(self): | |
| 41 LOG.info("Generating and saving plots") | |
| 42 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | |
| 43 'error', 'class_report', 'learning', 'calibration', | |
| 44 'vc', 'dimension', 'manifold', 'rfe', 'feature', | |
| 45 'feature_all'] | |
| 46 for plot_name in plots: | |
| 47 try: | |
| 48 if plot_name == 'auc' and not self.exp.is_multiclass: | |
| 49 plot_path = self.exp.plot_model(self.best_model, | |
| 50 plot=plot_name, | |
| 51 save=True, | |
| 52 plot_kwargs={ | |
| 53 'micro': False, | |
| 54 'macro': False, | |
| 55 'per_class': False, | |
| 56 'binary': True | |
| 57 } | |
| 58 ) | |
| 59 self.plots[plot_name] = plot_path | |
| 60 continue | |
| 61 | |
| 62 plot_path = self.exp.plot_model(self.best_model, | |
| 63 plot=plot_name, save=True) | |
| 64 self.plots[plot_name] = plot_path | |
| 65 except Exception as e: | |
| 66 LOG.error(f"Error generating plot {plot_name}: {e}") | |
| 67 continue | |
| 68 | |
| 69 def generate_plots_explainer(self): | |
| 70 LOG.info("Generating and saving plots from explainer") | |
| 71 | |
| 72 from explainerdashboard import ClassifierExplainer | |
| 73 | |
| 74 X_test = self.exp.X_test_transformed.copy() | |
| 75 y_test = self.exp.y_test_transformed | |
| 76 | |
| 77 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | |
| 78 self.expaliner = explainer | |
| 79 plots_explainer_html = "" | |
| 80 | |
| 81 try: | |
| 82 fig_importance = explainer.plot_importances() | |
| 83 plots_explainer_html += add_plot_to_html(fig_importance) | |
| 84 plots_explainer_html += add_hr_to_html() | |
| 85 except Exception as e: | |
| 86 LOG.error(f"Error generating plot importance(mean shap): {e}") | |
| 87 | |
| 88 try: | |
| 89 fig_importance_perm = explainer.plot_importances( | |
| 90 kind="permutation") | |
| 91 plots_explainer_html += add_plot_to_html(fig_importance_perm) | |
| 92 plots_explainer_html += add_hr_to_html() | |
| 93 except Exception as e: | |
| 94 LOG.error(f"Error generating plot importance(permutation): {e}") | |
| 95 | |
| 96 # try: | |
| 97 # fig_shap = explainer.plot_shap_summary() | |
| 98 # plots_explainer_html += add_plot_to_html(fig_shap, | |
| 99 # include_plotlyjs=False) | |
| 100 # except Exception as e: | |
| 101 # LOG.error(f"Error generating plot shap: {e}") | |
| 102 | |
| 103 # try: | |
| 104 # fig_contributions = explainer.plot_contributions( | |
| 105 # index=0) | |
| 106 # plots_explainer_html += add_plot_to_html( | |
| 107 # fig_contributions, include_plotlyjs=False) | |
| 108 # except Exception as e: | |
| 109 # LOG.error(f"Error generating plot contributions: {e}") | |
| 110 | |
| 111 # try: | |
| 112 # for feature in self.features_name: | |
| 113 # fig_dependence = explainer.plot_dependence(col=feature) | |
| 114 # plots_explainer_html += add_plot_to_html(fig_dependence) | |
| 115 # except Exception as e: | |
| 116 # LOG.error(f"Error generating plot dependencies: {e}") | |
| 117 | |
| 118 try: | |
| 119 for feature in self.features_name: | |
| 120 fig_pdp = explainer.plot_pdp(feature) | |
| 121 plots_explainer_html += add_plot_to_html(fig_pdp) | |
| 122 plots_explainer_html += add_hr_to_html() | |
| 123 except Exception as e: | |
| 124 LOG.error(f"Error generating plot pdp: {e}") | |
| 125 | |
| 126 try: | |
| 127 for feature in self.features_name: | |
| 128 fig_interaction = explainer.plot_interaction( | |
| 129 col=feature, interact_col=feature) | |
| 130 plots_explainer_html += add_plot_to_html(fig_interaction) | |
| 131 except Exception as e: | |
| 132 LOG.error(f"Error generating plot interactions: {e}") | |
| 133 | |
| 134 try: | |
| 135 for feature in self.features_name: | |
| 136 fig_interactions_importance = \ | |
| 137 explainer.plot_interactions_importance( | |
| 138 col=feature) | |
| 139 plots_explainer_html += add_plot_to_html( | |
| 140 fig_interactions_importance) | |
| 141 plots_explainer_html += add_hr_to_html() | |
| 142 except Exception as e: | |
| 143 LOG.error(f"Error generating plot interactions importance: {e}") | |
| 144 | |
| 145 # try: | |
| 146 # for feature in self.features_name: | |
| 147 # fig_interactions_detailed = \ | |
| 148 # explainer.plot_interactions_detailed( | |
| 149 # col=feature) | |
| 150 # plots_explainer_html += add_plot_to_html( | |
| 151 # fig_interactions_detailed) | |
| 152 # except Exception as e: | |
| 153 # LOG.error(f"Error generating plot interactions detailed: {e}") | |
| 154 | |
| 155 try: | |
| 156 fig_precision = explainer.plot_precision() | |
| 157 plots_explainer_html += add_plot_to_html(fig_precision) | |
| 158 plots_explainer_html += add_hr_to_html() | |
| 159 except Exception as e: | |
| 160 LOG.error(f"Error generating plot precision: {e}") | |
| 161 | |
| 162 try: | |
| 163 fig_cumulative_precision = explainer.plot_cumulative_precision() | |
| 164 plots_explainer_html += add_plot_to_html(fig_cumulative_precision) | |
| 165 plots_explainer_html += add_hr_to_html() | |
| 166 except Exception as e: | |
| 167 LOG.error(f"Error generating plot cumulative precision: {e}") | |
| 168 | |
| 169 try: | |
| 170 fig_classification = explainer.plot_classification() | |
| 171 plots_explainer_html += add_plot_to_html(fig_classification) | |
| 172 plots_explainer_html += add_hr_to_html() | |
| 173 except Exception as e: | |
| 174 LOG.error(f"Error generating plot classification: {e}") | |
| 175 | |
| 176 try: | |
| 177 fig_confusion_matrix = explainer.plot_confusion_matrix() | |
| 178 plots_explainer_html += add_plot_to_html(fig_confusion_matrix) | |
| 179 plots_explainer_html += add_hr_to_html() | |
| 180 except Exception as e: | |
| 181 LOG.error(f"Error generating plot confusion matrix: {e}") | |
| 182 | |
| 183 try: | |
| 184 fig_lift_curve = explainer.plot_lift_curve() | |
| 185 plots_explainer_html += add_plot_to_html(fig_lift_curve) | |
| 186 plots_explainer_html += add_hr_to_html() | |
| 187 except Exception as e: | |
| 188 LOG.error(f"Error generating plot lift curve: {e}") | |
| 189 | |
| 190 try: | |
| 191 fig_roc_auc = explainer.plot_roc_auc() | |
| 192 plots_explainer_html += add_plot_to_html(fig_roc_auc) | |
| 193 plots_explainer_html += add_hr_to_html() | |
| 194 except Exception as e: | |
| 195 LOG.error(f"Error generating plot roc auc: {e}") | |
| 196 | |
| 197 try: | |
| 198 fig_pr_auc = explainer.plot_pr_auc() | |
| 199 plots_explainer_html += add_plot_to_html(fig_pr_auc) | |
| 200 plots_explainer_html += add_hr_to_html() | |
| 201 except Exception as e: | |
| 202 LOG.error(f"Error generating plot pr auc: {e}") | |
| 203 | |
| 204 self.plots_explainer_html = plots_explainer_html |
