comparison pycaret_classification.py @ 2:0314dad38aaa draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit ff6d674ecc83db933153b797ef4dbde17f07b10e
author goeckslab
date Wed, 01 Jan 2025 03:19:27 +0000
parents 1f20fe57fdee
children
comparison
equal deleted inserted replaced
1:4a7df9abe4c4 2:0314dad38aaa
4 4
5 from dashboard import generate_classifier_explainer_dashboard 5 from dashboard import generate_classifier_explainer_dashboard
6 6
7 from pycaret.classification import ClassificationExperiment 7 from pycaret.classification import ClassificationExperiment
8 8
9 from utils import add_hr_to_html, add_plot_to_html 9 from utils import add_hr_to_html, add_plot_to_html, predict_proba
10 10
11 LOG = logging.getLogger(__name__) 11 LOG = logging.getLogger(__name__)
12 12
13 13
14 class ClassificationModelTrainer(BaseModelTrainer): 14 class ClassificationModelTrainer(BaseModelTrainer):
37 self.best_model) 37 self.best_model)
38 dashboard.save_html("dashboard.html") 38 dashboard.save_html("dashboard.html")
39 39
40 def generate_plots(self): 40 def generate_plots(self):
41 LOG.info("Generating and saving plots") 41 LOG.info("Generating and saving plots")
42
43 if not hasattr(self.best_model, "predict_proba"):
44 import types
45 self.best_model.predict_proba = types.MethodType(
46 predict_proba, self.best_model)
47 LOG.warning(
48 f"The model {type(self.best_model).__name__}\
49 does not support `predict_proba`. \
50 Applying monkey patch.")
51
42 plots = ['confusion_matrix', 'auc', 'threshold', 'pr', 52 plots = ['confusion_matrix', 'auc', 'threshold', 'pr',
43 'error', 'class_report', 'learning', 'calibration', 53 'error', 'class_report', 'learning', 'calibration',
44 'vc', 'dimension', 'manifold', 'rfe', 'feature', 54 'vc', 'dimension', 'manifold', 'rfe', 'feature',
45 'feature_all'] 55 'feature_all']
46 for plot_name in plots: 56 for plot_name in plots:
72 from explainerdashboard import ClassifierExplainer 82 from explainerdashboard import ClassifierExplainer
73 83
74 X_test = self.exp.X_test_transformed.copy() 84 X_test = self.exp.X_test_transformed.copy()
75 y_test = self.exp.y_test_transformed 85 y_test = self.exp.y_test_transformed
76 86
77 explainer = ClassifierExplainer(self.best_model, X_test, y_test) 87 try:
78 self.expaliner = explainer 88 explainer = ClassifierExplainer(self.best_model, X_test, y_test)
79 plots_explainer_html = "" 89 self.expaliner = explainer
90 plots_explainer_html = ""
91 except Exception as e:
92 LOG.error(f"Error creating explainer: {e}")
93 self.plots_explainer_html = None
94 return
80 95
81 try: 96 try:
82 fig_importance = explainer.plot_importances() 97 fig_importance = explainer.plot_importances()
83 plots_explainer_html += add_plot_to_html(fig_importance) 98 plots_explainer_html += add_plot_to_html(fig_importance)
84 plots_explainer_html += add_hr_to_html() 99 plots_explainer_html += add_hr_to_html()