Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/run_alphafold_test.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
author | galaxy-australia |
---|---|
date | Tue, 01 Mar 2022 02:53:05 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
0:7ae9d78b06f5 | 1:6c92e000d684 |
---|---|
1 # Copyright 2021 DeepMind Technologies Limited | |
2 # | |
3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 # you may not use this file except in compliance with the License. | |
5 # You may obtain a copy of the License at | |
6 # | |
7 # http://www.apache.org/licenses/LICENSE-2.0 | |
8 # | |
9 # Unless required by applicable law or agreed to in writing, software | |
10 # distributed under the License is distributed on an "AS IS" BASIS, | |
11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 # See the License for the specific language governing permissions and | |
13 # limitations under the License. | |
14 | |
15 """Tests for run_alphafold.""" | |
16 | |
17 import os | |
18 | |
19 from absl.testing import absltest | |
20 from absl.testing import parameterized | |
21 import run_alphafold | |
22 import mock | |
23 import numpy as np | |
24 # Internal import (7716). | |
25 | |
26 | |
27 class RunAlphafoldTest(parameterized.TestCase): | |
28 | |
29 @parameterized.named_parameters( | |
30 ('relax', True), | |
31 ('no_relax', False), | |
32 ) | |
33 def test_end_to_end(self, do_relax): | |
34 | |
35 data_pipeline_mock = mock.Mock() | |
36 model_runner_mock = mock.Mock() | |
37 amber_relaxer_mock = mock.Mock() | |
38 | |
39 data_pipeline_mock.process.return_value = {} | |
40 model_runner_mock.process_features.return_value = { | |
41 'aatype': np.zeros((12, 10), dtype=np.int32), | |
42 'residue_index': np.tile(np.arange(10, dtype=np.int32)[None], (12, 1)), | |
43 } | |
44 model_runner_mock.predict.return_value = { | |
45 'structure_module': { | |
46 'final_atom_positions': np.zeros((10, 37, 3)), | |
47 'final_atom_mask': np.ones((10, 37)), | |
48 }, | |
49 'predicted_lddt': { | |
50 'logits': np.ones((10, 50)), | |
51 }, | |
52 'plddt': np.ones(10) * 42, | |
53 'ranking_confidence': 90, | |
54 'ptm': np.array(0.), | |
55 'aligned_confidence_probs': np.zeros((10, 10, 50)), | |
56 'predicted_aligned_error': np.zeros((10, 10)), | |
57 'max_predicted_aligned_error': np.array(0.), | |
58 } | |
59 model_runner_mock.multimer_mode = False | |
60 amber_relaxer_mock.process.return_value = ('RELAXED', None, None) | |
61 | |
62 fasta_path = os.path.join(absltest.get_default_test_tmpdir(), | |
63 'target.fasta') | |
64 with open(fasta_path, 'wt') as f: | |
65 f.write('>A\nAAAAAAAAAAAAA') | |
66 fasta_name = 'test' | |
67 | |
68 out_dir = absltest.get_default_test_tmpdir() | |
69 | |
70 run_alphafold.predict_structure( | |
71 fasta_path=fasta_path, | |
72 fasta_name=fasta_name, | |
73 output_dir_base=out_dir, | |
74 data_pipeline=data_pipeline_mock, | |
75 model_runners={'model1': model_runner_mock}, | |
76 amber_relaxer=amber_relaxer_mock if do_relax else None, | |
77 benchmark=False, | |
78 random_seed=0) | |
79 | |
80 base_output_files = os.listdir(out_dir) | |
81 self.assertIn('target.fasta', base_output_files) | |
82 self.assertIn('test', base_output_files) | |
83 | |
84 target_output_files = os.listdir(os.path.join(out_dir, 'test')) | |
85 expected_files = [ | |
86 'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', | |
87 'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb', | |
88 ] | |
89 if do_relax: | |
90 expected_files.append('relaxed_model1.pdb') | |
91 self.assertCountEqual(expected_files, target_output_files) | |
92 | |
93 # Check that pLDDT is set in the B-factor column. | |
94 with open(os.path.join(out_dir, 'test', 'unrelaxed_model1.pdb')) as f: | |
95 for line in f: | |
96 if line.startswith('ATOM'): | |
97 self.assertEqual(line[61:66], '42.00') | |
98 | |
99 | |
100 if __name__ == '__main__': | |
101 absltest.main() |