Mercurial > repos > petr-novak > repeat_annotation_pipeline3
comparison reAnnotate.py @ 11:5366d5ea04bc draft
planemo upload commit 9d1b19f98d8b7f0a0d1baf2da63a373d155626f8-dirty
author | petr-novak |
---|---|
date | Fri, 04 Aug 2023 12:35:32 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
10:276efc4cb17f | 11:5366d5ea04bc |
---|---|
1 #!/usr/bin/env python | |
2 """ | |
3 parse blast output table to gff file | |
4 """ | |
5 import argparse | |
6 import itertools | |
7 import os | |
8 import re | |
9 import shutil | |
10 import subprocess | |
11 import sys | |
12 import tempfile | |
13 from collections import defaultdict | |
14 | |
15 # check version of python, must be at least 3.7 | |
16 if sys.version_info < (3, 10): | |
17 sys.exit("Python 3.10 or a more recent version is required.") | |
18 | |
19 def make_temp_files(number_of_files): | |
20 """ | |
21 Make named temporary files, file will not be deleted upon exit! | |
22 :param number_of_files: | |
23 :return: | |
24 filepaths | |
25 """ | |
26 temp_files = [] | |
27 for i in range(number_of_files): | |
28 temp_files.append(tempfile.NamedTemporaryFile(delete=False).name) | |
29 os.remove(temp_files[-1]) | |
30 return temp_files | |
31 | |
32 | |
33 def split_fasta_to_chunks(fasta_file, chunk_size=100000000, overlap=100000): | |
34 """ | |
35 Split fasta file to chunks, sequences longe than chuck size are split to overlaping | |
36 peaces. If sequences are shorter, chunck with multiple sequences are created. | |
37 :param fasta_file: | |
38 | |
39 :param fasta_file: | |
40 :param chunk_size: | |
41 :param overlap: | |
42 :return: | |
43 fasta_file_split | |
44 matching_table (list of lists [header,chunk_number, start, end, new_header]) | |
45 """ | |
46 min_chunk_size = chunk_size * 2 | |
47 fasta_sizes_dict = read_fasta_sequence_size(fasta_file) | |
48 # calculate size of items in fasta_dist dictionary | |
49 fasta_size = sum(fasta_sizes_dict.values()) | |
50 | |
51 # calculates ranges for splitting of fasta files and store them in list | |
52 matching_table = [] | |
53 fasta_file_split = tempfile.NamedTemporaryFile(delete=False).name | |
54 for header, size in fasta_sizes_dict.items(): | |
55 print(header, size, min_chunk_size) | |
56 | |
57 if size > min_chunk_size: | |
58 number_of_chunks = int(size / chunk_size) | |
59 print("number_of_chunks", number_of_chunks) | |
60 print("size", size) | |
61 print("chunk_size", chunk_size) | |
62 print("-----------------------------------------") | |
63 adjusted_chunk_size = int(size / number_of_chunks) | |
64 for i in range(number_of_chunks): | |
65 start = i * adjusted_chunk_size | |
66 end = ((i + 1) * | |
67 adjusted_chunk_size | |
68 + overlap) if i + 1 < number_of_chunks else size | |
69 new_header = header + '_' + str(i) | |
70 matching_table.append([header, i, start, end, new_header]) | |
71 else: | |
72 new_header = header + '_0' | |
73 matching_table.append([header, 0, 0, size, new_header]) | |
74 # read sequences from fasta files and split them to chunks according to matching table | |
75 # open output and input files, use with statement to close files | |
76 number_of_temp_files = len(matching_table) | |
77 print('number of temp files', number_of_temp_files) | |
78 fasta_dict = read_single_fasta_to_dictionary(open(fasta_file, 'r')) | |
79 with open(fasta_file_split, 'w') as fh_out: | |
80 for header in fasta_dict: | |
81 matching_table_part = [x for x in matching_table if x[0] == header] | |
82 for header2, i, start, end, new_header in matching_table_part: | |
83 fh_out.write('>' + new_header + '\n') | |
84 fh_out.write(fasta_dict[header][start:end] + '\n') | |
85 temp_files_fasta = make_temp_files(number_of_temp_files) | |
86 fasta_seq_size = read_fasta_sequence_size(fasta_file_split) | |
87 seq_id_size_sorted = [i[0] for i in sorted( | |
88 fasta_seq_size.items(), key=lambda x: int(x[1]), reverse=True | |
89 )] | |
90 seq_id_file_dict = dict(zip(seq_id_size_sorted, itertools.cycle(temp_files_fasta))) | |
91 # write sequences to temporary files | |
92 with open(fasta_file_split, 'r') as f: | |
93 first = True | |
94 for line in f: | |
95 if line[0] == '>': | |
96 # close previous file if it is not the first sequence | |
97 if not first: | |
98 fout.close() | |
99 first = False | |
100 header = line.strip().split(' ')[0][1:] | |
101 fout = open(seq_id_file_dict[header],'a') | |
102 fout.write(line) | |
103 else: | |
104 fout.write(line) | |
105 os.remove(fasta_file_split) | |
106 return temp_files_fasta, matching_table | |
107 | |
108 | |
109 def read_fasta_sequence_size(fasta_file): | |
110 """Read size of sequence into dictionary""" | |
111 fasta_dict = {} | |
112 with open(fasta_file, 'r') as f: | |
113 for line in f: | |
114 if line[0] == '>': | |
115 header = line.strip().split(' ')[0][1:] # remove part of name after space | |
116 fasta_dict[header] = 0 | |
117 else: | |
118 fasta_dict[header] += len(line.strip()) | |
119 return fasta_dict | |
120 | |
121 | |
122 def read_single_fasta_to_dictionary(fh): | |
123 """ | |
124 Read fasta file into dictionary | |
125 :param fh: | |
126 :return: | |
127 fasta_dict | |
128 """ | |
129 fasta_dict = {} | |
130 for line in fh: | |
131 if line[0] == '>': | |
132 header = line.strip().split(' ')[0][1:] # remove part of name after space | |
133 fasta_dict[header] = [] | |
134 else: | |
135 fasta_dict[header] += [line.strip()] | |
136 fasta_dict = {k: ''.join(v) for k, v in fasta_dict.items()} | |
137 return fasta_dict | |
138 | |
139 | |
140 def overlap(a, b): | |
141 """ | |
142 check if two intervals overlap | |
143 """ | |
144 return max(a[0], b[0]) <= min(a[1], b[1]) | |
145 | |
146 | |
147 def blast2disjoint( | |
148 blastfile, seqid_counts=None, start_column=6, end_column=7, class_column=1, | |
149 bitscore_column=11, pident_column=2, canonical_classification=True | |
150 ): | |
151 """ | |
152 find all interval beginning and ends in blast file and create bed file | |
153 input blastfile is tab separated file with columns: | |
154 'qaccver saccver pident length mismatch gapopen qstart qend sstart send | |
155 evalue bitscore' (default outfmt 6 | |
156 blast must be sorted on qseqid and qstart | |
157 """ | |
158 # assume all in one chromosome! | |
159 starts_ends = {} | |
160 intervals = {} | |
161 if canonical_classification: | |
162 # make regular expression for canonical classification | |
163 # to match: Name#classification | |
164 # e.g. "Name_of_sequence#LTR/Ty1_copia/Angela" | |
165 regex = re.compile(r"(.*)[#](.*)") | |
166 group = 2 | |
167 else: | |
168 # make regular expression for non-canonical classification | |
169 # to match: Classification__Name | |
170 # e.g. "LTR/Ty1_copia/Angela__Name_of_sequence" | |
171 regex = re.compile(r"(.*)__(.*)") | |
172 group = 1 | |
173 | |
174 # identify continuous intervals | |
175 with open(blastfile, "r") as f: | |
176 for seqid in sorted(seqid_counts.keys()): | |
177 n_lines = seqid_counts[seqid] | |
178 starts_ends[seqid] = set() | |
179 for i in range(n_lines): | |
180 items = f.readline().strip().split() | |
181 # note 1s and 2s labels are used to distinguish between start and end and | |
182 # guarantee that with same coordinated start will be before end when | |
183 # sorting (1s < 2e) | |
184 starts_ends[seqid].add((int(items[start_column]), '1s')) | |
185 starts_ends[seqid].add((int(items[end_column]), '2e')) | |
186 intervals[seqid] = [] | |
187 for p1, p2 in itertools.pairwise(sorted(starts_ends[seqid])): | |
188 if p1[1] == '1s': | |
189 sp = 0 | |
190 else: | |
191 sp = 1 | |
192 if p2[1] == '2e': | |
193 ep = 0 | |
194 else: | |
195 ep = 1 | |
196 intervals[seqid].append((p1[0] + sp, p2[0] - ep)) | |
197 # scan each blast hit against continuous region and record hit with best score | |
198 with open(blastfile, "r") as f: | |
199 disjoint_regions = [] | |
200 for seqid in sorted(seqid_counts.keys()): | |
201 n_lines = seqid_counts[seqid] | |
202 idx_of_overlaps = {} | |
203 best_pident = defaultdict(lambda: 0.0) | |
204 best_bitscore = defaultdict(lambda: 0.0) | |
205 best_hit_name = defaultdict(lambda: "") | |
206 i1 = 0 | |
207 for i in range(n_lines): | |
208 items = f.readline().strip().split() | |
209 start = int(items[start_column]) | |
210 end = int(items[end_column]) | |
211 pident = float(items[pident_column]) | |
212 bitscore = float(items[bitscore_column]) | |
213 classification = items[class_column] | |
214 j = 0 | |
215 done = False | |
216 while True: | |
217 # beginning of searched region - does it overlap? | |
218 c_ovl = overlap(intervals[seqid][i1], (start, end)) | |
219 if c_ovl: | |
220 # if overlap is detected, add to dictionary | |
221 idx_of_overlaps[i] = [i1] | |
222 if best_bitscore[i1] < bitscore: | |
223 best_pident[i1] = pident | |
224 best_bitscore[i1] = bitscore | |
225 best_hit_name[i1] = classification | |
226 # add search also downstream | |
227 while True: | |
228 j += 1 | |
229 if j + i1 >= len(intervals[seqid]): | |
230 done = True | |
231 break | |
232 c_ovl = overlap(intervals[seqid][i1 + j], (start, end)) | |
233 if c_ovl: | |
234 idx_of_overlaps[i].append(i1 + j) | |
235 if best_bitscore[i1 + j] < bitscore: | |
236 best_pident[i1 + j] = pident | |
237 best_bitscore[i1 + j] = bitscore | |
238 best_hit_name[i1 + j] = classification | |
239 else: | |
240 done = True | |
241 break | |
242 | |
243 else: | |
244 # does no overlap - search next interval | |
245 i1 += 1 | |
246 if done or i1 >= (len(intervals[seqid]) - 1): | |
247 break | |
248 | |
249 for i in sorted(best_pident.keys()): | |
250 try: | |
251 classification = re.match(regex, best_hit_name[i]).group(group) | |
252 except AttributeError: | |
253 classification = best_hit_name[i] | |
254 record = ( | |
255 seqid, intervals[seqid][i][0], intervals[seqid][i][1], best_pident[i], | |
256 classification) | |
257 disjoint_regions.append(record) | |
258 return disjoint_regions | |
259 | |
260 | |
261 def remove_short_interrupting_regions(regions, min_len=10, max_gap=2): | |
262 """ | |
263 remove intervals shorter than min_len which are directly adjacent to other | |
264 regions on both sides which are longer than min_len and has same classification | |
265 """ | |
266 regions_to_remove = [] | |
267 for i in range(1, len(regions) - 1): | |
268 if regions[i][2] - regions[i][1] < min_len: | |
269 c1 = regions[i - 1][2] - regions[i - 1][1] > min_len | |
270 c2 = regions[i + 1][2] - regions[i + 1][1] > min_len | |
271 c3 = regions[i - 1][4] == regions[i + 1][4] # same classification | |
272 c4 = regions[i + 1][4] != regions[i][4] # different classification | |
273 c5 = regions[i][1] - regions[i - 1][2] < max_gap # max gap between regions | |
274 c6 = regions[i + 1][1] - regions[i][2] < max_gap # max gap between regions | |
275 if c1 and c2 and c3 & c4 and c5 and c6: | |
276 regions_to_remove.append(i) | |
277 for i in sorted(regions_to_remove, reverse=True): | |
278 del regions[i] | |
279 return regions | |
280 | |
281 | |
282 def remove_short_regions(regions, min_l_score=600): | |
283 """ | |
284 remove intervals shorter than min_len | |
285 min_l_score is the minimum score for a region to be considered | |
286 l_score = length * PID | |
287 """ | |
288 regions_to_remove = [] | |
289 for i in range(len(regions)): | |
290 l_score = (regions[i][3] - 50) * (regions[i][2] - regions[i][1]) | |
291 if l_score < min_l_score: | |
292 regions_to_remove.append(i) | |
293 for i in sorted(regions_to_remove, reverse=True): | |
294 del regions[i] | |
295 return regions | |
296 | |
297 | |
298 def join_disjoint_regions_by_classification(disjoint_regions, max_gap=0): | |
299 """ | |
300 merge neighboring intervals with same classification and calculate mean weighted score | |
301 weight correspond to length of the interval | |
302 """ | |
303 merged_regions = [] | |
304 for seqid, start, end, score, classification in disjoint_regions: | |
305 score_length = (end - start + 1) * score | |
306 if len(merged_regions) == 0: | |
307 merged_regions.append([seqid, start, end, score_length, classification]) | |
308 else: | |
309 cond_same_class = merged_regions[-1][4] == classification | |
310 cond_same_seqid = merged_regions[-1][0] == seqid | |
311 cond_neighboring = start - merged_regions[-1][2] + 1 <= max_gap | |
312 if cond_same_class and cond_same_seqid and cond_neighboring: | |
313 # extend region | |
314 merged_regions[-1] = [merged_regions[-1][0], merged_regions[-1][1], end, | |
315 merged_regions[-1][3] + score_length, | |
316 merged_regions[-1][4]] | |
317 else: | |
318 merged_regions.append([seqid, start, end, score_length, classification]) | |
319 # recalculate length weighted score | |
320 for record in merged_regions: | |
321 record[3] = record[3] / (record[2] - record[1] + 1) | |
322 return merged_regions | |
323 | |
324 | |
325 def write_merged_regions_to_gff3(merged_regions, outfile): | |
326 """ | |
327 write merged regions to gff3 file | |
328 """ | |
329 with open(outfile, "w") as f: | |
330 # write header | |
331 f.write("##gff-version 3\n") | |
332 for seqid, start, end, score, classification in merged_regions: | |
333 attributes = "Name={};score={}".format(classification, score) | |
334 f.write( | |
335 "\t".join( | |
336 [seqid, "blast_parsed", "repeat_region", str(start), str(end), | |
337 str(round(score,2)), ".", ".", attributes] | |
338 ) | |
339 ) | |
340 f.write("\n") | |
341 | |
342 | |
343 def sort_blast_table( | |
344 blastfile, seqid_column=0, start_column=6, cpu=1 | |
345 ): | |
346 """ | |
347 split blast table by seqid and sort by start position | |
348 stores output in temp files | |
349 columns are indexed from 0 | |
350 but cut uses 1-based indexing! | |
351 """ | |
352 blast_sorted = tempfile.NamedTemporaryFile().name | |
353 # create sorted dictionary seqid counts | |
354 seq_id_counts = {} | |
355 # sort blast file on disk using sort on seqid and start (numeric) position columns | |
356 # using sort command as blast output could be very large | |
357 cmd = "sort -k {0},{0} -k {1},{1}n --parallel {4} {2} > {3}".format( | |
358 seqid_column + 1, start_column + 1, blastfile, blast_sorted, cpu | |
359 ) | |
360 subprocess.check_call(cmd, shell=True) | |
361 | |
362 # count seqids using uniq command | |
363 cmd = "cut -f {0} {1} | uniq -c > {2}".format( | |
364 seqid_column + 1, blast_sorted, blast_sorted + ".counts" | |
365 ) | |
366 subprocess.check_call(cmd, shell=True) | |
367 # read counts file and create dictionary | |
368 with open(blast_sorted + ".counts", "r") as f: | |
369 for line in f: | |
370 line = line.strip().split() | |
371 seq_id_counts[line[1]] = int(line[0]) | |
372 # remove counts file | |
373 subprocess.call(["rm", blast_sorted + ".counts"]) | |
374 # return sorted dictionary and sorted blast file | |
375 return seq_id_counts, blast_sorted | |
376 | |
377 | |
378 def run_blastn( | |
379 query, db, blastfile, evalue=1e-3, max_target_seqs=999999999, gapopen=2, | |
380 gapextend=1, reward=1, penalty=-1, word_size=9, num_threads=1, outfmt="6" | |
381 ): | |
382 """ | |
383 run blastn | |
384 """ | |
385 # create temporary blast database: | |
386 db_formated = tempfile.NamedTemporaryFile().name | |
387 cmd = "makeblastdb -in {0} -dbtype nucl -out {1}".format(db, db_formated) | |
388 subprocess.check_call(cmd, shell=True) | |
389 # if query is smaller than 1GB, run blast on single file | |
390 size = os.path.getsize(query) | |
391 print("query size: {} bytes".format(size)) | |
392 max_size = 1e6 | |
393 overlap = 50000 | |
394 if size < max_size: | |
395 cmd = ("blastn -task rmblastn -query {0} -db {1} -out {2} -evalue {3} " | |
396 "-max_target_seqs {4} " | |
397 "-gapopen {5} -gapextend {6} -word_size {7} -num_threads " | |
398 "{8} -outfmt '{9}' -reward {10} -penalty {11} -dust no").format( | |
399 query, db_formated, blastfile, evalue, max_target_seqs, gapopen, gapextend, | |
400 word_size, num_threads, outfmt, reward, penalty | |
401 ) | |
402 subprocess.check_call(cmd, shell=True) | |
403 # if query is larger than 1GB, split query in chunks and run blast on each chunk | |
404 else: | |
405 print(f"query is larger than {max_size}, splitting query in chunks") | |
406 query_parts, matching_table = split_fasta_to_chunks(query, max_size, overlap) | |
407 print(query_parts) | |
408 for i, part in enumerate(query_parts): | |
409 print(f"running blast on chunk {i}") | |
410 print(part) | |
411 cmd = ("blastn -task rmblastn -query {0} -db {1} -out {2} -evalue {3} " | |
412 "-max_target_seqs {4} " | |
413 "-gapopen {5} -gapextend {6} -word_size {7} -num_threads " | |
414 "{8} -outfmt '{9}' -reward {10} -penalty {11} -dust no").format( | |
415 part, db_formated, f'{blastfile}.{i}', evalue, max_target_seqs, gapopen, | |
416 gapextend, | |
417 word_size, num_threads, outfmt, reward, penalty | |
418 ) | |
419 subprocess.check_call(cmd, shell=True) | |
420 print(cmd) | |
421 # remove part file | |
422 # os.unlink(part) | |
423 # merge blast results and recalculate start, end positions and header | |
424 merge_blast_results(blastfile, matching_table, n_parts=len(query_parts)) | |
425 | |
426 # remove temporary blast database | |
427 os.unlink(db_formated + ".nhr") | |
428 os.unlink(db_formated + ".nin") | |
429 os.unlink(db_formated + ".nsq") | |
430 | |
431 def merge_blast_results(blastfile, matching_table, n_parts): | |
432 """ | |
433 Merge blast tables and recalculate start, end positions based on | |
434 matching table | |
435 """ | |
436 with open(blastfile, "w") as f: | |
437 matching_table_dict = {i[4]: i for i in matching_table} | |
438 print(matching_table_dict) | |
439 for i in range(n_parts): | |
440 with open(f'{blastfile}.{i}', "r") as f2: | |
441 for line in f2: | |
442 line = line.strip().split("\t") | |
443 # seqid (header) is in column 1 | |
444 seqid = line[0] | |
445 line[0] = matching_table_dict[seqid][0] | |
446 # increase coordinates by start position of chunk | |
447 line[6] = str(int(line[6]) + matching_table_dict[seqid][2]) | |
448 line[7] = str(int(line[7]) + matching_table_dict[seqid][2]) | |
449 f.write("\t".join(line) + "\n") | |
450 # remove temporary blast file | |
451 # os.unlink(f'{blastfile}.{i}') | |
452 | |
453 def main(): | |
454 """ | |
455 main function | |
456 """ | |
457 # get command line arguments | |
458 parser = argparse.ArgumentParser( | |
459 description="""This script is used to parse blast output table to gff file""", | |
460 formatter_class=argparse.RawTextHelpFormatter | |
461 ) | |
462 parser.add_argument( | |
463 '-i', '--input', default=None, required=True, help="input file", type=str, | |
464 action='store' | |
465 ) | |
466 parser.add_argument( | |
467 '-d', '--db', default=None, required=False, | |
468 help="Fasta file with repeat database", type=str, action='store' | |
469 ) | |
470 parser.add_argument( | |
471 '-o', '--output', default=None, required=True, help="output file name", type=str, | |
472 action='store' | |
473 ) | |
474 parser.add_argument( | |
475 '-a', '--alternative_classification_coding', default=False, | |
476 help="Use alternative classification coding", action='store_true' | |
477 ) | |
478 parser.add_argument( | |
479 '-f', '--fasta_input', default=False, | |
480 help="Input is fasta file instead of blast table", action='store_true' | |
481 ) | |
482 parser.add_argument( | |
483 '-c', '--cpu', default=1, help="Number of cpu to use", type=int | |
484 ) | |
485 | |
486 args = parser.parse_args() | |
487 | |
488 if args.fasta_input: | |
489 # run blast using blastn | |
490 blastfile = tempfile.NamedTemporaryFile().name | |
491 if args.db: | |
492 run_blastn(args.input, args.db, blastfile, num_threads=args.cpu) | |
493 else: | |
494 sys.exit("No repeat database provided") | |
495 else: | |
496 blastfile = args.input | |
497 | |
498 # sort blast table | |
499 seq_id_counts, blast_sorted = sort_blast_table(blastfile, cpu=args.cpu) | |
500 disjoin_regions = blast2disjoint( | |
501 blast_sorted, seq_id_counts, | |
502 canonical_classification=not args.alternative_classification_coding | |
503 ) | |
504 | |
505 # remove short regions | |
506 disjoin_regions = remove_short_interrupting_regions(disjoin_regions) | |
507 | |
508 # join neighboring regions with same classification | |
509 merged_regions = join_disjoint_regions_by_classification(disjoin_regions) | |
510 | |
511 # remove short regions again | |
512 merged_regions = remove_short_interrupting_regions(merged_regions) | |
513 | |
514 # merge again neighboring regions with same classification | |
515 merged_regions = join_disjoint_regions_by_classification(merged_regions, max_gap=10) | |
516 | |
517 # remove short weak regions | |
518 merged_regions = remove_short_regions(merged_regions) | |
519 | |
520 # last merge | |
521 merged_regions = join_disjoint_regions_by_classification(merged_regions, max_gap=20) | |
522 write_merged_regions_to_gff3(merged_regions, args.output) | |
523 # remove temporary files | |
524 os.remove(blast_sorted) | |
525 | |
526 | |
527 if __name__ == "__main__": | |
528 main() |