Mercurial > repos > goeckslab > ludwig_experiment
comparison ludwig_experiment.py @ 0:78e6686a218e draft
planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit bdea9430787658783a51cc6c2ae951a01e455bb4
| author | goeckslab |
|---|---|
| date | Tue, 07 Jan 2025 22:45:39 +0000 |
| parents | |
| children | 44267c11e02b |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:78e6686a218e |
|---|---|
| 1 import json | |
| 2 import logging | |
| 3 import os | |
| 4 import pickle | |
| 5 import sys | |
| 6 | |
| 7 from jinja_report import generate_report | |
| 8 | |
| 9 from ludwig.experiment import cli | |
| 10 from ludwig.globals import ( | |
| 11 DESCRIPTION_FILE_NAME, | |
| 12 PREDICTIONS_PARQUET_FILE_NAME, | |
| 13 TEST_STATISTICS_FILE_NAME, | |
| 14 TRAIN_SET_METADATA_FILE_NAME | |
| 15 ) | |
| 16 from ludwig.utils.data_utils import get_split_path | |
| 17 from ludwig.visualize import get_visualizations_registry | |
| 18 | |
| 19 from model_unpickler import SafeUnpickler | |
| 20 | |
| 21 import pandas as pd | |
| 22 | |
| 23 from utils import ( | |
| 24 encode_image_to_base64, | |
| 25 get_html_closing, | |
| 26 get_html_template | |
| 27 ) | |
| 28 | |
| 29 import yaml | |
| 30 | |
| 31 | |
| 32 logging.basicConfig(level=logging.DEBUG) | |
| 33 | |
| 34 LOG = logging.getLogger(__name__) | |
| 35 | |
| 36 setattr(pickle, 'Unpickler', SafeUnpickler) | |
| 37 | |
| 38 # visualization | |
| 39 output_directory = None | |
| 40 for ix, arg in enumerate(sys.argv): | |
| 41 if arg == "--output_directory": | |
| 42 output_directory = sys.argv[ix+1] | |
| 43 break | |
| 44 | |
| 45 viz_output_directory = os.path.join(output_directory, "visualizations") | |
| 46 | |
| 47 | |
| 48 def get_output_feature_name(experiment_dir, output_feature=0): | |
| 49 """Helper function to extract specified output feature name. | |
| 50 | |
| 51 :param experiment_dir: Path to the experiment directory | |
| 52 :param output_feature: position of the output feature the description.json | |
| 53 :return output_feature_name: name of the first output feature name | |
| 54 from the experiment | |
| 55 """ | |
| 56 if os.path.exists(os.path.join(experiment_dir, DESCRIPTION_FILE_NAME)): | |
| 57 description_file = os.path.join(experiment_dir, DESCRIPTION_FILE_NAME) | |
| 58 with open(description_file, "rb") as f: | |
| 59 content = json.load(f) | |
| 60 output_feature_name = \ | |
| 61 content["config"]["output_features"][output_feature]["name"] | |
| 62 dataset_path = content["dataset"] | |
| 63 return output_feature_name, dataset_path | |
| 64 return None, None | |
| 65 | |
| 66 | |
| 67 def check_file(file_path): | |
| 68 """Check if the file exists; return None if it doesn't.""" | |
| 69 return file_path if os.path.exists(file_path) else None | |
| 70 | |
| 71 | |
| 72 def make_visualizations(ludwig_output_directory_name): | |
| 73 ludwig_output_directory = os.path.join( | |
| 74 output_directory, | |
| 75 ludwig_output_directory_name, | |
| 76 ) | |
| 77 visualizations = [ | |
| 78 "confidence_thresholding", | |
| 79 "confidence_thresholding_data_vs_acc", | |
| 80 "confidence_thresholding_data_vs_acc_subset", | |
| 81 "confidence_thresholding_data_vs_acc_subset_per_class", | |
| 82 "confidence_thresholding_2thresholds_2d", | |
| 83 "confidence_thresholding_2thresholds_3d", | |
| 84 "binary_threshold_vs_metric", | |
| 85 "roc_curves", | |
| 86 "roc_curves_from_test_statistics", | |
| 87 "calibration_1_vs_all", | |
| 88 "calibration_multiclass", | |
| 89 "confusion_matrix", | |
| 90 "frequency_vs_f1", | |
| 91 "learning_curves", | |
| 92 ] | |
| 93 | |
| 94 # Check existence of required files | |
| 95 training_statistics = check_file(os.path.join( | |
| 96 ludwig_output_directory, | |
| 97 "training_statistics.json", | |
| 98 )) | |
| 99 test_statistics = check_file(os.path.join( | |
| 100 ludwig_output_directory, | |
| 101 TEST_STATISTICS_FILE_NAME, | |
| 102 )) | |
| 103 ground_truth_metadata = check_file(os.path.join( | |
| 104 ludwig_output_directory, | |
| 105 "model", | |
| 106 TRAIN_SET_METADATA_FILE_NAME, | |
| 107 )) | |
| 108 probabilities = check_file(os.path.join( | |
| 109 ludwig_output_directory, | |
| 110 PREDICTIONS_PARQUET_FILE_NAME, | |
| 111 )) | |
| 112 | |
| 113 output_feature, dataset_path = get_output_feature_name( | |
| 114 ludwig_output_directory) | |
| 115 ground_truth = None | |
| 116 split_file = None | |
| 117 if dataset_path: | |
| 118 ground_truth = check_file(dataset_path) | |
| 119 split_file = check_file(get_split_path(dataset_path)) | |
| 120 | |
| 121 if (not output_feature) and (test_statistics): | |
| 122 test_stat = os.path.join(test_statistics) | |
| 123 with open(test_stat, "rb") as f: | |
| 124 content = json.load(f) | |
| 125 output_feature = next(iter(content.keys())) | |
| 126 | |
| 127 for viz in visualizations: | |
| 128 viz_func = get_visualizations_registry()[viz] | |
| 129 try: | |
| 130 viz_func( | |
| 131 training_statistics=[training_statistics] | |
| 132 if training_statistics else [], | |
| 133 test_statistics=[test_statistics] if test_statistics else [], | |
| 134 probabilities=[probabilities] if probabilities else [], | |
| 135 top_n_classes=[0], | |
| 136 output_feature_name=output_feature if output_feature else "", | |
| 137 ground_truth_split=2, | |
| 138 top_k=3, | |
| 139 ground_truth_metadata=ground_truth_metadata, | |
| 140 ground_truth=ground_truth, | |
| 141 split_file=split_file, | |
| 142 output_directory=viz_output_directory, | |
| 143 normalize=False, | |
| 144 file_format="png", | |
| 145 ) | |
| 146 except Exception as e: | |
| 147 LOG.info(f"Visualization: {viz}") | |
| 148 LOG.info(f"Error: {e}") | |
| 149 | |
| 150 | |
| 151 # report | |
| 152 def render_report( | |
| 153 title: str, | |
| 154 ludwig_output_directory_name: str, | |
| 155 show_visualization: bool = True | |
| 156 ): | |
| 157 ludwig_output_directory = os.path.join( | |
| 158 output_directory, | |
| 159 ludwig_output_directory_name, | |
| 160 ) | |
| 161 report_config = { | |
| 162 "title": title, | |
| 163 } | |
| 164 if show_visualization: | |
| 165 report_config["visualizations"] = [ | |
| 166 { | |
| 167 "src": f"visualizations/{fl}", | |
| 168 "type": "image" if fl[fl.rindex(".") + 1:] == "png" else | |
| 169 fl[fl.rindex(".") + 1:], | |
| 170 } for fl in sorted(os.listdir(viz_output_directory)) | |
| 171 ] | |
| 172 report_config["raw outputs"] = [ | |
| 173 { | |
| 174 "src": f"{fl}", | |
| 175 "type": "json" if fl.endswith(".json") else "unclassified", | |
| 176 } for fl in sorted(os.listdir(ludwig_output_directory)) | |
| 177 if fl.endswith((".json", ".parquet")) | |
| 178 ] | |
| 179 | |
| 180 with open(os.path.join(output_directory, "report_config.yml"), 'w') as fh: | |
| 181 yaml.safe_dump(report_config, fh) | |
| 182 | |
| 183 report_path = os.path.join(output_directory, "smart_report.html") | |
| 184 generate_report.main( | |
| 185 report_config, | |
| 186 schema={"html_height": 800}, | |
| 187 outfile=report_path, | |
| 188 ) | |
| 189 | |
| 190 | |
| 191 def convert_parquet_to_csv(ludwig_output_directory_name): | |
| 192 """Convert the predictions Parquet file to CSV.""" | |
| 193 ludwig_output_directory = os.path.join( | |
| 194 output_directory, ludwig_output_directory_name) | |
| 195 parquet_path = os.path.join( | |
| 196 ludwig_output_directory, "predictions.parquet") | |
| 197 csv_path = os.path.join( | |
| 198 ludwig_output_directory, "predictions_parquet.csv") | |
| 199 | |
| 200 try: | |
| 201 df = pd.read_parquet(parquet_path) | |
| 202 df.to_csv(csv_path, index=False) | |
| 203 LOG.info(f"Converted Parquet to CSV: {csv_path}") | |
| 204 except Exception as e: | |
| 205 LOG.error(f"Error converting Parquet to CSV: {e}") | |
| 206 | |
| 207 | |
| 208 def generate_html_report(title, ludwig_output_directory_name): | |
| 209 # ludwig_output_directory = os.path.join( | |
| 210 # output_directory, ludwig_output_directory_name) | |
| 211 | |
| 212 # test_statistics_html = "" | |
| 213 # # Read test statistics JSON and convert to HTML table | |
| 214 # try: | |
| 215 # test_statistics_path = os.path.join( | |
| 216 # ludwig_output_directory, TEST_STATISTICS_FILE_NAME) | |
| 217 # with open(test_statistics_path, "r") as f: | |
| 218 # test_statistics = json.load(f) | |
| 219 # test_statistics_html = "<h2>Test Statistics</h2>" | |
| 220 # test_statistics_html += json_to_html_table( | |
| 221 # test_statistics) | |
| 222 # except Exception as e: | |
| 223 # LOG.info(f"Error reading test statistics: {e}") | |
| 224 | |
| 225 # Convert visualizations to HTML | |
| 226 plots_html = "" | |
| 227 if len(os.listdir(viz_output_directory)) > 0: | |
| 228 plots_html = "<h2>Visualizations</h2>" | |
| 229 for plot_file in sorted(os.listdir(viz_output_directory)): | |
| 230 plot_path = os.path.join(viz_output_directory, plot_file) | |
| 231 if os.path.isfile(plot_path) and plot_file.endswith((".png", ".jpg")): | |
| 232 encoded_image = encode_image_to_base64(plot_path) | |
| 233 plots_html += ( | |
| 234 f'<div class="plot">' | |
| 235 f'<h3>{os.path.splitext(plot_file)[0]}</h3>' | |
| 236 '<img src="data:image/png;base64,' | |
| 237 f'{encoded_image}" alt="{plot_file}">' | |
| 238 f'</div>' | |
| 239 ) | |
| 240 | |
| 241 # Generate the full HTML content | |
| 242 html_content = f""" | |
| 243 {get_html_template()} | |
| 244 <h1>{title}</h1> | |
| 245 {plots_html} | |
| 246 {get_html_closing()} | |
| 247 """ | |
| 248 | |
| 249 # Save the HTML report | |
| 250 title: str | |
| 251 report_name = title.lower().replace(" ", "_") | |
| 252 report_path = os.path.join(output_directory, f"{report_name}_report.html") | |
| 253 with open(report_path, "w") as report_file: | |
| 254 report_file.write(html_content) | |
| 255 | |
| 256 LOG.info(f"HTML report generated at: {report_path}") | |
| 257 | |
| 258 | |
| 259 if __name__ == "__main__": | |
| 260 | |
| 261 cli(sys.argv[1:]) | |
| 262 | |
| 263 ludwig_output_directory_name = "experiment_run" | |
| 264 | |
| 265 make_visualizations(ludwig_output_directory_name) | |
| 266 # title = "Ludwig Experiment" | |
| 267 # render_report(title, ludwig_output_directory_name) | |
| 268 convert_parquet_to_csv(ludwig_output_directory_name) | |
| 269 generate_html_report("Ludwig Experiment", ludwig_output_directory_name) |
