Mercurial > repos > bgruening > keras_model_config
comparison keras_deep_learning.py @ 1:0fd7d8e90e2a draft
"planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/sklearn commit ba6a47bdf76bbf4cb276206ac1a8cbf61332fd16"
author | bgruening |
---|---|
date | Fri, 13 Sep 2019 12:19:45 -0400 |
parents | 1046cf73236b |
children | c3813c64d678 |
comparison
equal
deleted
inserted
replaced
0:1046cf73236b | 1:0fd7d8e90e2a |
---|---|
6 import six | 6 import six |
7 import warnings | 7 import warnings |
8 | 8 |
9 from ast import literal_eval | 9 from ast import literal_eval |
10 from keras.models import Sequential, Model | 10 from keras.models import Sequential, Model |
11 from galaxy_ml.utils import try_get_attr, get_search_params | 11 from galaxy_ml.utils import try_get_attr, get_search_params, SafeEval |
12 | |
13 | |
14 safe_eval = SafeEval() | |
12 | 15 |
13 | 16 |
14 def _handle_shape(literal): | 17 def _handle_shape(literal): |
15 """Eval integer or list/tuple of integers from string | 18 """Eval integer or list/tuple of integers from string |
16 | 19 |
98 continue | 101 continue |
99 | 102 |
100 if key in ['input_shape', 'noise_shape', 'shape', 'batch_shape', | 103 if key in ['input_shape', 'noise_shape', 'shape', 'batch_shape', |
101 'target_shape', 'dims', 'kernel_size', 'strides', | 104 'target_shape', 'dims', 'kernel_size', 'strides', |
102 'dilation_rate', 'output_padding', 'cropping', 'size', | 105 'dilation_rate', 'output_padding', 'cropping', 'size', |
103 'padding', 'pool_size', 'axis', 'shared_axes']: | 106 'padding', 'pool_size', 'axis', 'shared_axes'] \ |
107 and isinstance(value, str): | |
104 params[key] = _handle_shape(value) | 108 params[key] = _handle_shape(value) |
105 | 109 |
106 elif key.endswith('_regularizer'): | 110 elif key.endswith('_regularizer') and isinstance(value, dict): |
107 params[key] = _handle_regularizer(value) | 111 params[key] = _handle_regularizer(value) |
108 | 112 |
109 elif key.endswith('_constraint'): | 113 elif key.endswith('_constraint') and isinstance(value, dict): |
110 params[key] = _handle_constraint(value) | 114 params[key] = _handle_constraint(value) |
111 | 115 |
112 elif key == 'function': # No support for lambda/function eval | 116 elif key == 'function': # No support for lambda/function eval |
113 params.pop(key) | 117 params.pop(key) |
114 | 118 |
127 layers = config['layers'] | 131 layers = config['layers'] |
128 for layer in layers: | 132 for layer in layers: |
129 options = layer['layer_selection'] | 133 options = layer['layer_selection'] |
130 layer_type = options.pop('layer_type') | 134 layer_type = options.pop('layer_type') |
131 klass = getattr(keras.layers, layer_type) | 135 klass = getattr(keras.layers, layer_type) |
132 other_options = options.pop('layer_options', {}) | 136 kwargs = options.pop('kwargs', '') |
133 options.update(other_options) | |
134 | 137 |
135 # parameters needs special care | 138 # parameters needs special care |
136 options = _handle_layer_parameters(options) | 139 options = _handle_layer_parameters(options) |
140 | |
141 if kwargs: | |
142 kwargs = safe_eval('dict(' + kwargs + ')') | |
143 options.update(kwargs) | |
137 | 144 |
138 # add input_shape to the first layer only | 145 # add input_shape to the first layer only |
139 if not getattr(model, '_layers') and input_shape is not None: | 146 if not getattr(model, '_layers') and input_shape is not None: |
140 options['input_shape'] = input_shape | 147 options['input_shape'] = input_shape |
141 | 148 |
156 for layer in layers: | 163 for layer in layers: |
157 options = layer['layer_selection'] | 164 options = layer['layer_selection'] |
158 layer_type = options.pop('layer_type') | 165 layer_type = options.pop('layer_type') |
159 klass = getattr(keras.layers, layer_type) | 166 klass = getattr(keras.layers, layer_type) |
160 inbound_nodes = options.pop('inbound_nodes', None) | 167 inbound_nodes = options.pop('inbound_nodes', None) |
161 other_options = options.pop('layer_options', {}) | 168 kwargs = options.pop('kwargs', '') |
162 options.update(other_options) | |
163 | 169 |
164 # parameters needs special care | 170 # parameters needs special care |
165 options = _handle_layer_parameters(options) | 171 options = _handle_layer_parameters(options) |
172 | |
173 if kwargs: | |
174 kwargs = safe_eval('dict(' + kwargs + ')') | |
175 options.update(kwargs) | |
176 | |
166 # merge layers | 177 # merge layers |
167 if 'merging_layers' in options: | 178 if 'merging_layers' in options: |
168 idxs = literal_eval(options.pop('merging_layers')) | 179 idxs = literal_eval(options.pop('merging_layers')) |
169 merging_layers = [all_layers[i-1] for i in idxs] | 180 merging_layers = [all_layers[i-1] for i in idxs] |
170 new_layer = klass(**options)(merging_layers) | 181 new_layer = klass(**options)(merging_layers) |