comparison lib/graphtools.py @ 0:1d1b9e1b2e2f draft

Uploaded
author petr-novak
date Thu, 19 Dec 2019 10:24:45 -0500
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:1d1b9e1b2e2f
1 #!/usr/bin/env python3
2 '''
3 This module is mainly for large graph (e.i hitsort) storage, parsing and for clustering
4 '''
5 import os
6 import sys
7 import sqlite3
8 import time
9 import subprocess
10 import logging
11 from collections import defaultdict
12 import collections
13 import operator
14 import math
15 import random
16 import itertools
17 import config
18 from lib import r2py
19 from lib.utils import FilePath
20 from lib.parallel.parallel import parallel2 as parallel
21 REQUIRED_VERSION = (3, 4)
22 MAX_BUFFER_SIZE = 100000
23 if sys.version_info < REQUIRED_VERSION:
24 raise Exception("\n\npython 3.4 or higher is required!\n")
25 LOGGER = logging.getLogger(__name__)
26
27
28 def dfs(start, graph):
29 """
30 helper function for cluster merging.
31 Does depth-first search, returning a set of all nodes seen.
32 Takes: a graph in node --> [neighbors] form.
33 """
34 visited, worklist = set(), [start]
35
36 while worklist:
37 node = worklist.pop()
38 if node not in visited:
39 visited.add(node)
40 # Add all the neighbors to the worklist.
41 worklist.extend(graph[node])
42
43 return visited
44
45
46 def graph_components(edges):
47 """
48 Given a graph as a list of edges, divide the nodes into components.
49 Takes a list of pairs of nodes, where the nodes are integers.
50 """
51
52 # Construct a graph (mapping node --> [neighbors]) from the edges.
53 graph = defaultdict(list)
54 nodes = set()
55
56 for v1, v2 in edges:
57 nodes.add(v1)
58 nodes.add(v2)
59
60 graph[v1].append(v2)
61 graph[v2].append(v1)
62
63 # Traverse the graph to find the components.
64 components = []
65
66 # We don't care what order we see the nodes in.
67 while nodes:
68 component = dfs(nodes.pop(), graph)
69 components.append(component)
70
71 # Remove this component from the nodes under consideration.
72 nodes -= component
73
74 return components
75
76
77 class Graph():
78 '''
79 create Graph object stored in sqlite database, either in memory or on disk
80 structure of table is:
81 V1 V2 weigth12
82 V2 V3 weight23
83 V4 V5 weight45
84 ...
85 ...
86 !! this is undirected simple graph - duplicated edges must
87 be removed on graph creation
88
89 '''
90 # seed for random number generator - this is necessary for reproducibility between runs
91 seed = '123'
92
93 def __init__(self,
94 source=None,
95 filename=None,
96 new=False,
97 paired=True,
98 seqids=None):
99 '''
100 filename : fite where to store database, if not defined it is stored in memory
101 source : ncol file from which describe graph
102 new : if false and source is not define graph can be loaded from database (filename)
103
104 vertices_name must be in correcti order!!!
105 '''
106
107 self.filename = filename
108 self.source = source
109 self.paired = paired
110 # path to indexed graph - will be set later
111 self.indexed_file = None
112 self._cluster_list = None
113 # these two attributes are set after clustering
114 # communities before merging
115 self.graph_2_community0 = None
116 # communities after merging
117 self.graph_2_community = None
118 self.number_of_clusters = None
119 self.binary_file = None
120 self.cluster_sizes = None
121 self.graph_tree = None
122 self.graph_tree_log = None
123 self.weights_file = None
124
125 if filename:
126 if os.path.isfile(filename) and (new or source):
127 os.remove(filename)
128 self.conn = sqlite3.connect(filename)
129 else:
130 self.conn = sqlite3.connect(":memory:")
131 c = self.conn.cursor()
132
133 c.execute("PRAGMA page_size=8192")
134 c.execute("PRAGMA cache_size = 2000000 ") # this helps
135
136 try:
137 c.execute((
138 "create table graph (v1 integer, v2 integer, weight integer, "
139 "pair integer, v1length integer, v1start integer, v1end integer, "
140 "v2length integer, v2start integer, v2end integer, pid integer,"
141 "evalue real, strand text )"))
142 except sqlite3.OperationalError:
143 pass # table already exist
144 else:
145 c.execute(
146 "create table vertices (vertexname text primary key, vertexindex integer)")
147 tables = sorted(c.execute(
148 "SELECT name FROM sqlite_master WHERE type='table'").fetchall())
149
150 if not [('graph', ), ('vertices', )] == tables:
151 raise Exception("tables for sqlite for graph are not correct")
152
153 if source:
154 self._read_from_hitsort()
155
156 if paired and seqids:
157 # vertices must be defined - create graph of paired reads:
158 # last character must disinguish pair
159 c.execute((
160 "create table pairs (basename, vertexname1, vertexname2,"
161 "v1 integer, v2 integer, cluster1 integer, cluster2 integer)"))
162 buffer = []
163 for i, k in zip(seqids[0::2], seqids[1::2]):
164 assert i[:-1] == k[:-1], "problem with pair reads ids"
165 # some vertices are not in graph - singletons
166 try:
167 index1 = self.vertices[i]
168 except KeyError:
169 index1 = -1
170
171 try:
172 index2 = self.vertices[k]
173 except KeyError:
174 index2 = -1
175
176 buffer.append((i[:-1], i, k, index1, index2))
177
178 self.conn.executemany(
179 "insert into pairs (basename, vertexname1, vertexname2, v1, v2) values (?,?,?,?,?)",
180 buffer)
181 self.conn.commit()
182
183 def _read_from_hitsort(self):
184
185 c = self.conn.cursor()
186 c.execute("delete from graph")
187 buffer = []
188 vertices = {}
189 counter = 0
190 v_count = 0
191 with open(self.source, 'r') as f:
192 for i in f:
193 edge_index = {}
194 items = i.split()
195 # get or insert vertex index
196 for vn in items[0:2]:
197 if vn not in vertices:
198 vertices[vn] = v_count
199 edge_index[vn] = v_count
200 v_count += 1
201 else:
202 edge_index[vn] = vertices[vn]
203 if self.paired:
204 pair = int(items[0][:-1] == items[1][:-1])
205 else:
206 pair = 0
207 buffer.append(((edge_index[items[0]], edge_index[items[1]],
208 items[2], pair) + tuple(items[3:])))
209 if len(buffer) == MAX_BUFFER_SIZE:
210 counter += 1
211 self.conn.executemany(
212 "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)",
213 buffer)
214 buffer = []
215 if buffer:
216 self.conn.executemany(
217 "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)",
218 buffer)
219
220 self.conn.commit()
221 self.vertices = vertices
222 self.vertexid2name = {
223 vertex: index
224 for index, vertex in vertices.items()
225 }
226 self.vcount = len(vertices)
227 c = self.conn.cursor()
228 c.execute("select count(*) from graph")
229 self.ecount = c.fetchone()[0]
230 # fill table of vertices
231 self.conn.executemany("insert into vertices values (?,?)",
232 vertices.items())
233 self.conn.commit()
234
235 def save_indexed_graph(self, file=None):
236 if not file:
237 self.indexed_file = "{}.int".format(self.source)
238 else:
239 self.indexed_file = file
240 c = self.conn.cursor()
241 with open(self.indexed_file, 'w') as f:
242 out = c.execute('select v1,v2,weight from graph')
243 for v1, v2, weight in out:
244 f.write('{}\t{}\t{}\n'.format(v1, v2, weight))
245
246 def get_subgraph(self, vertices):
247 pass
248
249 def _levels(self):
250 with open(self.graph_tree_log, 'r') as f:
251 levels = -1
252 for i in f:
253 if i[:5] == 'level':
254 levels += 1
255 return levels
256
257 def _reindex_community(self, id2com):
258 '''
259 reindex community and superclusters so that biggest cluster is no.1
260 '''
261 self.conn.commit()
262 _, community, supercluster = zip(*id2com)
263 (cluster_index, frq, self.cluster_sizes,
264 self.number_of_clusters) = self._get_index_and_frequency(community)
265
266 supercluster_index, sc_frq, _, _ = self._get_index_and_frequency(
267 supercluster)
268 id2com_reindexed = []
269
270 for i, _ in enumerate(id2com):
271 id2com_reindexed.append((id2com[i][0], id2com[i][1], frq[
272 i], cluster_index[i], supercluster_index[i], sc_frq[i]))
273 return id2com_reindexed
274
275 @staticmethod
276 def _get_index_and_frequency(membership):
277 frequency_table = collections.Counter(membership)
278 frequency_table_sorted = sorted(frequency_table.items(),
279 key=operator.itemgetter(1),
280 reverse=True)
281 frq = []
282 for i in membership:
283 frq.append(frequency_table[i])
284 rank = {}
285 index = 0
286 for comm, _ in frequency_table_sorted:
287 index += 1
288 rank[comm] = index
289 cluster_index = [rank[i] for i in membership]
290 cluster_sizes = [i[1] for i in frequency_table_sorted]
291 number_of_clusters = len(frequency_table)
292 return [cluster_index, frq, cluster_sizes, number_of_clusters]
293
294 def louvain_clustering(self, merge_threshold=0, cleanup=False):
295 '''
296 input - graph
297 output - list of clusters
298 executables path ??
299 '''
300 LOGGER.info("converting hitsort to binary format")
301 self.binary_file = "{}.bin".format(self.indexed_file)
302 self.weights_file = "{}.weight".format(self.indexed_file)
303 self.graph_tree = "{}.graph_tree".format(self.indexed_file)
304 self.graph_tree_log = "{}.graph_tree_log".format(self.indexed_file)
305 self.graph_2_community0 = "{}.graph_2_community0".format(
306 self.indexed_file)
307 self._cluster_list = None
308 self.graph_2_community = "{}.graph_2_community".format(
309 self.indexed_file)
310 print(["louvain_convert", "-i", self.indexed_file, "-o",
311 self.binary_file, "-w", self.weights_file])
312 subprocess.check_call(
313 ["louvain_convert", "-i", self.indexed_file, "-o",
314 self.binary_file, "-w", self.weights_file],
315 timeout=None)
316
317 gt = open(self.graph_tree, 'w')
318 gtl = open(self.graph_tree_log, 'w')
319 LOGGER.info("running louvain clustering...")
320 subprocess.check_call(
321 ["louvain_community", self.binary_file, "-l", "-1", "-w",
322 self.weights_file, "-v ", "-s", self.seed],
323 stdout=gt,
324 stderr=gtl,
325 timeout=None)
326 gt.close()
327 gtl.close()
328
329 LOGGER.info("creating list of cummunities")
330 gt2c = open(self.graph_2_community0, 'w')
331 subprocess.check_call(
332 ['louvain_hierarchy', self.graph_tree, "-l", str(self._levels())],
333 stdout=gt2c)
334 gt2c.close()
335 if merge_threshold and self.paired:
336 com2newcom = self.find_superclusters(merge_threshold)
337 elif self.paired:
338 com2newcom = self.find_superclusters(config.SUPERCLUSTER_THRESHOLD)
339 else:
340 com2newcom = {}
341 # merging of clusters, creatting superclusters
342 LOGGER.info("mergings clusters based on mate-pairs ")
343 # modify self.graph_2_community file
344 # rewrite graph2community
345 with open(self.graph_2_community0, 'r') as fin:
346 with open(self.graph_2_community, 'w') as fout:
347 for i in fin:
348 # write graph 2 community file in format:
349 # id communityid supeclusterid
350 # if merging - community and superclustwers are identical
351 vi, com = i.split()
352 if merge_threshold:
353 ## mergin
354 if int(com) in com2newcom:
355 fout.write("{} {} {}\n".format(vi, com2newcom[int(
356 com)], com2newcom[int(com)]))
357 else:
358 fout.write("{} {} {}\n".format(vi, com, com))
359 else:
360 ## superclusters
361 if int(com) in com2newcom:
362 fout.write("{} {} {}\n".format(vi, com, com2newcom[
363 int(com)]))
364 else:
365 fout.write("{} {} {}\n".format(vi, com, com))
366
367 LOGGER.info("loading communities into database")
368 c = self.conn.cursor()
369 c.execute(("create table communities (vertexindex integer primary key,"
370 "community integer, size integer, cluster integer, "
371 "supercluster integer, supercluster_size integer)"))
372 id2com = []
373 with open(self.graph_2_community, 'r') as f:
374 for i in f:
375 name, com, supercluster = i.split()
376 id2com.append((name, com, supercluster))
377 id2com_reindexed = self._reindex_community(id2com)
378 c.executemany("insert into communities values (?,?,?,?,?,?)",
379 id2com_reindexed)
380 #create table of superclusters - clusters
381 c.execute(("create table superclusters as "
382 "select distinct supercluster, supercluster_size, "
383 "cluster, size from communities;"))
384 # create view id-index-cluster
385 c.execute(
386 ("CREATE VIEW vertex_cluster AS SELECT vertices.vertexname,"
387 "vertices.vertexindex, communities.cluster, communities.size"
388 " FROM vertices JOIN communities USING (vertexindex)"))
389 self.conn.commit()
390
391 # add clustering infor to graph
392 LOGGER.info("updating graph table")
393 t0 = time.time()
394
395 c.execute("alter table graph add c1 integer")
396 c.execute("alter table graph add c2 integer")
397 c.execute(("update graph set c1 = (select cluster FROM communities "
398 "where communities.vertexindex=graph.v1)"))
399 c.execute(
400 ("update graph set c2 = (select cluster FROM communities where "
401 "communities.vertexindex=graph.v2)"))
402 self.conn.commit()
403 t1 = time.time()
404 LOGGER.info("updating graph table - done in {} seconds".format(t1 -
405 t0))
406
407 # identify similarity connections between clusters
408 c.execute(
409 "create table cluster_connections as SELECT c1,c2 , count(*) FROM (SELECT c1, c2 FROM graph WHERE c1>c2 UNION ALL SELECT c2 as c1, c1 as c2 FROM graph WHERE c2>c1) GROUP BY c1, c2")
410 # TODO - remove directionality - summarize -
411
412 # add cluster identity to pairs table
413
414 if self.paired:
415 LOGGER.info("analyzing pairs ")
416 t0 = time.time()
417 c.execute(
418 "UPDATE pairs SET cluster1=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v1)")
419 t1 = time.time()
420 LOGGER.info(
421 "updating pairs table - cluster1 - done in {} seconds".format(
422 t1 - t0))
423
424 t0 = time.time()
425 c.execute(
426 "UPDATE pairs SET cluster2=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v2)")
427 t1 = time.time()
428 LOGGER.info(
429 "updating pairs table - cluster2 - done in {} seconds".format(
430 t1 - t0))
431 # reorder records
432
433 t0 = time.time()
434 c.execute(
435 "UPDATE pairs SET cluster1=cluster2, cluster2=cluster1, vertexname1=vertexname2,vertexname2=vertexname1 where cluster1<cluster2")
436 t1 = time.time()
437 LOGGER.info("sorting - done in {} seconds".format(t1 - t0))
438
439 t0 = time.time()
440 c.execute(
441 "create table cluster_mate_connections as select cluster1 as c1, cluster2 as c2, count(*) as N, group_concat(basename) as ids from pairs where cluster1!=cluster2 group by cluster1, cluster2;")
442 t1 = time.time()
443 LOGGER.info(
444 "creating cluster_mate_connections table - done in {} seconds".format(
445 t1 - t0))
446 # summarize
447 t0 = time.time()
448 self._calculate_pair_bond()
449 t1 = time.time()
450 LOGGER.info(
451 "calculating cluster pair bond - done in {} seconds".format(
452 t1 - t0))
453 t0 = time.time()
454 else:
455 # not paired - create empty tables
456 self._add_empty_tables()
457
458 self.conn.commit()
459 t1 = time.time()
460 LOGGER.info("commiting changes - done in {} seconds".format(t1 - t0))
461
462 if cleanup:
463 LOGGER.info("cleaning clustering temp files")
464 os.unlink(self.binary_file)
465 os.unlink(self.weights_file)
466 os.unlink(self.graph_tree)
467 os.unlink(self.graph_tree_log)
468 os.unlink(self.graph_2_community0)
469 os.unlink(self.graph_2_community)
470 os.unlink(self.indexed_file)
471 self.binary_file = None
472 self.weights_file = None
473 self.graph_tree = None
474 self.graph_tree_log = None
475 self.graph_2_community0 = None
476 self.graph_2_community = None
477 self.indexed_file = None
478
479 # calcultate k
480
481 def find_superclusters(self, merge_threshold):
482 '''Find superclusters from clustering based on paired reads '''
483 clsdict = {}
484 with open(self.graph_2_community0, 'r') as f:
485 for i in f:
486 vi, com = i.split()
487 if com in clsdict:
488 clsdict[com] += [self.vertexid2name[int(vi)][0:-1]]
489 else:
490 clsdict[com] = [self.vertexid2name[int(vi)][0:-1]]
491 # remove all small clusters - these will not be merged:
492 small_cls = []
493 for i in clsdict:
494 if len(clsdict[i]) < config.MINIMUM_NUMBER_OF_READS_FOR_MERGING:
495 small_cls.append(i)
496 for i in small_cls:
497 del clsdict[i]
498 pairs = []
499 for i, j in itertools.combinations(clsdict, 2):
500 s1 = set(clsdict[i])
501 s2 = set(clsdict[j])
502 wgh = len(s1 & s2)
503 if wgh < config.MINIMUM_NUMBER_OF_SHARED_PAIRS_FOR_MERGING:
504 continue
505 else:
506 n1 = len(s1) * 2 - len(clsdict[i])
507 n2 = len(s2) * 2 - len(clsdict[j])
508 k = 2 * wgh / (n1 + n2)
509 if k > merge_threshold:
510 pairs.append((int(i), int(j)))
511 # find connected commponents - will be merged
512 cls2merge = graph_components(pairs)
513 com2newcom = {}
514 for i in cls2merge:
515 newcom = min(i)
516 for j in i:
517 com2newcom[j] = newcom
518 return com2newcom
519
520 def adjust_cluster_size(self, proportion_kept, ids_kept):
521 LOGGER.info("adjusting cluster sizes")
522 c = self.conn.cursor()
523 c.execute("ALTER TABLE superclusters ADD COLUMN size_uncorrected INTEGER")
524 c.execute("UPDATE superclusters SET size_uncorrected=size")
525 if ids_kept:
526 ids_kept_set = set(ids_kept)
527 ratio = (1 - proportion_kept)/proportion_kept
528 for cl, size in c.execute("SELECT cluster,size FROM superclusters"):
529 ids = self.get_cluster_reads(cl)
530 ovl_size = len(ids_kept_set.intersection(ids))
531 size_adjusted = int(len(ids) + ovl_size * ratio)
532 if size_adjusted > size:
533 c.execute("UPDATE superclusters SET size=? WHERE cluster=?",
534 (size_adjusted, cl))
535 self.conn.commit()
536 LOGGER.info("adjusting cluster sizes - done")
537
538 def export_cls(self, path):
539 with open(path, 'w') as f:
540 for i in range(1, self.number_of_clusters + 1):
541 ids = self.get_cluster_reads(i)
542 f.write(">CL{}\t{}\n".format(i, len(ids)))
543 f.write("\t".join(ids))
544 f.write("\n")
545
546 def _calculate_pair_bond(self):
547 c = self.conn.cursor()
548 out = c.execute("select c1, c2, ids from cluster_mate_connections")
549 buffer = []
550 for c1, c2, ids in out:
551 w = len(set(ids.split(",")))
552 n1 = len(set([i[:-1] for i in self.get_cluster_reads(c1)
553 ])) * 2 - len(self.get_cluster_reads(c1))
554 n2 = len(set([i[:-1] for i in self.get_cluster_reads(c2)
555 ])) * 2 - len(self.get_cluster_reads(c2))
556 buffer.append((c1, c2, n1, n2, w, 2 * w / (n1 + n2)))
557 c.execute(
558 "CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)")
559 c.executemany(" INSERT INTO cluster_mate_bond values (?,?,?,?,?,?)",
560 buffer)
561
562 def _add_empty_tables(self):
563 '''This is used with reads that are not paired
564 - it creates empty mate tables, this is necessary for
565 subsequent reporting to work corectly '''
566 c = self.conn.cursor()
567 c.execute(("CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, "
568 "n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)"))
569 c.execute(
570 "CREATE TABLE cluster_mate_connections (c1 INTEGER, c2 INTEGER, N INTEGER, ids TEXT) ")
571
572 def get_cluster_supercluster(self, cluster):
573 '''Get supercluster id for suplied cluster '''
574 c = self.conn.cursor()
575 out = c.execute(
576 'SELECT supercluster FROM communities WHERE cluster="{0}" LIMIT 1'.format(
577 cluster))
578 sc = out.fetchone()[0]
579 return sc
580
581 def get_cluster_reads(self, cluster):
582
583 if self._cluster_list:
584 return self._cluster_list[str(cluster)]
585 else:
586 # if queried first time
587 c = self.conn.cursor()
588 out = c.execute("select cluster, vertexname from vertex_cluster")
589 cluster_list = collections.defaultdict(list)
590 for clusterindex, vertexname in out:
591 cluster_list[str(clusterindex)].append(vertexname)
592 self._cluster_list = cluster_list
593 return self._cluster_list[str(cluster)]
594
595
596 def extract_cluster_blast(self, path, index, ids=None):
597 ''' Extract blast for cluster and save it to path
598 return number of blast lines ( i.e. number of graph edges E)
599 if ids is specified , only subset of blast is used'''
600 c = self.conn.cursor()
601 if ids:
602 vertexindex = (
603 "select vertexindex from vertices "
604 "where vertexname in ({})").format('"' + '","'.join(ids) + '"')
605
606 out = c.execute(("select * from graph where c1={0} and c2={0}"
607 " and v1 in ({1}) and v2 in ({1})").format(
608 index, vertexindex))
609 else:
610 out = c.execute(
611 "select * from graph where c1={0} and c2={0}".format(index))
612 E = 0
613 N = len(self.get_cluster_reads(index))
614 with open(path, 'w') as f:
615 for i in out:
616 print(self.vertexid2name[i[0]],
617 self.vertexid2name[
618 i[1]],
619 i[2],
620 *i[4:13],
621 sep='\t',
622 file=f)
623 E += 1
624 return E
625
626 def export_clusters_files_multiple(self,
627 min_size,
628 directory,
629 sequences=None,
630 tRNA_database_path=None,
631 satellite_model_path=None):
632 def load_fun(N, E):
633 ''' estimate mem usage from graph size and density'''
634 NE = math.log(float(N) * float(E), 10)
635 if NE > 11.5:
636 return 1
637 if NE > 11:
638 return 0.9
639 if NE > 10:
640 return 0.4
641 if NE > 9:
642 return 0.2
643 if NE > 8:
644 return 0.07
645 return 0.02
646
647 def estimate_sample_size(NV, NE, maxv, maxe):
648 ''' estimat suitable sampling based on the graph density
649 NV,NE is |V| and |E| of the graph
650 maxv, maxe are maximal |V| and |E|'''
651
652 d = (2 * NE) / (NV * (NV - 1))
653 eEst = (maxv * (maxv - 1) * d) / 2
654 nEst = (d + math.sqrt(d**2 + 8 * d * maxe)) / (2 * d)
655 if eEst >= maxe:
656 N = int(nEst)
657 if nEst >= maxv:
658 N = int(maxv)
659 return N
660
661 clusterindex = 1
662 cluster_input_args = []
663 ppn = []
664 # is is comparative analysis?
665 if sequences.prefix_length:
666 self.conn.execute("CREATE TABLE comparative_counts (clusterindex INTEGER,"
667 + ", ".join(["[{}] INTEGER".format(i) for i in sequences.prefix_codes.keys()]) + ")")
668 # do for comparative analysis
669
670 for cl in range(self.number_of_clusters):
671 prefix_codes = dict((key, 0) for key in sequences.prefix_codes.keys())
672 for i in self.get_cluster_reads(cl):
673 prefix_codes[i[0:sequences.prefix_length]] += 1
674 header = ", ".join(["[" + str(i) + "]" for i in prefix_codes.keys()])
675 values = ", ".join([str(i) for i in prefix_codes.values()])
676 self.conn.execute(
677 "INSERT INTO comparative_counts (clusterindex, {}) VALUES ({}, {})".format(
678 header, cl, values))
679 else:
680 prefix_codes = {}
681
682 while True:
683 read_names = self.get_cluster_reads(clusterindex)
684 supercluster = self.get_cluster_supercluster(clusterindex)
685 N = len(read_names)
686 print("sequences.ids_kept -2 ")
687 print(sequences.ids_kept)
688 if sequences.ids_kept:
689 N_adjusted = round(len(set(sequences.ids_kept).intersection(read_names)) *
690 ((1 - config.FILTER_PROPORTION_OF_KEPT) /
691 config.FILTER_PROPORTION_OF_KEPT) + N)
692 else:
693 N_adjusted = N
694 if N < min_size:
695 break
696 else:
697 LOGGER.info("exporting cluster {}".format(clusterindex))
698 blast_file = "{dir}/dir_CL{i:04}/hitsort_part.csv".format(
699 dir=directory, i=clusterindex)
700 cluster_dir = "{dir}/dir_CL{i:04}".format(dir=directory,
701 i=clusterindex)
702 fasta_file = "{dir}/reads_selection.fasta".format(dir=cluster_dir)
703 fasta_file_full = "{dir}/reads.fasta".format(dir=cluster_dir)
704
705 os.makedirs(os.path.dirname(blast_file), exist_ok=True)
706 E = self.extract_cluster_blast(index=clusterindex,
707 path=blast_file)
708 # check if blast must be sampled
709 n_sample = estimate_sample_size(NV=N,
710 NE=E,
711 maxv=config.CLUSTER_VMAX,
712 maxe=config.CLUSTER_EMAX)
713 LOGGER.info("directories created..")
714 if n_sample < N:
715 LOGGER.info(("cluster is too large - sampling.."
716 "original size: {N}\n"
717 "sample size: {NS}\n"
718 "").format(N=N, NS=n_sample))
719 random.seed(self.seed)
720 read_names_sample = random.sample(read_names, n_sample)
721 LOGGER.info("reads id sampled...")
722 blast_file_sample = "{dir}/dir_CL{i:04}/blast_sample.csv".format(
723 dir=directory, i=clusterindex)
724 E_sample = self.extract_cluster_blast(
725 index=clusterindex,
726 path=blast_file,
727 ids=read_names_sample)
728 LOGGER.info("numner of edges in sample: {}".format(
729 E_sample))
730 sequences.save2fasta(fasta_file, subset=read_names_sample)
731 sequences.save2fasta(fasta_file_full, subset=read_names)
732
733 else:
734 read_names_sample = None
735 E_sample = None
736 blast_file_sample = None
737 n_sample = None
738 sequences.save2fasta(fasta_file_full, subset=read_names)
739 ## TODO - use symlink instead of :
740 sequences.save2fasta(fasta_file, subset=read_names)
741 # export individual annotations tables:
742 # annotation is always for full cluster
743 LOGGER.info("exporting cluster annotation")
744 annotations = {}
745 annotations_custom = {}
746 for n in sequences.annotations:
747 print("sequences.annotations:", n)
748 if n.find("custom_db") == 0:
749 print("custom")
750 annotations_custom[n] = sequences.save_annotation(
751 annotation_name=n,
752 subset=read_names,
753 dir=cluster_dir)
754 else:
755 print("built in")
756 annotations[n] = sequences.save_annotation(
757 annotation_name=n,
758 subset=read_names,
759 dir=cluster_dir)
760
761 cluster_input_args.append([
762 n_sample, N,N_adjusted, blast_file, fasta_file, fasta_file_full,
763 clusterindex, supercluster, self.paired,
764 tRNA_database_path, satellite_model_path, sequences.prefix_codes,
765 prefix_codes, annotations, annotations_custom
766 ])
767 clusterindex += 1
768 ppn.append(load_fun(N, E))
769
770
771
772 self.conn.commit()
773
774 # run in parallel:
775 # reorder jobs based on the ppn:
776 cluster_input_args = [
777 x
778 for (y, x) in sorted(
779 zip(ppn, cluster_input_args),
780 key=lambda pair: pair[0],
781 reverse=True)
782 ]
783 ppn = sorted(ppn, reverse=True)
784 LOGGER.info("creating clusters in parallel")
785 clusters_info = parallel(Cluster,
786 *[list(i) for i in zip(*cluster_input_args)],
787 ppn=ppn)
788 # sort it back:
789 clusters_info = sorted(clusters_info, key=lambda cl: cl.index)
790 return clusters_info
791
792
793 class Cluster():
794 ''' store and show information about cluster properties '''
795
796 def __init__(self,
797 size,
798 size_real,
799 size_adjusted,
800 blast_file,
801 fasta_file,
802 fasta_file_full,
803 index,
804 supercluster,
805 paired,
806 tRNA_database_path,
807 satellite_model_path,
808 all_prefix_codes,
809 prefix_codes,
810 annotations,
811 annotations_custom={},
812 loop_index_threshold=0.7,
813 pair_completeness_threshold=0.40,
814 loop_index_unpaired_threshold=0.85):
815 if size:
816 # cluster was scaled down
817 self.size = size
818 self.size_real = size_real
819 else:
820 self.size = self.size_real = size_real
821 self.size_adjusted = size_adjusted
822 self.filtered = True if size_adjusted != size_real else False
823 self.all_prefix_codes = all_prefix_codes.keys
824 self.prefix_codes = prefix_codes
825 self.dir = FilePath(os.path.dirname(blast_file))
826 self.blast_file = FilePath(blast_file)
827 self.fasta_file = FilePath(fasta_file)
828 self.fasta_file_full = FilePath(fasta_file_full)
829 self.index = index
830 self.assembly_files = {}
831 self.ltr_detection = None
832 self.supercluster = supercluster
833 self.annotations_files = annotations
834 self.annotations_files_custom = annotations_custom
835 self.annotations_summary, self.annotations_table = self._summarize_annotations(
836 annotations, size_real)
837 # add annotation
838 if len(annotations_custom):
839 self.annotations_summary_custom, self.annotations_custom_table = self._summarize_annotations(
840 annotations_custom, size_real)
841 else:
842 self.annotations_summary_custom, self.annotations_custom_table = "", ""
843
844 self.paired = paired
845 self.graph_file = FilePath("{0}/graph_layout.GL".format(self.dir))
846 self.directed_graph_file = FilePath(
847 "{0}/graph_layout_directed.RData".format(self.dir))
848 self.fasta_oriented_file = FilePath("{0}/reads_selection_oriented.fasta".format(
849 self.dir))
850 self.image_file = FilePath("{0}/graph_layout.png".format(self.dir))
851 self.image_file_tmb = FilePath("{0}/graph_layout_tmb.png".format(self.dir))
852 self.html_report_main = FilePath("{0}/index.html".format(self.dir))
853 self.html_report_files = FilePath("{0}/html_files".format(self.dir))
854 self.supercluster_best_hit = "NA"
855 TAREAN = r2py.R(config.RSOURCE_tarean)
856 LOGGER.info("creating graph no.{}".format(self.index))
857 # if FileType muast be converted to str for rfunctions
858 graph_info = eval(
859 TAREAN.mgblast2graph(
860 self.blast_file,
861 seqfile=self.fasta_file,
862 seqfile_full=self.fasta_file_full,
863 graph_destination=self.graph_file,
864 directed_graph_destination=self.directed_graph_file,
865 oriented_sequences=self.fasta_oriented_file,
866 image_file=self.image_file,
867 image_file_tmb=self.image_file_tmb,
868 repex=True,
869 paired=self.paired,
870 satellite_model_path=satellite_model_path,
871 maxv=config.CLUSTER_VMAX,
872 maxe=config.CLUSTER_EMAX)
873 )
874 print(graph_info)
875 self.ecount = graph_info['ecount']
876 self.vcount = graph_info['vcount']
877 self.loop_index = graph_info['loop_index']
878 self.pair_completeness = graph_info['pair_completeness']
879 self.orientation_score = graph_info['escore']
880 self.satellite_probability = graph_info['satellite_probability']
881 self.satellite = graph_info['satellite']
882 # for paired reads:
883 cond1 = (self.paired and self.loop_index > loop_index_threshold and
884 self.pair_completeness > pair_completeness_threshold)
885 # no pairs
886 cond2 = ((not self.paired) and
887 self.loop_index > loop_index_unpaired_threshold)
888 if (cond1 or cond2) and config.ARGS.options.name != "oxford_nanopore":
889 self.putative_tandem = True
890 self.dir_tarean = FilePath("{}/tarean".format(self.dir))
891 lock_file = self.dir + "../lock"
892 out = eval(
893 TAREAN.tarean(input_sequences=self.fasta_oriented_file,
894 output_dir=self.dir_tarean,
895 CPU=1,
896 reorient_reads=False,
897 tRNA_database_path=tRNA_database_path,
898 lock_file=lock_file)
899 )
900 self.html_tarean = FilePath(out['htmlfile'])
901 self.tarean_contig_file = out['tarean_contig_file']
902 self.TR_score = out['TR_score']
903 self.TR_monomer_length = out['TR_monomer_length']
904 self.TR_consensus = out['TR_consensus']
905 self.pbs_score = out['pbs_score']
906 self.max_ORF_length = out['orf_l']
907 if (out['orf_l'] > config.ORF_THRESHOLD or
908 out['pbs_score'] > config.PBS_THRESHOLD):
909 self.tandem_rank = 3
910 elif self.satellite:
911 self.tandem_rank = 1
912 else:
913 self.tandem_rank = 2
914 # some tandems could be rDNA genes - this must be check
915 # by annotation
916 if self.annotations_table:
917 rdna_score = 0
918 contamination_score = 0
919 for i in self.annotations_table:
920 if 'rDNA/' in i[0]:
921 rdna_score += i[1]
922 if 'contamination' in i[0]:
923 contamination_score += i[1]
924 if rdna_score > config.RDNA_THRESHOLD:
925 self.tandem_rank = 4
926 if contamination_score > config.CONTAMINATION_THRESHOLD:
927 self.tandem_rank = 0 # other
928
929 # by custom annotation - castom annotation has preference
930 if self.annotations_custom_table:
931 print("custom table searching")
932 rdna_score = 0
933 contamination_score = 0
934 print(self.annotations_custom_table)
935 for i in self.annotations_custom_table:
936 if 'rDNA' in i[0]:
937 rdna_score += i[1]
938 if 'contamination' in i[0]:
939 contamination_score += i[1]
940 if rdna_score > 0:
941 self.tandem_rank = 4
942 if contamination_score > config.CONTAMINATION_THRESHOLD:
943 self.tandem_rank = 0 # other
944
945 else:
946 self.putative_tandem = False
947 self.dir_tarean = None
948 self.html_tarean = None
949 self.TR_score = None
950 self.TR_monomer_length = None
951 self.TR_consensus = None
952 self.pbs_score = None
953 self.max_ORF_length = None
954 self.tandem_rank = 0
955 self.tarean_contig_file = None
956
957 def __str__(self):
958 out = [
959 "cluster no {}:".format(self.index),
960 "Number of vertices : {}".format(self.size),
961 "Number of edges : {}".format(self.ecount),
962 "Loop index : {}".format(self.loop_index),
963 "Pair completeness : {}".format(self.pair_completeness),
964 "Orientation score : {}".format(self.orientation_score)
965 ]
966 return "\n".join(out)
967
968 def listing(self, asdict=True):
969 ''' convert attributes to dictionary for printing purposes'''
970 out = {}
971 for i in dir(self):
972 # do not show private
973 if i[:2] != "__":
974 value = getattr(self, i)
975 if not callable(value):
976 # for dictionary
977 if isinstance(value, dict):
978 for k in value:
979 out[i + "_" + k] = value[k]
980 else:
981 out[i] = value
982 if asdict:
983 return out
984 else:
985 return {'keys': list(out.keys()), 'values': list(out.values())}
986
987 def detect_ltr(self, trna_database):
988 '''detection of ltr in assembly files, output of analysis is stored in file'''
989 CREATE_ANNOTATION = r2py.R(config.RSOURCE_create_annotation, verbose=False)
990 if self.assembly_files['{}.{}.ace']:
991 ace_file = self.assembly_files['{}.{}.ace']
992 print(ace_file, "running LTR detection")
993 fout = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['BASE'])
994 subprocess.check_call([
995 config.LTR_DETECTION,
996 '-i', ace_file,
997 '-o', fout,
998 '-p', trna_database])
999 # evaluate LTR presence
1000 fn = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['PBS_BLAST'])
1001 self.ltr_detection = CREATE_ANNOTATION.evaluate_LTR_detection(fn)
1002
1003
1004 @staticmethod
1005 def _summarize_annotations(annotations_files, size):
1006 ''' will tabulate annotation results '''
1007 # TODO
1008 summaries = {}
1009 # weight is in percentage
1010 weight = 100 / size
1011 for i in annotations_files:
1012 with open(annotations_files[i]) as f:
1013 header = f.readline().split()
1014 id_index = [
1015 i for i, item in enumerate(header) if item == "db_id"
1016 ][0]
1017 for line in f:
1018 classification = line.split()[id_index].split("#")[1]
1019 if classification in summaries:
1020 summaries[classification] += weight
1021 else:
1022 summaries[classification] = weight
1023 # format summaries for printing
1024 annotation_string = ""
1025 annotation_table = []
1026 for i in sorted(summaries.items(), key=lambda x: x[1], reverse=True):
1027 ## hits with smaller proportion are not shown!
1028 if i[1] > 0.1:
1029 if i[1] > 1:
1030 annotation_string += "<b>{1:.2f}% {0}</b>\n".format(*i)
1031 else:
1032 annotation_string += "{1:.2f}% {0}\n".format(*i)
1033 annotation_table.append(i)
1034 return [annotation_string, annotation_table]
1035
1036 @staticmethod
1037 def add_cluster_table_to_database(cluster_table, db_path):
1038 '''get column names from Cluster object and create
1039 correspopnding table in database values from all
1040 clusters are filled to database'''
1041 column_name_and_type = []
1042 column_list = []
1043
1044 # get all atribute names -> they are column names
1045 # in sqlite table, detect proper sqlite type
1046 def identity(x):
1047 return (x)
1048
1049 for i in cluster_table[1]:
1050 t = type(cluster_table[1][i])
1051 if t == int:
1052 sqltype = "integer"
1053 convert = identity
1054 elif t == float:
1055 sqltype = "real"
1056 convert = identity
1057 elif t == bool:
1058 sqltype = "boolean"
1059 convert = bool
1060 else:
1061 sqltype = "text"
1062 convert = str
1063 column_name_and_type += ["[{}] {}".format(i, sqltype)]
1064 column_list += [tuple((i, convert))]
1065 header = ", ".join(column_name_and_type)
1066 db = sqlite3.connect(db_path)
1067 c = db.cursor()
1068 print("CREATE TABLE cluster_info ({})".format(header))
1069 c.execute("CREATE TABLE cluster_info ({})".format(header))
1070 # file data to cluster_table
1071 buffer = []
1072 for i in cluster_table:
1073 buffer.append(tuple('{}'.format(fun(i[j])) for j, fun in
1074 column_list))
1075 wildcards = ",".join(["?"] * len(column_list))
1076 print(buffer)
1077 c.executemany("insert into cluster_info values ({})".format(wildcards),
1078 buffer)
1079 db.commit()