Mercurial > repos > petr-novak > repeat_annotation_pipeline3
view reAnnotate.py @ 12:755a4d643184 draft default tip
planemo upload commit a61591d548f42ff417781e7fe7418dc2901ccc23
author | petr-novak |
---|---|
date | Tue, 26 Sep 2023 07:28:04 +0000 |
parents | 5366d5ea04bc |
children |
line wrap: on
line source
#!/usr/bin/env python """ parse blast output table to gff file """ import argparse import itertools import os import re import shutil import subprocess import sys import tempfile from collections import defaultdict # check version of python, must be at least 3.7 if sys.version_info < (3, 10): sys.exit("Python 3.10 or a more recent version is required.") def make_temp_files(number_of_files): """ Make named temporary files, file will not be deleted upon exit! :param number_of_files: :return: filepaths """ temp_files = [] for i in range(number_of_files): temp_files.append(tempfile.NamedTemporaryFile(delete=False).name) os.remove(temp_files[-1]) return temp_files def split_fasta_to_chunks(fasta_file, chunk_size=100000000, overlap=100000): """ Split fasta file to chunks, sequences longe than chuck size are split to overlaping peaces. If sequences are shorter, chunck with multiple sequences are created. :param fasta_file: :param fasta_file: :param chunk_size: :param overlap: :return: fasta_file_split matching_table (list of lists [header,chunk_number, start, end, new_header]) """ min_chunk_size = chunk_size * 2 fasta_sizes_dict = read_fasta_sequence_size(fasta_file) # calculate size of items in fasta_dist dictionary fasta_size = sum(fasta_sizes_dict.values()) # calculates ranges for splitting of fasta files and store them in list matching_table = [] fasta_file_split = tempfile.NamedTemporaryFile(delete=False).name for header, size in fasta_sizes_dict.items(): print(header, size, min_chunk_size) if size > min_chunk_size: number_of_chunks = int(size / chunk_size) print("number_of_chunks", number_of_chunks) print("size", size) print("chunk_size", chunk_size) print("-----------------------------------------") adjusted_chunk_size = int(size / number_of_chunks) for i in range(number_of_chunks): start = i * adjusted_chunk_size end = ((i + 1) * adjusted_chunk_size + overlap) if i + 1 < number_of_chunks else size new_header = header + '_' + str(i) matching_table.append([header, i, start, end, new_header]) else: new_header = header + '_0' matching_table.append([header, 0, 0, size, new_header]) # read sequences from fasta files and split them to chunks according to matching table # open output and input files, use with statement to close files number_of_temp_files = len(matching_table) print('number of temp files', number_of_temp_files) fasta_dict = read_single_fasta_to_dictionary(open(fasta_file, 'r')) with open(fasta_file_split, 'w') as fh_out: for header in fasta_dict: matching_table_part = [x for x in matching_table if x[0] == header] for header2, i, start, end, new_header in matching_table_part: fh_out.write('>' + new_header + '\n') fh_out.write(fasta_dict[header][start:end] + '\n') temp_files_fasta = make_temp_files(number_of_temp_files) fasta_seq_size = read_fasta_sequence_size(fasta_file_split) seq_id_size_sorted = [i[0] for i in sorted( fasta_seq_size.items(), key=lambda x: int(x[1]), reverse=True )] seq_id_file_dict = dict(zip(seq_id_size_sorted, itertools.cycle(temp_files_fasta))) # write sequences to temporary files with open(fasta_file_split, 'r') as f: first = True for line in f: if line[0] == '>': # close previous file if it is not the first sequence if not first: fout.close() first = False header = line.strip().split(' ')[0][1:] fout = open(seq_id_file_dict[header],'a') fout.write(line) else: fout.write(line) os.remove(fasta_file_split) return temp_files_fasta, matching_table def read_fasta_sequence_size(fasta_file): """Read size of sequence into dictionary""" fasta_dict = {} with open(fasta_file, 'r') as f: for line in f: if line[0] == '>': header = line.strip().split(' ')[0][1:] # remove part of name after space fasta_dict[header] = 0 else: fasta_dict[header] += len(line.strip()) return fasta_dict def read_single_fasta_to_dictionary(fh): """ Read fasta file into dictionary :param fh: :return: fasta_dict """ fasta_dict = {} for line in fh: if line[0] == '>': header = line.strip().split(' ')[0][1:] # remove part of name after space fasta_dict[header] = [] else: fasta_dict[header] += [line.strip()] fasta_dict = {k: ''.join(v) for k, v in fasta_dict.items()} return fasta_dict def overlap(a, b): """ check if two intervals overlap """ return max(a[0], b[0]) <= min(a[1], b[1]) def blast2disjoint( blastfile, seqid_counts=None, start_column=6, end_column=7, class_column=1, bitscore_column=11, pident_column=2, canonical_classification=True ): """ find all interval beginning and ends in blast file and create bed file input blastfile is tab separated file with columns: 'qaccver saccver pident length mismatch gapopen qstart qend sstart send evalue bitscore' (default outfmt 6 blast must be sorted on qseqid and qstart """ # assume all in one chromosome! starts_ends = {} intervals = {} if canonical_classification: # make regular expression for canonical classification # to match: Name#classification # e.g. "Name_of_sequence#LTR/Ty1_copia/Angela" regex = re.compile(r"(.*)[#](.*)") group = 2 else: # make regular expression for non-canonical classification # to match: Classification__Name # e.g. "LTR/Ty1_copia/Angela__Name_of_sequence" regex = re.compile(r"(.*)__(.*)") group = 1 # identify continuous intervals with open(blastfile, "r") as f: for seqid in sorted(seqid_counts.keys()): n_lines = seqid_counts[seqid] starts_ends[seqid] = set() for i in range(n_lines): items = f.readline().strip().split() # note 1s and 2s labels are used to distinguish between start and end and # guarantee that with same coordinated start will be before end when # sorting (1s < 2e) starts_ends[seqid].add((int(items[start_column]), '1s')) starts_ends[seqid].add((int(items[end_column]), '2e')) intervals[seqid] = [] for p1, p2 in itertools.pairwise(sorted(starts_ends[seqid])): if p1[1] == '1s': sp = 0 else: sp = 1 if p2[1] == '2e': ep = 0 else: ep = 1 intervals[seqid].append((p1[0] + sp, p2[0] - ep)) # scan each blast hit against continuous region and record hit with best score with open(blastfile, "r") as f: disjoint_regions = [] for seqid in sorted(seqid_counts.keys()): n_lines = seqid_counts[seqid] idx_of_overlaps = {} best_pident = defaultdict(lambda: 0.0) best_bitscore = defaultdict(lambda: 0.0) best_hit_name = defaultdict(lambda: "") i1 = 0 for i in range(n_lines): items = f.readline().strip().split() start = int(items[start_column]) end = int(items[end_column]) pident = float(items[pident_column]) bitscore = float(items[bitscore_column]) classification = items[class_column] j = 0 done = False while True: # beginning of searched region - does it overlap? c_ovl = overlap(intervals[seqid][i1], (start, end)) if c_ovl: # if overlap is detected, add to dictionary idx_of_overlaps[i] = [i1] if best_bitscore[i1] < bitscore: best_pident[i1] = pident best_bitscore[i1] = bitscore best_hit_name[i1] = classification # add search also downstream while True: j += 1 if j + i1 >= len(intervals[seqid]): done = True break c_ovl = overlap(intervals[seqid][i1 + j], (start, end)) if c_ovl: idx_of_overlaps[i].append(i1 + j) if best_bitscore[i1 + j] < bitscore: best_pident[i1 + j] = pident best_bitscore[i1 + j] = bitscore best_hit_name[i1 + j] = classification else: done = True break else: # does no overlap - search next interval i1 += 1 if done or i1 >= (len(intervals[seqid]) - 1): break for i in sorted(best_pident.keys()): try: classification = re.match(regex, best_hit_name[i]).group(group) except AttributeError: classification = best_hit_name[i] record = ( seqid, intervals[seqid][i][0], intervals[seqid][i][1], best_pident[i], classification) disjoint_regions.append(record) return disjoint_regions def remove_short_interrupting_regions(regions, min_len=10, max_gap=2): """ remove intervals shorter than min_len which are directly adjacent to other regions on both sides which are longer than min_len and has same classification """ regions_to_remove = [] for i in range(1, len(regions) - 1): if regions[i][2] - regions[i][1] < min_len: c1 = regions[i - 1][2] - regions[i - 1][1] > min_len c2 = regions[i + 1][2] - regions[i + 1][1] > min_len c3 = regions[i - 1][4] == regions[i + 1][4] # same classification c4 = regions[i + 1][4] != regions[i][4] # different classification c5 = regions[i][1] - regions[i - 1][2] < max_gap # max gap between regions c6 = regions[i + 1][1] - regions[i][2] < max_gap # max gap between regions if c1 and c2 and c3 & c4 and c5 and c6: regions_to_remove.append(i) for i in sorted(regions_to_remove, reverse=True): del regions[i] return regions def remove_short_regions(regions, min_l_score=600): """ remove intervals shorter than min_len min_l_score is the minimum score for a region to be considered l_score = length * PID """ regions_to_remove = [] for i in range(len(regions)): l_score = (regions[i][3] - 50) * (regions[i][2] - regions[i][1]) if l_score < min_l_score: regions_to_remove.append(i) for i in sorted(regions_to_remove, reverse=True): del regions[i] return regions def join_disjoint_regions_by_classification(disjoint_regions, max_gap=0): """ merge neighboring intervals with same classification and calculate mean weighted score weight correspond to length of the interval """ merged_regions = [] for seqid, start, end, score, classification in disjoint_regions: score_length = (end - start + 1) * score if len(merged_regions) == 0: merged_regions.append([seqid, start, end, score_length, classification]) else: cond_same_class = merged_regions[-1][4] == classification cond_same_seqid = merged_regions[-1][0] == seqid cond_neighboring = start - merged_regions[-1][2] + 1 <= max_gap if cond_same_class and cond_same_seqid and cond_neighboring: # extend region merged_regions[-1] = [merged_regions[-1][0], merged_regions[-1][1], end, merged_regions[-1][3] + score_length, merged_regions[-1][4]] else: merged_regions.append([seqid, start, end, score_length, classification]) # recalculate length weighted score for record in merged_regions: record[3] = record[3] / (record[2] - record[1] + 1) return merged_regions def write_merged_regions_to_gff3(merged_regions, outfile): """ write merged regions to gff3 file """ with open(outfile, "w") as f: # write header f.write("##gff-version 3\n") for seqid, start, end, score, classification in merged_regions: attributes = "Name={};score={}".format(classification, score) f.write( "\t".join( [seqid, "blast_parsed", "repeat_region", str(start), str(end), str(round(score,2)), ".", ".", attributes] ) ) f.write("\n") def sort_blast_table( blastfile, seqid_column=0, start_column=6, cpu=1 ): """ split blast table by seqid and sort by start position stores output in temp files columns are indexed from 0 but cut uses 1-based indexing! """ blast_sorted = tempfile.NamedTemporaryFile().name # create sorted dictionary seqid counts seq_id_counts = {} # sort blast file on disk using sort on seqid and start (numeric) position columns # using sort command as blast output could be very large cmd = "sort -k {0},{0} -k {1},{1}n --parallel {4} {2} > {3}".format( seqid_column + 1, start_column + 1, blastfile, blast_sorted, cpu ) subprocess.check_call(cmd, shell=True) # count seqids using uniq command cmd = "cut -f {0} {1} | uniq -c > {2}".format( seqid_column + 1, blast_sorted, blast_sorted + ".counts" ) subprocess.check_call(cmd, shell=True) # read counts file and create dictionary with open(blast_sorted + ".counts", "r") as f: for line in f: line = line.strip().split() seq_id_counts[line[1]] = int(line[0]) # remove counts file subprocess.call(["rm", blast_sorted + ".counts"]) # return sorted dictionary and sorted blast file return seq_id_counts, blast_sorted def run_blastn( query, db, blastfile, evalue=1e-3, max_target_seqs=999999999, gapopen=2, gapextend=1, reward=1, penalty=-1, word_size=9, num_threads=1, outfmt="6" ): """ run blastn """ # create temporary blast database: db_formated = tempfile.NamedTemporaryFile().name cmd = "makeblastdb -in {0} -dbtype nucl -out {1}".format(db, db_formated) subprocess.check_call(cmd, shell=True) # if query is smaller than 1GB, run blast on single file size = os.path.getsize(query) print("query size: {} bytes".format(size)) max_size = 1e6 overlap = 50000 if size < max_size: cmd = ("blastn -task rmblastn -query {0} -db {1} -out {2} -evalue {3} " "-max_target_seqs {4} " "-gapopen {5} -gapextend {6} -word_size {7} -num_threads " "{8} -outfmt '{9}' -reward {10} -penalty {11} -dust no").format( query, db_formated, blastfile, evalue, max_target_seqs, gapopen, gapextend, word_size, num_threads, outfmt, reward, penalty ) subprocess.check_call(cmd, shell=True) # if query is larger than 1GB, split query in chunks and run blast on each chunk else: print(f"query is larger than {max_size}, splitting query in chunks") query_parts, matching_table = split_fasta_to_chunks(query, max_size, overlap) print(query_parts) for i, part in enumerate(query_parts): print(f"running blast on chunk {i}") print(part) cmd = ("blastn -task rmblastn -query {0} -db {1} -out {2} -evalue {3} " "-max_target_seqs {4} " "-gapopen {5} -gapextend {6} -word_size {7} -num_threads " "{8} -outfmt '{9}' -reward {10} -penalty {11} -dust no").format( part, db_formated, f'{blastfile}.{i}', evalue, max_target_seqs, gapopen, gapextend, word_size, num_threads, outfmt, reward, penalty ) subprocess.check_call(cmd, shell=True) print(cmd) # remove part file # os.unlink(part) # merge blast results and recalculate start, end positions and header merge_blast_results(blastfile, matching_table, n_parts=len(query_parts)) # remove temporary blast database os.unlink(db_formated + ".nhr") os.unlink(db_formated + ".nin") os.unlink(db_formated + ".nsq") def merge_blast_results(blastfile, matching_table, n_parts): """ Merge blast tables and recalculate start, end positions based on matching table """ with open(blastfile, "w") as f: matching_table_dict = {i[4]: i for i in matching_table} print(matching_table_dict) for i in range(n_parts): with open(f'{blastfile}.{i}', "r") as f2: for line in f2: line = line.strip().split("\t") # seqid (header) is in column 1 seqid = line[0] line[0] = matching_table_dict[seqid][0] # increase coordinates by start position of chunk line[6] = str(int(line[6]) + matching_table_dict[seqid][2]) line[7] = str(int(line[7]) + matching_table_dict[seqid][2]) f.write("\t".join(line) + "\n") # remove temporary blast file # os.unlink(f'{blastfile}.{i}') def main(): """ main function """ # get command line arguments parser = argparse.ArgumentParser( description="""This script is used to parse blast output table to gff file""", formatter_class=argparse.RawTextHelpFormatter ) parser.add_argument( '-i', '--input', default=None, required=True, help="input file", type=str, action='store' ) parser.add_argument( '-d', '--db', default=None, required=False, help="Fasta file with repeat database", type=str, action='store' ) parser.add_argument( '-o', '--output', default=None, required=True, help="output file name", type=str, action='store' ) parser.add_argument( '-a', '--alternative_classification_coding', default=False, help="Use alternative classification coding", action='store_true' ) parser.add_argument( '-f', '--fasta_input', default=False, help="Input is fasta file instead of blast table", action='store_true' ) parser.add_argument( '-c', '--cpu', default=1, help="Number of cpu to use", type=int ) args = parser.parse_args() if args.fasta_input: # run blast using blastn blastfile = tempfile.NamedTemporaryFile().name if args.db: run_blastn(args.input, args.db, blastfile, num_threads=args.cpu) else: sys.exit("No repeat database provided") else: blastfile = args.input # sort blast table seq_id_counts, blast_sorted = sort_blast_table(blastfile, cpu=args.cpu) disjoin_regions = blast2disjoint( blast_sorted, seq_id_counts, canonical_classification=not args.alternative_classification_coding ) # remove short regions disjoin_regions = remove_short_interrupting_regions(disjoin_regions) # join neighboring regions with same classification merged_regions = join_disjoint_regions_by_classification(disjoin_regions) # remove short regions again merged_regions = remove_short_interrupting_regions(merged_regions) # merge again neighboring regions with same classification merged_regions = join_disjoint_regions_by_classification(merged_regions, max_gap=10) # remove short weak regions merged_regions = remove_short_regions(merged_regions) # last merge merged_regions = join_disjoint_regions_by_classification(merged_regions, max_gap=20) write_merged_regions_to_gff3(merged_regions, args.output) # remove temporary files os.remove(blast_sorted) if __name__ == "__main__": main()