Mercurial > repos > iuc > virhunter
annotate utils/batch_loader.py @ 1:9b12bc1b1e2c draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit cea5a324cffd684e4cb08438fe5a305cfaf0f73a
author | iuc |
---|---|
date | Wed, 30 Nov 2022 17:31:52 +0000 |
parents | 457fd8fd681a |
children |
rev | line source |
---|---|
0
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
1 #!/usr/bin/env python |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
2 # -*- coding: utf-8 -*- |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
3 # Credits: Grigorii Sukhorukov, Macha Nikolski |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
4 import numpy as np |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
5 from sklearn.utils import shuffle |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
6 from tensorflow import keras |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
7 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
8 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
9 class BatchLoader(keras.utils.Sequence): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
10 """Helper to iterate over the data (as Numpy arrays).""" |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
11 def __init__( |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
12 self, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
13 input_seqs, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
14 input_seqs_rc, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
15 input_labs, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
16 batches, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
17 rc=True, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
18 random_seed=1 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
19 ): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
20 self.input_seqs = input_seqs |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
21 self.input_seqs_rc = input_seqs_rc |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
22 self.input_labs = input_labs |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
23 self.batches = batches |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
24 self.rc = rc |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
25 self.random_seed = random_seed |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
26 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
27 def __len__(self): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
28 return len(self.batches) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
29 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
30 def __getitem__(self, idx): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
31 batch = sorted(self.batches[idx]) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
32 batch_seqs, batch_seqs_rc, batch_labs = shuffle( |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
33 np.array(self.input_seqs[batch, ...]), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
34 np.array(self.input_seqs_rc[batch, ...]), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
35 np.array(self.input_labs[batch, ...]), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
36 random_state=self.random_seed |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
37 ) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
38 # adding reverse batches |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
39 # batch_seqs = np.concatenate((batch_seqs, batch_seqs[:, ::-1, ...])) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
40 # batch_seqs_rc = np.concatenate((batch_seqs_rc, batch_seqs_rc[:, ::-1, ...])) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
41 # batch_labs = np.concatenate((batch_labs, batch_labs[:, ::-1, ...])) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
42 if self.rc: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
43 return (batch_seqs, batch_seqs_rc), batch_labs |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
44 else: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
45 return batch_seqs, batch_labs |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
46 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
47 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
48 class BatchGenerator: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
49 """Helper to iterate over the data (as Numpy arrays).""" |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
50 def __init__( |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
51 self, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
52 input_seqs, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
53 input_seqs_rc, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
54 input_labs, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
55 batches, |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
56 random_seed=1 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
57 ): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
58 self.input_seqs = input_seqs |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
59 self.input_seqs_rc = input_seqs_rc |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
60 self.input_labs = input_labs |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
61 self.batches = batches |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
62 self.random_seed = random_seed |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
63 |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
64 def __call__(self): |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
65 for batch in self.batches: |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
66 batch = sorted(batch) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
67 batch_seqs, batch_seqs_rc, batch_labs = shuffle( |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
68 np.array(self.input_seqs[batch, ...]), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
69 np.array(self.input_seqs_rc[batch, ...]), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
70 np.array(self.input_labs[batch, ...]), |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
71 random_state=self.random_seed |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
72 ) |
457fd8fd681a
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
iuc
parents:
diff
changeset
|
73 yield (batch_seqs, batch_seqs_rc), batch_labs |