comparison ludwig_experiment.py @ 0:dceb8493730d draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit bdea9430787658783a51cc6c2ae951a01e455bb4
author goeckslab
date Tue, 07 Jan 2025 22:44:54 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:dceb8493730d
1 import json
2 import logging
3 import os
4 import pickle
5 import sys
6
7 from jinja_report import generate_report
8
9 from ludwig.experiment import cli
10 from ludwig.globals import (
11 DESCRIPTION_FILE_NAME,
12 PREDICTIONS_PARQUET_FILE_NAME,
13 TEST_STATISTICS_FILE_NAME,
14 TRAIN_SET_METADATA_FILE_NAME
15 )
16 from ludwig.utils.data_utils import get_split_path
17 from ludwig.visualize import get_visualizations_registry
18
19 from model_unpickler import SafeUnpickler
20
21 import pandas as pd
22
23 from utils import (
24 encode_image_to_base64,
25 get_html_closing,
26 get_html_template
27 )
28
29 import yaml
30
31
32 logging.basicConfig(level=logging.DEBUG)
33
34 LOG = logging.getLogger(__name__)
35
36 setattr(pickle, 'Unpickler', SafeUnpickler)
37
38 # visualization
39 output_directory = None
40 for ix, arg in enumerate(sys.argv):
41 if arg == "--output_directory":
42 output_directory = sys.argv[ix+1]
43 break
44
45 viz_output_directory = os.path.join(output_directory, "visualizations")
46
47
48 def get_output_feature_name(experiment_dir, output_feature=0):
49 """Helper function to extract specified output feature name.
50
51 :param experiment_dir: Path to the experiment directory
52 :param output_feature: position of the output feature the description.json
53 :return output_feature_name: name of the first output feature name
54 from the experiment
55 """
56 if os.path.exists(os.path.join(experiment_dir, DESCRIPTION_FILE_NAME)):
57 description_file = os.path.join(experiment_dir, DESCRIPTION_FILE_NAME)
58 with open(description_file, "rb") as f:
59 content = json.load(f)
60 output_feature_name = \
61 content["config"]["output_features"][output_feature]["name"]
62 dataset_path = content["dataset"]
63 return output_feature_name, dataset_path
64 return None, None
65
66
67 def check_file(file_path):
68 """Check if the file exists; return None if it doesn't."""
69 return file_path if os.path.exists(file_path) else None
70
71
72 def make_visualizations(ludwig_output_directory_name):
73 ludwig_output_directory = os.path.join(
74 output_directory,
75 ludwig_output_directory_name,
76 )
77 visualizations = [
78 "confidence_thresholding",
79 "confidence_thresholding_data_vs_acc",
80 "confidence_thresholding_data_vs_acc_subset",
81 "confidence_thresholding_data_vs_acc_subset_per_class",
82 "confidence_thresholding_2thresholds_2d",
83 "confidence_thresholding_2thresholds_3d",
84 "binary_threshold_vs_metric",
85 "roc_curves",
86 "roc_curves_from_test_statistics",
87 "calibration_1_vs_all",
88 "calibration_multiclass",
89 "confusion_matrix",
90 "frequency_vs_f1",
91 "learning_curves",
92 ]
93
94 # Check existence of required files
95 training_statistics = check_file(os.path.join(
96 ludwig_output_directory,
97 "training_statistics.json",
98 ))
99 test_statistics = check_file(os.path.join(
100 ludwig_output_directory,
101 TEST_STATISTICS_FILE_NAME,
102 ))
103 ground_truth_metadata = check_file(os.path.join(
104 ludwig_output_directory,
105 "model",
106 TRAIN_SET_METADATA_FILE_NAME,
107 ))
108 probabilities = check_file(os.path.join(
109 ludwig_output_directory,
110 PREDICTIONS_PARQUET_FILE_NAME,
111 ))
112
113 output_feature, dataset_path = get_output_feature_name(
114 ludwig_output_directory)
115 ground_truth = None
116 split_file = None
117 if dataset_path:
118 ground_truth = check_file(dataset_path)
119 split_file = check_file(get_split_path(dataset_path))
120
121 if (not output_feature) and (test_statistics):
122 test_stat = os.path.join(test_statistics)
123 with open(test_stat, "rb") as f:
124 content = json.load(f)
125 output_feature = next(iter(content.keys()))
126
127 for viz in visualizations:
128 viz_func = get_visualizations_registry()[viz]
129 try:
130 viz_func(
131 training_statistics=[training_statistics]
132 if training_statistics else [],
133 test_statistics=[test_statistics] if test_statistics else [],
134 probabilities=[probabilities] if probabilities else [],
135 top_n_classes=[0],
136 output_feature_name=output_feature if output_feature else "",
137 ground_truth_split=2,
138 top_k=3,
139 ground_truth_metadata=ground_truth_metadata,
140 ground_truth=ground_truth,
141 split_file=split_file,
142 output_directory=viz_output_directory,
143 normalize=False,
144 file_format="png",
145 )
146 except Exception as e:
147 LOG.info(f"Visualization: {viz}")
148 LOG.info(f"Error: {e}")
149
150
151 # report
152 def render_report(
153 title: str,
154 ludwig_output_directory_name: str,
155 show_visualization: bool = True
156 ):
157 ludwig_output_directory = os.path.join(
158 output_directory,
159 ludwig_output_directory_name,
160 )
161 report_config = {
162 "title": title,
163 }
164 if show_visualization:
165 report_config["visualizations"] = [
166 {
167 "src": f"visualizations/{fl}",
168 "type": "image" if fl[fl.rindex(".") + 1:] == "png" else
169 fl[fl.rindex(".") + 1:],
170 } for fl in sorted(os.listdir(viz_output_directory))
171 ]
172 report_config["raw outputs"] = [
173 {
174 "src": f"{fl}",
175 "type": "json" if fl.endswith(".json") else "unclassified",
176 } for fl in sorted(os.listdir(ludwig_output_directory))
177 if fl.endswith((".json", ".parquet"))
178 ]
179
180 with open(os.path.join(output_directory, "report_config.yml"), 'w') as fh:
181 yaml.safe_dump(report_config, fh)
182
183 report_path = os.path.join(output_directory, "smart_report.html")
184 generate_report.main(
185 report_config,
186 schema={"html_height": 800},
187 outfile=report_path,
188 )
189
190
191 def convert_parquet_to_csv(ludwig_output_directory_name):
192 """Convert the predictions Parquet file to CSV."""
193 ludwig_output_directory = os.path.join(
194 output_directory, ludwig_output_directory_name)
195 parquet_path = os.path.join(
196 ludwig_output_directory, "predictions.parquet")
197 csv_path = os.path.join(
198 ludwig_output_directory, "predictions_parquet.csv")
199
200 try:
201 df = pd.read_parquet(parquet_path)
202 df.to_csv(csv_path, index=False)
203 LOG.info(f"Converted Parquet to CSV: {csv_path}")
204 except Exception as e:
205 LOG.error(f"Error converting Parquet to CSV: {e}")
206
207
208 def generate_html_report(title, ludwig_output_directory_name):
209 # ludwig_output_directory = os.path.join(
210 # output_directory, ludwig_output_directory_name)
211
212 # test_statistics_html = ""
213 # # Read test statistics JSON and convert to HTML table
214 # try:
215 # test_statistics_path = os.path.join(
216 # ludwig_output_directory, TEST_STATISTICS_FILE_NAME)
217 # with open(test_statistics_path, "r") as f:
218 # test_statistics = json.load(f)
219 # test_statistics_html = "<h2>Test Statistics</h2>"
220 # test_statistics_html += json_to_html_table(
221 # test_statistics)
222 # except Exception as e:
223 # LOG.info(f"Error reading test statistics: {e}")
224
225 # Convert visualizations to HTML
226 plots_html = ""
227 if len(os.listdir(viz_output_directory)) > 0:
228 plots_html = "<h2>Visualizations</h2>"
229 for plot_file in sorted(os.listdir(viz_output_directory)):
230 plot_path = os.path.join(viz_output_directory, plot_file)
231 if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")):
232 encoded_image = encode_image_to_base64(plot_path)
233 plots_html += (
234 f'<div class="plot">'
235 f'<h3>{os.path.splitext(plot_file)[0]}</h3>'
236 '<img src="data:image/png;base64,'
237 f'{encoded_image}" alt="{plot_file}">'
238 f'</div>'
239 )
240
241 # Generate the full HTML content
242 html_content = f"""
243 {get_html_template()}
244 <h1>{title}</h1>
245 {plots_html}
246 {get_html_closing()}
247 """
248
249 # Save the HTML report
250 title: str
251 report_name = title.lower().replace(" ", "_")
252 report_path = os.path.join(output_directory, f"{report_name}_report.html")
253 with open(report_path, "w") as report_file:
254 report_file.write(html_content)
255
256 LOG.info(f"HTML report generated at: {report_path}")
257
258
259 if __name__ == "__main__":
260
261 cli(sys.argv[1:])
262
263 ludwig_output_directory_name = "experiment_run"
264
265 make_visualizations(ludwig_output_directory_name)
266 # title = "Ludwig Experiment"
267 # render_report(title, ludwig_output_directory_name)
268 convert_parquet_to_csv(ludwig_output_directory_name)
269 generate_html_report("Ludwig Experiment", ludwig_output_directory_name)