Mercurial > repos > goeckslab > pycaret_predict
comparison pycaret_classification.py @ 2:0314dad38aaa draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
author | goeckslab |
---|---|
date | Wed, 01 Jan 2025 03:19:27 +0000 |
parents | 1f20fe57fdee |
children |
comparison
equal
deleted
inserted
replaced
1:4a7df9abe4c4 | 2:0314dad38aaa |
---|---|
4 | 4 |
5 from dashboard import generate_classifier_explainer_dashboard | 5 from dashboard import generate_classifier_explainer_dashboard |
6 | 6 |
7 from pycaret.classification import ClassificationExperiment | 7 from pycaret.classification import ClassificationExperiment |
8 | 8 |
9 from utils import add_hr_to_html, add_plot_to_html | 9 from utils import add_hr_to_html, add_plot_to_html, predict_proba |
10 | 10 |
11 LOG = logging.getLogger(__name__) | 11 LOG = logging.getLogger(__name__) |
12 | 12 |
13 | 13 |
14 class ClassificationModelTrainer(BaseModelTrainer): | 14 class ClassificationModelTrainer(BaseModelTrainer): |
37 self.best_model) | 37 self.best_model) |
38 dashboard.save_html("dashboard.html") | 38 dashboard.save_html("dashboard.html") |
39 | 39 |
40 def generate_plots(self): | 40 def generate_plots(self): |
41 LOG.info("Generating and saving plots") | 41 LOG.info("Generating and saving plots") |
42 | |
43 if not hasattr(self.best_model, "predict_proba"): | |
44 import types | |
45 self.best_model.predict_proba = types.MethodType( | |
46 predict_proba, self.best_model) | |
47 LOG.warning( | |
48 f"The model {type(self.best_model).__name__}\ | |
49 does not support `predict_proba`. \ | |
50 Applying monkey patch.") | |
51 | |
42 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', | 52 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', |
43 'error', 'class_report', 'learning', 'calibration', | 53 'error', 'class_report', 'learning', 'calibration', |
44 'vc', 'dimension', 'manifold', 'rfe', 'feature', | 54 'vc', 'dimension', 'manifold', 'rfe', 'feature', |
45 'feature_all'] | 55 'feature_all'] |
46 for plot_name in plots: | 56 for plot_name in plots: |
72 from explainerdashboard import ClassifierExplainer | 82 from explainerdashboard import ClassifierExplainer |
73 | 83 |
74 X_test = self.exp.X_test_transformed.copy() | 84 X_test = self.exp.X_test_transformed.copy() |
75 y_test = self.exp.y_test_transformed | 85 y_test = self.exp.y_test_transformed |
76 | 86 |
77 explainer = ClassifierExplainer(self.best_model, X_test, y_test) | 87 try: |
78 self.expaliner = explainer | 88 explainer = ClassifierExplainer(self.best_model, X_test, y_test) |
79 plots_explainer_html = "" | 89 self.expaliner = explainer |
90 plots_explainer_html = "" | |
91 except Exception as e: | |
92 LOG.error(f"Error creating explainer: {e}") | |
93 self.plots_explainer_html = None | |
94 return | |
80 | 95 |
81 try: | 96 try: |
82 fig_importance = explainer.plot_importances() | 97 fig_importance = explainer.plot_importances() |
83 plots_explainer_html += add_plot_to_html(fig_importance) | 98 plots_explainer_html += add_plot_to_html(fig_importance) |
84 plots_explainer_html += add_hr_to_html() | 99 plots_explainer_html += add_hr_to_html() |