Mercurial > repos > galaxy-australia > alphafold2
view docker/alphafold/alphafold/model/layer_stack_test.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
line wrap: on
line source
# Copyright 2021 DeepMind Technologies Limited # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Tests for layer_stack.""" import functools from absl.testing import absltest from absl.testing import parameterized from alphafold.model import layer_stack import haiku as hk import jax import jax.numpy as jnp import numpy as np import scipy # Suffixes applied by Haiku for repeated module names. suffixes = [''] + [f'_{i}' for i in range(1, 100)] def _slice_layers_params(layers_params): sliced_layers_params = {} for k, v in layers_params.items(): for inner_k in v: for var_slice, suffix in zip(v[inner_k], suffixes): k_new = k.split('/')[-1] + suffix if k_new not in sliced_layers_params: sliced_layers_params[k_new] = {} sliced_layers_params[k_new][inner_k] = var_slice return sliced_layers_params class LayerStackTest(parameterized.TestCase): @parameterized.parameters([1, 2, 4]) def test_layer_stack(self, unroll): """Compare layer_stack to the equivalent unrolled stack. Tests that the layer_stack application of a Haiku layer function is equivalent to repeatedly applying the layer function in an unrolled loop. Args: unroll: Number of unrolled layers. """ num_layers = 20 def inner_fn(x): x += hk.Linear(100, name='linear1')(x) x += hk.Linear(100, name='linear2')(x) return x def outer_fn_unrolled(x): for _ in range(num_layers): x = inner_fn(x) return x def outer_fn_layer_stack(x): stack = layer_stack.layer_stack(num_layers, unroll=unroll)(inner_fn) return stack(x) unrolled_fn = hk.transform(outer_fn_unrolled) layer_stack_fn = hk.transform(outer_fn_layer_stack) x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100]) rng_init = jax.random.PRNGKey(42) params = layer_stack_fn.init(rng_init, x) sliced_params = _slice_layers_params(params) unrolled_pred = unrolled_fn.apply(sliced_params, None, x) layer_stack_pred = layer_stack_fn.apply(params, None, x) np.testing.assert_allclose(unrolled_pred, layer_stack_pred) def test_layer_stack_multi_args(self): """Compare layer_stack to the equivalent unrolled stack. Similar to `test_layer_stack`, but use a function that takes more than one argument. """ num_layers = 20 def inner_fn(x, y): x_out = x + hk.Linear(100, name='linear1')(y) y_out = y + hk.Linear(100, name='linear2')(x) return x_out, y_out def outer_fn_unrolled(x, y): for _ in range(num_layers): x, y = inner_fn(x, y) return x, y def outer_fn_layer_stack(x, y): stack = layer_stack.layer_stack(num_layers)(inner_fn) return stack(x, y) unrolled_fn = hk.transform(outer_fn_unrolled) layer_stack_fn = hk.transform(outer_fn_layer_stack) x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100]) y = jax.random.uniform(jax.random.PRNGKey(1), [10, 256, 100]) rng_init = jax.random.PRNGKey(42) params = layer_stack_fn.init(rng_init, x, y) sliced_params = _slice_layers_params(params) unrolled_x, unrolled_y = unrolled_fn.apply(sliced_params, None, x, y) layer_stack_x, layer_stack_y = layer_stack_fn.apply(params, None, x, y) np.testing.assert_allclose(unrolled_x, layer_stack_x) np.testing.assert_allclose(unrolled_y, layer_stack_y) def test_layer_stack_no_varargs(self): """Test an error is raised when using a function with varargs.""" class VarArgsModule(hk.Module): """When used, this module should cause layer_stack to raise an Error.""" def __call__(self, *args): return args class NoVarArgsModule(hk.Module): """This module should be fine to use with layer_stack.""" def __call__(self, x): return x def build_and_init_stack(module_class): def stack_fn(x): module = module_class() return layer_stack.layer_stack(1)(module)(x) stack = hk.without_apply_rng(hk.transform(stack_fn)) stack.init(jax.random.PRNGKey(1729), jnp.ones([5])) build_and_init_stack(NoVarArgsModule) with self.assertRaisesRegex( ValueError, 'The function `f` should not have any `varargs`'): build_and_init_stack(VarArgsModule) @parameterized.parameters([1, 2, 4]) def test_layer_stack_grads(self, unroll): """Compare layer_stack gradients to the equivalent unrolled stack. Tests that the layer_stack application of a Haiku layer function is equivalent to repeatedly applying the layer function in an unrolled loop. Args: unroll: Number of unrolled layers. """ num_layers = 20 def inner_fn(x): x += hk.Linear(100, name='linear1')(x) x += hk.Linear(100, name='linear2')(x) return x def outer_fn_unrolled(x): for _ in range(num_layers): x = inner_fn(x) return x def outer_fn_layer_stack(x): stack = layer_stack.layer_stack(num_layers, unroll=unroll)(inner_fn) return stack(x) unrolled_fn = hk.transform(outer_fn_unrolled) layer_stack_fn = hk.transform(outer_fn_layer_stack) x = jax.random.uniform(jax.random.PRNGKey(0), [10, 256, 100]) rng_init = jax.random.PRNGKey(42) params = layer_stack_fn.init(rng_init, x) sliced_params = _slice_layers_params(params) unrolled_grad = jax.grad( lambda p, x: jnp.mean(unrolled_fn.apply(p, None, x)))(sliced_params, x) layer_stack_grad = jax.grad( lambda p, x: jnp.mean(layer_stack_fn.apply(p, None, x)))(params, x) assert_fn = functools.partial( np.testing.assert_allclose, atol=1e-4, rtol=1e-4) jax.tree_multimap(assert_fn, unrolled_grad, _slice_layers_params(layer_stack_grad)) def test_random(self): """Random numbers should be handled correctly.""" n = 100 @hk.transform @layer_stack.layer_stack(n) def add_random(x): x = x + jax.random.normal(hk.next_rng_key()) return x # Evaluate a bunch of times key, *keys = jax.random.split(jax.random.PRNGKey(7), 1024 + 1) params = add_random.init(key, 0.) apply_fn = jax.jit(add_random.apply) values = [apply_fn(params, key, 0.) for key in keys] # Should be roughly N(0, sqrt(n)) cdf = scipy.stats.norm(scale=np.sqrt(n)).cdf _, p = scipy.stats.kstest(values, cdf) self.assertLess(0.3, p) def test_threading(self): """Test @layer_stack when the function gets per-layer state.""" n = 5 @layer_stack.layer_stack(n, with_state=True) def f(x, y): x = x + y * jax.nn.one_hot(y, len(x)) / 10 return x, 2 * y @hk.without_apply_rng @hk.transform def g(x, ys): x, zs = f(x, ys) # Check here to catch issues at init time self.assertEqual(zs.shape, (n,)) return x, zs rng = jax.random.PRNGKey(7) x = np.zeros(n) ys = np.arange(n).astype(np.float32) params = g.init(rng, x, ys) x, zs = g.apply(params, x, ys) self.assertTrue(np.allclose(x, [0, .1, .2, .3, .4])) self.assertTrue(np.all(zs == 2 * ys)) def test_nested_stacks(self): def stack_fn(x): def layer_fn(x): return hk.Linear(100)(x) outer_fn = layer_stack.layer_stack(10)(layer_fn) layer_outer = layer_stack.layer_stack(20)(outer_fn) return layer_outer(x) hk_mod = hk.transform(stack_fn) apply_rng, init_rng = jax.random.split(jax.random.PRNGKey(0)) params = hk_mod.init(init_rng, jnp.zeros([10, 100])) hk_mod.apply(params, apply_rng, jnp.zeros([10, 100])) p, = params.values() assert p['w'].shape == (10, 20, 100, 100) assert p['b'].shape == (10, 20, 100) def test_with_state_multi_args(self): """Test layer_stack with state with multiple arguments.""" width = 4 batch_size = 5 stack_height = 3 def f_with_multi_args(x, a, b): return hk.Linear( width, w_init=hk.initializers.Constant( jnp.eye(width)))(x) * a + b, None @hk.without_apply_rng @hk.transform def hk_fn(x): return layer_stack.layer_stack( stack_height, with_state=True)(f_with_multi_args)(x, jnp.full([stack_height], 2.), jnp.ones([stack_height])) x = jnp.zeros([batch_size, width]) key_seq = hk.PRNGSequence(19) params = hk_fn.init(next(key_seq), x) output, z = hk_fn.apply(params, x) self.assertIsNone(z) self.assertEqual(output.shape, (batch_size, width)) np.testing.assert_equal(output, np.full([batch_size, width], 7.)) def test_with_container_state(self): width = 2 batch_size = 2 stack_height = 3 def f_with_container_state(x): hk_layer = hk.Linear( width, w_init=hk.initializers.Constant(jnp.eye(width))) layer_output = hk_layer(x) layer_state = { 'raw_output': layer_output, 'output_projection': jnp.sum(layer_output) } return layer_output + jnp.ones_like(layer_output), layer_state @hk.without_apply_rng @hk.transform def hk_fn(x): return layer_stack.layer_stack( stack_height, with_state=True)(f_with_container_state)(x) x = jnp.zeros([batch_size, width]) key_seq = hk.PRNGSequence(19) params = hk_fn.init(next(key_seq), x) output, z = hk_fn.apply(params, x) self.assertEqual(z['raw_output'].shape, (stack_height, batch_size, width)) self.assertEqual(output.shape, (batch_size, width)) self.assertEqual(z['output_projection'].shape, (stack_height,)) np.testing.assert_equal(np.sum(z['output_projection']), np.array(12.)) np.testing.assert_equal( np.all(z['raw_output'] == np.array([0., 1., 2.])[..., None, None]), np.array(True)) if __name__ == '__main__': absltest.main()