comparison feature_importance.py @ 0:1f20fe57fdee draft

planemo upload for repository https://github.com/goeckslab/Galaxy-Pycaret commit d79b0f722b7d09505a526d1a4332f87e548a3df1
author goeckslab
date Wed, 11 Dec 2024 04:59:43 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:1f20fe57fdee
1 import base64
2 import logging
3 import os
4
5 import matplotlib.pyplot as plt
6
7 import pandas as pd
8
9 from pycaret.classification import ClassificationExperiment
10 from pycaret.regression import RegressionExperiment
11
12 logging.basicConfig(level=logging.DEBUG)
13 LOG = logging.getLogger(__name__)
14
15
16 class FeatureImportanceAnalyzer:
17 def __init__(
18 self,
19 task_type,
20 output_dir,
21 data_path=None,
22 data=None,
23 target_col=None):
24
25 if data is not None:
26 self.data = data
27 LOG.info("Data loaded from memory")
28 else:
29 self.target_col = target_col
30 self.data = pd.read_csv(data_path, sep=None, engine='python')
31 self.data.columns = self.data.columns.str.replace('.', '_')
32 self.data = self.data.fillna(self.data.median(numeric_only=True))
33 self.task_type = task_type
34 self.target = self.data.columns[int(target_col) - 1]
35 self.exp = ClassificationExperiment() \
36 if task_type == 'classification' \
37 else RegressionExperiment()
38 self.plots = {}
39 self.output_dir = output_dir
40
41 def setup_pycaret(self):
42 LOG.info("Initializing PyCaret")
43 setup_params = {
44 'target': self.target,
45 'session_id': 123,
46 'html': True,
47 'log_experiment': False,
48 'system_log': False
49 }
50 LOG.info(self.task_type)
51 LOG.info(self.exp)
52 self.exp.setup(self.data, **setup_params)
53
54 # def save_coefficients(self):
55 # model = self.exp.create_model('lr')
56 # coef_df = pd.DataFrame({
57 # 'Feature': self.data.columns.drop(self.target),
58 # 'Coefficient': model.coef_[0]
59 # })
60 # coef_html = coef_df.to_html(index=False)
61 # return coef_html
62
63 def save_tree_importance(self):
64 model = self.exp.create_model('rf')
65 importances = model.feature_importances_
66 processed_features = self.exp.get_config('X_transformed').columns
67 LOG.debug(f"Feature importances: {importances}")
68 LOG.debug(f"Features: {processed_features}")
69 feature_importances = pd.DataFrame({
70 'Feature': processed_features,
71 'Importance': importances
72 }).sort_values(by='Importance', ascending=False)
73 plt.figure(figsize=(10, 6))
74 plt.barh(
75 feature_importances['Feature'],
76 feature_importances['Importance'])
77 plt.xlabel('Importance')
78 plt.title('Feature Importance (Random Forest)')
79 plot_path = os.path.join(
80 self.output_dir,
81 'tree_importance.png')
82 plt.savefig(plot_path)
83 plt.close()
84 self.plots['tree_importance'] = plot_path
85
86 def save_shap_values(self):
87 model = self.exp.create_model('lightgbm')
88 import shap
89 explainer = shap.Explainer(model)
90 shap_values = explainer.shap_values(
91 self.exp.get_config('X_transformed'))
92 shap.summary_plot(shap_values,
93 self.exp.get_config('X_transformed'), show=False)
94 plt.title('Shap (LightGBM)')
95 plot_path = os.path.join(
96 self.output_dir, 'shap_summary.png')
97 plt.savefig(plot_path)
98 plt.close()
99 self.plots['shap_summary'] = plot_path
100
101 def generate_feature_importance(self):
102 # coef_html = self.save_coefficients()
103 self.save_tree_importance()
104 self.save_shap_values()
105
106 def encode_image_to_base64(self, img_path):
107 with open(img_path, 'rb') as img_file:
108 return base64.b64encode(img_file.read()).decode('utf-8')
109
110 def generate_html_report(self):
111 LOG.info("Generating HTML report")
112
113 # Read and encode plot images
114 plots_html = ""
115 for plot_name, plot_path in self.plots.items():
116 encoded_image = self.encode_image_to_base64(plot_path)
117 plots_html += f"""
118 <div class="plot" id="{plot_name}">
119 <h2>{'Feature importance analysis from a'
120 'trained Random Forest'
121 if plot_name == 'tree_importance'
122 else 'SHAP Summary from a trained lightgbm'}</h2>
123 <h3>{'Use gini impurity for'
124 'calculating feature importance for classification'
125 'and Variance Reduction for regression'
126 if plot_name == 'tree_importance'
127 else ''}</h3>
128 <img src="data:image/png;base64,
129 {encoded_image}" alt="{plot_name}">
130 </div>
131 """
132
133 # Generate HTML content with tabs
134 html_content = f"""
135 <h1>PyCaret Feature Importance Report</h1>
136 {plots_html}
137 """
138
139 return html_content
140
141 def run(self):
142 LOG.info("Running feature importance analysis")
143 self.setup_pycaret()
144 self.generate_feature_importance()
145 html_content = self.generate_html_report()
146 LOG.info("Feature importance analysis completed")
147 return html_content
148
149
150 if __name__ == "__main__":
151 import argparse
152 parser = argparse.ArgumentParser(description="Feature Importance Analysis")
153 parser.add_argument(
154 "--data_path", type=str, help="Path to the dataset")
155 parser.add_argument(
156 "--target_col", type=int,
157 help="Index of the target column (1-based)")
158 parser.add_argument(
159 "--task_type", type=str,
160 choices=["classification", "regression"],
161 help="Task type: classification or regression")
162 parser.add_argument(
163 "--output_dir",
164 type=str,
165 help="Directory to save the outputs")
166 args = parser.parse_args()
167
168 analyzer = FeatureImportanceAnalyzer(
169 args.data_path, args.target_col,
170 args.task_type, args.output_dir)
171 analyzer.run()