Mercurial > repos > petr-novak > repeat_annotation_pipeline3
diff reAnnotate.py @ 11:5366d5ea04bc draft
planemo upload commit 9d1b19f98d8b7f0a0d1baf2da63a373d155626f8-dirty
author | petr-novak |
---|---|
date | Fri, 04 Aug 2023 12:35:32 +0000 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/reAnnotate.py Fri Aug 04 12:35:32 2023 +0000 @@ -0,0 +1,528 @@ +#!/usr/bin/env python +""" +parse blast output table to gff file +""" +import argparse +import itertools +import os +import re +import shutil +import subprocess +import sys +import tempfile +from collections import defaultdict + +# check version of python, must be at least 3.7 +if sys.version_info < (3, 10): + sys.exit("Python 3.10 or a more recent version is required.") + +def make_temp_files(number_of_files): + """ + Make named temporary files, file will not be deleted upon exit! + :param number_of_files: + :return: + filepaths + """ + temp_files = [] + for i in range(number_of_files): + temp_files.append(tempfile.NamedTemporaryFile(delete=False).name) + os.remove(temp_files[-1]) + return temp_files + + +def split_fasta_to_chunks(fasta_file, chunk_size=100000000, overlap=100000): + """ + Split fasta file to chunks, sequences longe than chuck size are split to overlaping + peaces. If sequences are shorter, chunck with multiple sequences are created. + :param fasta_file: + + :param fasta_file: + :param chunk_size: + :param overlap: + :return: + fasta_file_split + matching_table (list of lists [header,chunk_number, start, end, new_header]) + """ + min_chunk_size = chunk_size * 2 + fasta_sizes_dict = read_fasta_sequence_size(fasta_file) + # calculate size of items in fasta_dist dictionary + fasta_size = sum(fasta_sizes_dict.values()) + + # calculates ranges for splitting of fasta files and store them in list + matching_table = [] + fasta_file_split = tempfile.NamedTemporaryFile(delete=False).name + for header, size in fasta_sizes_dict.items(): + print(header, size, min_chunk_size) + + if size > min_chunk_size: + number_of_chunks = int(size / chunk_size) + print("number_of_chunks", number_of_chunks) + print("size", size) + print("chunk_size", chunk_size) + print("-----------------------------------------") + adjusted_chunk_size = int(size / number_of_chunks) + for i in range(number_of_chunks): + start = i * adjusted_chunk_size + end = ((i + 1) * + adjusted_chunk_size + + overlap) if i + 1 < number_of_chunks else size + new_header = header + '_' + str(i) + matching_table.append([header, i, start, end, new_header]) + else: + new_header = header + '_0' + matching_table.append([header, 0, 0, size, new_header]) + # read sequences from fasta files and split them to chunks according to matching table + # open output and input files, use with statement to close files + number_of_temp_files = len(matching_table) + print('number of temp files', number_of_temp_files) + fasta_dict = read_single_fasta_to_dictionary(open(fasta_file, 'r')) + with open(fasta_file_split, 'w') as fh_out: + for header in fasta_dict: + matching_table_part = [x for x in matching_table if x[0] == header] + for header2, i, start, end, new_header in matching_table_part: + fh_out.write('>' + new_header + '\n') + fh_out.write(fasta_dict[header][start:end] + '\n') + temp_files_fasta = make_temp_files(number_of_temp_files) + fasta_seq_size = read_fasta_sequence_size(fasta_file_split) + seq_id_size_sorted = [i[0] for i in sorted( + fasta_seq_size.items(), key=lambda x: int(x[1]), reverse=True + )] + seq_id_file_dict = dict(zip(seq_id_size_sorted, itertools.cycle(temp_files_fasta))) + # write sequences to temporary files + with open(fasta_file_split, 'r') as f: + first = True + for line in f: + if line[0] == '>': + # close previous file if it is not the first sequence + if not first: + fout.close() + first = False + header = line.strip().split(' ')[0][1:] + fout = open(seq_id_file_dict[header],'a') + fout.write(line) + else: + fout.write(line) + os.remove(fasta_file_split) + return temp_files_fasta, matching_table + + +def read_fasta_sequence_size(fasta_file): + """Read size of sequence into dictionary""" + fasta_dict = {} + with open(fasta_file, 'r') as f: + for line in f: + if line[0] == '>': + header = line.strip().split(' ')[0][1:] # remove part of name after space + fasta_dict[header] = 0 + else: + fasta_dict[header] += len(line.strip()) + return fasta_dict + + +def read_single_fasta_to_dictionary(fh): + """ + Read fasta file into dictionary + :param fh: + :return: + fasta_dict + """ + fasta_dict = {} + for line in fh: + if line[0] == '>': + header = line.strip().split(' ')[0][1:] # remove part of name after space + fasta_dict[header] = [] + else: + fasta_dict[header] += [line.strip()] + fasta_dict = {k: ''.join(v) for k, v in fasta_dict.items()} + return fasta_dict + + +def overlap(a, b): + """ + check if two intervals overlap + """ + return max(a[0], b[0]) <= min(a[1], b[1]) + + +def blast2disjoint( + blastfile, seqid_counts=None, start_column=6, end_column=7, class_column=1, + bitscore_column=11, pident_column=2, canonical_classification=True + ): + """ + find all interval beginning and ends in blast file and create bed file + input blastfile is tab separated file with columns: + 'qaccver saccver pident length mismatch gapopen qstart qend sstart send + evalue bitscore' (default outfmt 6 + blast must be sorted on qseqid and qstart + """ + # assume all in one chromosome! + starts_ends = {} + intervals = {} + if canonical_classification: + # make regular expression for canonical classification + # to match: Name#classification + # e.g. "Name_of_sequence#LTR/Ty1_copia/Angela" + regex = re.compile(r"(.*)[#](.*)") + group = 2 + else: + # make regular expression for non-canonical classification + # to match: Classification__Name + # e.g. "LTR/Ty1_copia/Angela__Name_of_sequence" + regex = re.compile(r"(.*)__(.*)") + group = 1 + + # identify continuous intervals + with open(blastfile, "r") as f: + for seqid in sorted(seqid_counts.keys()): + n_lines = seqid_counts[seqid] + starts_ends[seqid] = set() + for i in range(n_lines): + items = f.readline().strip().split() + # note 1s and 2s labels are used to distinguish between start and end and + # guarantee that with same coordinated start will be before end when + # sorting (1s < 2e) + starts_ends[seqid].add((int(items[start_column]), '1s')) + starts_ends[seqid].add((int(items[end_column]), '2e')) + intervals[seqid] = [] + for p1, p2 in itertools.pairwise(sorted(starts_ends[seqid])): + if p1[1] == '1s': + sp = 0 + else: + sp = 1 + if p2[1] == '2e': + ep = 0 + else: + ep = 1 + intervals[seqid].append((p1[0] + sp, p2[0] - ep)) + # scan each blast hit against continuous region and record hit with best score + with open(blastfile, "r") as f: + disjoint_regions = [] + for seqid in sorted(seqid_counts.keys()): + n_lines = seqid_counts[seqid] + idx_of_overlaps = {} + best_pident = defaultdict(lambda: 0.0) + best_bitscore = defaultdict(lambda: 0.0) + best_hit_name = defaultdict(lambda: "") + i1 = 0 + for i in range(n_lines): + items = f.readline().strip().split() + start = int(items[start_column]) + end = int(items[end_column]) + pident = float(items[pident_column]) + bitscore = float(items[bitscore_column]) + classification = items[class_column] + j = 0 + done = False + while True: + # beginning of searched region - does it overlap? + c_ovl = overlap(intervals[seqid][i1], (start, end)) + if c_ovl: + # if overlap is detected, add to dictionary + idx_of_overlaps[i] = [i1] + if best_bitscore[i1] < bitscore: + best_pident[i1] = pident + best_bitscore[i1] = bitscore + best_hit_name[i1] = classification + # add search also downstream + while True: + j += 1 + if j + i1 >= len(intervals[seqid]): + done = True + break + c_ovl = overlap(intervals[seqid][i1 + j], (start, end)) + if c_ovl: + idx_of_overlaps[i].append(i1 + j) + if best_bitscore[i1 + j] < bitscore: + best_pident[i1 + j] = pident + best_bitscore[i1 + j] = bitscore + best_hit_name[i1 + j] = classification + else: + done = True + break + + else: + # does no overlap - search next interval + i1 += 1 + if done or i1 >= (len(intervals[seqid]) - 1): + break + + for i in sorted(best_pident.keys()): + try: + classification = re.match(regex, best_hit_name[i]).group(group) + except AttributeError: + classification = best_hit_name[i] + record = ( + seqid, intervals[seqid][i][0], intervals[seqid][i][1], best_pident[i], + classification) + disjoint_regions.append(record) + return disjoint_regions + + +def remove_short_interrupting_regions(regions, min_len=10, max_gap=2): + """ + remove intervals shorter than min_len which are directly adjacent to other + regions on both sides which are longer than min_len and has same classification + """ + regions_to_remove = [] + for i in range(1, len(regions) - 1): + if regions[i][2] - regions[i][1] < min_len: + c1 = regions[i - 1][2] - regions[i - 1][1] > min_len + c2 = regions[i + 1][2] - regions[i + 1][1] > min_len + c3 = regions[i - 1][4] == regions[i + 1][4] # same classification + c4 = regions[i + 1][4] != regions[i][4] # different classification + c5 = regions[i][1] - regions[i - 1][2] < max_gap # max gap between regions + c6 = regions[i + 1][1] - regions[i][2] < max_gap # max gap between regions + if c1 and c2 and c3 & c4 and c5 and c6: + regions_to_remove.append(i) + for i in sorted(regions_to_remove, reverse=True): + del regions[i] + return regions + + +def remove_short_regions(regions, min_l_score=600): + """ + remove intervals shorter than min_len + min_l_score is the minimum score for a region to be considered + l_score = length * PID + """ + regions_to_remove = [] + for i in range(len(regions)): + l_score = (regions[i][3] - 50) * (regions[i][2] - regions[i][1]) + if l_score < min_l_score: + regions_to_remove.append(i) + for i in sorted(regions_to_remove, reverse=True): + del regions[i] + return regions + + +def join_disjoint_regions_by_classification(disjoint_regions, max_gap=0): + """ + merge neighboring intervals with same classification and calculate mean weighted score + weight correspond to length of the interval + """ + merged_regions = [] + for seqid, start, end, score, classification in disjoint_regions: + score_length = (end - start + 1) * score + if len(merged_regions) == 0: + merged_regions.append([seqid, start, end, score_length, classification]) + else: + cond_same_class = merged_regions[-1][4] == classification + cond_same_seqid = merged_regions[-1][0] == seqid + cond_neighboring = start - merged_regions[-1][2] + 1 <= max_gap + if cond_same_class and cond_same_seqid and cond_neighboring: + # extend region + merged_regions[-1] = [merged_regions[-1][0], merged_regions[-1][1], end, + merged_regions[-1][3] + score_length, + merged_regions[-1][4]] + else: + merged_regions.append([seqid, start, end, score_length, classification]) + # recalculate length weighted score + for record in merged_regions: + record[3] = record[3] / (record[2] - record[1] + 1) + return merged_regions + + +def write_merged_regions_to_gff3(merged_regions, outfile): + """ + write merged regions to gff3 file + """ + with open(outfile, "w") as f: + # write header + f.write("##gff-version 3\n") + for seqid, start, end, score, classification in merged_regions: + attributes = "Name={};score={}".format(classification, score) + f.write( + "\t".join( + [seqid, "blast_parsed", "repeat_region", str(start), str(end), + str(round(score,2)), ".", ".", attributes] + ) + ) + f.write("\n") + + +def sort_blast_table( + blastfile, seqid_column=0, start_column=6, cpu=1 + ): + """ + split blast table by seqid and sort by start position + stores output in temp files + columns are indexed from 0 + but cut uses 1-based indexing! + """ + blast_sorted = tempfile.NamedTemporaryFile().name + # create sorted dictionary seqid counts + seq_id_counts = {} + # sort blast file on disk using sort on seqid and start (numeric) position columns + # using sort command as blast output could be very large + cmd = "sort -k {0},{0} -k {1},{1}n --parallel {4} {2} > {3}".format( + seqid_column + 1, start_column + 1, blastfile, blast_sorted, cpu + ) + subprocess.check_call(cmd, shell=True) + + # count seqids using uniq command + cmd = "cut -f {0} {1} | uniq -c > {2}".format( + seqid_column + 1, blast_sorted, blast_sorted + ".counts" + ) + subprocess.check_call(cmd, shell=True) + # read counts file and create dictionary + with open(blast_sorted + ".counts", "r") as f: + for line in f: + line = line.strip().split() + seq_id_counts[line[1]] = int(line[0]) + # remove counts file + subprocess.call(["rm", blast_sorted + ".counts"]) + # return sorted dictionary and sorted blast file + return seq_id_counts, blast_sorted + + +def run_blastn( + query, db, blastfile, evalue=1e-3, max_target_seqs=999999999, gapopen=2, + gapextend=1, reward=1, penalty=-1, word_size=9, num_threads=1, outfmt="6" + ): + """ + run blastn + """ + # create temporary blast database: + db_formated = tempfile.NamedTemporaryFile().name + cmd = "makeblastdb -in {0} -dbtype nucl -out {1}".format(db, db_formated) + subprocess.check_call(cmd, shell=True) + # if query is smaller than 1GB, run blast on single file + size = os.path.getsize(query) + print("query size: {} bytes".format(size)) + max_size = 1e6 + overlap = 50000 + if size < max_size: + cmd = ("blastn -task rmblastn -query {0} -db {1} -out {2} -evalue {3} " + "-max_target_seqs {4} " + "-gapopen {5} -gapextend {6} -word_size {7} -num_threads " + "{8} -outfmt '{9}' -reward {10} -penalty {11} -dust no").format( + query, db_formated, blastfile, evalue, max_target_seqs, gapopen, gapextend, + word_size, num_threads, outfmt, reward, penalty + ) + subprocess.check_call(cmd, shell=True) + # if query is larger than 1GB, split query in chunks and run blast on each chunk + else: + print(f"query is larger than {max_size}, splitting query in chunks") + query_parts, matching_table = split_fasta_to_chunks(query, max_size, overlap) + print(query_parts) + for i, part in enumerate(query_parts): + print(f"running blast on chunk {i}") + print(part) + cmd = ("blastn -task rmblastn -query {0} -db {1} -out {2} -evalue {3} " + "-max_target_seqs {4} " + "-gapopen {5} -gapextend {6} -word_size {7} -num_threads " + "{8} -outfmt '{9}' -reward {10} -penalty {11} -dust no").format( + part, db_formated, f'{blastfile}.{i}', evalue, max_target_seqs, gapopen, + gapextend, + word_size, num_threads, outfmt, reward, penalty + ) + subprocess.check_call(cmd, shell=True) + print(cmd) + # remove part file + # os.unlink(part) + # merge blast results and recalculate start, end positions and header + merge_blast_results(blastfile, matching_table, n_parts=len(query_parts)) + + # remove temporary blast database + os.unlink(db_formated + ".nhr") + os.unlink(db_formated + ".nin") + os.unlink(db_formated + ".nsq") + +def merge_blast_results(blastfile, matching_table, n_parts): + """ + Merge blast tables and recalculate start, end positions based on + matching table + """ + with open(blastfile, "w") as f: + matching_table_dict = {i[4]: i for i in matching_table} + print(matching_table_dict) + for i in range(n_parts): + with open(f'{blastfile}.{i}', "r") as f2: + for line in f2: + line = line.strip().split("\t") + # seqid (header) is in column 1 + seqid = line[0] + line[0] = matching_table_dict[seqid][0] + # increase coordinates by start position of chunk + line[6] = str(int(line[6]) + matching_table_dict[seqid][2]) + line[7] = str(int(line[7]) + matching_table_dict[seqid][2]) + f.write("\t".join(line) + "\n") + # remove temporary blast file + # os.unlink(f'{blastfile}.{i}') + +def main(): + """ + main function + """ + # get command line arguments + parser = argparse.ArgumentParser( + description="""This script is used to parse blast output table to gff file""", + formatter_class=argparse.RawTextHelpFormatter + ) + parser.add_argument( + '-i', '--input', default=None, required=True, help="input file", type=str, + action='store' + ) + parser.add_argument( + '-d', '--db', default=None, required=False, + help="Fasta file with repeat database", type=str, action='store' + ) + parser.add_argument( + '-o', '--output', default=None, required=True, help="output file name", type=str, + action='store' + ) + parser.add_argument( + '-a', '--alternative_classification_coding', default=False, + help="Use alternative classification coding", action='store_true' + ) + parser.add_argument( + '-f', '--fasta_input', default=False, + help="Input is fasta file instead of blast table", action='store_true' + ) + parser.add_argument( + '-c', '--cpu', default=1, help="Number of cpu to use", type=int + ) + + args = parser.parse_args() + + if args.fasta_input: + # run blast using blastn + blastfile = tempfile.NamedTemporaryFile().name + if args.db: + run_blastn(args.input, args.db, blastfile, num_threads=args.cpu) + else: + sys.exit("No repeat database provided") + else: + blastfile = args.input + + # sort blast table + seq_id_counts, blast_sorted = sort_blast_table(blastfile, cpu=args.cpu) + disjoin_regions = blast2disjoint( + blast_sorted, seq_id_counts, + canonical_classification=not args.alternative_classification_coding + ) + + # remove short regions + disjoin_regions = remove_short_interrupting_regions(disjoin_regions) + + # join neighboring regions with same classification + merged_regions = join_disjoint_regions_by_classification(disjoin_regions) + + # remove short regions again + merged_regions = remove_short_interrupting_regions(merged_regions) + + # merge again neighboring regions with same classification + merged_regions = join_disjoint_regions_by_classification(merged_regions, max_gap=10) + + # remove short weak regions + merged_regions = remove_short_regions(merged_regions) + + # last merge + merged_regions = join_disjoint_regions_by_classification(merged_regions, max_gap=20) + write_merged_regions_to_gff3(merged_regions, args.output) + # remove temporary files + os.remove(blast_sorted) + + +if __name__ == "__main__": + main()