comparison utils/batch_loader.py @ 0:457fd8fd681a draft

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
author iuc
date Wed, 09 Nov 2022 12:19:26 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:457fd8fd681a
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Credits: Grigorii Sukhorukov, Macha Nikolski
4 import numpy as np
5 from sklearn.utils import shuffle
6 from tensorflow import keras
7
8
9 class BatchLoader(keras.utils.Sequence):
10 """Helper to iterate over the data (as Numpy arrays)."""
11 def __init__(
12 self,
13 input_seqs,
14 input_seqs_rc,
15 input_labs,
16 batches,
17 rc=True,
18 random_seed=1
19 ):
20 self.input_seqs = input_seqs
21 self.input_seqs_rc = input_seqs_rc
22 self.input_labs = input_labs
23 self.batches = batches
24 self.rc = rc
25 self.random_seed = random_seed
26
27 def __len__(self):
28 return len(self.batches)
29
30 def __getitem__(self, idx):
31 batch = sorted(self.batches[idx])
32 batch_seqs, batch_seqs_rc, batch_labs = shuffle(
33 np.array(self.input_seqs[batch, ...]),
34 np.array(self.input_seqs_rc[batch, ...]),
35 np.array(self.input_labs[batch, ...]),
36 random_state=self.random_seed
37 )
38 # adding reverse batches
39 # batch_seqs = np.concatenate((batch_seqs, batch_seqs[:, ::-1, ...]))
40 # batch_seqs_rc = np.concatenate((batch_seqs_rc, batch_seqs_rc[:, ::-1, ...]))
41 # batch_labs = np.concatenate((batch_labs, batch_labs[:, ::-1, ...]))
42 if self.rc:
43 return (batch_seqs, batch_seqs_rc), batch_labs
44 else:
45 return batch_seqs, batch_labs
46
47
48 class BatchGenerator:
49 """Helper to iterate over the data (as Numpy arrays)."""
50 def __init__(
51 self,
52 input_seqs,
53 input_seqs_rc,
54 input_labs,
55 batches,
56 random_seed=1
57 ):
58 self.input_seqs = input_seqs
59 self.input_seqs_rc = input_seqs_rc
60 self.input_labs = input_labs
61 self.batches = batches
62 self.random_seed = random_seed
63
64 def __call__(self):
65 for batch in self.batches:
66 batch = sorted(batch)
67 batch_seqs, batch_seqs_rc, batch_labs = shuffle(
68 np.array(self.input_seqs[batch, ...]),
69 np.array(self.input_seqs_rc[batch, ...]),
70 np.array(self.input_labs[batch, ...]),
71 random_state=self.random_seed
72 )
73 yield (batch_seqs, batch_seqs_rc), batch_labs