comparison utils/preprocess.py @ 0:457fd8fd681a draft

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
author iuc
date Wed, 09 Nov 2022 12:19:26 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:457fd8fd681a
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Credits: Grigorii Sukhorukov, Macha Nikolski
4 import math
5 import os
6 import pathlib
7 import random
8
9 import h5py
10 import numpy as np
11 from Bio import SeqIO
12 from Bio.Seq import Seq
13 from Bio.SeqRecord import SeqRecord
14 from sklearn.utils import shuffle
15
16
17 def reverse_complement(fragment):
18 """
19 provides reverse complement to sequences
20 Input:
21 sequences - list with SeqRecord sequences in fasta format
22 Output:
23 complementary_sequences -
24 list with SeqRecord complementary sequences in fasta format
25 """
26 # complementary_sequences = []
27 # for sequence in sequences:
28 # complementary_sequence = SeqRecord(
29 # seq=Seq(sequence.seq).reverse_complement(),
30 # id=sequence.id + "_reverse_complement",
31 # )
32 # complementary_sequences.append(complementary_sequence)
33 fragment = fragment[::-1].translate(str.maketrans('ACGT', 'TGCA'))
34 return fragment
35
36
37 def introduce_mutations(seqs, mut_rate, rs=None):
38 """
39 Function that mutates sequences in the entering fasta file
40 A proportion of nucleotides are changed to other nucleotide
41 Not yet taking account of mutation for gaps
42 mut_rate - proportion from 0.0 to 1.0, float
43 """
44 random.seed(a=rs)
45 assert 0.0 <= mut_rate <= 1.0
46 mutated_seqs = []
47 for seq in seqs:
48 mut_seq = list(str(seq.seq))
49 l_ = len(mut_seq)
50 mutated_sites_i = random.sample(range(l_), int(mut_rate * l_))
51 for mut_site_i in mutated_sites_i:
52 mut_site = mut_seq[mut_site_i]
53 mutations = ["A", "C", "T", "G"]
54 if mut_site in mutations:
55 mutations.remove(mut_site)
56 mut_seq[mut_site_i] = random.sample(mutations, 1)[0]
57 mutated_seq = SeqRecord(
58 seq=Seq("".join(mut_seq)),
59 id=seq.id + f"mut_{mut_rate}",
60 name="",
61 description="",
62 )
63 mutated_seqs.append(mutated_seq)
64 return mutated_seqs
65
66
67 def separate_by_length(length_, seq_list, fold=None,):
68 # TODO: add docs
69 included = []
70 to_process = []
71 excluded = 0
72 for seq_ in seq_list:
73 l_ = len(seq_.seq)
74 if l_ >= length_:
75 if fold is None:
76 included.append(seq_)
77 elif l_ < length_ * fold:
78 included.append(seq_)
79 else:
80 to_process.append(seq_)
81 else:
82 excluded += 1
83 print(f"A total of {excluded} sequences was excluded due to being smaller than {length_}")
84 return included, to_process
85
86
87 def chunks(lst, n):
88 """Yield successive n-sized chunks from lst.
89 https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks"""
90 for i in range(0, len(lst), n):
91 yield lst[i:i + n]
92
93
94 def correct(frag):
95 """
96 leaves only unambiguous DNA code (ACTG-)
97 Input:
98 frag - string of nucleotides
99 Output:
100 pr_frag - corrected string of nucleotides
101 """
102 pr_frag = frag.upper()
103 pr_frag_s = set(pr_frag)
104 if pr_frag_s != {"A", "C", "G", "T", "-"}:
105 for letter in pr_frag_s - {"A", "C", "G", "T", "-"}:
106 pr_frag = pr_frag.replace(letter, "-")
107 return pr_frag
108
109
110 def fragmenting(sequences, sl_wind_size, max_gap=0.05, sl_wind_step=None):
111 """
112 slices sequences in fragments by sliding window
113 based on its size and step.
114 last fragment is padded by '-'
115 fragments have ambiguous bases replaced by '-'
116 fragments with many '-' are discarded
117 Input:
118 sequences - list with SeqRecord sequences in fasta format
119 max_gap - max allowed proportion of '-'
120 sl_wind_size - sliding window step
121 sl_wind_step - sliding window step, by default equals
122 sliding window size (None is replaced by it)
123 Output:
124 fragments - list with sequence fragments
125 """
126 if sl_wind_step is None:
127 sl_wind_step = sl_wind_size
128 fragments = []
129 fragments_rc = []
130 out_sequences = []
131 for sequence in sequences:
132 seq = str(sequence.seq)
133 n_fragments = 1 + max(0, math.ceil((len(seq) - sl_wind_size) / sl_wind_step))
134 for n in range(n_fragments):
135 if n + 1 != n_fragments:
136 frag = seq[n * sl_wind_step: n * sl_wind_step + sl_wind_size]
137 elif n_fragments == 1:
138 # padding the shorter fragment to sl_wind_size
139 frag_short = seq[n * sl_wind_step: n * sl_wind_step + sl_wind_size]
140 frag = frag_short + (sl_wind_size - len(frag_short)) * "-"
141 else:
142 frag = seq[(len(seq) - sl_wind_size):]
143 # replace ambiguous characters
144 frag = correct(frag)
145 assert len(frag) == sl_wind_size, f"{len(frag)} vs {sl_wind_size}"
146 # skipping sequences with many gaps
147 if frag.count("-") / sl_wind_size <= max_gap:
148 fragments.append(frag)
149 # generating reverse complement
150 fragments_rc.append(reverse_complement(frag))
151 fr_seq = SeqRecord(
152 seq=Seq(frag),
153 id=f"{sequence.id}_{n*sl_wind_step}_{sl_wind_size}",
154 name="",
155 description="",
156 )
157 out_sequences.append(fr_seq)
158 return fragments, fragments_rc, out_sequences
159
160
161 def label_fasta_fragments(sequences, label):
162 """
163 Provides labels to generated fragments stored in fasta
164 Input:
165 sequences - list with SeqRecord sequences
166 label - type of label (bacteria, virus, plant)
167 Output:
168 labeled_fragments - list with labeled SeqRecord sequences
169 """
170 assert label in ["virus", "plant", "bacteria"]
171 labeled_fragments = []
172 for sequence in sequences:
173 sequence.id = sequence.id + f"_{label}"
174 labeled_fragments.append(sequence)
175 return labeled_fragments
176
177
178 def one_hot_encode(fragments):
179 """
180 produces one-hot matrices from fragments and labels
181 '-' is given all zeros
182 Input:
183 fragments - list with sequence fragments
184 label - type of label (int <= depth)
185 label_depth - number of possible labels
186 Output:
187 encoded_fragments - list with one-hot encoded fragments
188 labels - list with one-hot encoded labels
189 """
190 import tensorflow as tf
191 encoded_fragments = []
192 map_dict = {"A": 0, "C": 1, "G": 2, "T": 3, "-": -1}
193 for frag in fragments:
194 frag_array = np.array(list(frag))
195 integer_encoded = np.int8(np.vectorize(map_dict.get)(frag_array))
196 one_hot_encoded = tf.one_hot(integer_encoded, depth=4, dtype=tf.int8).numpy()
197 encoded_fragments.append(one_hot_encoded)
198 encoded_fragments = np.stack(encoded_fragments)
199 return encoded_fragments
200
201
202 def prepare_labels(fragments, label, label_depth):
203 """
204 produces one-hot labels
205 '-' is given all zeros
206 Input:
207 fragments - list with sequence fragments
208 label - type of label (int <= depth)
209 label_depth - number of possible labels
210 Output:
211 labels - list with one-hot encoded labels
212 """
213 import tensorflow as tf
214 n_fragments = len(fragments)
215 labels = np.int8(np.full(n_fragments, label))
216 labels = tf.one_hot(labels, depth=label_depth).numpy()
217 return labels
218
219
220 # TODO: write docs for functions
221 def calculate_total_length(seq_path):
222 """
223 Calculate total length of the sequences in the fasta file.
224 Needed for weighted sampling
225 Input:
226 seq_path - path to the file with sequences
227 Output:
228 seq_length - total length of all sequences in the file
229 """
230 seqs = list(SeqIO.parse(seq_path, "fasta"))
231 seq_length = 0
232 for seq in seqs:
233 seq_length += len(seq.seq)
234 return seq_length
235
236
237 def prepare_seq_lists(in_paths, n_fragments, weights=None,):
238 """
239 selects files with sequences based on extension
240 and calculates number of fragments to be sampled
241 Input:
242 in_paths - list of paths to folder with sequence files. Can be a string also a string
243 n_fragments - number of fragments to be sampled
244 weights - upsampling of fragments. fractions should sum to one
245 Output:
246 seqs_list - list with path to files with sequences
247 n_fragments_list - number of fragments to be sampled
248 """
249 # case when we recieve a single sequence file
250 if type(in_paths) is str and in_paths.endswith(('.fna', '.fasta')):
251 return [[in_paths, n_fragments]]
252 else:
253 # transform string to list
254 if type(in_paths) is str or type(in_paths) is pathlib.PosixPath:
255 in_paths = [in_paths]
256
257 if weights:
258 assert len(weights) == len(in_paths)
259 assert 1.01 > round(sum(weights), 2) > 0.99
260 else:
261 l_ = len(in_paths)
262 weights = [1 / l_] * l_
263 n_fragments_list_all = []
264 seqs_list_all = []
265 for in_paths, w_ in zip(in_paths, weights):
266 seqs_list = []
267 seq_length_list = []
268 total_length = 0
269 for file in os.listdir(in_paths):
270 if file.endswith("fna") or file.endswith("fasta"):
271 seq_path = (os.path.join(in_paths, file))
272 seqs_length = calculate_total_length(seq_path)
273 seqs_list.append(seq_path)
274 seq_length_list.append(seqs_length)
275 total_length += seqs_length
276 # + 1 may lead to a slightly bigger number than desired
277 n_fragments_list = [((seq_length / total_length) * n_fragments * w_ + 1) for seq_length in seq_length_list]
278 n_fragments_list_all.extend(n_fragments_list)
279 seqs_list_all.extend(seqs_list)
280 print("list calculation done")
281 return list(zip(seqs_list_all, n_fragments_list_all))
282
283
284 def sample_fragments(seq_container, length, random_seed=1, limit=None, max_gap=0.05, sl_wind_step=None):
285 """
286 Randomly samples fragments from sequences in the list.
287 Input:
288 seq_container - list with each entry containing path to sequence,
289 and n samples from this sequence.
290 length - desired length of sampled fragments
291 Output:
292 fragments - list with sequence fragments
293 """
294 random.seed(a=random_seed)
295 total_fragments = []
296 total_fragments_rc = []
297 total_seqs = []
298 for entry in seq_container:
299 seq = list(SeqIO.parse(entry[0], "fasta"))
300 n_fragments = entry[1]
301 seqs = []
302 fragments = []
303 fragments_rc = []
304 counter_1 = 0
305 counter_2 = 0
306 while counter_1 < n_fragments:
307 # select chromosomes if there are any
308 fragment_full = random.choice(seq)
309 r_end = len(fragment_full.seq) - length
310 try:
311 r_start = random.randrange(r_end)
312 fragment = SeqRecord(
313 seq=fragment_full.seq[r_start:(r_start + length)],
314 id=f"{fragment_full.id}_{length}_{r_start}",
315 name="",
316 description="",
317 )
318 temp_, temp_rc, _ = fragmenting([fragment], length, max_gap, sl_wind_step=sl_wind_step)
319 if temp_ and temp_rc:
320 seqs.append(fragment)
321 fragments.extend(temp_)
322 fragments_rc.extend(temp_rc)
323 counter_1 += 1
324 except ValueError:
325 # print(f"{fragment_full.id} has length {len(fragment_full.seq)} and is too short to be sampled")
326 pass
327 counter_2 += 1
328 if limit:
329 assert counter_2 <= limit * n_fragments, f"While cycle iterated more than {limit}, data is ambiguous." \
330 f" Only {len(fragments)} fragments were sampled out of {n_fragments}"
331 total_fragments.extend(fragments)
332 total_fragments_rc.extend(fragments_rc)
333 total_seqs.extend(seqs)
334 # print("sequence sampling done")
335 return total_fragments, total_fragments_rc, total_seqs
336
337
338 def prepare_ds_fragmenting(in_seq, label, label_int, fragment_length, sl_wind_step, max_gap=0.05, n_cpus=1):
339 if sl_wind_step is None:
340 sl_wind_step = int(fragment_length / 2)
341 # generating viral fragments and labels
342 seqs = list(SeqIO.parse(in_seq, "fasta"))
343 frags, frags_rc, seqs_ = fragmenting(seqs, fragment_length, max_gap=max_gap, sl_wind_step=sl_wind_step)
344 encoded = one_hot_encode(frags)
345 encoded_rc = one_hot_encode(frags_rc)
346 labs = prepare_labels(frags, label=label_int, label_depth=3)
347 seqs_ = label_fasta_fragments(seqs_, label=label)
348 # subsetting to unique fragments
349 u_encoded, indices = np.unique(encoded, axis=0, return_index=True)
350 u_encoded_rc = encoded_rc[indices]
351 u_labs = labs[indices]
352 u_seqs = [seqs_[i] for i in indices]
353 assert (np.shape(u_encoded)[0] == np.shape(u_encoded_rc)[0])
354 print(f"Encoding {label} sequences finished")
355 # print(f"{np.shape(u_encoded)[0]} forward fragments generated")
356 n_frags = np.shape(u_encoded)[0]
357 return u_encoded, u_encoded_rc, u_labs, u_seqs, n_frags
358
359
360 def prepare_ds_sampling(in_seqs, fragment_length, n_frags, label, label_int, random_seed, n_cpus=1, limit=100):
361 # generating plant fragments and labels
362 seqs_list = prepare_seq_lists(in_seqs, n_frags)
363 frags, frags_rc, seqs_ = sample_fragments.remote(seqs_list, fragment_length, random_seed, limit=limit, max_gap=0.05)
364 frags, frags_rc, seqs_ = shuffle(frags, frags_rc, seqs_, random_state=random_seed, n_samples=int(n_frags))
365 encoded = one_hot_encode(frags)
366 encoded_rc = one_hot_encode(frags_rc)
367 labs = prepare_labels(frags, label=label_int, label_depth=3)
368 seqs_ = label_fasta_fragments(seqs_, label=label)
369 assert (np.shape(encoded)[0] == np.shape(encoded_rc)[0])
370 print(f"Encoding {label} sequences finished")
371 # print(f"{np.shape(encoded)[0]} forward fragments generated")
372 return encoded, encoded_rc, labs, seqs_, n_frags
373
374
375 def storing_encoded(encoded, encoded_rc, labs, out_path, ):
376 f = h5py.File(out_path, "w")
377 f.create_dataset("fragments", data=encoded)
378 f.create_dataset("fragments_rc", data=encoded_rc)
379 f.create_dataset("labels", data=labs)
380 f.close()
381 print(f"encoded fragments and labels stored in {out_path}")