Mercurial > repos > galaxy-australia > alphafold2
comparison docker/alphafold/run_alphafold_test.py @ 1:6c92e000d684 draft
"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit a510e97ebd604a5e30b1f16e5031f62074f23e86"
| author | galaxy-australia |
|---|---|
| date | Tue, 01 Mar 2022 02:53:05 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 0:7ae9d78b06f5 | 1:6c92e000d684 |
|---|---|
| 1 # Copyright 2021 DeepMind Technologies Limited | |
| 2 # | |
| 3 # Licensed under the Apache License, Version 2.0 (the "License"); | |
| 4 # you may not use this file except in compliance with the License. | |
| 5 # You may obtain a copy of the License at | |
| 6 # | |
| 7 # http://www.apache.org/licenses/LICENSE-2.0 | |
| 8 # | |
| 9 # Unless required by applicable law or agreed to in writing, software | |
| 10 # distributed under the License is distributed on an "AS IS" BASIS, | |
| 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| 12 # See the License for the specific language governing permissions and | |
| 13 # limitations under the License. | |
| 14 | |
| 15 """Tests for run_alphafold.""" | |
| 16 | |
| 17 import os | |
| 18 | |
| 19 from absl.testing import absltest | |
| 20 from absl.testing import parameterized | |
| 21 import run_alphafold | |
| 22 import mock | |
| 23 import numpy as np | |
| 24 # Internal import (7716). | |
| 25 | |
| 26 | |
| 27 class RunAlphafoldTest(parameterized.TestCase): | |
| 28 | |
| 29 @parameterized.named_parameters( | |
| 30 ('relax', True), | |
| 31 ('no_relax', False), | |
| 32 ) | |
| 33 def test_end_to_end(self, do_relax): | |
| 34 | |
| 35 data_pipeline_mock = mock.Mock() | |
| 36 model_runner_mock = mock.Mock() | |
| 37 amber_relaxer_mock = mock.Mock() | |
| 38 | |
| 39 data_pipeline_mock.process.return_value = {} | |
| 40 model_runner_mock.process_features.return_value = { | |
| 41 'aatype': np.zeros((12, 10), dtype=np.int32), | |
| 42 'residue_index': np.tile(np.arange(10, dtype=np.int32)[None], (12, 1)), | |
| 43 } | |
| 44 model_runner_mock.predict.return_value = { | |
| 45 'structure_module': { | |
| 46 'final_atom_positions': np.zeros((10, 37, 3)), | |
| 47 'final_atom_mask': np.ones((10, 37)), | |
| 48 }, | |
| 49 'predicted_lddt': { | |
| 50 'logits': np.ones((10, 50)), | |
| 51 }, | |
| 52 'plddt': np.ones(10) * 42, | |
| 53 'ranking_confidence': 90, | |
| 54 'ptm': np.array(0.), | |
| 55 'aligned_confidence_probs': np.zeros((10, 10, 50)), | |
| 56 'predicted_aligned_error': np.zeros((10, 10)), | |
| 57 'max_predicted_aligned_error': np.array(0.), | |
| 58 } | |
| 59 model_runner_mock.multimer_mode = False | |
| 60 amber_relaxer_mock.process.return_value = ('RELAXED', None, None) | |
| 61 | |
| 62 fasta_path = os.path.join(absltest.get_default_test_tmpdir(), | |
| 63 'target.fasta') | |
| 64 with open(fasta_path, 'wt') as f: | |
| 65 f.write('>A\nAAAAAAAAAAAAA') | |
| 66 fasta_name = 'test' | |
| 67 | |
| 68 out_dir = absltest.get_default_test_tmpdir() | |
| 69 | |
| 70 run_alphafold.predict_structure( | |
| 71 fasta_path=fasta_path, | |
| 72 fasta_name=fasta_name, | |
| 73 output_dir_base=out_dir, | |
| 74 data_pipeline=data_pipeline_mock, | |
| 75 model_runners={'model1': model_runner_mock}, | |
| 76 amber_relaxer=amber_relaxer_mock if do_relax else None, | |
| 77 benchmark=False, | |
| 78 random_seed=0) | |
| 79 | |
| 80 base_output_files = os.listdir(out_dir) | |
| 81 self.assertIn('target.fasta', base_output_files) | |
| 82 self.assertIn('test', base_output_files) | |
| 83 | |
| 84 target_output_files = os.listdir(os.path.join(out_dir, 'test')) | |
| 85 expected_files = [ | |
| 86 'features.pkl', 'msas', 'ranked_0.pdb', 'ranking_debug.json', | |
| 87 'result_model1.pkl', 'timings.json', 'unrelaxed_model1.pdb', | |
| 88 ] | |
| 89 if do_relax: | |
| 90 expected_files.append('relaxed_model1.pdb') | |
| 91 self.assertCountEqual(expected_files, target_output_files) | |
| 92 | |
| 93 # Check that pLDDT is set in the B-factor column. | |
| 94 with open(os.path.join(out_dir, 'test', 'unrelaxed_model1.pdb')) as f: | |
| 95 for line in f: | |
| 96 if line.startswith('ATOM'): | |
| 97 self.assertEqual(line[61:66], '42.00') | |
| 98 | |
| 99 | |
| 100 if __name__ == '__main__': | |
| 101 absltest.main() |
