Mercurial > repos > iuc > decontaminator
comparison utils/batch_loader.py @ 0:b856d3d95413 draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/decontaminator commit 3f8e87001f3dfe7d005d0765aeaa930225c93b72
| author | iuc |
|---|---|
| date | Mon, 09 Jan 2023 13:27:09 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:b856d3d95413 |
|---|---|
| 1 #!/usr/bin/env python | |
| 2 # -*- coding: utf-8 -*- | |
| 3 # Credits: Grigorii Sukhorukov, Macha Nikolski | |
| 4 import numpy as np | |
| 5 from sklearn.utils import shuffle | |
| 6 from tensorflow import keras | |
| 7 | |
| 8 | |
| 9 class BatchLoader(keras.utils.Sequence): | |
| 10 """Helper to iterate over the data (as Numpy arrays).""" | |
| 11 def __init__( | |
| 12 self, | |
| 13 input_seqs, | |
| 14 input_seqs_rc, | |
| 15 input_labs, | |
| 16 batches, | |
| 17 rc=True, | |
| 18 random_seed=1 | |
| 19 ): | |
| 20 self.input_seqs = input_seqs | |
| 21 self.input_seqs_rc = input_seqs_rc | |
| 22 self.input_labs = input_labs | |
| 23 self.batches = batches | |
| 24 self.rc = rc | |
| 25 self.random_seed = random_seed | |
| 26 | |
| 27 def __len__(self): | |
| 28 return len(self.batches) | |
| 29 | |
| 30 def __getitem__(self, idx): | |
| 31 batch = sorted(self.batches[idx]) | |
| 32 batch_seqs, batch_seqs_rc, batch_labs = shuffle( | |
| 33 np.array(self.input_seqs[batch, ...]), | |
| 34 np.array(self.input_seqs_rc[batch, ...]), | |
| 35 np.array(self.input_labs[batch, ...]), | |
| 36 random_state=self.random_seed | |
| 37 ) | |
| 38 # adding reverse batches | |
| 39 # batch_seqs = np.concatenate((batch_seqs, batch_seqs[:, ::-1, ...])) | |
| 40 # batch_seqs_rc = np.concatenate((batch_seqs_rc, batch_seqs_rc[:, ::-1, ...])) | |
| 41 # batch_labs = np.concatenate((batch_labs, batch_labs[:, ::-1, ...])) | |
| 42 if self.rc: | |
| 43 return (batch_seqs, batch_seqs_rc), batch_labs | |
| 44 else: | |
| 45 return batch_seqs, batch_labs | |
| 46 | |
| 47 | |
| 48 class BatchGenerator: | |
| 49 """Helper to iterate over the data (as Numpy arrays).""" | |
| 50 def __init__( | |
| 51 self, | |
| 52 input_seqs, | |
| 53 input_seqs_rc, | |
| 54 input_labs, | |
| 55 batches, | |
| 56 random_seed=1 | |
| 57 ): | |
| 58 self.input_seqs = input_seqs | |
| 59 self.input_seqs_rc = input_seqs_rc | |
| 60 self.input_labs = input_labs | |
| 61 self.batches = batches | |
| 62 self.random_seed = random_seed | |
| 63 | |
| 64 def __call__(self): | |
| 65 for batch in self.batches: | |
| 66 batch = sorted(batch) | |
| 67 batch_seqs, batch_seqs_rc, batch_labs = shuffle( | |
| 68 np.array(self.input_seqs[batch, ...]), | |
| 69 np.array(self.input_seqs_rc[batch, ...]), | |
| 70 np.array(self.input_labs[batch, ...]), | |
| 71 random_state=self.random_seed | |
| 72 ) | |
| 73 yield (batch_seqs, batch_seqs_rc), batch_labs |
