Mercurial > repos > iuc > virannot_otu
diff otu.py @ 0:c9dac9b2e01c draft
planemo upload for repository https://github.com/galaxyproject/tools-iuc/tree/master/tools/virAnnot commit 3a3b40c15ae5e82334f016e88b1f3c5bbbb3b2cd
author | iuc |
---|---|
date | Mon, 04 Mar 2024 19:56:40 +0000 |
parents | |
children | 40fb54cc6628 |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/otu.py Mon Mar 04 19:56:40 2024 +0000 @@ -0,0 +1,442 @@ +#!/usr/bin/env python3 + + +# Name: virAnnot_otu +# Author: Marie Lefebvre - INRAE +# Reuirements: Ete3 toolkit and external apps +# Aims: Create viral OTUs based on RPS and Blast annotations + + +import argparse +import csv +import logging as log +import os +import random +import re + +import pandas as pd +import xlsxwriter +from Bio import SeqIO +from Bio.Align.Applications import ClustalOmegaCommandline +from ete3 import NodeStyle, SeqGroup, SeqMotifFace, Tree, TreeStyle + + +def main(): + """ + 1 - retrieve info (sequence, query_id, taxo) from RPS file + 2 - align protein sequences of the same domain, calculate + matrix of distances, generate trees + 3 - get statistics (read number) per otu + 4 - create HTML report + """ + options = _set_options() + _set_log_level(options.verbosity) + hits_collection = _cut_sequence(options) + _align_sequences(options, hits_collection) + _get_stats(options, hits_collection) + _create_html(options, hits_collection) + + +def _cut_sequence(options): + """ + Retrieve viral hits and sequences from RPS files + """ + log.info("Cut sequences") + i = 0 # keep track of iterations over rps files to use the corresponding fasta file + collection = {} + options.rps.sort() + for rps_file in options.rps: + log.debug("Reading rps file " + str(rps_file)) + with open(rps_file[0], 'r') as rps_current_file: + rps_reader = csv.reader(rps_current_file, delimiter='\t') + headers = 0 + for row in rps_reader: + if headers == 0: + # headers + headers += 1 + else: + if row[1] == "no_hit": + pass + else: + query_id = row[0] + cdd_id = row[2] + startQ = int(row[5]) + endQ = int(row[6]) + frame = float(row[7]) + description = row[8] + superkingdom = row[9] + match = re.search("Viruses", superkingdom) + # if contig is viral then retrieve sequence + if match: + options.fasta.sort() + seq = _retrieve_fasta_seq(options.fasta[i][0], query_id) + seq_length = len(seq) + if endQ < seq_length: + seq = seq[startQ - 1:endQ] + else: + seq = seq[startQ - 1:seq_length] + if frame < 0: + seq = seq.reverse_complement() + prot = seq.translate() + if len(prot) >= options.min_protein_length: + log.debug("Add " + query_id + " to collection") + if cdd_id not in collection: + collection[cdd_id] = {} + collection[cdd_id][query_id] = {} + collection[cdd_id][query_id]["nuccleotide"] = seq + collection[cdd_id][query_id]["protein"] = prot + collection[cdd_id][query_id]["full_description"] = description + if options.blast is not None: + options.blast.sort() + with open(options.blast[i][0], 'r') as blast_current_file: + blast_reader = csv.reader(blast_current_file, delimiter='\t') + for b_query in blast_reader: + if b_query[1] == query_id: + collection[cdd_id][query_id]["nb"] = b_query[2] + if len(b_query) > 10: + collection[cdd_id][query_id]["taxonomy"] = b_query[14] + else: + collection[cdd_id][query_id]["taxonomy"] = "Unknown" + else: + if "nb" not in collection[cdd_id][query_id]: + collection[cdd_id][query_id]["nb"] = 0 + if "taxonomy" not in collection[cdd_id][query_id]: + collection[cdd_id][query_id]["taxonomy"] = "Unknown" + else: + log.info("No blast file") + collection[cdd_id][query_id]["taxonomy"] = "Unknown" + collection[cdd_id][query_id]["nb"] = 0 + + collection[cdd_id]["short_description"] = description.split(",")[0] + description.split(",")[1] # keep pfamXXX and RdRp 1 + collection[cdd_id]["full_description"] = description + i += 1 + return collection + + +def _retrieve_fasta_seq(fasta_file, query_id): + """ + From fasta file retrieve specific sequence with id + """ + contigs_list = SeqIO.to_dict(SeqIO.parse(open(fasta_file), 'fasta')) + try: + seq = contigs_list[query_id].seq + except KeyError: + print("KeyError for " + query_id + " file " + fasta_file) + else: + return seq + + +def _create_tree(tree, fasta, out, color): + """ + Create phylogenic tree from multiple alignments + """ + try: + f = open(tree, 'r') + except IOError: + log.info("Unknown file: " + tree + ". You may have less than 2 sequences to align.") + return + + line = "" + for word in f: + line += word.strip() + + f.close() + seqs = SeqGroup(fasta, format="fasta") + t = Tree(tree) + ts = TreeStyle() + ts.show_branch_length = True + colors = _parse_color_file(color) + node_names = t.get_leaf_names() + for name in node_names: + seq = seqs.get_seq(name) + seqFace = SeqMotifFace(seq, seq_format="()") + node = t.get_leaves_by_name(name) + for i in range(0, len(node)): + if name in colors: + ns = NodeStyle() + ns['bgcolor'] = colors[name] + node[i].set_style(ns) + node[i].add_face(seqFace, 0, 'aligned') + + t.render(out, tree_style=ts) + + +def _parse_color_file(file): + fh = open(file) + reader = csv.reader(fh, delimiter="\t") + data = list(reader) + colors = {} + for i in range(0, len(data)): + colors[data[i][0]] = data[i][1] + + return colors + + +def _align_sequences(options, hits_collection): + """ + Align hit sequences with pfam reference + """ + log.info("Align sequences") + if not os.path.exists(options.output): + os.mkdir(options.output) + color_by_sample = {} + for cdd_id in hits_collection: + cdd_output = options.output + "/" + hits_collection[cdd_id]["short_description"].replace(" ", "_") + if not os.path.exists(cdd_output): + os.mkdir(cdd_output) + if os.path.exists(cdd_output + "/seq_to_align.fasta"): + os.remove(cdd_output + "/seq_to_align.fasta") + file_seq_to_align = cdd_output + "/seq_to_align.fasta" + file_color_config = cdd_output + "/color_config.txt" + f = open(file_seq_to_align, "a") + f_c = open(file_color_config, "w+") + log.info("Writing to " + file_seq_to_align) + count = 0 # count number of contig per domain + for query_id in hits_collection[cdd_id]: + if query_id not in ["short_description", "full_description"]: + sample = query_id.split("_")[0] # get sample from SAMPLE_IdCONTIG + sample_color = "#" + ''.join([random.choice('ABCDEF0123456789') for i in range(6)]) + # same color for each contig of the same sample + if sample not in color_by_sample.keys(): + color_by_sample[sample] = sample_color + f.write(">" + query_id + "\n") + f.write(str(hits_collection[cdd_id][query_id]["protein"]) + "\n") + f_c.write(query_id + '\t' + color_by_sample[sample] + '\n') + count += 1 + f.close() + f_c.close() + file_seq_aligned = cdd_output + '/seq_aligned.final_tree.fa' + tree_file = cdd_output + '/tree.dnd' + file_cluster = cdd_output + '/otu_cluster.csv' + # create alignment for domain with more than 1 contigs + if count > 1: + log.info("Run clustal omega...") + clustalo_cmd = ClustalOmegaCommandline("clustalo", infile=file_seq_to_align, outfile=file_seq_aligned, + guidetree_out=tree_file, seqtype="protein", force=True) + log.debug(clustalo_cmd) + stdout, stderr = clustalo_cmd() + log.debug(stdout + stderr) + + # create tree plot with colors + file_matrix = cdd_output + "/identity_matrix.csv" + log.info("Create tree...") + _create_tree(tree_file, file_seq_aligned, tree_file + '.png', file_color_config) + _compute_pairwise_distance(options, file_seq_aligned, file_matrix, cdd_id) + log.info("Retrieve OTUs...") + # if os.path.exists(file_cluster): + # os.remove(file_cluster) + otu_cmd = os.path.join(options.tool_path, 'seek_otu.R') + ' ' + file_matrix + ' ' + file_cluster + ' ' + str(options.perc) + log.debug(otu_cmd) + os.system(otu_cmd) + # only one contig + else: + mv_cmd = 'cp ' + file_seq_to_align + ' ' + file_seq_aligned + log.debug(mv_cmd) + os.system(mv_cmd) + + f = open(file_cluster, "w+") + f.write('OTU_1,1,' + list(hits_collection[cdd_id].keys())[0] + ',') + f.close() + + +def _compute_pairwise_distance(options, file_seq_aligned, file_matrix, cdd_id): + """ + Calculate paiwise distance between aligned protein sequences + from a cdd_id + """ + log.info("Compute pairwise distance of " + cdd_id) + matrix = {} + for k1 in SeqIO.parse(file_seq_aligned, "fasta"): + row = [] + for k2 in SeqIO.parse(file_seq_aligned, "fasta"): + identic = 0 + compared = 0 + keep_pos = 0 + for base in k1: + base2 = k2[keep_pos] + # mutation, next + if base == 'X' or base2 == 'X': + keep_pos += 1 + continue + # gap in both sequences, next + if base == '-' and base2 == '-': + keep_pos += 1 + continue + # gap in one of the sequence, next + if base == '-' or base2 == '-': + keep_pos += 1 + continue + # identity + if base == base2: + identic += 1 + compared += 1 + keep_pos += 1 + # set minimum overlap to 20 + if compared == 0 or compared < 20: + percentIdentity = 0 + else: + percentIdentity = (identic / compared) * 100 + row.append(percentIdentity) + matrix[k1.id] = row + log.debug("Write " + file_matrix) + f = open(file_matrix, "w+") + for row in matrix: + f.write(row + ',' + ', '.join(map(str, matrix[row])) + "\n") + f.close() + + +def _get_stats(options, hits_collection): + """ + Retrieve annotation and number of read + for each OTUs + """ + file_xlsx = options.output + '/otu_stats.xlsx' # Create a workbook + workbook = xlsxwriter.Workbook(file_xlsx) + log.info("Writing stats to " + file_xlsx) + for cdd_id in hits_collection: + otu_collection = {} + cdd_output = options.output + "/" + hits_collection[cdd_id]["short_description"].replace(" ", "_") + worksheet = workbook.add_worksheet(hits_collection[cdd_id]["short_description"]) # add a worksheet + file_cluster = cdd_output + '/otu_cluster.csv' + with open(file_cluster, 'r') as clust: + otu_reader = csv.reader(clust, delimiter=',') + samples_list = [] + for row in otu_reader: + contigs_list = row[2:len(row) - 1] # remove last empty column + otu_collection[row[0]] = {} # key -> otu number + otu_collection[row[0]]['contigs_list'] = contigs_list + for contig in contigs_list: + sample = contig.split('_')[0] + samples_list.append(sample) if sample not in samples_list else samples_list + if sample not in otu_collection[row[0]]: + otu_collection[row[0]][sample] = {} + otu_collection[row[0]][sample][contig] = {} + # add read number of the contig and annotation + if 'nb' in hits_collection[cdd_id][contig]: + otu_collection[row[0]][sample][contig]['nb'] = hits_collection[cdd_id][contig]["nb"] + else: + otu_collection[row[0]][sample][contig]['nb'] = 0 + if 'taxonomy' in hits_collection[cdd_id][contig]: + otu_collection[row[0]][sample][contig]['taxonomy'] = hits_collection[cdd_id][contig]["taxonomy"] + else: + otu_collection[row[0]][sample][contig]['taxonomy'] = 'unknown' + else: + otu_collection[row[0]][sample][contig] = {} + # add read number of the contig and annotation + if 'nb' in hits_collection[cdd_id][contig]: + otu_collection[row[0]][sample][contig]['nb'] = hits_collection[cdd_id][contig]["nb"] + else: + otu_collection[row[0]][sample][contig]['nb'] = 0 + if 'taxonomy' in hits_collection[cdd_id][contig]: + otu_collection[row[0]][sample][contig]['taxonomy'] = hits_collection[cdd_id][contig]["taxonomy"] + else: + otu_collection[row[0]][sample][contig]['taxonomy'] = 'unknown' + if 'taxonomy' in hits_collection[cdd_id][contig]: + otu_collection[row[0]]['global_taxonomy'] = hits_collection[cdd_id][contig]["taxonomy"] + else: + otu_collection[row[0]]['global_taxonomy'] = 'unknown' + + # calculate total number of reads for each sample of each OTU + for otu in otu_collection: + for sample in otu_collection[otu]: + if sample not in ['contigs_list', 'global_taxonomy']: + total_nb_read = 0 + for contig in otu_collection[otu][sample]: + total_nb_read += int(otu_collection[otu][sample][contig]['nb']) + otu_collection[otu][sample]['total_nb_read'] = total_nb_read + row = 0 + column = 0 + item = '#OTU_name' + worksheet.write(row, column, item) + for samp in samples_list: + column += 1 + worksheet.write(row, column, samp) + worksheet.write(row, column + 1, 'taxonomy') + worksheet.write(row, column + 2, 'contigs_list') + row = 1 + # column = 0 + for otu in otu_collection: + if isinstance(otu_collection[otu], dict): + column = 0 + worksheet.write(row, column, otu) + # prepare table with 0 in each cells + for sample in otu_collection[otu]: + column = 1 + for samp in samples_list: + worksheet.write(row, column, 0) + column += 1 + # fill in table with nb of read for each sample and each OTU + for sample in otu_collection[otu]: + column = 1 + for samp in samples_list: + if samp == sample: + worksheet.write(row, column, otu_collection[otu][sample]['total_nb_read']) + column += 1 + worksheet.write(row, len(samples_list) + 1, otu_collection[otu]['global_taxonomy'].replace(';', ' ')) + worksheet.write(row, len(samples_list) + 2, ",".join(otu_collection[otu]['contigs_list'])) + row += 1 + workbook.close() + read_file = pd.ExcelFile(file_xlsx) + for sheet in read_file.sheet_names: + cluster_nb_reads_file = options.output + "/" + sheet.replace(" ", "_") + "/cluster_nb_reads_files.tab" + data_xls = pd.read_excel(file_xlsx, sheet, dtype=str, index_col=None) + data_xls.to_csv(cluster_nb_reads_file, encoding='utf-8', index=False, sep='\t') + + +def _create_html(options, hits_collection): + """ + Create HTML file with all results + """ + # create mapping file with all informations to use to create HTML report + map_file_path = options.output + "/map.txt" + if os.path.exists(map_file_path): + os.remove(map_file_path) + + map_file = open(map_file_path, "w+") + headers = ['#cdd_id', 'align_files', 'tree_files', 'cluster_files', 'cluster_nb_reads_files', 'pairwise_files', 'description', 'full_description\n'] + map_file.write("\t".join(headers)) + for cdd_id in hits_collection: + cdd_output = hits_collection[cdd_id]["short_description"].replace(" ", "_") + short_description = cdd_output + file_seq_aligned = cdd_output + '/seq_aligned.final_tree.fa' + tree_file = cdd_output + '/tree.dnd.png' + file_cluster = cdd_output + '/otu_cluster.csv' + file_matrix = cdd_output + "/identity_matrix.csv" + cluster_nb_reads_files = cdd_output + "/cluster_nb_reads_files.tab" + map_file.write(cdd_id + "\t" + file_seq_aligned + "\t" + tree_file + "\t") + map_file.write(file_cluster + "\t" + cluster_nb_reads_files + "\t" + file_matrix + "\t") + map_file.write(short_description + "\t" + hits_collection[cdd_id]["full_description"] + "\n") + map_file.close() + log.info("Writing HTML report") + html_cmd = os.path.join(options.tool_path, 'rps2tree_html.py') + ' -m ' + map_file_path + ' -o ' + options.output + log.debug(html_cmd) + os.system(html_cmd) + + +def _set_options(): + parser = argparse.ArgumentParser() + parser.add_argument('-b', '--blast', help='TAB blast file from blast2ecsv module.', action='append', required=False, dest='blast', nargs='+') + parser.add_argument('-r', '--rps', help='TAB rpsblast file from rps2ecsv module.', action='append', required=True, dest='rps', nargs='+') + parser.add_argument('-f', '--fasta', help='FASTA file with contigs', action='append', required=True, dest='fasta', nargs='+') + parser.add_argument('-p', '--percentage', help='Percentage similarity threshold for OTUs cutoff.', action='store', type=int, default=90, dest='perc') + parser.add_argument('-vp', '--viral_portion', help='Minimun portion of viral sequences in RPS domain to be included.', action='store', type=float, default=0.3, dest='viral_portion') + parser.add_argument('-mpl', '--min_protein_length', help='Minimum query protein length.', action='store', type=int, default=100, dest='min_protein_length') + parser.add_argument('-tp', '--tool_path', help='Path to otu_seek.R', action='store', type=str, default='./', dest='tool_path') + parser.add_argument('-o', '--out', help='The output directory', action='store', type=str, default='./Rps2tree_OTU', dest='output') + parser.add_argument('-rgb', '--rgb-conf', help='Color palette for contigs coloration', action='store', type=str, default='rgb.txt', dest='file_rgb') + parser.add_argument('-v', '--verbosity', help='Verbose level', action='store', type=int, choices=[1, 2, 3, 4], default=1) + args = parser.parse_args() + return args + + +def _set_log_level(verbosity): + if verbosity == 1: + log_format = '%(asctime)s %(levelname)-8s %(message)s' + log.basicConfig(level=log.INFO, format=log_format) + elif verbosity == 3: + log_format = '%(filename)s:%(lineno)s - %(asctime)s %(levelname)-8s %(message)s' + log.basicConfig(level=log.DEBUG, format=log_format) + + +if __name__ == "__main__": + main()