diff base_model_trainer.py @ 0:1f20fe57fdee draft

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author goeckslab
date Wed, 11 Dec 2024 04:59:43 +0000
parents
children 0314dad38aaa
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/base_model_trainer.py	Wed Dec 11 04:59:43 2024 +0000
@@ -0,0 +1,359 @@
+import base64
+import logging
+import os
+import tempfile
+
+from feature_importance import FeatureImportanceAnalyzer
+
+import h5py
+
+import joblib
+
+import numpy as np
+
+import pandas as pd
+
+from sklearn.metrics import average_precision_score
+
+from utils import get_html_closing, get_html_template
+
+logging.basicConfig(level=logging.DEBUG)
+LOG = logging.getLogger(__name__)
+
+
+class BaseModelTrainer:
+
+    def __init__(
+            self,
+            input_file,
+            target_col,
+            output_dir,
+            task_type,
+            random_seed,
+            test_file=None,
+            **kwargs
+            ):
+        self.exp = None  # This will be set in the subclass
+        self.input_file = input_file
+        self.target_col = target_col
+        self.output_dir = output_dir
+        self.task_type = task_type
+        self.random_seed = random_seed
+        self.data = None
+        self.target = None
+        self.best_model = None
+        self.results = None
+        self.features_name = None
+        self.plots = {}
+        self.expaliner = None
+        self.plots_explainer_html = None
+        self.trees = []
+        for key, value in kwargs.items():
+            setattr(self, key, value)
+        self.setup_params = {}
+        self.test_file = test_file
+        self.test_data = None
+
+        LOG.info(f"Model kwargs: {self.__dict__}")
+
+    def load_data(self):
+        LOG.info(f"Loading data from {self.input_file}")
+        self.data = pd.read_csv(self.input_file, sep=None, engine='python')
+        self.data.columns = self.data.columns.str.replace('.', '_')
+
+        numeric_cols = self.data.select_dtypes(include=['number']).columns
+        non_numeric_cols = self.data.select_dtypes(exclude=['number']).columns
+
+        self.data[numeric_cols] = self.data[numeric_cols].apply(
+            pd.to_numeric, errors='coerce')
+
+        if len(non_numeric_cols) > 0:
+            LOG.info(f"Non-numeric columns found: {non_numeric_cols.tolist()}")
+
+        names = self.data.columns.to_list()
+        target_index = int(self.target_col)-1
+        self.target = names[target_index]
+        self.features_name = [name
+                              for i, name in enumerate(names)
+                              if i != target_index]
+        if hasattr(self, 'missing_value_strategy'):
+            if self.missing_value_strategy == 'mean':
+                self.data = self.data.fillna(
+                    self.data.mean(numeric_only=True))
+            elif self.missing_value_strategy == 'median':
+                self.data = self.data.fillna(
+                    self.data.median(numeric_only=True))
+            elif self.missing_value_strategy == 'drop':
+                self.data = self.data.dropna()
+        else:
+            # Default strategy if not specified
+            self.data = self.data.fillna(self.data.median(numeric_only=True))
+
+        if self.test_file:
+            LOG.info(f"Loading test data from {self.test_file}")
+            self.test_data = pd.read_csv(
+                self.test_file, sep=None, engine='python')
+            self.test_data = self.test_data[numeric_cols].apply(
+                pd.to_numeric, errors='coerce')
+            self.test_data.columns = self.test_data.columns.str.replace(
+                '.', '_'
+                )
+
+    def setup_pycaret(self):
+        LOG.info("Initializing PyCaret")
+        self.setup_params = {
+            'target': self.target,
+            'session_id': self.random_seed,
+            'html': True,
+            'log_experiment': False,
+            'system_log': False,
+            'index': False,
+        }
+
+        if self.test_data is not None:
+            self.setup_params['test_data'] = self.test_data
+
+        if hasattr(self, 'train_size') and self.train_size is not None \
+                and self.test_data is None:
+            self.setup_params['train_size'] = self.train_size
+
+        if hasattr(self, 'normalize') and self.normalize is not None:
+            self.setup_params['normalize'] = self.normalize
+
+        if hasattr(self, 'feature_selection') and \
+                self.feature_selection is not None:
+            self.setup_params['feature_selection'] = self.feature_selection
+
+        if hasattr(self, 'cross_validation') and \
+                self.cross_validation is not None \
+                and self.cross_validation is False:
+            self.setup_params['cross_validation'] = self.cross_validation
+
+        if hasattr(self, 'cross_validation') and \
+                self.cross_validation is not None:
+            if hasattr(self, 'cross_validation_folds'):
+                self.setup_params['fold'] = self.cross_validation_folds
+
+        if hasattr(self, 'remove_outliers') and \
+                self.remove_outliers is not None:
+            self.setup_params['remove_outliers'] = self.remove_outliers
+
+        if hasattr(self, 'remove_multicollinearity') and \
+                self.remove_multicollinearity is not None:
+            self.setup_params['remove_multicollinearity'] = \
+                self.remove_multicollinearity
+
+        if hasattr(self, 'polynomial_features') and \
+                self.polynomial_features is not None:
+            self.setup_params['polynomial_features'] = self.polynomial_features
+
+        if hasattr(self, 'fix_imbalance') and \
+                self.fix_imbalance is not None:
+            self.setup_params['fix_imbalance'] = self.fix_imbalance
+
+        LOG.info(self.setup_params)
+        self.exp.setup(self.data, **self.setup_params)
+
+    def train_model(self):
+        LOG.info("Training and selecting the best model")
+        if self.task_type == "classification":
+            average_displayed = "Weighted"
+            self.exp.add_metric(id=f'PR-AUC-{average_displayed}',
+                                name=f'PR-AUC-{average_displayed}',
+                                target='pred_proba',
+                                score_func=average_precision_score,
+                                average='weighted'
+                                )
+
+        if hasattr(self, 'models') and self.models is not None:
+            self.best_model = self.exp.compare_models(
+                include=self.models)
+        else:
+            self.best_model = self.exp.compare_models()
+        self.results = self.exp.pull()
+        if self.task_type == "classification":
+            self.results.rename(columns={'AUC': 'ROC-AUC'}, inplace=True)
+
+        _ = self.exp.predict_model(self.best_model)
+        self.test_result_df = self.exp.pull()
+        if self.task_type == "classification":
+            self.test_result_df.rename(
+                columns={'AUC': 'ROC-AUC'}, inplace=True)
+
+    def save_model(self):
+        hdf5_model_path = "pycaret_model.h5"
+        with h5py.File(hdf5_model_path, 'w') as f:
+            with tempfile.NamedTemporaryFile(delete=False) as temp_file:
+                joblib.dump(self.best_model, temp_file.name)
+                temp_file.seek(0)
+                model_bytes = temp_file.read()
+            f.create_dataset('model', data=np.void(model_bytes))
+
+    def generate_plots(self):
+        raise NotImplementedError("Subclasses should implement this method")
+
+    def encode_image_to_base64(self, img_path):
+        with open(img_path, 'rb') as img_file:
+            return base64.b64encode(img_file.read()).decode('utf-8')
+
+    def save_html_report(self):
+        LOG.info("Saving HTML report")
+
+        model_name = type(self.best_model).__name__
+        excluded_params = ['html', 'log_experiment', 'system_log', 'test_data']
+        filtered_setup_params = {
+            k: v
+            for k, v in self.setup_params.items() if k not in excluded_params
+        }
+        setup_params_table = pd.DataFrame(
+            list(filtered_setup_params.items()),
+            columns=['Parameter', 'Value'])
+
+        best_model_params = pd.DataFrame(
+            self.best_model.get_params().items(),
+            columns=['Parameter', 'Value'])
+        best_model_params.to_csv(
+            os.path.join(self.output_dir, 'best_model.csv'),
+            index=False)
+        self.results.to_csv(os.path.join(
+            self.output_dir, "comparison_results.csv"))
+        self.test_result_df.to_csv(os.path.join(
+            self.output_dir, "test_results.csv"))
+
+        plots_html = ""
+        length = len(self.plots)
+        for i, (plot_name, plot_path) in enumerate(self.plots.items()):
+            encoded_image = self.encode_image_to_base64(plot_path)
+            plots_html += f"""
+            <div class="plot">
+                <h3>{plot_name.capitalize()}</h3>
+                <img src="data:image/png;base64,{encoded_image}"
+                    alt="{plot_name}">
+            </div>
+            """
+            if i < length - 1:
+                plots_html += "<hr>"
+
+        tree_plots = ""
+        for i, tree in enumerate(self.trees):
+            if tree:
+                tree_plots += f"""
+                <div class="plot">
+                    <h3>Tree {i+1}</h3>
+                    <img src="data:image/png;base64,
+                    {tree}"
+                    alt="tree {i+1}">
+                </div>
+                """
+
+        analyzer = FeatureImportanceAnalyzer(
+            data=self.data,
+            target_col=self.target_col,
+            task_type=self.task_type,
+            output_dir=self.output_dir)
+        feature_importance_html = analyzer.run()
+
+        html_content = f"""
+        {get_html_template()}
+            <h1>PyCaret Model Training Report</h1>
+            <div class="tabs">
+                <div class="tab" onclick="openTab(event, 'summary')">
+                Setup & Best Model</div>
+                <div class="tab" onclick="openTab(event, 'plots')">
+                Best Model Plots</div>
+                <div class="tab" onclick="openTab(event, 'feature')">
+                Feature Importance</div>
+                <div class="tab" onclick="openTab(event, 'explainer')">
+                Explainer
+                </div>
+            </div>
+            <div id="summary" class="tab-content">
+                <h2>Setup Parameters</h2>
+                <table>
+                    <tr><th>Parameter</th><th>Value</th></tr>
+                    {setup_params_table.to_html(
+                        index=False, header=False, classes='table')}
+                </table>
+                <h5>If you want to know all the experiment setup parameters,
+                  please check the PyCaret documentation for
+                  the classification/regression <code>exp</code> function.</h5>
+                <h2>Best Model: {model_name}</h2>
+                <table>
+                    <tr><th>Parameter</th><th>Value</th></tr>
+                    {best_model_params.to_html(
+                        index=False, header=False, classes='table')}
+                </table>
+                <h2>Comparison Results on the Cross-Validation Set</h2>
+                <table>
+                    {self.results.to_html(index=False, classes='table')}
+                </table>
+                <h2>Results on the Test Set for the best model</h2>
+                <table>
+                    {self.test_result_df.to_html(index=False, classes='table')}
+                </table>
+            </div>
+            <div id="plots" class="tab-content">
+                <h2>Best Model Plots on the testing set</h2>
+                {plots_html}
+            </div>
+            <div id="feature" class="tab-content">
+                {feature_importance_html}
+            </div>
+            <div id="explainer" class="tab-content">
+                {self.plots_explainer_html}
+                {tree_plots}
+            </div>
+        {get_html_closing()}
+        """
+
+        with open(os.path.join(
+                self.output_dir, "comparison_result.html"), "w") as file:
+            file.write(html_content)
+
+    def save_dashboard(self):
+        raise NotImplementedError("Subclasses should implement this method")
+
+    def generate_plots_explainer(self):
+        raise NotImplementedError("Subclasses should implement this method")
+
+    # not working now
+    def generate_tree_plots(self):
+        from sklearn.ensemble import RandomForestClassifier, \
+            RandomForestRegressor
+        from xgboost import XGBClassifier, XGBRegressor
+        from explainerdashboard.explainers import RandomForestExplainer
+
+        LOG.info("Generating tree plots")
+        X_test = self.exp.X_test_transformed.copy()
+        y_test = self.exp.y_test_transformed
+
+        is_rf = isinstance(self.best_model, RandomForestClassifier) or \
+            isinstance(self.best_model, RandomForestRegressor)
+
+        is_xgb = isinstance(self.best_model, XGBClassifier) or \
+            isinstance(self.best_model, XGBRegressor)
+
+        try:
+            if is_rf:
+                num_trees = self.best_model.n_estimators
+            if is_xgb:
+                num_trees = len(self.best_model.get_booster().get_dump())
+            explainer = RandomForestExplainer(self.best_model, X_test, y_test)
+            for i in range(num_trees):
+                fig = explainer.decisiontree_encoded(tree_idx=i, index=0)
+                LOG.info(f"Tree {i+1}")
+                LOG.info(fig)
+                self.trees.append(fig)
+        except Exception as e:
+            LOG.error(f"Error generating tree plots: {e}")
+
+    def run(self):
+        self.load_data()
+        self.setup_pycaret()
+        self.train_model()
+        self.save_model()
+        self.generate_plots()
+        self.generate_plots_explainer()
+        self.generate_tree_plots()
+        self.save_html_report()
+        # self.save_dashboard()