Mercurial > repos > goeckslab > ludwig_train
comparison ludwig_render_config.py @ 0:f0be10937f5c draft default tip
planemo upload for repository https://github.com/goeckslab/Galaxy-Ludwig.git commit bdea9430787658783a51cc6c2ae951a01e455bb4
author | goeckslab |
---|---|
date | Tue, 07 Jan 2025 22:44:09 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:f0be10937f5c |
---|---|
1 import json | |
2 import logging | |
3 import sys | |
4 | |
5 from ludwig.constants import ( | |
6 COMBINER, | |
7 HYPEROPT, | |
8 INPUT_FEATURES, | |
9 MODEL_TYPE, | |
10 OUTPUT_FEATURES, | |
11 PROC_COLUMN, | |
12 TRAINER, | |
13 ) | |
14 from ludwig.schema.model_types.utils import merge_with_defaults | |
15 | |
16 import yaml | |
17 | |
18 logging.basicConfig(level=logging.DEBUG) | |
19 LOG = logging.getLogger(__name__) | |
20 inputs = sys.argv[1] | |
21 with open(inputs, 'r') as handler: | |
22 params = json.load(handler) | |
23 | |
24 config = {} | |
25 # input features | |
26 config[INPUT_FEATURES] = [] | |
27 for ftr in params[INPUT_FEATURES]['input_feature']: | |
28 config[INPUT_FEATURES].append(ftr['input_feature_selector']) | |
29 | |
30 # output features | |
31 config[OUTPUT_FEATURES] = [] | |
32 for ftr in params[OUTPUT_FEATURES]['output_feature']: | |
33 config[OUTPUT_FEATURES].append(ftr['output_feature_selector']) | |
34 | |
35 # combiner | |
36 config[COMBINER] = params[COMBINER] | |
37 | |
38 # training | |
39 config[TRAINER] = params[TRAINER][TRAINER] | |
40 config[MODEL_TYPE] = config[TRAINER].pop(MODEL_TYPE) | |
41 | |
42 # hyperopt | |
43 if params[HYPEROPT]['do_hyperopt'] == 'true': | |
44 config[HYPEROPT] = params[HYPEROPT][HYPEROPT] | |
45 | |
46 with open('./pre_config.yml', 'w') as f: | |
47 yaml.safe_dump(config, f, allow_unicode=True, default_flow_style=False) | |
48 | |
49 output = sys.argv[2] | |
50 output_config = merge_with_defaults(config) | |
51 | |
52 | |
53 def clean_proc_column(config: dict) -> None: | |
54 for ftr in config[INPUT_FEATURES]: | |
55 ftr.pop(PROC_COLUMN, None) | |
56 for ftr in config[OUTPUT_FEATURES]: | |
57 ftr.pop(PROC_COLUMN, None) | |
58 | |
59 | |
60 clean_proc_column(output_config) | |
61 | |
62 with open(output, "w") as f: | |
63 yaml.safe_dump(output_config, f, sort_keys=False) |