Mercurial > repos > iuc > decontaminator
comparison utils/preprocess.py @ 0:b856d3d95413 draft default tip
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/decontaminator commit 3f8e87001f3dfe7d005d0765aeaa930225c93b72
author | iuc |
---|---|
date | Mon, 09 Jan 2023 13:27:09 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:b856d3d95413 |
---|---|
1 #!/usr/bin/env python | |
2 # -*- coding: utf-8 -*- | |
3 # Credits: Grigorii Sukhorukov, Macha Nikolski | |
4 | |
5 import math | |
6 import os | |
7 import pathlib | |
8 import random | |
9 | |
10 import h5py | |
11 import numpy as np | |
12 from Bio import SeqIO | |
13 from Bio.Seq import Seq | |
14 from Bio.SeqRecord import SeqRecord | |
15 from sklearn.utils import shuffle | |
16 | |
17 | |
18 def reverse_complement(fragment): | |
19 """ | |
20 provides reverse complement to sequences | |
21 Input: | |
22 sequences - list with SeqRecord sequences in fasta format | |
23 Output: | |
24 complementary_sequences - | |
25 list with SeqRecord complementary sequences in fasta format | |
26 """ | |
27 # complementary_sequences = [] | |
28 # for sequence in sequences: | |
29 # complementary_sequence = SeqRecord( | |
30 # seq=Seq(sequence.seq).reverse_complement(), | |
31 # id=sequence.id + "_reverse_complement", | |
32 # ) | |
33 # complementary_sequences.append(complementary_sequence) | |
34 fragment = fragment[::-1].translate(str.maketrans('ACGT', 'TGCA')) | |
35 return fragment | |
36 | |
37 | |
38 def introduce_mutations(seqs, mut_rate, rs=None): | |
39 """ | |
40 Function that mutates sequences in the entering fasta file | |
41 A proportion of nucleotides are changed to other nucleotide | |
42 Not yet taking account of mutation for gaps | |
43 mut_rate - proportion from 0.0 to 1.0, float | |
44 """ | |
45 random.seed(a=rs) | |
46 assert 0.0 <= mut_rate <= 1.0 | |
47 mutated_seqs = [] | |
48 for seq in seqs: | |
49 mut_seq = list(str(seq.seq)) | |
50 l_ = len(mut_seq) | |
51 mutated_sites_i = random.sample(range(l_), int(mut_rate * l_)) | |
52 for mut_site_i in mutated_sites_i: | |
53 mut_site = mut_seq[mut_site_i] | |
54 mutations = ["A", "C", "T", "G"] | |
55 if mut_site in mutations: | |
56 mutations.remove(mut_site) | |
57 mut_seq[mut_site_i] = random.sample(mutations, 1)[0] | |
58 mutated_seq = SeqRecord( | |
59 seq=Seq("".join(mut_seq)), | |
60 id=seq.id + f"mut_{mut_rate}", | |
61 name="", | |
62 description="", | |
63 ) | |
64 mutated_seqs.append(mutated_seq) | |
65 return mutated_seqs | |
66 | |
67 | |
68 def separate_by_length(length_, seq_list, fold=None,): | |
69 # TODO: add docs | |
70 included = [] | |
71 to_process = [] | |
72 excluded = 0 | |
73 for seq_ in seq_list: | |
74 l_ = len(seq_.seq) | |
75 if l_ >= length_: | |
76 if fold is None: | |
77 included.append(seq_) | |
78 elif l_ < length_ * fold: | |
79 included.append(seq_) | |
80 else: | |
81 to_process.append(seq_) | |
82 else: | |
83 excluded += 1 | |
84 print(f"A total of {excluded} sequences was excluded due to being smaller than {length_}") | |
85 return included, to_process | |
86 | |
87 | |
88 def chunks(lst, n): | |
89 """Yield successive n-sized chunks from lst. | |
90 https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks""" | |
91 for i in range(0, len(lst), n): | |
92 yield lst[i:i + n] | |
93 | |
94 | |
95 def correct(frag): | |
96 """ | |
97 leaves only unambiguous DNA code (ACTG-) | |
98 Input: | |
99 frag - string of nucleotides | |
100 Output: | |
101 pr_frag - corrected string of nucleotides | |
102 """ | |
103 pr_frag = frag.upper() | |
104 pr_frag_s = set(pr_frag) | |
105 if pr_frag_s != {"A", "C", "G", "T", "-"}: | |
106 for letter in pr_frag_s - {"A", "C", "G", "T", "-"}: | |
107 pr_frag = pr_frag.replace(letter, "-") | |
108 return pr_frag | |
109 | |
110 | |
111 def fragmenting(sequences, sl_wind_size, max_gap=0.05, sl_wind_step=None): | |
112 """ | |
113 slices sequences in fragments by sliding window | |
114 based on its size and step. | |
115 last fragment is padded by '-' | |
116 fragments have ambiguous bases replaced by '-' | |
117 fragments with many '-' are discarded | |
118 Input: | |
119 sequences - list with SeqRecord sequences in fasta format | |
120 max_gap - max allowed proportion of '-' | |
121 sl_wind_size - sliding window step | |
122 sl_wind_step - sliding window step, by default equals | |
123 sliding window size (None is replaced by it) | |
124 Output: | |
125 fragments - list with sequence fragments | |
126 """ | |
127 if sl_wind_step is None: | |
128 sl_wind_step = sl_wind_size | |
129 fragments = [] | |
130 fragments_rc = [] | |
131 out_sequences = [] | |
132 for sequence in sequences: | |
133 seq = str(sequence.seq) | |
134 n_fragments = 1 + max(0, math.ceil((len(seq) - sl_wind_size) / sl_wind_step)) | |
135 for n in range(n_fragments): | |
136 if n + 1 != n_fragments: | |
137 frag = seq[n * sl_wind_step: n * sl_wind_step + sl_wind_size] | |
138 elif n_fragments == 1: | |
139 # padding the shorter fragment to sl_wind_size | |
140 frag_short = seq[n * sl_wind_step: n * sl_wind_step + sl_wind_size] | |
141 frag = frag_short + (sl_wind_size - len(frag_short)) * "-" | |
142 else: | |
143 frag = seq[(len(seq) - sl_wind_size):] | |
144 # replace ambiguous characters | |
145 frag = correct(frag) | |
146 assert len(frag) == sl_wind_size, f"{len(frag)} vs {sl_wind_size}" | |
147 # skipping sequences with many gaps | |
148 if frag.count("-") / sl_wind_size <= max_gap: | |
149 fragments.append(frag) | |
150 # generating reverse complement | |
151 fragments_rc.append(reverse_complement(frag)) | |
152 fr_seq = SeqRecord( | |
153 seq=Seq(frag), | |
154 id=f"{sequence.id}_{n*sl_wind_step}_{sl_wind_size}", | |
155 name="", | |
156 description="", | |
157 ) | |
158 out_sequences.append(fr_seq) | |
159 return fragments, fragments_rc, out_sequences | |
160 | |
161 | |
162 def label_fasta_fragments(sequences, label): | |
163 """ | |
164 Provides labels to generated fragments stored in fasta | |
165 Input: | |
166 sequences - list with SeqRecord sequences | |
167 label - type of label (bacteria, virus, plant) | |
168 Output: | |
169 labeled_fragments - list with labeled SeqRecord sequences | |
170 """ | |
171 # assert label in ["virus", "plant", "bacteria"] | |
172 labeled_fragments = [] | |
173 for sequence in sequences: | |
174 sequence.id = sequence.id + f"_{label}" | |
175 labeled_fragments.append(sequence) | |
176 return labeled_fragments | |
177 | |
178 | |
179 def one_hot_encode(fragments): | |
180 """ | |
181 produces one-hot matrices from fragments and labels | |
182 '-' is given all zeros | |
183 Input: | |
184 fragments - list with sequence fragments | |
185 label - type of label (int <= depth) | |
186 label_depth - number of possible labels | |
187 Output: | |
188 encoded_fragments - list with one-hot encoded fragments | |
189 labels - list with one-hot encoded labels | |
190 """ | |
191 import tensorflow as tf | |
192 encoded_fragments = [] | |
193 map_dict = {"A": 0, "C": 1, "G": 2, "T": 3, "-": -1} | |
194 for frag in fragments: | |
195 frag_array = np.array(list(frag)) | |
196 integer_encoded = np.int8(np.vectorize(map_dict.get)(frag_array)) | |
197 one_hot_encoded = tf.one_hot(integer_encoded, depth=4, dtype=tf.int8).numpy() | |
198 encoded_fragments.append(one_hot_encoded) | |
199 encoded_fragments = np.stack(encoded_fragments) | |
200 return encoded_fragments | |
201 | |
202 | |
203 def prepare_labels(fragments, label, label_depth): | |
204 """ | |
205 produces one-hot labels | |
206 '-' is given all zeros | |
207 Input: | |
208 fragments - list with sequence fragments | |
209 label - type of label (int <= depth) | |
210 label_depth - number of possible labels | |
211 Output: | |
212 labels - list with one-hot encoded labels | |
213 """ | |
214 import tensorflow as tf | |
215 n_fragments = len(fragments) | |
216 labels = np.int8(np.full(n_fragments, label)) | |
217 labels = tf.one_hot(labels, depth=label_depth).numpy() | |
218 return labels | |
219 | |
220 | |
221 # TODO: write docs for functions | |
222 def calculate_total_length(seq_path): | |
223 """ | |
224 Calculate total length of the sequences in the fasta file. | |
225 Needed for weighted sampling | |
226 Input: | |
227 seq_path - path to the file with sequences | |
228 Output: | |
229 seq_length - total length of all sequences in the file | |
230 """ | |
231 seqs = list(SeqIO.parse(seq_path, "fasta")) | |
232 seq_length = 0 | |
233 for seq in seqs: | |
234 seq_length += len(seq.seq) | |
235 return seq_length | |
236 | |
237 | |
238 def prepare_seq_lists(in_paths, n_fragments, weights=None,): | |
239 """ | |
240 selects files with sequences based on extension | |
241 and calculates number of fragments to be sampled | |
242 Input: | |
243 in_paths - list of paths to folder with sequence files. Can be a string also a string | |
244 n_fragments - number of fragments to be sampled | |
245 weights - upsampling of fragments. fractions should sum to one | |
246 Output: | |
247 seqs_list - list with path to files with sequences | |
248 n_fragments_list - number of fragments to be sampled | |
249 lists are zipped to work with ray iterators | |
250 """ | |
251 # case when we recieve a single sequence file | |
252 if type(in_paths) is str and in_paths.endswith(('.fna', '.fasta')): | |
253 return [[in_paths, n_fragments]] | |
254 else: | |
255 # transform string to list | |
256 if type(in_paths) is str or type(in_paths) is pathlib.PosixPath: | |
257 in_paths = [in_paths] | |
258 | |
259 if weights: | |
260 assert len(weights) == len(in_paths) | |
261 assert 1.01 > round(sum(weights), 2) > 0.99 | |
262 else: | |
263 l_ = len(in_paths) | |
264 weights = [1 / l_] * l_ | |
265 n_fragments_list_all = [] | |
266 seqs_list_all = [] | |
267 for in_paths, w_ in zip(in_paths, weights): | |
268 seqs_list = [] | |
269 seq_length_list = [] | |
270 total_length = 0 | |
271 for file in os.listdir(in_paths): | |
272 if file.endswith("fna") or file.endswith("fasta"): | |
273 seq_path = (os.path.join(in_paths, file)) | |
274 seqs_length = calculate_total_length(seq_path) | |
275 seqs_list.append(seq_path) | |
276 seq_length_list.append(seqs_length) | |
277 total_length += seqs_length | |
278 # + 1 may lead to a slightly bigger number than desired | |
279 n_fragments_list = [((seq_length / total_length) * n_fragments * w_ + 1) for seq_length in seq_length_list] | |
280 n_fragments_list_all.extend(n_fragments_list) | |
281 seqs_list_all.extend(seqs_list) | |
282 print("list calculation done") | |
283 return list(zip(seqs_list_all, n_fragments_list_all)) | |
284 | |
285 | |
286 def sample_fragments(seq_container, length, random_seed=1, limit=None, max_gap=0.05, sl_wind_step=None): | |
287 """ | |
288 Randomly samples fragments from sequences in the list. | |
289 Is a bit cumbersome written to work with ray. | |
290 Input: | |
291 seq_container - list with each entry containing path to sequence, | |
292 and n samples from this sequence. | |
293 length - desired length of sampled fragments | |
294 Output: | |
295 fragments - list with sequence fragments | |
296 """ | |
297 random.seed(a=random_seed) | |
298 total_fragments = [] | |
299 total_fragments_rc = [] | |
300 total_seqs = [] | |
301 for entry in seq_container: | |
302 seq = list(SeqIO.parse(entry[0], "fasta")) | |
303 n_fragments = entry[1] | |
304 seqs = [] | |
305 fragments = [] | |
306 fragments_rc = [] | |
307 counter_1 = 0 | |
308 counter_2 = 0 | |
309 while counter_1 < n_fragments: | |
310 # select chromosomes if there are any | |
311 fragment_full = random.choice(seq) | |
312 r_end = len(fragment_full.seq) - length | |
313 try: | |
314 r_start = random.randrange(r_end) | |
315 fragment = SeqRecord( | |
316 seq=fragment_full.seq[r_start:(r_start + length)], | |
317 id=f"{fragment_full.id}_{length}_{r_start}", | |
318 name="", | |
319 description="", | |
320 ) | |
321 temp_, temp_rc, _ = fragmenting([fragment], length, max_gap, sl_wind_step=sl_wind_step) | |
322 if temp_ and temp_rc: | |
323 seqs.append(fragment) | |
324 fragments.extend(temp_) | |
325 fragments_rc.extend(temp_rc) | |
326 counter_1 += 1 | |
327 except ValueError: | |
328 # print(f"{fragment_full.id} has length {len(fragment_full.seq)} and is too short to be sampled") | |
329 pass | |
330 counter_2 += 1 | |
331 if limit: | |
332 assert counter_2 <= limit * n_fragments, f"While cycle iterated more than {limit}, data is ambiguous." \ | |
333 f" Only {len(fragments)} fragments were sampled out of {n_fragments}" | |
334 total_fragments.extend(fragments) | |
335 total_fragments_rc.extend(fragments_rc) | |
336 total_seqs.extend(seqs) | |
337 # print("sequence sampling done") | |
338 return total_fragments, total_fragments_rc, total_seqs | |
339 | |
340 | |
341 def prepare_ds_fragmenting(in_seq, label, label_int, fragment_length, sl_wind_step, max_gap=0.05, n_cpus=1): | |
342 if sl_wind_step is None: | |
343 sl_wind_step = int(fragment_length / 2) | |
344 # generating viral fragments and labels | |
345 seqs = list(SeqIO.parse(in_seq, "fasta")) | |
346 frags, frags_rc, seqs_ = fragmenting(seqs, fragment_length, max_gap=max_gap, sl_wind_step=sl_wind_step) | |
347 encoded = one_hot_encode(frags) | |
348 encoded_rc = one_hot_encode(frags_rc) | |
349 labs = prepare_labels(frags, label=label_int, label_depth=2) | |
350 seqs_ = label_fasta_fragments(seqs_, label=label) | |
351 # subsetting to unique fragments | |
352 u_encoded, indices = np.unique(encoded, axis=0, return_index=True) | |
353 u_encoded_rc = encoded_rc[indices] | |
354 u_labs = labs[indices] | |
355 u_seqs = [seqs_[i] for i in indices] | |
356 assert (np.shape(u_encoded)[0] == np.shape(u_encoded_rc)[0]) | |
357 print(f"Encoding {label} sequences finished") | |
358 # print(f"{np.shape(u_encoded)[0]} forward fragments generated") | |
359 n_frags = np.shape(u_encoded)[0] | |
360 return u_encoded, u_encoded_rc, u_labs, u_seqs, n_frags | |
361 | |
362 | |
363 def prepare_ds_sampling(in_seqs, fragment_length, n_frags, label, label_int, random_seed, n_cpus=1, limit=100): | |
364 # generating plant fragments and labels | |
365 seqs_list = prepare_seq_lists(in_seqs, n_frags) | |
366 frags, frags_rc, seqs_ = sample_fragments(seqs_list, fragment_length, random_seed, limit=limit, max_gap=0.05) | |
367 frags, frags_rc, seqs_ = shuffle(frags, frags_rc, seqs_, random_state=random_seed, n_samples=int(n_frags)) | |
368 encoded = one_hot_encode(frags) | |
369 encoded_rc = one_hot_encode(frags_rc) | |
370 labs = prepare_labels(frags, label=label_int, label_depth=2) | |
371 seqs_ = label_fasta_fragments(seqs_, label=label) | |
372 assert (np.shape(encoded)[0] == np.shape(encoded_rc)[0]) | |
373 print(f"Encoding {label} sequences finished") | |
374 # print(f"{np.shape(encoded)[0]} forward fragments generated") | |
375 return encoded, encoded_rc, labs, seqs_, n_frags | |
376 | |
377 | |
378 def storing_encoded(encoded, encoded_rc, labs, out_path, ): | |
379 f = h5py.File(out_path, "w") | |
380 f.create_dataset("fragments", data=encoded) | |
381 f.create_dataset("fragments_rc", data=encoded_rc) | |
382 f.create_dataset("labels", data=labs) | |
383 f.close() |