Mercurial > repos > goeckslab > ludwig_train
diff ludwig_hyperopt.py @ 0:f0be10937f5c draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit bdea9430787658783a51cc6c2ae951a01e455bb4
author | goeckslab |
---|---|
date | Tue, 07 Jan 2025 22:44:09 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/ludwig_hyperopt.py Tue Jan 07 22:44:09 2025 +0000 @@ -0,0 +1,108 @@ +import logging +import os +import pickle +import sys + +from ludwig.globals import ( + HYPEROPT_STATISTICS_FILE_NAME, +) +from ludwig.hyperopt_cli import cli +from ludwig.visualize import get_visualizations_registry + +from model_unpickler import SafeUnpickler + +from utils import ( + encode_image_to_base64, + get_html_closing, + get_html_template +) + +logging.basicConfig(level=logging.DEBUG) + +LOG = logging.getLogger(__name__) + +setattr(pickle, 'Unpickler', SafeUnpickler) + +cli(sys.argv[1:]) + + +def generate_html_report(title): + + # Read test statistics JSON and convert to HTML table + # try: + # test_statistics_path = hyperopt_stats_path + # with open(test_statistics_path, "r") as f: + # test_statistics = json.load(f) + # test_statistics_html = "<h2>Hyperopt Statistics</h2>" + # test_statistics_html += json_to_html_table(test_statistics) + # except Exception as e: + # LOG.info(f"Error reading hyperopt statistics: {e}") + + plots_html = "" + # Convert visualizations to HTML + hyperopt_hiplot_path = os.path.join( + viz_output_directory, "hyperopt_hiplot.html") + if os.path.isfile(hyperopt_hiplot_path): + with open(hyperopt_hiplot_path, "r", encoding="utf-8") as file: + hyperopt_hiplot_html = file.read() + plots_html += f'<div class="hiplot">{hyperopt_hiplot_html}</div>' + plots_html += "uid is the identifier for different hyperopt runs" + plots_html += "<br><br>" + + # Iterate through other files in viz_output_directory + for plot_file in sorted(os.listdir(viz_output_directory)): + plot_path = os.path.join(viz_output_directory, plot_file) + if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")): + encoded_image = encode_image_to_base64(plot_path) + plots_html += ( + f'<div class="plot">' + f'<h3>{os.path.splitext(plot_file)[0]}</h3>' + '<img src="data:image/png;base64,' + f'{encoded_image}" alt="{plot_file}">' + f'</div>' + ) + + # Generate the full HTML content + html_content = f""" + {get_html_template()} + <h1>{title}</h1> + <h2>Visualizations</h2> + {plots_html} + {get_html_closing()} + """ + + # Save the HTML report + report_name = title.lower().replace(" ", "_") + report_path = os.path.join(output_directory, f"{report_name}_report.html") + with open(report_path, "w") as report_file: + report_file.write(html_content) + + LOG.info(f"HTML report generated at: {report_path}") + + +# visualization +output_directory = None +for ix, arg in enumerate(sys.argv): + if arg == "--output_directory": + output_directory = sys.argv[ix+1] + break + +hyperopt_stats_path = os.path.join( + output_directory, + "hyperopt", HYPEROPT_STATISTICS_FILE_NAME +) + +visualizations = ["hyperopt_report", "hyperopt_hiplot"] + +viz_output_directory = os.path.join(output_directory, "visualizations") +for viz in visualizations: + viz_func = get_visualizations_registry()[viz] + viz_func( + hyperopt_stats_path=hyperopt_stats_path, + output_directory=viz_output_directory, + file_format="png", + ) + +# report +title = "Ludwig Hyperopt" +generate_html_report(title)