Mercurial > repos > iuc > virhunter
diff utils/batch_loader.py @ 0:457fd8fd681a draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
author | iuc |
---|---|
date | Wed, 09 Nov 2022 12:19:26 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/utils/batch_loader.py Wed Nov 09 12:19:26 2022 +0000 @@ -0,0 +1,73 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# Credits: Grigorii Sukhorukov, Macha Nikolski +import numpy as np +from sklearn.utils import shuffle +from tensorflow import keras + + +class BatchLoader(keras.utils.Sequence): + """Helper to iterate over the data (as Numpy arrays).""" + def __init__( + self, + input_seqs, + input_seqs_rc, + input_labs, + batches, + rc=True, + random_seed=1 + ): + self.input_seqs = input_seqs + self.input_seqs_rc = input_seqs_rc + self.input_labs = input_labs + self.batches = batches + self.rc = rc + self.random_seed = random_seed + + def __len__(self): + return len(self.batches) + + def __getitem__(self, idx): + batch = sorted(self.batches[idx]) + batch_seqs, batch_seqs_rc, batch_labs = shuffle( + np.array(self.input_seqs[batch, ...]), + np.array(self.input_seqs_rc[batch, ...]), + np.array(self.input_labs[batch, ...]), + random_state=self.random_seed + ) + # adding reverse batches + # batch_seqs = np.concatenate((batch_seqs, batch_seqs[:, ::-1, ...])) + # batch_seqs_rc = np.concatenate((batch_seqs_rc, batch_seqs_rc[:, ::-1, ...])) + # batch_labs = np.concatenate((batch_labs, batch_labs[:, ::-1, ...])) + if self.rc: + return (batch_seqs, batch_seqs_rc), batch_labs + else: + return batch_seqs, batch_labs + + +class BatchGenerator: + """Helper to iterate over the data (as Numpy arrays).""" + def __init__( + self, + input_seqs, + input_seqs_rc, + input_labs, + batches, + random_seed=1 + ): + self.input_seqs = input_seqs + self.input_seqs_rc = input_seqs_rc + self.input_labs = input_labs + self.batches = batches + self.random_seed = random_seed + + def __call__(self): + for batch in self.batches: + batch = sorted(batch) + batch_seqs, batch_seqs_rc, batch_labs = shuffle( + np.array(self.input_seqs[batch, ...]), + np.array(self.input_seqs_rc[batch, ...]), + np.array(self.input_labs[batch, ...]), + random_state=self.random_seed + ) + yield (batch_seqs, batch_seqs_rc), batch_labs