Mercurial > repos > petr-novak > repeatrxplorer
diff lib/graphtools.py @ 0:1d1b9e1b2e2f draft
Uploaded
author | petr-novak |
---|---|
date | Thu, 19 Dec 2019 10:24:45 -0500 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/lib/graphtools.py Thu Dec 19 10:24:45 2019 -0500 @@ -0,0 +1,1079 @@ +#!/usr/bin/env python3 +''' +This module is mainly for large graph (e.i hitsort) storage, parsing and for clustering +''' +import os +import sys +import sqlite3 +import time +import subprocess +import logging +from collections import defaultdict +import collections +import operator +import math +import random +import itertools +import config +from lib import r2py +from lib.utils import FilePath +from lib.parallel.parallel import parallel2 as parallel +REQUIRED_VERSION = (3, 4) +MAX_BUFFER_SIZE = 100000 +if sys.version_info < REQUIRED_VERSION: + raise Exception("\n\npython 3.4 or higher is required!\n") +LOGGER = logging.getLogger(__name__) + + +def dfs(start, graph): + """ + helper function for cluster merging. + Does depth-first search, returning a set of all nodes seen. + Takes: a graph in node --> [neighbors] form. + """ + visited, worklist = set(), [start] + + while worklist: + node = worklist.pop() + if node not in visited: + visited.add(node) + # Add all the neighbors to the worklist. + worklist.extend(graph[node]) + + return visited + + +def graph_components(edges): + """ + Given a graph as a list of edges, divide the nodes into components. + Takes a list of pairs of nodes, where the nodes are integers. + """ + + # Construct a graph (mapping node --> [neighbors]) from the edges. + graph = defaultdict(list) + nodes = set() + + for v1, v2 in edges: + nodes.add(v1) + nodes.add(v2) + + graph[v1].append(v2) + graph[v2].append(v1) + + # Traverse the graph to find the components. + components = [] + + # We don't care what order we see the nodes in. + while nodes: + component = dfs(nodes.pop(), graph) + components.append(component) + + # Remove this component from the nodes under consideration. + nodes -= component + + return components + + +class Graph(): + ''' + create Graph object stored in sqlite database, either in memory or on disk + structure of table is: + V1 V2 weigth12 + V2 V3 weight23 + V4 V5 weight45 + ... + ... + !! this is undirected simple graph - duplicated edges must + be removed on graph creation + + ''' + # seed for random number generator - this is necessary for reproducibility between runs + seed = '123' + + def __init__(self, + source=None, + filename=None, + new=False, + paired=True, + seqids=None): + ''' + filename : fite where to store database, if not defined it is stored in memory + source : ncol file from which describe graph + new : if false and source is not define graph can be loaded from database (filename) + + vertices_name must be in correcti order!!! + ''' + + self.filename = filename + self.source = source + self.paired = paired + # path to indexed graph - will be set later + self.indexed_file = None + self._cluster_list = None + # these two attributes are set after clustering + # communities before merging + self.graph_2_community0 = None + # communities after merging + self.graph_2_community = None + self.number_of_clusters = None + self.binary_file = None + self.cluster_sizes = None + self.graph_tree = None + self.graph_tree_log = None + self.weights_file = None + + if filename: + if os.path.isfile(filename) and (new or source): + os.remove(filename) + self.conn = sqlite3.connect(filename) + else: + self.conn = sqlite3.connect(":memory:") + c = self.conn.cursor() + + c.execute("PRAGMA page_size=8192") + c.execute("PRAGMA cache_size = 2000000 ") # this helps + + try: + c.execute(( + "create table graph (v1 integer, v2 integer, weight integer, " + "pair integer, v1length integer, v1start integer, v1end integer, " + "v2length integer, v2start integer, v2end integer, pid integer," + "evalue real, strand text )")) + except sqlite3.OperationalError: + pass # table already exist + else: + c.execute( + "create table vertices (vertexname text primary key, vertexindex integer)") + tables = sorted(c.execute( + "SELECT name FROM sqlite_master WHERE type='table'").fetchall()) + + if not [('graph', ), ('vertices', )] == tables: + raise Exception("tables for sqlite for graph are not correct") + + if source: + self._read_from_hitsort() + + if paired and seqids: + # vertices must be defined - create graph of paired reads: + # last character must disinguish pair + c.execute(( + "create table pairs (basename, vertexname1, vertexname2," + "v1 integer, v2 integer, cluster1 integer, cluster2 integer)")) + buffer = [] + for i, k in zip(seqids[0::2], seqids[1::2]): + assert i[:-1] == k[:-1], "problem with pair reads ids" + # some vertices are not in graph - singletons + try: + index1 = self.vertices[i] + except KeyError: + index1 = -1 + + try: + index2 = self.vertices[k] + except KeyError: + index2 = -1 + + buffer.append((i[:-1], i, k, index1, index2)) + + self.conn.executemany( + "insert into pairs (basename, vertexname1, vertexname2, v1, v2) values (?,?,?,?,?)", + buffer) + self.conn.commit() + + def _read_from_hitsort(self): + + c = self.conn.cursor() + c.execute("delete from graph") + buffer = [] + vertices = {} + counter = 0 + v_count = 0 + with open(self.source, 'r') as f: + for i in f: + edge_index = {} + items = i.split() + # get or insert vertex index + for vn in items[0:2]: + if vn not in vertices: + vertices[vn] = v_count + edge_index[vn] = v_count + v_count += 1 + else: + edge_index[vn] = vertices[vn] + if self.paired: + pair = int(items[0][:-1] == items[1][:-1]) + else: + pair = 0 + buffer.append(((edge_index[items[0]], edge_index[items[1]], + items[2], pair) + tuple(items[3:]))) + if len(buffer) == MAX_BUFFER_SIZE: + counter += 1 + self.conn.executemany( + "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)", + buffer) + buffer = [] + if buffer: + self.conn.executemany( + "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)", + buffer) + + self.conn.commit() + self.vertices = vertices + self.vertexid2name = { + vertex: index + for index, vertex in vertices.items() + } + self.vcount = len(vertices) + c = self.conn.cursor() + c.execute("select count(*) from graph") + self.ecount = c.fetchone()[0] + # fill table of vertices + self.conn.executemany("insert into vertices values (?,?)", + vertices.items()) + self.conn.commit() + + def save_indexed_graph(self, file=None): + if not file: + self.indexed_file = "{}.int".format(self.source) + else: + self.indexed_file = file + c = self.conn.cursor() + with open(self.indexed_file, 'w') as f: + out = c.execute('select v1,v2,weight from graph') + for v1, v2, weight in out: + f.write('{}\t{}\t{}\n'.format(v1, v2, weight)) + + def get_subgraph(self, vertices): + pass + + def _levels(self): + with open(self.graph_tree_log, 'r') as f: + levels = -1 + for i in f: + if i[:5] == 'level': + levels += 1 + return levels + + def _reindex_community(self, id2com): + ''' + reindex community and superclusters so that biggest cluster is no.1 + ''' + self.conn.commit() + _, community, supercluster = zip(*id2com) + (cluster_index, frq, self.cluster_sizes, + self.number_of_clusters) = self._get_index_and_frequency(community) + + supercluster_index, sc_frq, _, _ = self._get_index_and_frequency( + supercluster) + id2com_reindexed = [] + + for i, _ in enumerate(id2com): + id2com_reindexed.append((id2com[i][0], id2com[i][1], frq[ + i], cluster_index[i], supercluster_index[i], sc_frq[i])) + return id2com_reindexed + + @staticmethod + def _get_index_and_frequency(membership): + frequency_table = collections.Counter(membership) + frequency_table_sorted = sorted(frequency_table.items(), + key=operator.itemgetter(1), + reverse=True) + frq = [] + for i in membership: + frq.append(frequency_table[i]) + rank = {} + index = 0 + for comm, _ in frequency_table_sorted: + index += 1 + rank[comm] = index + cluster_index = [rank[i] for i in membership] + cluster_sizes = [i[1] for i in frequency_table_sorted] + number_of_clusters = len(frequency_table) + return [cluster_index, frq, cluster_sizes, number_of_clusters] + + def louvain_clustering(self, merge_threshold=0, cleanup=False): + ''' + input - graph + output - list of clusters + executables path ?? + ''' + LOGGER.info("converting hitsort to binary format") + self.binary_file = "{}.bin".format(self.indexed_file) + self.weights_file = "{}.weight".format(self.indexed_file) + self.graph_tree = "{}.graph_tree".format(self.indexed_file) + self.graph_tree_log = "{}.graph_tree_log".format(self.indexed_file) + self.graph_2_community0 = "{}.graph_2_community0".format( + self.indexed_file) + self._cluster_list = None + self.graph_2_community = "{}.graph_2_community".format( + self.indexed_file) + print(["louvain_convert", "-i", self.indexed_file, "-o", + self.binary_file, "-w", self.weights_file]) + subprocess.check_call( + ["louvain_convert", "-i", self.indexed_file, "-o", + self.binary_file, "-w", self.weights_file], + timeout=None) + + gt = open(self.graph_tree, 'w') + gtl = open(self.graph_tree_log, 'w') + LOGGER.info("running louvain clustering...") + subprocess.check_call( + ["louvain_community", self.binary_file, "-l", "-1", "-w", + self.weights_file, "-v ", "-s", self.seed], + stdout=gt, + stderr=gtl, + timeout=None) + gt.close() + gtl.close() + + LOGGER.info("creating list of cummunities") + gt2c = open(self.graph_2_community0, 'w') + subprocess.check_call( + ['louvain_hierarchy', self.graph_tree, "-l", str(self._levels())], + stdout=gt2c) + gt2c.close() + if merge_threshold and self.paired: + com2newcom = self.find_superclusters(merge_threshold) + elif self.paired: + com2newcom = self.find_superclusters(config.SUPERCLUSTER_THRESHOLD) + else: + com2newcom = {} + # merging of clusters, creatting superclusters + LOGGER.info("mergings clusters based on mate-pairs ") + # modify self.graph_2_community file + # rewrite graph2community + with open(self.graph_2_community0, 'r') as fin: + with open(self.graph_2_community, 'w') as fout: + for i in fin: + # write graph 2 community file in format: + # id communityid supeclusterid + # if merging - community and superclustwers are identical + vi, com = i.split() + if merge_threshold: + ## mergin + if int(com) in com2newcom: + fout.write("{} {} {}\n".format(vi, com2newcom[int( + com)], com2newcom[int(com)])) + else: + fout.write("{} {} {}\n".format(vi, com, com)) + else: + ## superclusters + if int(com) in com2newcom: + fout.write("{} {} {}\n".format(vi, com, com2newcom[ + int(com)])) + else: + fout.write("{} {} {}\n".format(vi, com, com)) + + LOGGER.info("loading communities into database") + c = self.conn.cursor() + c.execute(("create table communities (vertexindex integer primary key," + "community integer, size integer, cluster integer, " + "supercluster integer, supercluster_size integer)")) + id2com = [] + with open(self.graph_2_community, 'r') as f: + for i in f: + name, com, supercluster = i.split() + id2com.append((name, com, supercluster)) + id2com_reindexed = self._reindex_community(id2com) + c.executemany("insert into communities values (?,?,?,?,?,?)", + id2com_reindexed) + #create table of superclusters - clusters + c.execute(("create table superclusters as " + "select distinct supercluster, supercluster_size, " + "cluster, size from communities;")) + # create view id-index-cluster + c.execute( + ("CREATE VIEW vertex_cluster AS SELECT vertices.vertexname," + "vertices.vertexindex, communities.cluster, communities.size" + " FROM vertices JOIN communities USING (vertexindex)")) + self.conn.commit() + + # add clustering infor to graph + LOGGER.info("updating graph table") + t0 = time.time() + + c.execute("alter table graph add c1 integer") + c.execute("alter table graph add c2 integer") + c.execute(("update graph set c1 = (select cluster FROM communities " + "where communities.vertexindex=graph.v1)")) + c.execute( + ("update graph set c2 = (select cluster FROM communities where " + "communities.vertexindex=graph.v2)")) + self.conn.commit() + t1 = time.time() + LOGGER.info("updating graph table - done in {} seconds".format(t1 - + t0)) + + # identify similarity connections between clusters + c.execute( + "create table cluster_connections as SELECT c1,c2 , count(*) FROM (SELECT c1, c2 FROM graph WHERE c1>c2 UNION ALL SELECT c2 as c1, c1 as c2 FROM graph WHERE c2>c1) GROUP BY c1, c2") + # TODO - remove directionality - summarize - + + # add cluster identity to pairs table + + if self.paired: + LOGGER.info("analyzing pairs ") + t0 = time.time() + c.execute( + "UPDATE pairs SET cluster1=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v1)") + t1 = time.time() + LOGGER.info( + "updating pairs table - cluster1 - done in {} seconds".format( + t1 - t0)) + + t0 = time.time() + c.execute( + "UPDATE pairs SET cluster2=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v2)") + t1 = time.time() + LOGGER.info( + "updating pairs table - cluster2 - done in {} seconds".format( + t1 - t0)) + # reorder records + + t0 = time.time() + c.execute( + "UPDATE pairs SET cluster1=cluster2, cluster2=cluster1, vertexname1=vertexname2,vertexname2=vertexname1 where cluster1<cluster2") + t1 = time.time() + LOGGER.info("sorting - done in {} seconds".format(t1 - t0)) + + t0 = time.time() + c.execute( + "create table cluster_mate_connections as select cluster1 as c1, cluster2 as c2, count(*) as N, group_concat(basename) as ids from pairs where cluster1!=cluster2 group by cluster1, cluster2;") + t1 = time.time() + LOGGER.info( + "creating cluster_mate_connections table - done in {} seconds".format( + t1 - t0)) + # summarize + t0 = time.time() + self._calculate_pair_bond() + t1 = time.time() + LOGGER.info( + "calculating cluster pair bond - done in {} seconds".format( + t1 - t0)) + t0 = time.time() + else: + # not paired - create empty tables + self._add_empty_tables() + + self.conn.commit() + t1 = time.time() + LOGGER.info("commiting changes - done in {} seconds".format(t1 - t0)) + + if cleanup: + LOGGER.info("cleaning clustering temp files") + os.unlink(self.binary_file) + os.unlink(self.weights_file) + os.unlink(self.graph_tree) + os.unlink(self.graph_tree_log) + os.unlink(self.graph_2_community0) + os.unlink(self.graph_2_community) + os.unlink(self.indexed_file) + self.binary_file = None + self.weights_file = None + self.graph_tree = None + self.graph_tree_log = None + self.graph_2_community0 = None + self.graph_2_community = None + self.indexed_file = None + + # calcultate k + + def find_superclusters(self, merge_threshold): + '''Find superclusters from clustering based on paired reads ''' + clsdict = {} + with open(self.graph_2_community0, 'r') as f: + for i in f: + vi, com = i.split() + if com in clsdict: + clsdict[com] += [self.vertexid2name[int(vi)][0:-1]] + else: + clsdict[com] = [self.vertexid2name[int(vi)][0:-1]] + # remove all small clusters - these will not be merged: + small_cls = [] + for i in clsdict: + if len(clsdict[i]) < config.MINIMUM_NUMBER_OF_READS_FOR_MERGING: + small_cls.append(i) + for i in small_cls: + del clsdict[i] + pairs = [] + for i, j in itertools.combinations(clsdict, 2): + s1 = set(clsdict[i]) + s2 = set(clsdict[j]) + wgh = len(s1 & s2) + if wgh < config.MINIMUM_NUMBER_OF_SHARED_PAIRS_FOR_MERGING: + continue + else: + n1 = len(s1) * 2 - len(clsdict[i]) + n2 = len(s2) * 2 - len(clsdict[j]) + k = 2 * wgh / (n1 + n2) + if k > merge_threshold: + pairs.append((int(i), int(j))) + # find connected commponents - will be merged + cls2merge = graph_components(pairs) + com2newcom = {} + for i in cls2merge: + newcom = min(i) + for j in i: + com2newcom[j] = newcom + return com2newcom + + def adjust_cluster_size(self, proportion_kept, ids_kept): + LOGGER.info("adjusting cluster sizes") + c = self.conn.cursor() + c.execute("ALTER TABLE superclusters ADD COLUMN size_uncorrected INTEGER") + c.execute("UPDATE superclusters SET size_uncorrected=size") + if ids_kept: + ids_kept_set = set(ids_kept) + ratio = (1 - proportion_kept)/proportion_kept + for cl, size in c.execute("SELECT cluster,size FROM superclusters"): + ids = self.get_cluster_reads(cl) + ovl_size = len(ids_kept_set.intersection(ids)) + size_adjusted = int(len(ids) + ovl_size * ratio) + if size_adjusted > size: + c.execute("UPDATE superclusters SET size=? WHERE cluster=?", + (size_adjusted, cl)) + self.conn.commit() + LOGGER.info("adjusting cluster sizes - done") + + def export_cls(self, path): + with open(path, 'w') as f: + for i in range(1, self.number_of_clusters + 1): + ids = self.get_cluster_reads(i) + f.write(">CL{}\t{}\n".format(i, len(ids))) + f.write("\t".join(ids)) + f.write("\n") + + def _calculate_pair_bond(self): + c = self.conn.cursor() + out = c.execute("select c1, c2, ids from cluster_mate_connections") + buffer = [] + for c1, c2, ids in out: + w = len(set(ids.split(","))) + n1 = len(set([i[:-1] for i in self.get_cluster_reads(c1) + ])) * 2 - len(self.get_cluster_reads(c1)) + n2 = len(set([i[:-1] for i in self.get_cluster_reads(c2) + ])) * 2 - len(self.get_cluster_reads(c2)) + buffer.append((c1, c2, n1, n2, w, 2 * w / (n1 + n2))) + c.execute( + "CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)") + c.executemany(" INSERT INTO cluster_mate_bond values (?,?,?,?,?,?)", + buffer) + + def _add_empty_tables(self): + '''This is used with reads that are not paired + - it creates empty mate tables, this is necessary for + subsequent reporting to work corectly ''' + c = self.conn.cursor() + c.execute(("CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, " + "n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)")) + c.execute( + "CREATE TABLE cluster_mate_connections (c1 INTEGER, c2 INTEGER, N INTEGER, ids TEXT) ") + + def get_cluster_supercluster(self, cluster): + '''Get supercluster id for suplied cluster ''' + c = self.conn.cursor() + out = c.execute( + 'SELECT supercluster FROM communities WHERE cluster="{0}" LIMIT 1'.format( + cluster)) + sc = out.fetchone()[0] + return sc + + def get_cluster_reads(self, cluster): + + if self._cluster_list: + return self._cluster_list[str(cluster)] + else: + # if queried first time + c = self.conn.cursor() + out = c.execute("select cluster, vertexname from vertex_cluster") + cluster_list = collections.defaultdict(list) + for clusterindex, vertexname in out: + cluster_list[str(clusterindex)].append(vertexname) + self._cluster_list = cluster_list + return self._cluster_list[str(cluster)] + + + def extract_cluster_blast(self, path, index, ids=None): + ''' Extract blast for cluster and save it to path + return number of blast lines ( i.e. number of graph edges E) + if ids is specified , only subset of blast is used''' + c = self.conn.cursor() + if ids: + vertexindex = ( + "select vertexindex from vertices " + "where vertexname in ({})").format('"' + '","'.join(ids) + '"') + + out = c.execute(("select * from graph where c1={0} and c2={0}" + " and v1 in ({1}) and v2 in ({1})").format( + index, vertexindex)) + else: + out = c.execute( + "select * from graph where c1={0} and c2={0}".format(index)) + E = 0 + N = len(self.get_cluster_reads(index)) + with open(path, 'w') as f: + for i in out: + print(self.vertexid2name[i[0]], + self.vertexid2name[ + i[1]], + i[2], + *i[4:13], + sep='\t', + file=f) + E += 1 + return E + + def export_clusters_files_multiple(self, + min_size, + directory, + sequences=None, + tRNA_database_path=None, + satellite_model_path=None): + def load_fun(N, E): + ''' estimate mem usage from graph size and density''' + NE = math.log(float(N) * float(E), 10) + if NE > 11.5: + return 1 + if NE > 11: + return 0.9 + if NE > 10: + return 0.4 + if NE > 9: + return 0.2 + if NE > 8: + return 0.07 + return 0.02 + + def estimate_sample_size(NV, NE, maxv, maxe): + ''' estimat suitable sampling based on the graph density + NV,NE is |V| and |E| of the graph + maxv, maxe are maximal |V| and |E|''' + + d = (2 * NE) / (NV * (NV - 1)) + eEst = (maxv * (maxv - 1) * d) / 2 + nEst = (d + math.sqrt(d**2 + 8 * d * maxe)) / (2 * d) + if eEst >= maxe: + N = int(nEst) + if nEst >= maxv: + N = int(maxv) + return N + + clusterindex = 1 + cluster_input_args = [] + ppn = [] + # is is comparative analysis? + if sequences.prefix_length: + self.conn.execute("CREATE TABLE comparative_counts (clusterindex INTEGER," + + ", ".join(["[{}] INTEGER".format(i) for i in sequences.prefix_codes.keys()]) + ")") + # do for comparative analysis + + for cl in range(self.number_of_clusters): + prefix_codes = dict((key, 0) for key in sequences.prefix_codes.keys()) + for i in self.get_cluster_reads(cl): + prefix_codes[i[0:sequences.prefix_length]] += 1 + header = ", ".join(["[" + str(i) + "]" for i in prefix_codes.keys()]) + values = ", ".join([str(i) for i in prefix_codes.values()]) + self.conn.execute( + "INSERT INTO comparative_counts (clusterindex, {}) VALUES ({}, {})".format( + header, cl, values)) + else: + prefix_codes = {} + + while True: + read_names = self.get_cluster_reads(clusterindex) + supercluster = self.get_cluster_supercluster(clusterindex) + N = len(read_names) + print("sequences.ids_kept -2 ") + print(sequences.ids_kept) + if sequences.ids_kept: + N_adjusted = round(len(set(sequences.ids_kept).intersection(read_names)) * + ((1 - config.FILTER_PROPORTION_OF_KEPT) / + config.FILTER_PROPORTION_OF_KEPT) + N) + else: + N_adjusted = N + if N < min_size: + break + else: + LOGGER.info("exporting cluster {}".format(clusterindex)) + blast_file = "{dir}/dir_CL{i:04}/hitsort_part.csv".format( + dir=directory, i=clusterindex) + cluster_dir = "{dir}/dir_CL{i:04}".format(dir=directory, + i=clusterindex) + fasta_file = "{dir}/reads_selection.fasta".format(dir=cluster_dir) + fasta_file_full = "{dir}/reads.fasta".format(dir=cluster_dir) + + os.makedirs(os.path.dirname(blast_file), exist_ok=True) + E = self.extract_cluster_blast(index=clusterindex, + path=blast_file) + # check if blast must be sampled + n_sample = estimate_sample_size(NV=N, + NE=E, + maxv=config.CLUSTER_VMAX, + maxe=config.CLUSTER_EMAX) + LOGGER.info("directories created..") + if n_sample < N: + LOGGER.info(("cluster is too large - sampling.." + "original size: {N}\n" + "sample size: {NS}\n" + "").format(N=N, NS=n_sample)) + random.seed(self.seed) + read_names_sample = random.sample(read_names, n_sample) + LOGGER.info("reads id sampled...") + blast_file_sample = "{dir}/dir_CL{i:04}/blast_sample.csv".format( + dir=directory, i=clusterindex) + E_sample = self.extract_cluster_blast( + index=clusterindex, + path=blast_file, + ids=read_names_sample) + LOGGER.info("numner of edges in sample: {}".format( + E_sample)) + sequences.save2fasta(fasta_file, subset=read_names_sample) + sequences.save2fasta(fasta_file_full, subset=read_names) + + else: + read_names_sample = None + E_sample = None + blast_file_sample = None + n_sample = None + sequences.save2fasta(fasta_file_full, subset=read_names) + ## TODO - use symlink instead of : + sequences.save2fasta(fasta_file, subset=read_names) + # export individual annotations tables: + # annotation is always for full cluster + LOGGER.info("exporting cluster annotation") + annotations = {} + annotations_custom = {} + for n in sequences.annotations: + print("sequences.annotations:", n) + if n.find("custom_db") == 0: + print("custom") + annotations_custom[n] = sequences.save_annotation( + annotation_name=n, + subset=read_names, + dir=cluster_dir) + else: + print("built in") + annotations[n] = sequences.save_annotation( + annotation_name=n, + subset=read_names, + dir=cluster_dir) + + cluster_input_args.append([ + n_sample, N,N_adjusted, blast_file, fasta_file, fasta_file_full, + clusterindex, supercluster, self.paired, + tRNA_database_path, satellite_model_path, sequences.prefix_codes, + prefix_codes, annotations, annotations_custom + ]) + clusterindex += 1 + ppn.append(load_fun(N, E)) + + + + self.conn.commit() + + # run in parallel: + # reorder jobs based on the ppn: + cluster_input_args = [ + x + for (y, x) in sorted( + zip(ppn, cluster_input_args), + key=lambda pair: pair[0], + reverse=True) + ] + ppn = sorted(ppn, reverse=True) + LOGGER.info("creating clusters in parallel") + clusters_info = parallel(Cluster, + *[list(i) for i in zip(*cluster_input_args)], + ppn=ppn) + # sort it back: + clusters_info = sorted(clusters_info, key=lambda cl: cl.index) + return clusters_info + + +class Cluster(): + ''' store and show information about cluster properties ''' + + def __init__(self, + size, + size_real, + size_adjusted, + blast_file, + fasta_file, + fasta_file_full, + index, + supercluster, + paired, + tRNA_database_path, + satellite_model_path, + all_prefix_codes, + prefix_codes, + annotations, + annotations_custom={}, + loop_index_threshold=0.7, + pair_completeness_threshold=0.40, + loop_index_unpaired_threshold=0.85): + if size: + # cluster was scaled down + self.size = size + self.size_real = size_real + else: + self.size = self.size_real = size_real + self.size_adjusted = size_adjusted + self.filtered = True if size_adjusted != size_real else False + self.all_prefix_codes = all_prefix_codes.keys + self.prefix_codes = prefix_codes + self.dir = FilePath(os.path.dirname(blast_file)) + self.blast_file = FilePath(blast_file) + self.fasta_file = FilePath(fasta_file) + self.fasta_file_full = FilePath(fasta_file_full) + self.index = index + self.assembly_files = {} + self.ltr_detection = None + self.supercluster = supercluster + self.annotations_files = annotations + self.annotations_files_custom = annotations_custom + self.annotations_summary, self.annotations_table = self._summarize_annotations( + annotations, size_real) + # add annotation + if len(annotations_custom): + self.annotations_summary_custom, self.annotations_custom_table = self._summarize_annotations( + annotations_custom, size_real) + else: + self.annotations_summary_custom, self.annotations_custom_table = "", "" + + self.paired = paired + self.graph_file = FilePath("{0}/graph_layout.GL".format(self.dir)) + self.directed_graph_file = FilePath( + "{0}/graph_layout_directed.RData".format(self.dir)) + self.fasta_oriented_file = FilePath("{0}/reads_selection_oriented.fasta".format( + self.dir)) + self.image_file = FilePath("{0}/graph_layout.png".format(self.dir)) + self.image_file_tmb = FilePath("{0}/graph_layout_tmb.png".format(self.dir)) + self.html_report_main = FilePath("{0}/index.html".format(self.dir)) + self.html_report_files = FilePath("{0}/html_files".format(self.dir)) + self.supercluster_best_hit = "NA" + TAREAN = r2py.R(config.RSOURCE_tarean) + LOGGER.info("creating graph no.{}".format(self.index)) + # if FileType muast be converted to str for rfunctions + graph_info = eval( + TAREAN.mgblast2graph( + self.blast_file, + seqfile=self.fasta_file, + seqfile_full=self.fasta_file_full, + graph_destination=self.graph_file, + directed_graph_destination=self.directed_graph_file, + oriented_sequences=self.fasta_oriented_file, + image_file=self.image_file, + image_file_tmb=self.image_file_tmb, + repex=True, + paired=self.paired, + satellite_model_path=satellite_model_path, + maxv=config.CLUSTER_VMAX, + maxe=config.CLUSTER_EMAX) + ) + print(graph_info) + self.ecount = graph_info['ecount'] + self.vcount = graph_info['vcount'] + self.loop_index = graph_info['loop_index'] + self.pair_completeness = graph_info['pair_completeness'] + self.orientation_score = graph_info['escore'] + self.satellite_probability = graph_info['satellite_probability'] + self.satellite = graph_info['satellite'] + # for paired reads: + cond1 = (self.paired and self.loop_index > loop_index_threshold and + self.pair_completeness > pair_completeness_threshold) + # no pairs + cond2 = ((not self.paired) and + self.loop_index > loop_index_unpaired_threshold) + if (cond1 or cond2) and config.ARGS.options.name != "oxford_nanopore": + self.putative_tandem = True + self.dir_tarean = FilePath("{}/tarean".format(self.dir)) + lock_file = self.dir + "../lock" + out = eval( + TAREAN.tarean(input_sequences=self.fasta_oriented_file, + output_dir=self.dir_tarean, + CPU=1, + reorient_reads=False, + tRNA_database_path=tRNA_database_path, + lock_file=lock_file) + ) + self.html_tarean = FilePath(out['htmlfile']) + self.tarean_contig_file = out['tarean_contig_file'] + self.TR_score = out['TR_score'] + self.TR_monomer_length = out['TR_monomer_length'] + self.TR_consensus = out['TR_consensus'] + self.pbs_score = out['pbs_score'] + self.max_ORF_length = out['orf_l'] + if (out['orf_l'] > config.ORF_THRESHOLD or + out['pbs_score'] > config.PBS_THRESHOLD): + self.tandem_rank = 3 + elif self.satellite: + self.tandem_rank = 1 + else: + self.tandem_rank = 2 + # some tandems could be rDNA genes - this must be check + # by annotation + if self.annotations_table: + rdna_score = 0 + contamination_score = 0 + for i in self.annotations_table: + if 'rDNA/' in i[0]: + rdna_score += i[1] + if 'contamination' in i[0]: + contamination_score += i[1] + if rdna_score > config.RDNA_THRESHOLD: + self.tandem_rank = 4 + if contamination_score > config.CONTAMINATION_THRESHOLD: + self.tandem_rank = 0 # other + + # by custom annotation - castom annotation has preference + if self.annotations_custom_table: + print("custom table searching") + rdna_score = 0 + contamination_score = 0 + print(self.annotations_custom_table) + for i in self.annotations_custom_table: + if 'rDNA' in i[0]: + rdna_score += i[1] + if 'contamination' in i[0]: + contamination_score += i[1] + if rdna_score > 0: + self.tandem_rank = 4 + if contamination_score > config.CONTAMINATION_THRESHOLD: + self.tandem_rank = 0 # other + + else: + self.putative_tandem = False + self.dir_tarean = None + self.html_tarean = None + self.TR_score = None + self.TR_monomer_length = None + self.TR_consensus = None + self.pbs_score = None + self.max_ORF_length = None + self.tandem_rank = 0 + self.tarean_contig_file = None + + def __str__(self): + out = [ + "cluster no {}:".format(self.index), + "Number of vertices : {}".format(self.size), + "Number of edges : {}".format(self.ecount), + "Loop index : {}".format(self.loop_index), + "Pair completeness : {}".format(self.pair_completeness), + "Orientation score : {}".format(self.orientation_score) + ] + return "\n".join(out) + + def listing(self, asdict=True): + ''' convert attributes to dictionary for printing purposes''' + out = {} + for i in dir(self): + # do not show private + if i[:2] != "__": + value = getattr(self, i) + if not callable(value): + # for dictionary + if isinstance(value, dict): + for k in value: + out[i + "_" + k] = value[k] + else: + out[i] = value + if asdict: + return out + else: + return {'keys': list(out.keys()), 'values': list(out.values())} + + def detect_ltr(self, trna_database): + '''detection of ltr in assembly files, output of analysis is stored in file''' + CREATE_ANNOTATION = r2py.R(config.RSOURCE_create_annotation, verbose=False) + if self.assembly_files['{}.{}.ace']: + ace_file = self.assembly_files['{}.{}.ace'] + print(ace_file, "running LTR detection") + fout = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['BASE']) + subprocess.check_call([ + config.LTR_DETECTION, + '-i', ace_file, + '-o', fout, + '-p', trna_database]) + # evaluate LTR presence + fn = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['PBS_BLAST']) + self.ltr_detection = CREATE_ANNOTATION.evaluate_LTR_detection(fn) + + + @staticmethod + def _summarize_annotations(annotations_files, size): + ''' will tabulate annotation results ''' + # TODO + summaries = {} + # weight is in percentage + weight = 100 / size + for i in annotations_files: + with open(annotations_files[i]) as f: + header = f.readline().split() + id_index = [ + i for i, item in enumerate(header) if item == "db_id" + ][0] + for line in f: + classification = line.split()[id_index].split("#")[1] + if classification in summaries: + summaries[classification] += weight + else: + summaries[classification] = weight + # format summaries for printing + annotation_string = "" + annotation_table = [] + for i in sorted(summaries.items(), key=lambda x: x[1], reverse=True): + ## hits with smaller proportion are not shown! + if i[1] > 0.1: + if i[1] > 1: + annotation_string += "<b>{1:.2f}% {0}</b>\n".format(*i) + else: + annotation_string += "{1:.2f}% {0}\n".format(*i) + annotation_table.append(i) + return [annotation_string, annotation_table] + + @staticmethod + def add_cluster_table_to_database(cluster_table, db_path): + '''get column names from Cluster object and create + correspopnding table in database values from all + clusters are filled to database''' + column_name_and_type = [] + column_list = [] + + # get all atribute names -> they are column names + # in sqlite table, detect proper sqlite type + def identity(x): + return (x) + + for i in cluster_table[1]: + t = type(cluster_table[1][i]) + if t == int: + sqltype = "integer" + convert = identity + elif t == float: + sqltype = "real" + convert = identity + elif t == bool: + sqltype = "boolean" + convert = bool + else: + sqltype = "text" + convert = str + column_name_and_type += ["[{}] {}".format(i, sqltype)] + column_list += [tuple((i, convert))] + header = ", ".join(column_name_and_type) + db = sqlite3.connect(db_path) + c = db.cursor() + print("CREATE TABLE cluster_info ({})".format(header)) + c.execute("CREATE TABLE cluster_info ({})".format(header)) + # file data to cluster_table + buffer = [] + for i in cluster_table: + buffer.append(tuple('{}'.format(fun(i[j])) for j, fun in + column_list)) + wildcards = ",".join(["?"] * len(column_list)) + print(buffer) + c.executemany("insert into cluster_info values ({})".format(wildcards), + buffer) + db.commit()