Mercurial > repos > goeckslab > ludwig_train
comparison ludwig_experiment.py @ 1:4d12452c5361 draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit a341ff5627ef7a39489a7f377d96017fb3f42efb
author | goeckslab |
---|---|
date | Thu, 13 Mar 2025 16:43:12 +0000 (43 hours ago) |
parents | f0be10937f5c |
children |
comparison
equal
deleted
inserted
replaced
0:f0be10937f5c | 1:4d12452c5361 |
---|---|
1 import json | 1 import json |
2 import logging | 2 import logging |
3 import os | 3 import os |
4 import pickle | 4 import pickle |
5 import sys | 5 import sys |
6 | |
7 from jinja_report import generate_report | |
8 | 6 |
9 from ludwig.experiment import cli | 7 from ludwig.experiment import cli |
10 from ludwig.globals import ( | 8 from ludwig.globals import ( |
11 DESCRIPTION_FILE_NAME, | 9 DESCRIPTION_FILE_NAME, |
12 PREDICTIONS_PARQUET_FILE_NAME, | 10 PREDICTIONS_PARQUET_FILE_NAME, |
23 from utils import ( | 21 from utils import ( |
24 encode_image_to_base64, | 22 encode_image_to_base64, |
25 get_html_closing, | 23 get_html_closing, |
26 get_html_template | 24 get_html_template |
27 ) | 25 ) |
28 | |
29 import yaml | |
30 | 26 |
31 | 27 |
32 logging.basicConfig(level=logging.DEBUG) | 28 logging.basicConfig(level=logging.DEBUG) |
33 | 29 |
34 LOG = logging.getLogger(__name__) | 30 LOG = logging.getLogger(__name__) |
146 except Exception as e: | 142 except Exception as e: |
147 LOG.info(f"Visualization: {viz}") | 143 LOG.info(f"Visualization: {viz}") |
148 LOG.info(f"Error: {e}") | 144 LOG.info(f"Error: {e}") |
149 | 145 |
150 | 146 |
151 # report | |
152 def render_report( | |
153 title: str, | |
154 ludwig_output_directory_name: str, | |
155 show_visualization: bool = True | |
156 ): | |
157 ludwig_output_directory = os.path.join( | |
158 output_directory, | |
159 ludwig_output_directory_name, | |
160 ) | |
161 report_config = { | |
162 "title": title, | |
163 } | |
164 if show_visualization: | |
165 report_config["visualizations"] = [ | |
166 { | |
167 "src": f"visualizations/{fl}", | |
168 "type": "image" if fl[fl.rindex(".") + 1:] == "png" else | |
169 fl[fl.rindex(".") + 1:], | |
170 } for fl in sorted(os.listdir(viz_output_directory)) | |
171 ] | |
172 report_config["raw outputs"] = [ | |
173 { | |
174 "src": f"{fl}", | |
175 "type": "json" if fl.endswith(".json") else "unclassified", | |
176 } for fl in sorted(os.listdir(ludwig_output_directory)) | |
177 if fl.endswith((".json", ".parquet")) | |
178 ] | |
179 | |
180 with open(os.path.join(output_directory, "report_config.yml"), 'w') as fh: | |
181 yaml.safe_dump(report_config, fh) | |
182 | |
183 report_path = os.path.join(output_directory, "smart_report.html") | |
184 generate_report.main( | |
185 report_config, | |
186 schema={"html_height": 800}, | |
187 outfile=report_path, | |
188 ) | |
189 | |
190 | |
191 def convert_parquet_to_csv(ludwig_output_directory_name): | 147 def convert_parquet_to_csv(ludwig_output_directory_name): |
192 """Convert the predictions Parquet file to CSV.""" | 148 """Convert the predictions Parquet file to CSV.""" |
193 ludwig_output_directory = os.path.join( | 149 ludwig_output_directory = os.path.join( |
194 output_directory, ludwig_output_directory_name) | 150 output_directory, ludwig_output_directory_name) |
195 parquet_path = os.path.join( | 151 parquet_path = os.path.join( |
261 cli(sys.argv[1:]) | 217 cli(sys.argv[1:]) |
262 | 218 |
263 ludwig_output_directory_name = "experiment_run" | 219 ludwig_output_directory_name = "experiment_run" |
264 | 220 |
265 make_visualizations(ludwig_output_directory_name) | 221 make_visualizations(ludwig_output_directory_name) |
266 # title = "Ludwig Experiment" | |
267 # render_report(title, ludwig_output_directory_name) | |
268 convert_parquet_to_csv(ludwig_output_directory_name) | 222 convert_parquet_to_csv(ludwig_output_directory_name) |
269 generate_html_report("Ludwig Experiment", ludwig_output_directory_name) | 223 generate_html_report("Ludwig Experiment", ludwig_output_directory_name) |