comparison pycaret_classification.py @ 0:1f20fe57fdee draft

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author goeckslab
date Wed, 11 Dec 2024 04:59:43 +0000
parents
children 0314dad38aaa
comparison
equal deleted inserted replaced
-1:000000000000 0:1f20fe57fdee
1 import logging
2
3 from base_model_trainer import BaseModelTrainer
4
5 from dashboard import generate_classifier_explainer_dashboard
6
7 from pycaret.classification import ClassificationExperiment
8
9 from utils import add_hr_to_html, add_plot_to_html
10
11 LOG = logging.getLogger(__name__)
12
13
14 class ClassificationModelTrainer(BaseModelTrainer):
15 def __init__(
16 self,
17 input_file,
18 target_col,
19 output_dir,
20 task_type,
21 random_seed,
22 test_file=None,
23 **kwargs):
24 super().__init__(
25 input_file,
26 target_col,
27 output_dir,
28 task_type,
29 random_seed,
30 test_file,
31 **kwargs)
32 self.exp = ClassificationExperiment()
33
34 def save_dashboard(self):
35 LOG.info("Saving explainer dashboard")
36 dashboard = generate_classifier_explainer_dashboard(self.exp,
37 self.best_model)
38 dashboard.save_html("dashboard.html")
39
40 def generate_plots(self):
41 LOG.info("Generating and saving plots")
42 plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
43 'error', 'class_report', 'learning', 'calibration',
44 'vc', 'dimension', 'manifold', 'rfe', 'feature',
45 'feature_all']
46 for plot_name in plots:
47 try:
48 if plot_name == 'auc' and not self.exp.is_multiclass:
49 plot_path = self.exp.plot_model(self.best_model,
50 plot=plot_name,
51 save=True,
52 plot_kwargs={
53 'micro': False,
54 'macro': False,
55 'per_class': False,
56 'binary': True
57 }
58 )
59 self.plots[plot_name] = plot_path
60 continue
61
62 plot_path = self.exp.plot_model(self.best_model,
63 plot=plot_name, save=True)
64 self.plots[plot_name] = plot_path
65 except Exception as e:
66 LOG.error(f"Error generating plot {plot_name}: {e}")
67 continue
68
69 def generate_plots_explainer(self):
70 LOG.info("Generating and saving plots from explainer")
71
72 from explainerdashboard import ClassifierExplainer
73
74 X_test = self.exp.X_test_transformed.copy()
75 y_test = self.exp.y_test_transformed
76
77 explainer = ClassifierExplainer(self.best_model, X_test, y_test)
78 self.expaliner = explainer
79 plots_explainer_html = ""
80
81 try:
82 fig_importance = explainer.plot_importances()
83 plots_explainer_html += add_plot_to_html(fig_importance)
84 plots_explainer_html += add_hr_to_html()
85 except Exception as e:
86 LOG.error(f"Error generating plot importance(mean shap): {e}")
87
88 try:
89 fig_importance_perm = explainer.plot_importances(
90 kind="permutation")
91 plots_explainer_html += add_plot_to_html(fig_importance_perm)
92 plots_explainer_html += add_hr_to_html()
93 except Exception as e:
94 LOG.error(f"Error generating plot importance(permutation): {e}")
95
96 # try:
97 # fig_shap = explainer.plot_shap_summary()
98 # plots_explainer_html += add_plot_to_html(fig_shap,
99 # include_plotlyjs=False)
100 # except Exception as e:
101 # LOG.error(f"Error generating plot shap: {e}")
102
103 # try:
104 # fig_contributions = explainer.plot_contributions(
105 # index=0)
106 # plots_explainer_html += add_plot_to_html(
107 # fig_contributions, include_plotlyjs=False)
108 # except Exception as e:
109 # LOG.error(f"Error generating plot contributions: {e}")
110
111 # try:
112 # for feature in self.features_name:
113 # fig_dependence = explainer.plot_dependence(col=feature)
114 # plots_explainer_html += add_plot_to_html(fig_dependence)
115 # except Exception as e:
116 # LOG.error(f"Error generating plot dependencies: {e}")
117
118 try:
119 for feature in self.features_name:
120 fig_pdp = explainer.plot_pdp(feature)
121 plots_explainer_html += add_plot_to_html(fig_pdp)
122 plots_explainer_html += add_hr_to_html()
123 except Exception as e:
124 LOG.error(f"Error generating plot pdp: {e}")
125
126 try:
127 for feature in self.features_name:
128 fig_interaction = explainer.plot_interaction(
129 col=feature, interact_col=feature)
130 plots_explainer_html += add_plot_to_html(fig_interaction)
131 except Exception as e:
132 LOG.error(f"Error generating plot interactions: {e}")
133
134 try:
135 for feature in self.features_name:
136 fig_interactions_importance = \
137 explainer.plot_interactions_importance(
138 col=feature)
139 plots_explainer_html += add_plot_to_html(
140 fig_interactions_importance)
141 plots_explainer_html += add_hr_to_html()
142 except Exception as e:
143 LOG.error(f"Error generating plot interactions importance: {e}")
144
145 # try:
146 # for feature in self.features_name:
147 # fig_interactions_detailed = \
148 # explainer.plot_interactions_detailed(
149 # col=feature)
150 # plots_explainer_html += add_plot_to_html(
151 # fig_interactions_detailed)
152 # except Exception as e:
153 # LOG.error(f"Error generating plot interactions detailed: {e}")
154
155 try:
156 fig_precision = explainer.plot_precision()
157 plots_explainer_html += add_plot_to_html(fig_precision)
158 plots_explainer_html += add_hr_to_html()
159 except Exception as e:
160 LOG.error(f"Error generating plot precision: {e}")
161
162 try:
163 fig_cumulative_precision = explainer.plot_cumulative_precision()
164 plots_explainer_html += add_plot_to_html(fig_cumulative_precision)
165 plots_explainer_html += add_hr_to_html()
166 except Exception as e:
167 LOG.error(f"Error generating plot cumulative precision: {e}")
168
169 try:
170 fig_classification = explainer.plot_classification()
171 plots_explainer_html += add_plot_to_html(fig_classification)
172 plots_explainer_html += add_hr_to_html()
173 except Exception as e:
174 LOG.error(f"Error generating plot classification: {e}")
175
176 try:
177 fig_confusion_matrix = explainer.plot_confusion_matrix()
178 plots_explainer_html += add_plot_to_html(fig_confusion_matrix)
179 plots_explainer_html += add_hr_to_html()
180 except Exception as e:
181 LOG.error(f"Error generating plot confusion matrix: {e}")
182
183 try:
184 fig_lift_curve = explainer.plot_lift_curve()
185 plots_explainer_html += add_plot_to_html(fig_lift_curve)
186 plots_explainer_html += add_hr_to_html()
187 except Exception as e:
188 LOG.error(f"Error generating plot lift curve: {e}")
189
190 try:
191 fig_roc_auc = explainer.plot_roc_auc()
192 plots_explainer_html += add_plot_to_html(fig_roc_auc)
193 plots_explainer_html += add_hr_to_html()
194 except Exception as e:
195 LOG.error(f"Error generating plot roc auc: {e}")
196
197 try:
198 fig_pr_auc = explainer.plot_pr_auc()
199 plots_explainer_html += add_plot_to_html(fig_pr_auc)
200 plots_explainer_html += add_hr_to_html()
201 except Exception as e:
202 LOG.error(f"Error generating plot pr auc: {e}")
203
204 self.plots_explainer_html = plots_explainer_html