comparison ludwig_hyperopt.py @ 0:70a4d910f09a draft default tip

planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit bdea9430787658783a51cc6c2ae951a01e455bb4
author goeckslab
date Tue, 07 Jan 2025 22:46:16 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:70a4d910f09a
1 import logging
2 import os
3 import pickle
4 import sys
5
6 from ludwig.globals import (
7 HYPEROPT_STATISTICS_FILE_NAME,
8 )
9 from ludwig.hyperopt_cli import cli
10 from ludwig.visualize import get_visualizations_registry
11
12 from model_unpickler import SafeUnpickler
13
14 from utils import (
15 encode_image_to_base64,
16 get_html_closing,
17 get_html_template
18 )
19
20 logging.basicConfig(level=logging.DEBUG)
21
22 LOG = logging.getLogger(__name__)
23
24 setattr(pickle, 'Unpickler', SafeUnpickler)
25
26 cli(sys.argv[1:])
27
28
29 def generate_html_report(title):
30
31 # Read test statistics JSON and convert to HTML table
32 # try:
33 # test_statistics_path = hyperopt_stats_path
34 # with open(test_statistics_path, "r") as f:
35 # test_statistics = json.load(f)
36 # test_statistics_html = "<h2>Hyperopt Statistics</h2>"
37 # test_statistics_html += json_to_html_table(test_statistics)
38 # except Exception as e:
39 # LOG.info(f"Error reading hyperopt statistics: {e}")
40
41 plots_html = ""
42 # Convert visualizations to HTML
43 hyperopt_hiplot_path = os.path.join(
44 viz_output_directory, "hyperopt_hiplot.html")
45 if os.path.isfile(hyperopt_hiplot_path):
46 with open(hyperopt_hiplot_path, "r", encoding="utf-8") as file:
47 hyperopt_hiplot_html = file.read()
48 plots_html += f'<div class="hiplot">{hyperopt_hiplot_html}</div>'
49 plots_html += "uid is the identifier for different hyperopt runs"
50 plots_html += "<br><br>"
51
52 # Iterate through other files in viz_output_directory
53 for plot_file in sorted(os.listdir(viz_output_directory)):
54 plot_path = os.path.join(viz_output_directory, plot_file)
55 if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")):
56 encoded_image = encode_image_to_base64(plot_path)
57 plots_html += (
58 f'<div class="plot">'
59 f'<h3>{os.path.splitext(plot_file)[0]}</h3>'
60 '<img src="data:image/png;base64,'
61 f'{encoded_image}" alt="{plot_file}">'
62 f'</div>'
63 )
64
65 # Generate the full HTML content
66 html_content = f"""
67 {get_html_template()}
68 <h1>{title}</h1>
69 <h2>Visualizations</h2>
70 {plots_html}
71 {get_html_closing()}
72 """
73
74 # Save the HTML report
75 report_name = title.lower().replace(" ", "_")
76 report_path = os.path.join(output_directory, f"{report_name}_report.html")
77 with open(report_path, "w") as report_file:
78 report_file.write(html_content)
79
80 LOG.info(f"HTML report generated at: {report_path}")
81
82
83 # visualization
84 output_directory = None
85 for ix, arg in enumerate(sys.argv):
86 if arg == "--output_directory":
87 output_directory = sys.argv[ix+1]
88 break
89
90 hyperopt_stats_path = os.path.join(
91 output_directory,
92 "hyperopt", HYPEROPT_STATISTICS_FILE_NAME
93 )
94
95 visualizations = ["hyperopt_report", "hyperopt_hiplot"]
96
97 viz_output_directory = os.path.join(output_directory, "visualizations")
98 for viz in visualizations:
99 viz_func = get_visualizations_registry()[viz]
100 viz_func(
101 hyperopt_stats_path=hyperopt_stats_path,
102 output_directory=viz_output_directory,
103 file_format="png",
104 )
105
106 # report
107 title = "Ludwig Hyperopt"
108 generate_html_report(title)