Mercurial > repos > iuc > virhunter
comparison utils/preprocess.py @ 0:457fd8fd681a draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/VirHunter commit 628688c1302dbf972e48806d2a5bafe27847bdcc
author | iuc |
---|---|
date | Wed, 09 Nov 2022 12:19:26 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:457fd8fd681a |
---|---|
1 #!/usr/bin/env python | |
2 # -*- coding: utf-8 -*- | |
3 # Credits: Grigorii Sukhorukov, Macha Nikolski | |
4 import math | |
5 import os | |
6 import pathlib | |
7 import random | |
8 | |
9 import h5py | |
10 import numpy as np | |
11 from Bio import SeqIO | |
12 from Bio.Seq import Seq | |
13 from Bio.SeqRecord import SeqRecord | |
14 from sklearn.utils import shuffle | |
15 | |
16 | |
17 def reverse_complement(fragment): | |
18 """ | |
19 provides reverse complement to sequences | |
20 Input: | |
21 sequences - list with SeqRecord sequences in fasta format | |
22 Output: | |
23 complementary_sequences - | |
24 list with SeqRecord complementary sequences in fasta format | |
25 """ | |
26 # complementary_sequences = [] | |
27 # for sequence in sequences: | |
28 # complementary_sequence = SeqRecord( | |
29 # seq=Seq(sequence.seq).reverse_complement(), | |
30 # id=sequence.id + "_reverse_complement", | |
31 # ) | |
32 # complementary_sequences.append(complementary_sequence) | |
33 fragment = fragment[::-1].translate(str.maketrans('ACGT', 'TGCA')) | |
34 return fragment | |
35 | |
36 | |
37 def introduce_mutations(seqs, mut_rate, rs=None): | |
38 """ | |
39 Function that mutates sequences in the entering fasta file | |
40 A proportion of nucleotides are changed to other nucleotide | |
41 Not yet taking account of mutation for gaps | |
42 mut_rate - proportion from 0.0 to 1.0, float | |
43 """ | |
44 random.seed(a=rs) | |
45 assert 0.0 <= mut_rate <= 1.0 | |
46 mutated_seqs = [] | |
47 for seq in seqs: | |
48 mut_seq = list(str(seq.seq)) | |
49 l_ = len(mut_seq) | |
50 mutated_sites_i = random.sample(range(l_), int(mut_rate * l_)) | |
51 for mut_site_i in mutated_sites_i: | |
52 mut_site = mut_seq[mut_site_i] | |
53 mutations = ["A", "C", "T", "G"] | |
54 if mut_site in mutations: | |
55 mutations.remove(mut_site) | |
56 mut_seq[mut_site_i] = random.sample(mutations, 1)[0] | |
57 mutated_seq = SeqRecord( | |
58 seq=Seq("".join(mut_seq)), | |
59 id=seq.id + f"mut_{mut_rate}", | |
60 name="", | |
61 description="", | |
62 ) | |
63 mutated_seqs.append(mutated_seq) | |
64 return mutated_seqs | |
65 | |
66 | |
67 def separate_by_length(length_, seq_list, fold=None,): | |
68 # TODO: add docs | |
69 included = [] | |
70 to_process = [] | |
71 excluded = 0 | |
72 for seq_ in seq_list: | |
73 l_ = len(seq_.seq) | |
74 if l_ >= length_: | |
75 if fold is None: | |
76 included.append(seq_) | |
77 elif l_ < length_ * fold: | |
78 included.append(seq_) | |
79 else: | |
80 to_process.append(seq_) | |
81 else: | |
82 excluded += 1 | |
83 print(f"A total of {excluded} sequences was excluded due to being smaller than {length_}") | |
84 return included, to_process | |
85 | |
86 | |
87 def chunks(lst, n): | |
88 """Yield successive n-sized chunks from lst. | |
89 https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks""" | |
90 for i in range(0, len(lst), n): | |
91 yield lst[i:i + n] | |
92 | |
93 | |
94 def correct(frag): | |
95 """ | |
96 leaves only unambiguous DNA code (ACTG-) | |
97 Input: | |
98 frag - string of nucleotides | |
99 Output: | |
100 pr_frag - corrected string of nucleotides | |
101 """ | |
102 pr_frag = frag.upper() | |
103 pr_frag_s = set(pr_frag) | |
104 if pr_frag_s != {"A", "C", "G", "T", "-"}: | |
105 for letter in pr_frag_s - {"A", "C", "G", "T", "-"}: | |
106 pr_frag = pr_frag.replace(letter, "-") | |
107 return pr_frag | |
108 | |
109 | |
110 def fragmenting(sequences, sl_wind_size, max_gap=0.05, sl_wind_step=None): | |
111 """ | |
112 slices sequences in fragments by sliding window | |
113 based on its size and step. | |
114 last fragment is padded by '-' | |
115 fragments have ambiguous bases replaced by '-' | |
116 fragments with many '-' are discarded | |
117 Input: | |
118 sequences - list with SeqRecord sequences in fasta format | |
119 max_gap - max allowed proportion of '-' | |
120 sl_wind_size - sliding window step | |
121 sl_wind_step - sliding window step, by default equals | |
122 sliding window size (None is replaced by it) | |
123 Output: | |
124 fragments - list with sequence fragments | |
125 """ | |
126 if sl_wind_step is None: | |
127 sl_wind_step = sl_wind_size | |
128 fragments = [] | |
129 fragments_rc = [] | |
130 out_sequences = [] | |
131 for sequence in sequences: | |
132 seq = str(sequence.seq) | |
133 n_fragments = 1 + max(0, math.ceil((len(seq) - sl_wind_size) / sl_wind_step)) | |
134 for n in range(n_fragments): | |
135 if n + 1 != n_fragments: | |
136 frag = seq[n * sl_wind_step: n * sl_wind_step + sl_wind_size] | |
137 elif n_fragments == 1: | |
138 # padding the shorter fragment to sl_wind_size | |
139 frag_short = seq[n * sl_wind_step: n * sl_wind_step + sl_wind_size] | |
140 frag = frag_short + (sl_wind_size - len(frag_short)) * "-" | |
141 else: | |
142 frag = seq[(len(seq) - sl_wind_size):] | |
143 # replace ambiguous characters | |
144 frag = correct(frag) | |
145 assert len(frag) == sl_wind_size, f"{len(frag)} vs {sl_wind_size}" | |
146 # skipping sequences with many gaps | |
147 if frag.count("-") / sl_wind_size <= max_gap: | |
148 fragments.append(frag) | |
149 # generating reverse complement | |
150 fragments_rc.append(reverse_complement(frag)) | |
151 fr_seq = SeqRecord( | |
152 seq=Seq(frag), | |
153 id=f"{sequence.id}_{n*sl_wind_step}_{sl_wind_size}", | |
154 name="", | |
155 description="", | |
156 ) | |
157 out_sequences.append(fr_seq) | |
158 return fragments, fragments_rc, out_sequences | |
159 | |
160 | |
161 def label_fasta_fragments(sequences, label): | |
162 """ | |
163 Provides labels to generated fragments stored in fasta | |
164 Input: | |
165 sequences - list with SeqRecord sequences | |
166 label - type of label (bacteria, virus, plant) | |
167 Output: | |
168 labeled_fragments - list with labeled SeqRecord sequences | |
169 """ | |
170 assert label in ["virus", "plant", "bacteria"] | |
171 labeled_fragments = [] | |
172 for sequence in sequences: | |
173 sequence.id = sequence.id + f"_{label}" | |
174 labeled_fragments.append(sequence) | |
175 return labeled_fragments | |
176 | |
177 | |
178 def one_hot_encode(fragments): | |
179 """ | |
180 produces one-hot matrices from fragments and labels | |
181 '-' is given all zeros | |
182 Input: | |
183 fragments - list with sequence fragments | |
184 label - type of label (int <= depth) | |
185 label_depth - number of possible labels | |
186 Output: | |
187 encoded_fragments - list with one-hot encoded fragments | |
188 labels - list with one-hot encoded labels | |
189 """ | |
190 import tensorflow as tf | |
191 encoded_fragments = [] | |
192 map_dict = {"A": 0, "C": 1, "G": 2, "T": 3, "-": -1} | |
193 for frag in fragments: | |
194 frag_array = np.array(list(frag)) | |
195 integer_encoded = np.int8(np.vectorize(map_dict.get)(frag_array)) | |
196 one_hot_encoded = tf.one_hot(integer_encoded, depth=4, dtype=tf.int8).numpy() | |
197 encoded_fragments.append(one_hot_encoded) | |
198 encoded_fragments = np.stack(encoded_fragments) | |
199 return encoded_fragments | |
200 | |
201 | |
202 def prepare_labels(fragments, label, label_depth): | |
203 """ | |
204 produces one-hot labels | |
205 '-' is given all zeros | |
206 Input: | |
207 fragments - list with sequence fragments | |
208 label - type of label (int <= depth) | |
209 label_depth - number of possible labels | |
210 Output: | |
211 labels - list with one-hot encoded labels | |
212 """ | |
213 import tensorflow as tf | |
214 n_fragments = len(fragments) | |
215 labels = np.int8(np.full(n_fragments, label)) | |
216 labels = tf.one_hot(labels, depth=label_depth).numpy() | |
217 return labels | |
218 | |
219 | |
220 # TODO: write docs for functions | |
221 def calculate_total_length(seq_path): | |
222 """ | |
223 Calculate total length of the sequences in the fasta file. | |
224 Needed for weighted sampling | |
225 Input: | |
226 seq_path - path to the file with sequences | |
227 Output: | |
228 seq_length - total length of all sequences in the file | |
229 """ | |
230 seqs = list(SeqIO.parse(seq_path, "fasta")) | |
231 seq_length = 0 | |
232 for seq in seqs: | |
233 seq_length += len(seq.seq) | |
234 return seq_length | |
235 | |
236 | |
237 def prepare_seq_lists(in_paths, n_fragments, weights=None,): | |
238 """ | |
239 selects files with sequences based on extension | |
240 and calculates number of fragments to be sampled | |
241 Input: | |
242 in_paths - list of paths to folder with sequence files. Can be a string also a string | |
243 n_fragments - number of fragments to be sampled | |
244 weights - upsampling of fragments. fractions should sum to one | |
245 Output: | |
246 seqs_list - list with path to files with sequences | |
247 n_fragments_list - number of fragments to be sampled | |
248 """ | |
249 # case when we recieve a single sequence file | |
250 if type(in_paths) is str and in_paths.endswith(('.fna', '.fasta')): | |
251 return [[in_paths, n_fragments]] | |
252 else: | |
253 # transform string to list | |
254 if type(in_paths) is str or type(in_paths) is pathlib.PosixPath: | |
255 in_paths = [in_paths] | |
256 | |
257 if weights: | |
258 assert len(weights) == len(in_paths) | |
259 assert 1.01 > round(sum(weights), 2) > 0.99 | |
260 else: | |
261 l_ = len(in_paths) | |
262 weights = [1 / l_] * l_ | |
263 n_fragments_list_all = [] | |
264 seqs_list_all = [] | |
265 for in_paths, w_ in zip(in_paths, weights): | |
266 seqs_list = [] | |
267 seq_length_list = [] | |
268 total_length = 0 | |
269 for file in os.listdir(in_paths): | |
270 if file.endswith("fna") or file.endswith("fasta"): | |
271 seq_path = (os.path.join(in_paths, file)) | |
272 seqs_length = calculate_total_length(seq_path) | |
273 seqs_list.append(seq_path) | |
274 seq_length_list.append(seqs_length) | |
275 total_length += seqs_length | |
276 # + 1 may lead to a slightly bigger number than desired | |
277 n_fragments_list = [((seq_length / total_length) * n_fragments * w_ + 1) for seq_length in seq_length_list] | |
278 n_fragments_list_all.extend(n_fragments_list) | |
279 seqs_list_all.extend(seqs_list) | |
280 print("list calculation done") | |
281 return list(zip(seqs_list_all, n_fragments_list_all)) | |
282 | |
283 | |
284 def sample_fragments(seq_container, length, random_seed=1, limit=None, max_gap=0.05, sl_wind_step=None): | |
285 """ | |
286 Randomly samples fragments from sequences in the list. | |
287 Input: | |
288 seq_container - list with each entry containing path to sequence, | |
289 and n samples from this sequence. | |
290 length - desired length of sampled fragments | |
291 Output: | |
292 fragments - list with sequence fragments | |
293 """ | |
294 random.seed(a=random_seed) | |
295 total_fragments = [] | |
296 total_fragments_rc = [] | |
297 total_seqs = [] | |
298 for entry in seq_container: | |
299 seq = list(SeqIO.parse(entry[0], "fasta")) | |
300 n_fragments = entry[1] | |
301 seqs = [] | |
302 fragments = [] | |
303 fragments_rc = [] | |
304 counter_1 = 0 | |
305 counter_2 = 0 | |
306 while counter_1 < n_fragments: | |
307 # select chromosomes if there are any | |
308 fragment_full = random.choice(seq) | |
309 r_end = len(fragment_full.seq) - length | |
310 try: | |
311 r_start = random.randrange(r_end) | |
312 fragment = SeqRecord( | |
313 seq=fragment_full.seq[r_start:(r_start + length)], | |
314 id=f"{fragment_full.id}_{length}_{r_start}", | |
315 name="", | |
316 description="", | |
317 ) | |
318 temp_, temp_rc, _ = fragmenting([fragment], length, max_gap, sl_wind_step=sl_wind_step) | |
319 if temp_ and temp_rc: | |
320 seqs.append(fragment) | |
321 fragments.extend(temp_) | |
322 fragments_rc.extend(temp_rc) | |
323 counter_1 += 1 | |
324 except ValueError: | |
325 # print(f"{fragment_full.id} has length {len(fragment_full.seq)} and is too short to be sampled") | |
326 pass | |
327 counter_2 += 1 | |
328 if limit: | |
329 assert counter_2 <= limit * n_fragments, f"While cycle iterated more than {limit}, data is ambiguous." \ | |
330 f" Only {len(fragments)} fragments were sampled out of {n_fragments}" | |
331 total_fragments.extend(fragments) | |
332 total_fragments_rc.extend(fragments_rc) | |
333 total_seqs.extend(seqs) | |
334 # print("sequence sampling done") | |
335 return total_fragments, total_fragments_rc, total_seqs | |
336 | |
337 | |
338 def prepare_ds_fragmenting(in_seq, label, label_int, fragment_length, sl_wind_step, max_gap=0.05, n_cpus=1): | |
339 if sl_wind_step is None: | |
340 sl_wind_step = int(fragment_length / 2) | |
341 # generating viral fragments and labels | |
342 seqs = list(SeqIO.parse(in_seq, "fasta")) | |
343 frags, frags_rc, seqs_ = fragmenting(seqs, fragment_length, max_gap=max_gap, sl_wind_step=sl_wind_step) | |
344 encoded = one_hot_encode(frags) | |
345 encoded_rc = one_hot_encode(frags_rc) | |
346 labs = prepare_labels(frags, label=label_int, label_depth=3) | |
347 seqs_ = label_fasta_fragments(seqs_, label=label) | |
348 # subsetting to unique fragments | |
349 u_encoded, indices = np.unique(encoded, axis=0, return_index=True) | |
350 u_encoded_rc = encoded_rc[indices] | |
351 u_labs = labs[indices] | |
352 u_seqs = [seqs_[i] for i in indices] | |
353 assert (np.shape(u_encoded)[0] == np.shape(u_encoded_rc)[0]) | |
354 print(f"Encoding {label} sequences finished") | |
355 # print(f"{np.shape(u_encoded)[0]} forward fragments generated") | |
356 n_frags = np.shape(u_encoded)[0] | |
357 return u_encoded, u_encoded_rc, u_labs, u_seqs, n_frags | |
358 | |
359 | |
360 def prepare_ds_sampling(in_seqs, fragment_length, n_frags, label, label_int, random_seed, n_cpus=1, limit=100): | |
361 # generating plant fragments and labels | |
362 seqs_list = prepare_seq_lists(in_seqs, n_frags) | |
363 frags, frags_rc, seqs_ = sample_fragments.remote(seqs_list, fragment_length, random_seed, limit=limit, max_gap=0.05) | |
364 frags, frags_rc, seqs_ = shuffle(frags, frags_rc, seqs_, random_state=random_seed, n_samples=int(n_frags)) | |
365 encoded = one_hot_encode(frags) | |
366 encoded_rc = one_hot_encode(frags_rc) | |
367 labs = prepare_labels(frags, label=label_int, label_depth=3) | |
368 seqs_ = label_fasta_fragments(seqs_, label=label) | |
369 assert (np.shape(encoded)[0] == np.shape(encoded_rc)[0]) | |
370 print(f"Encoding {label} sequences finished") | |
371 # print(f"{np.shape(encoded)[0]} forward fragments generated") | |
372 return encoded, encoded_rc, labs, seqs_, n_frags | |
373 | |
374 | |
375 def storing_encoded(encoded, encoded_rc, labs, out_path, ): | |
376 f = h5py.File(out_path, "w") | |
377 f.create_dataset("fragments", data=encoded) | |
378 f.create_dataset("fragments_rc", data=encoded_rc) | |
379 f.create_dataset("labels", data=labs) | |
380 f.close() | |
381 print(f"encoded fragments and labels stored in {out_path}") |