comparison utils/preprocess.py @ 0:b856d3d95413 draft default tip

planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/decontaminator commit 3f8e87001f3dfe7d005d0765aeaa930225c93b72
author iuc
date Mon, 09 Jan 2023 13:27:09 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:b856d3d95413
1 #!/usr/bin/env python
2 # -*- coding: utf-8 -*-
3 # Credits: Grigorii Sukhorukov, Macha Nikolski
4
5 import math
6 import os
7 import pathlib
8 import random
9
10 import h5py
11 import numpy as np
12 from Bio import SeqIO
13 from Bio.Seq import Seq
14 from Bio.SeqRecord import SeqRecord
15 from sklearn.utils import shuffle
16
17
18 def reverse_complement(fragment):
19 """
20 provides reverse complement to sequences
21 Input:
22 sequences - list with SeqRecord sequences in fasta format
23 Output:
24 complementary_sequences -
25 list with SeqRecord complementary sequences in fasta format
26 """
27 # complementary_sequences = []
28 # for sequence in sequences:
29 # complementary_sequence = SeqRecord(
30 # seq=Seq(sequence.seq).reverse_complement(),
31 # id=sequence.id + "_reverse_complement",
32 # )
33 # complementary_sequences.append(complementary_sequence)
34 fragment = fragment[::-1].translate(str.maketrans('ACGT', 'TGCA'))
35 return fragment
36
37
38 def introduce_mutations(seqs, mut_rate, rs=None):
39 """
40 Function that mutates sequences in the entering fasta file
41 A proportion of nucleotides are changed to other nucleotide
42 Not yet taking account of mutation for gaps
43 mut_rate - proportion from 0.0 to 1.0, float
44 """
45 random.seed(a=rs)
46 assert 0.0 <= mut_rate <= 1.0
47 mutated_seqs = []
48 for seq in seqs:
49 mut_seq = list(str(seq.seq))
50 l_ = len(mut_seq)
51 mutated_sites_i = random.sample(range(l_), int(mut_rate * l_))
52 for mut_site_i in mutated_sites_i:
53 mut_site = mut_seq[mut_site_i]
54 mutations = ["A", "C", "T", "G"]
55 if mut_site in mutations:
56 mutations.remove(mut_site)
57 mut_seq[mut_site_i] = random.sample(mutations, 1)[0]
58 mutated_seq = SeqRecord(
59 seq=Seq("".join(mut_seq)),
60 id=seq.id + f"mut_{mut_rate}",
61 name="",
62 description="",
63 )
64 mutated_seqs.append(mutated_seq)
65 return mutated_seqs
66
67
68 def separate_by_length(length_, seq_list, fold=None,):
69 # TODO: add docs
70 included = []
71 to_process = []
72 excluded = 0
73 for seq_ in seq_list:
74 l_ = len(seq_.seq)
75 if l_ >= length_:
76 if fold is None:
77 included.append(seq_)
78 elif l_ < length_ * fold:
79 included.append(seq_)
80 else:
81 to_process.append(seq_)
82 else:
83 excluded += 1
84 print(f"A total of {excluded} sequences was excluded due to being smaller than {length_}")
85 return included, to_process
86
87
88 def chunks(lst, n):
89 """Yield successive n-sized chunks from lst.
90 https://stackoverflow.com/questions/312443/how-do-you-split-a-list-into-evenly-sized-chunks"""
91 for i in range(0, len(lst), n):
92 yield lst[i:i + n]
93
94
95 def correct(frag):
96 """
97 leaves only unambiguous DNA code (ACTG-)
98 Input:
99 frag - string of nucleotides
100 Output:
101 pr_frag - corrected string of nucleotides
102 """
103 pr_frag = frag.upper()
104 pr_frag_s = set(pr_frag)
105 if pr_frag_s != {"A", "C", "G", "T", "-"}:
106 for letter in pr_frag_s - {"A", "C", "G", "T", "-"}:
107 pr_frag = pr_frag.replace(letter, "-")
108 return pr_frag
109
110
111 def fragmenting(sequences, sl_wind_size, max_gap=0.05, sl_wind_step=None):
112 """
113 slices sequences in fragments by sliding window
114 based on its size and step.
115 last fragment is padded by '-'
116 fragments have ambiguous bases replaced by '-'
117 fragments with many '-' are discarded
118 Input:
119 sequences - list with SeqRecord sequences in fasta format
120 max_gap - max allowed proportion of '-'
121 sl_wind_size - sliding window step
122 sl_wind_step - sliding window step, by default equals
123 sliding window size (None is replaced by it)
124 Output:
125 fragments - list with sequence fragments
126 """
127 if sl_wind_step is None:
128 sl_wind_step = sl_wind_size
129 fragments = []
130 fragments_rc = []
131 out_sequences = []
132 for sequence in sequences:
133 seq = str(sequence.seq)
134 n_fragments = 1 + max(0, math.ceil((len(seq) - sl_wind_size) / sl_wind_step))
135 for n in range(n_fragments):
136 if n + 1 != n_fragments:
137 frag = seq[n * sl_wind_step: n * sl_wind_step + sl_wind_size]
138 elif n_fragments == 1:
139 # padding the shorter fragment to sl_wind_size
140 frag_short = seq[n * sl_wind_step: n * sl_wind_step + sl_wind_size]
141 frag = frag_short + (sl_wind_size - len(frag_short)) * "-"
142 else:
143 frag = seq[(len(seq) - sl_wind_size):]
144 # replace ambiguous characters
145 frag = correct(frag)
146 assert len(frag) == sl_wind_size, f"{len(frag)} vs {sl_wind_size}"
147 # skipping sequences with many gaps
148 if frag.count("-") / sl_wind_size <= max_gap:
149 fragments.append(frag)
150 # generating reverse complement
151 fragments_rc.append(reverse_complement(frag))
152 fr_seq = SeqRecord(
153 seq=Seq(frag),
154 id=f"{sequence.id}_{n*sl_wind_step}_{sl_wind_size}",
155 name="",
156 description="",
157 )
158 out_sequences.append(fr_seq)
159 return fragments, fragments_rc, out_sequences
160
161
162 def label_fasta_fragments(sequences, label):
163 """
164 Provides labels to generated fragments stored in fasta
165 Input:
166 sequences - list with SeqRecord sequences
167 label - type of label (bacteria, virus, plant)
168 Output:
169 labeled_fragments - list with labeled SeqRecord sequences
170 """
171 # assert label in ["virus", "plant", "bacteria"]
172 labeled_fragments = []
173 for sequence in sequences:
174 sequence.id = sequence.id + f"_{label}"
175 labeled_fragments.append(sequence)
176 return labeled_fragments
177
178
179 def one_hot_encode(fragments):
180 """
181 produces one-hot matrices from fragments and labels
182 '-' is given all zeros
183 Input:
184 fragments - list with sequence fragments
185 label - type of label (int <= depth)
186 label_depth - number of possible labels
187 Output:
188 encoded_fragments - list with one-hot encoded fragments
189 labels - list with one-hot encoded labels
190 """
191 import tensorflow as tf
192 encoded_fragments = []
193 map_dict = {"A": 0, "C": 1, "G": 2, "T": 3, "-": -1}
194 for frag in fragments:
195 frag_array = np.array(list(frag))
196 integer_encoded = np.int8(np.vectorize(map_dict.get)(frag_array))
197 one_hot_encoded = tf.one_hot(integer_encoded, depth=4, dtype=tf.int8).numpy()
198 encoded_fragments.append(one_hot_encoded)
199 encoded_fragments = np.stack(encoded_fragments)
200 return encoded_fragments
201
202
203 def prepare_labels(fragments, label, label_depth):
204 """
205 produces one-hot labels
206 '-' is given all zeros
207 Input:
208 fragments - list with sequence fragments
209 label - type of label (int <= depth)
210 label_depth - number of possible labels
211 Output:
212 labels - list with one-hot encoded labels
213 """
214 import tensorflow as tf
215 n_fragments = len(fragments)
216 labels = np.int8(np.full(n_fragments, label))
217 labels = tf.one_hot(labels, depth=label_depth).numpy()
218 return labels
219
220
221 # TODO: write docs for functions
222 def calculate_total_length(seq_path):
223 """
224 Calculate total length of the sequences in the fasta file.
225 Needed for weighted sampling
226 Input:
227 seq_path - path to the file with sequences
228 Output:
229 seq_length - total length of all sequences in the file
230 """
231 seqs = list(SeqIO.parse(seq_path, "fasta"))
232 seq_length = 0
233 for seq in seqs:
234 seq_length += len(seq.seq)
235 return seq_length
236
237
238 def prepare_seq_lists(in_paths, n_fragments, weights=None,):
239 """
240 selects files with sequences based on extension
241 and calculates number of fragments to be sampled
242 Input:
243 in_paths - list of paths to folder with sequence files. Can be a string also a string
244 n_fragments - number of fragments to be sampled
245 weights - upsampling of fragments. fractions should sum to one
246 Output:
247 seqs_list - list with path to files with sequences
248 n_fragments_list - number of fragments to be sampled
249 lists are zipped to work with ray iterators
250 """
251 # case when we recieve a single sequence file
252 if type(in_paths) is str and in_paths.endswith(('.fna', '.fasta')):
253 return [[in_paths, n_fragments]]
254 else:
255 # transform string to list
256 if type(in_paths) is str or type(in_paths) is pathlib.PosixPath:
257 in_paths = [in_paths]
258
259 if weights:
260 assert len(weights) == len(in_paths)
261 assert 1.01 > round(sum(weights), 2) > 0.99
262 else:
263 l_ = len(in_paths)
264 weights = [1 / l_] * l_
265 n_fragments_list_all = []
266 seqs_list_all = []
267 for in_paths, w_ in zip(in_paths, weights):
268 seqs_list = []
269 seq_length_list = []
270 total_length = 0
271 for file in os.listdir(in_paths):
272 if file.endswith("fna") or file.endswith("fasta"):
273 seq_path = (os.path.join(in_paths, file))
274 seqs_length = calculate_total_length(seq_path)
275 seqs_list.append(seq_path)
276 seq_length_list.append(seqs_length)
277 total_length += seqs_length
278 # + 1 may lead to a slightly bigger number than desired
279 n_fragments_list = [((seq_length / total_length) * n_fragments * w_ + 1) for seq_length in seq_length_list]
280 n_fragments_list_all.extend(n_fragments_list)
281 seqs_list_all.extend(seqs_list)
282 print("list calculation done")
283 return list(zip(seqs_list_all, n_fragments_list_all))
284
285
286 def sample_fragments(seq_container, length, random_seed=1, limit=None, max_gap=0.05, sl_wind_step=None):
287 """
288 Randomly samples fragments from sequences in the list.
289 Is a bit cumbersome written to work with ray.
290 Input:
291 seq_container - list with each entry containing path to sequence,
292 and n samples from this sequence.
293 length - desired length of sampled fragments
294 Output:
295 fragments - list with sequence fragments
296 """
297 random.seed(a=random_seed)
298 total_fragments = []
299 total_fragments_rc = []
300 total_seqs = []
301 for entry in seq_container:
302 seq = list(SeqIO.parse(entry[0], "fasta"))
303 n_fragments = entry[1]
304 seqs = []
305 fragments = []
306 fragments_rc = []
307 counter_1 = 0
308 counter_2 = 0
309 while counter_1 < n_fragments:
310 # select chromosomes if there are any
311 fragment_full = random.choice(seq)
312 r_end = len(fragment_full.seq) - length
313 try:
314 r_start = random.randrange(r_end)
315 fragment = SeqRecord(
316 seq=fragment_full.seq[r_start:(r_start + length)],
317 id=f"{fragment_full.id}_{length}_{r_start}",
318 name="",
319 description="",
320 )
321 temp_, temp_rc, _ = fragmenting([fragment], length, max_gap, sl_wind_step=sl_wind_step)
322 if temp_ and temp_rc:
323 seqs.append(fragment)
324 fragments.extend(temp_)
325 fragments_rc.extend(temp_rc)
326 counter_1 += 1
327 except ValueError:
328 # print(f"{fragment_full.id} has length {len(fragment_full.seq)} and is too short to be sampled")
329 pass
330 counter_2 += 1
331 if limit:
332 assert counter_2 <= limit * n_fragments, f"While cycle iterated more than {limit}, data is ambiguous." \
333 f" Only {len(fragments)} fragments were sampled out of {n_fragments}"
334 total_fragments.extend(fragments)
335 total_fragments_rc.extend(fragments_rc)
336 total_seqs.extend(seqs)
337 # print("sequence sampling done")
338 return total_fragments, total_fragments_rc, total_seqs
339
340
341 def prepare_ds_fragmenting(in_seq, label, label_int, fragment_length, sl_wind_step, max_gap=0.05, n_cpus=1):
342 if sl_wind_step is None:
343 sl_wind_step = int(fragment_length / 2)
344 # generating viral fragments and labels
345 seqs = list(SeqIO.parse(in_seq, "fasta"))
346 frags, frags_rc, seqs_ = fragmenting(seqs, fragment_length, max_gap=max_gap, sl_wind_step=sl_wind_step)
347 encoded = one_hot_encode(frags)
348 encoded_rc = one_hot_encode(frags_rc)
349 labs = prepare_labels(frags, label=label_int, label_depth=2)
350 seqs_ = label_fasta_fragments(seqs_, label=label)
351 # subsetting to unique fragments
352 u_encoded, indices = np.unique(encoded, axis=0, return_index=True)
353 u_encoded_rc = encoded_rc[indices]
354 u_labs = labs[indices]
355 u_seqs = [seqs_[i] for i in indices]
356 assert (np.shape(u_encoded)[0] == np.shape(u_encoded_rc)[0])
357 print(f"Encoding {label} sequences finished")
358 # print(f"{np.shape(u_encoded)[0]} forward fragments generated")
359 n_frags = np.shape(u_encoded)[0]
360 return u_encoded, u_encoded_rc, u_labs, u_seqs, n_frags
361
362
363 def prepare_ds_sampling(in_seqs, fragment_length, n_frags, label, label_int, random_seed, n_cpus=1, limit=100):
364 # generating plant fragments and labels
365 seqs_list = prepare_seq_lists(in_seqs, n_frags)
366 frags, frags_rc, seqs_ = sample_fragments(seqs_list, fragment_length, random_seed, limit=limit, max_gap=0.05)
367 frags, frags_rc, seqs_ = shuffle(frags, frags_rc, seqs_, random_state=random_seed, n_samples=int(n_frags))
368 encoded = one_hot_encode(frags)
369 encoded_rc = one_hot_encode(frags_rc)
370 labs = prepare_labels(frags, label=label_int, label_depth=2)
371 seqs_ = label_fasta_fragments(seqs_, label=label)
372 assert (np.shape(encoded)[0] == np.shape(encoded_rc)[0])
373 print(f"Encoding {label} sequences finished")
374 # print(f"{np.shape(encoded)[0]} forward fragments generated")
375 return encoded, encoded_rc, labs, seqs_, n_frags
376
377
378 def storing_encoded(encoded, encoded_rc, labs, out_path, ):
379 f = h5py.File(out_path, "w")
380 f.create_dataset("fragments", data=encoded)
381 f.create_dataset("fragments_rc", data=encoded_rc)
382 f.create_dataset("labels", data=labs)
383 f.close()