Mercurial > repos > iuc > virhunter
comparison utils/batch_loader.py @ 0:457fd8fd681a draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
author | iuc |
---|---|
date | Wed, 09 Nov 2022 12:19:26 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:457fd8fd681a |
---|---|
1 #!/usr/bin/env python | |
2 # -*- coding: utf-8 -*- | |
3 # Credits: Grigorii Sukhorukov, Macha Nikolski | |
4 import numpy as np | |
5 from sklearn.utils import shuffle | |
6 from tensorflow import keras | |
7 | |
8 | |
9 class BatchLoader(keras.utils.Sequence): | |
10 """Helper to iterate over the data (as Numpy arrays).""" | |
11 def __init__( | |
12 self, | |
13 input_seqs, | |
14 input_seqs_rc, | |
15 input_labs, | |
16 batches, | |
17 rc=True, | |
18 random_seed=1 | |
19 ): | |
20 self.input_seqs = input_seqs | |
21 self.input_seqs_rc = input_seqs_rc | |
22 self.input_labs = input_labs | |
23 self.batches = batches | |
24 self.rc = rc | |
25 self.random_seed = random_seed | |
26 | |
27 def __len__(self): | |
28 return len(self.batches) | |
29 | |
30 def __getitem__(self, idx): | |
31 batch = sorted(self.batches[idx]) | |
32 batch_seqs, batch_seqs_rc, batch_labs = shuffle( | |
33 np.array(self.input_seqs[batch, ...]), | |
34 np.array(self.input_seqs_rc[batch, ...]), | |
35 np.array(self.input_labs[batch, ...]), | |
36 random_state=self.random_seed | |
37 ) | |
38 # adding reverse batches | |
39 # batch_seqs = np.concatenate((batch_seqs, batch_seqs[:, ::-1, ...])) | |
40 # batch_seqs_rc = np.concatenate((batch_seqs_rc, batch_seqs_rc[:, ::-1, ...])) | |
41 # batch_labs = np.concatenate((batch_labs, batch_labs[:, ::-1, ...])) | |
42 if self.rc: | |
43 return (batch_seqs, batch_seqs_rc), batch_labs | |
44 else: | |
45 return batch_seqs, batch_labs | |
46 | |
47 | |
48 class BatchGenerator: | |
49 """Helper to iterate over the data (as Numpy arrays).""" | |
50 def __init__( | |
51 self, | |
52 input_seqs, | |
53 input_seqs_rc, | |
54 input_labs, | |
55 batches, | |
56 random_seed=1 | |
57 ): | |
58 self.input_seqs = input_seqs | |
59 self.input_seqs_rc = input_seqs_rc | |
60 self.input_labs = input_labs | |
61 self.batches = batches | |
62 self.random_seed = random_seed | |
63 | |
64 def __call__(self): | |
65 for batch in self.batches: | |
66 batch = sorted(batch) | |
67 batch_seqs, batch_seqs_rc, batch_labs = shuffle( | |
68 np.array(self.input_seqs[batch, ...]), | |
69 np.array(self.input_seqs_rc[batch, ...]), | |
70 np.array(self.input_labs[batch, ...]), | |
71 random_state=self.random_seed | |
72 ) | |
73 yield (batch_seqs, batch_seqs_rc), batch_labs |