Mercurial > repos > petr-novak > repeatrxplorer
comparison lib/graphtools.py @ 0:1d1b9e1b2e2f draft
Uploaded
author | petr-novak |
---|---|
date | Thu, 19 Dec 2019 10:24:45 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:1d1b9e1b2e2f |
---|---|
1 #!/usr/bin/env python3 | |
2 ''' | |
3 This module is mainly for large graph (e.i hitsort) storage, parsing and for clustering | |
4 ''' | |
5 import os | |
6 import sys | |
7 import sqlite3 | |
8 import time | |
9 import subprocess | |
10 import logging | |
11 from collections import defaultdict | |
12 import collections | |
13 import operator | |
14 import math | |
15 import random | |
16 import itertools | |
17 import config | |
18 from lib import r2py | |
19 from lib.utils import FilePath | |
20 from lib.parallel.parallel import parallel2 as parallel | |
21 REQUIRED_VERSION = (3, 4) | |
22 MAX_BUFFER_SIZE = 100000 | |
23 if sys.version_info < REQUIRED_VERSION: | |
24 raise Exception("\n\npython 3.4 or higher is required!\n") | |
25 LOGGER = logging.getLogger(__name__) | |
26 | |
27 | |
28 def dfs(start, graph): | |
29 """ | |
30 helper function for cluster merging. | |
31 Does depth-first search, returning a set of all nodes seen. | |
32 Takes: a graph in node --> [neighbors] form. | |
33 """ | |
34 visited, worklist = set(), [start] | |
35 | |
36 while worklist: | |
37 node = worklist.pop() | |
38 if node not in visited: | |
39 visited.add(node) | |
40 # Add all the neighbors to the worklist. | |
41 worklist.extend(graph[node]) | |
42 | |
43 return visited | |
44 | |
45 | |
46 def graph_components(edges): | |
47 """ | |
48 Given a graph as a list of edges, divide the nodes into components. | |
49 Takes a list of pairs of nodes, where the nodes are integers. | |
50 """ | |
51 | |
52 # Construct a graph (mapping node --> [neighbors]) from the edges. | |
53 graph = defaultdict(list) | |
54 nodes = set() | |
55 | |
56 for v1, v2 in edges: | |
57 nodes.add(v1) | |
58 nodes.add(v2) | |
59 | |
60 graph[v1].append(v2) | |
61 graph[v2].append(v1) | |
62 | |
63 # Traverse the graph to find the components. | |
64 components = [] | |
65 | |
66 # We don't care what order we see the nodes in. | |
67 while nodes: | |
68 component = dfs(nodes.pop(), graph) | |
69 components.append(component) | |
70 | |
71 # Remove this component from the nodes under consideration. | |
72 nodes -= component | |
73 | |
74 return components | |
75 | |
76 | |
77 class Graph(): | |
78 ''' | |
79 create Graph object stored in sqlite database, either in memory or on disk | |
80 structure of table is: | |
81 V1 V2 weigth12 | |
82 V2 V3 weight23 | |
83 V4 V5 weight45 | |
84 ... | |
85 ... | |
86 !! this is undirected simple graph - duplicated edges must | |
87 be removed on graph creation | |
88 | |
89 ''' | |
90 # seed for random number generator - this is necessary for reproducibility between runs | |
91 seed = '123' | |
92 | |
93 def __init__(self, | |
94 source=None, | |
95 filename=None, | |
96 new=False, | |
97 paired=True, | |
98 seqids=None): | |
99 ''' | |
100 filename : fite where to store database, if not defined it is stored in memory | |
101 source : ncol file from which describe graph | |
102 new : if false and source is not define graph can be loaded from database (filename) | |
103 | |
104 vertices_name must be in correcti order!!! | |
105 ''' | |
106 | |
107 self.filename = filename | |
108 self.source = source | |
109 self.paired = paired | |
110 # path to indexed graph - will be set later | |
111 self.indexed_file = None | |
112 self._cluster_list = None | |
113 # these two attributes are set after clustering | |
114 # communities before merging | |
115 self.graph_2_community0 = None | |
116 # communities after merging | |
117 self.graph_2_community = None | |
118 self.number_of_clusters = None | |
119 self.binary_file = None | |
120 self.cluster_sizes = None | |
121 self.graph_tree = None | |
122 self.graph_tree_log = None | |
123 self.weights_file = None | |
124 | |
125 if filename: | |
126 if os.path.isfile(filename) and (new or source): | |
127 os.remove(filename) | |
128 self.conn = sqlite3.connect(filename) | |
129 else: | |
130 self.conn = sqlite3.connect(":memory:") | |
131 c = self.conn.cursor() | |
132 | |
133 c.execute("PRAGMA page_size=8192") | |
134 c.execute("PRAGMA cache_size = 2000000 ") # this helps | |
135 | |
136 try: | |
137 c.execute(( | |
138 "create table graph (v1 integer, v2 integer, weight integer, " | |
139 "pair integer, v1length integer, v1start integer, v1end integer, " | |
140 "v2length integer, v2start integer, v2end integer, pid integer," | |
141 "evalue real, strand text )")) | |
142 except sqlite3.OperationalError: | |
143 pass # table already exist | |
144 else: | |
145 c.execute( | |
146 "create table vertices (vertexname text primary key, vertexindex integer)") | |
147 tables = sorted(c.execute( | |
148 "SELECT name FROM sqlite_master WHERE type='table'").fetchall()) | |
149 | |
150 if not [('graph', ), ('vertices', )] == tables: | |
151 raise Exception("tables for sqlite for graph are not correct") | |
152 | |
153 if source: | |
154 self._read_from_hitsort() | |
155 | |
156 if paired and seqids: | |
157 # vertices must be defined - create graph of paired reads: | |
158 # last character must disinguish pair | |
159 c.execute(( | |
160 "create table pairs (basename, vertexname1, vertexname2," | |
161 "v1 integer, v2 integer, cluster1 integer, cluster2 integer)")) | |
162 buffer = [] | |
163 for i, k in zip(seqids[0::2], seqids[1::2]): | |
164 assert i[:-1] == k[:-1], "problem with pair reads ids" | |
165 # some vertices are not in graph - singletons | |
166 try: | |
167 index1 = self.vertices[i] | |
168 except KeyError: | |
169 index1 = -1 | |
170 | |
171 try: | |
172 index2 = self.vertices[k] | |
173 except KeyError: | |
174 index2 = -1 | |
175 | |
176 buffer.append((i[:-1], i, k, index1, index2)) | |
177 | |
178 self.conn.executemany( | |
179 "insert into pairs (basename, vertexname1, vertexname2, v1, v2) values (?,?,?,?,?)", | |
180 buffer) | |
181 self.conn.commit() | |
182 | |
183 def _read_from_hitsort(self): | |
184 | |
185 c = self.conn.cursor() | |
186 c.execute("delete from graph") | |
187 buffer = [] | |
188 vertices = {} | |
189 counter = 0 | |
190 v_count = 0 | |
191 with open(self.source, 'r') as f: | |
192 for i in f: | |
193 edge_index = {} | |
194 items = i.split() | |
195 # get or insert vertex index | |
196 for vn in items[0:2]: | |
197 if vn not in vertices: | |
198 vertices[vn] = v_count | |
199 edge_index[vn] = v_count | |
200 v_count += 1 | |
201 else: | |
202 edge_index[vn] = vertices[vn] | |
203 if self.paired: | |
204 pair = int(items[0][:-1] == items[1][:-1]) | |
205 else: | |
206 pair = 0 | |
207 buffer.append(((edge_index[items[0]], edge_index[items[1]], | |
208 items[2], pair) + tuple(items[3:]))) | |
209 if len(buffer) == MAX_BUFFER_SIZE: | |
210 counter += 1 | |
211 self.conn.executemany( | |
212 "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)", | |
213 buffer) | |
214 buffer = [] | |
215 if buffer: | |
216 self.conn.executemany( | |
217 "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)", | |
218 buffer) | |
219 | |
220 self.conn.commit() | |
221 self.vertices = vertices | |
222 self.vertexid2name = { | |
223 vertex: index | |
224 for index, vertex in vertices.items() | |
225 } | |
226 self.vcount = len(vertices) | |
227 c = self.conn.cursor() | |
228 c.execute("select count(*) from graph") | |
229 self.ecount = c.fetchone()[0] | |
230 # fill table of vertices | |
231 self.conn.executemany("insert into vertices values (?,?)", | |
232 vertices.items()) | |
233 self.conn.commit() | |
234 | |
235 def save_indexed_graph(self, file=None): | |
236 if not file: | |
237 self.indexed_file = "{}.int".format(self.source) | |
238 else: | |
239 self.indexed_file = file | |
240 c = self.conn.cursor() | |
241 with open(self.indexed_file, 'w') as f: | |
242 out = c.execute('select v1,v2,weight from graph') | |
243 for v1, v2, weight in out: | |
244 f.write('{}\t{}\t{}\n'.format(v1, v2, weight)) | |
245 | |
246 def get_subgraph(self, vertices): | |
247 pass | |
248 | |
249 def _levels(self): | |
250 with open(self.graph_tree_log, 'r') as f: | |
251 levels = -1 | |
252 for i in f: | |
253 if i[:5] == 'level': | |
254 levels += 1 | |
255 return levels | |
256 | |
257 def _reindex_community(self, id2com): | |
258 ''' | |
259 reindex community and superclusters so that biggest cluster is no.1 | |
260 ''' | |
261 self.conn.commit() | |
262 _, community, supercluster = zip(*id2com) | |
263 (cluster_index, frq, self.cluster_sizes, | |
264 self.number_of_clusters) = self._get_index_and_frequency(community) | |
265 | |
266 supercluster_index, sc_frq, _, _ = self._get_index_and_frequency( | |
267 supercluster) | |
268 id2com_reindexed = [] | |
269 | |
270 for i, _ in enumerate(id2com): | |
271 id2com_reindexed.append((id2com[i][0], id2com[i][1], frq[ | |
272 i], cluster_index[i], supercluster_index[i], sc_frq[i])) | |
273 return id2com_reindexed | |
274 | |
275 @staticmethod | |
276 def _get_index_and_frequency(membership): | |
277 frequency_table = collections.Counter(membership) | |
278 frequency_table_sorted = sorted(frequency_table.items(), | |
279 key=operator.itemgetter(1), | |
280 reverse=True) | |
281 frq = [] | |
282 for i in membership: | |
283 frq.append(frequency_table[i]) | |
284 rank = {} | |
285 index = 0 | |
286 for comm, _ in frequency_table_sorted: | |
287 index += 1 | |
288 rank[comm] = index | |
289 cluster_index = [rank[i] for i in membership] | |
290 cluster_sizes = [i[1] for i in frequency_table_sorted] | |
291 number_of_clusters = len(frequency_table) | |
292 return [cluster_index, frq, cluster_sizes, number_of_clusters] | |
293 | |
294 def louvain_clustering(self, merge_threshold=0, cleanup=False): | |
295 ''' | |
296 input - graph | |
297 output - list of clusters | |
298 executables path ?? | |
299 ''' | |
300 LOGGER.info("converting hitsort to binary format") | |
301 self.binary_file = "{}.bin".format(self.indexed_file) | |
302 self.weights_file = "{}.weight".format(self.indexed_file) | |
303 self.graph_tree = "{}.graph_tree".format(self.indexed_file) | |
304 self.graph_tree_log = "{}.graph_tree_log".format(self.indexed_file) | |
305 self.graph_2_community0 = "{}.graph_2_community0".format( | |
306 self.indexed_file) | |
307 self._cluster_list = None | |
308 self.graph_2_community = "{}.graph_2_community".format( | |
309 self.indexed_file) | |
310 print(["louvain_convert", "-i", self.indexed_file, "-o", | |
311 self.binary_file, "-w", self.weights_file]) | |
312 subprocess.check_call( | |
313 ["louvain_convert", "-i", self.indexed_file, "-o", | |
314 self.binary_file, "-w", self.weights_file], | |
315 timeout=None) | |
316 | |
317 gt = open(self.graph_tree, 'w') | |
318 gtl = open(self.graph_tree_log, 'w') | |
319 LOGGER.info("running louvain clustering...") | |
320 subprocess.check_call( | |
321 ["louvain_community", self.binary_file, "-l", "-1", "-w", | |
322 self.weights_file, "-v ", "-s", self.seed], | |
323 stdout=gt, | |
324 stderr=gtl, | |
325 timeout=None) | |
326 gt.close() | |
327 gtl.close() | |
328 | |
329 LOGGER.info("creating list of cummunities") | |
330 gt2c = open(self.graph_2_community0, 'w') | |
331 subprocess.check_call( | |
332 ['louvain_hierarchy', self.graph_tree, "-l", str(self._levels())], | |
333 stdout=gt2c) | |
334 gt2c.close() | |
335 if merge_threshold and self.paired: | |
336 com2newcom = self.find_superclusters(merge_threshold) | |
337 elif self.paired: | |
338 com2newcom = self.find_superclusters(config.SUPERCLUSTER_THRESHOLD) | |
339 else: | |
340 com2newcom = {} | |
341 # merging of clusters, creatting superclusters | |
342 LOGGER.info("mergings clusters based on mate-pairs ") | |
343 # modify self.graph_2_community file | |
344 # rewrite graph2community | |
345 with open(self.graph_2_community0, 'r') as fin: | |
346 with open(self.graph_2_community, 'w') as fout: | |
347 for i in fin: | |
348 # write graph 2 community file in format: | |
349 # id communityid supeclusterid | |
350 # if merging - community and superclustwers are identical | |
351 vi, com = i.split() | |
352 if merge_threshold: | |
353 ## mergin | |
354 if int(com) in com2newcom: | |
355 fout.write("{} {} {}\n".format(vi, com2newcom[int( | |
356 com)], com2newcom[int(com)])) | |
357 else: | |
358 fout.write("{} {} {}\n".format(vi, com, com)) | |
359 else: | |
360 ## superclusters | |
361 if int(com) in com2newcom: | |
362 fout.write("{} {} {}\n".format(vi, com, com2newcom[ | |
363 int(com)])) | |
364 else: | |
365 fout.write("{} {} {}\n".format(vi, com, com)) | |
366 | |
367 LOGGER.info("loading communities into database") | |
368 c = self.conn.cursor() | |
369 c.execute(("create table communities (vertexindex integer primary key," | |
370 "community integer, size integer, cluster integer, " | |
371 "supercluster integer, supercluster_size integer)")) | |
372 id2com = [] | |
373 with open(self.graph_2_community, 'r') as f: | |
374 for i in f: | |
375 name, com, supercluster = i.split() | |
376 id2com.append((name, com, supercluster)) | |
377 id2com_reindexed = self._reindex_community(id2com) | |
378 c.executemany("insert into communities values (?,?,?,?,?,?)", | |
379 id2com_reindexed) | |
380 #create table of superclusters - clusters | |
381 c.execute(("create table superclusters as " | |
382 "select distinct supercluster, supercluster_size, " | |
383 "cluster, size from communities;")) | |
384 # create view id-index-cluster | |
385 c.execute( | |
386 ("CREATE VIEW vertex_cluster AS SELECT vertices.vertexname," | |
387 "vertices.vertexindex, communities.cluster, communities.size" | |
388 " FROM vertices JOIN communities USING (vertexindex)")) | |
389 self.conn.commit() | |
390 | |
391 # add clustering infor to graph | |
392 LOGGER.info("updating graph table") | |
393 t0 = time.time() | |
394 | |
395 c.execute("alter table graph add c1 integer") | |
396 c.execute("alter table graph add c2 integer") | |
397 c.execute(("update graph set c1 = (select cluster FROM communities " | |
398 "where communities.vertexindex=graph.v1)")) | |
399 c.execute( | |
400 ("update graph set c2 = (select cluster FROM communities where " | |
401 "communities.vertexindex=graph.v2)")) | |
402 self.conn.commit() | |
403 t1 = time.time() | |
404 LOGGER.info("updating graph table - done in {} seconds".format(t1 - | |
405 t0)) | |
406 | |
407 # identify similarity connections between clusters | |
408 c.execute( | |
409 "create table cluster_connections as SELECT c1,c2 , count(*) FROM (SELECT c1, c2 FROM graph WHERE c1>c2 UNION ALL SELECT c2 as c1, c1 as c2 FROM graph WHERE c2>c1) GROUP BY c1, c2") | |
410 # TODO - remove directionality - summarize - | |
411 | |
412 # add cluster identity to pairs table | |
413 | |
414 if self.paired: | |
415 LOGGER.info("analyzing pairs ") | |
416 t0 = time.time() | |
417 c.execute( | |
418 "UPDATE pairs SET cluster1=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v1)") | |
419 t1 = time.time() | |
420 LOGGER.info( | |
421 "updating pairs table - cluster1 - done in {} seconds".format( | |
422 t1 - t0)) | |
423 | |
424 t0 = time.time() | |
425 c.execute( | |
426 "UPDATE pairs SET cluster2=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v2)") | |
427 t1 = time.time() | |
428 LOGGER.info( | |
429 "updating pairs table - cluster2 - done in {} seconds".format( | |
430 t1 - t0)) | |
431 # reorder records | |
432 | |
433 t0 = time.time() | |
434 c.execute( | |
435 "UPDATE pairs SET cluster1=cluster2, cluster2=cluster1, vertexname1=vertexname2,vertexname2=vertexname1 where cluster1<cluster2") | |
436 t1 = time.time() | |
437 LOGGER.info("sorting - done in {} seconds".format(t1 - t0)) | |
438 | |
439 t0 = time.time() | |
440 c.execute( | |
441 "create table cluster_mate_connections as select cluster1 as c1, cluster2 as c2, count(*) as N, group_concat(basename) as ids from pairs where cluster1!=cluster2 group by cluster1, cluster2;") | |
442 t1 = time.time() | |
443 LOGGER.info( | |
444 "creating cluster_mate_connections table - done in {} seconds".format( | |
445 t1 - t0)) | |
446 # summarize | |
447 t0 = time.time() | |
448 self._calculate_pair_bond() | |
449 t1 = time.time() | |
450 LOGGER.info( | |
451 "calculating cluster pair bond - done in {} seconds".format( | |
452 t1 - t0)) | |
453 t0 = time.time() | |
454 else: | |
455 # not paired - create empty tables | |
456 self._add_empty_tables() | |
457 | |
458 self.conn.commit() | |
459 t1 = time.time() | |
460 LOGGER.info("commiting changes - done in {} seconds".format(t1 - t0)) | |
461 | |
462 if cleanup: | |
463 LOGGER.info("cleaning clustering temp files") | |
464 os.unlink(self.binary_file) | |
465 os.unlink(self.weights_file) | |
466 os.unlink(self.graph_tree) | |
467 os.unlink(self.graph_tree_log) | |
468 os.unlink(self.graph_2_community0) | |
469 os.unlink(self.graph_2_community) | |
470 os.unlink(self.indexed_file) | |
471 self.binary_file = None | |
472 self.weights_file = None | |
473 self.graph_tree = None | |
474 self.graph_tree_log = None | |
475 self.graph_2_community0 = None | |
476 self.graph_2_community = None | |
477 self.indexed_file = None | |
478 | |
479 # calcultate k | |
480 | |
481 def find_superclusters(self, merge_threshold): | |
482 '''Find superclusters from clustering based on paired reads ''' | |
483 clsdict = {} | |
484 with open(self.graph_2_community0, 'r') as f: | |
485 for i in f: | |
486 vi, com = i.split() | |
487 if com in clsdict: | |
488 clsdict[com] += [self.vertexid2name[int(vi)][0:-1]] | |
489 else: | |
490 clsdict[com] = [self.vertexid2name[int(vi)][0:-1]] | |
491 # remove all small clusters - these will not be merged: | |
492 small_cls = [] | |
493 for i in clsdict: | |
494 if len(clsdict[i]) < config.MINIMUM_NUMBER_OF_READS_FOR_MERGING: | |
495 small_cls.append(i) | |
496 for i in small_cls: | |
497 del clsdict[i] | |
498 pairs = [] | |
499 for i, j in itertools.combinations(clsdict, 2): | |
500 s1 = set(clsdict[i]) | |
501 s2 = set(clsdict[j]) | |
502 wgh = len(s1 & s2) | |
503 if wgh < config.MINIMUM_NUMBER_OF_SHARED_PAIRS_FOR_MERGING: | |
504 continue | |
505 else: | |
506 n1 = len(s1) * 2 - len(clsdict[i]) | |
507 n2 = len(s2) * 2 - len(clsdict[j]) | |
508 k = 2 * wgh / (n1 + n2) | |
509 if k > merge_threshold: | |
510 pairs.append((int(i), int(j))) | |
511 # find connected commponents - will be merged | |
512 cls2merge = graph_components(pairs) | |
513 com2newcom = {} | |
514 for i in cls2merge: | |
515 newcom = min(i) | |
516 for j in i: | |
517 com2newcom[j] = newcom | |
518 return com2newcom | |
519 | |
520 def adjust_cluster_size(self, proportion_kept, ids_kept): | |
521 LOGGER.info("adjusting cluster sizes") | |
522 c = self.conn.cursor() | |
523 c.execute("ALTER TABLE superclusters ADD COLUMN size_uncorrected INTEGER") | |
524 c.execute("UPDATE superclusters SET size_uncorrected=size") | |
525 if ids_kept: | |
526 ids_kept_set = set(ids_kept) | |
527 ratio = (1 - proportion_kept)/proportion_kept | |
528 for cl, size in c.execute("SELECT cluster,size FROM superclusters"): | |
529 ids = self.get_cluster_reads(cl) | |
530 ovl_size = len(ids_kept_set.intersection(ids)) | |
531 size_adjusted = int(len(ids) + ovl_size * ratio) | |
532 if size_adjusted > size: | |
533 c.execute("UPDATE superclusters SET size=? WHERE cluster=?", | |
534 (size_adjusted, cl)) | |
535 self.conn.commit() | |
536 LOGGER.info("adjusting cluster sizes - done") | |
537 | |
538 def export_cls(self, path): | |
539 with open(path, 'w') as f: | |
540 for i in range(1, self.number_of_clusters + 1): | |
541 ids = self.get_cluster_reads(i) | |
542 f.write(">CL{}\t{}\n".format(i, len(ids))) | |
543 f.write("\t".join(ids)) | |
544 f.write("\n") | |
545 | |
546 def _calculate_pair_bond(self): | |
547 c = self.conn.cursor() | |
548 out = c.execute("select c1, c2, ids from cluster_mate_connections") | |
549 buffer = [] | |
550 for c1, c2, ids in out: | |
551 w = len(set(ids.split(","))) | |
552 n1 = len(set([i[:-1] for i in self.get_cluster_reads(c1) | |
553 ])) * 2 - len(self.get_cluster_reads(c1)) | |
554 n2 = len(set([i[:-1] for i in self.get_cluster_reads(c2) | |
555 ])) * 2 - len(self.get_cluster_reads(c2)) | |
556 buffer.append((c1, c2, n1, n2, w, 2 * w / (n1 + n2))) | |
557 c.execute( | |
558 "CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)") | |
559 c.executemany(" INSERT INTO cluster_mate_bond values (?,?,?,?,?,?)", | |
560 buffer) | |
561 | |
562 def _add_empty_tables(self): | |
563 '''This is used with reads that are not paired | |
564 - it creates empty mate tables, this is necessary for | |
565 subsequent reporting to work corectly ''' | |
566 c = self.conn.cursor() | |
567 c.execute(("CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, " | |
568 "n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)")) | |
569 c.execute( | |
570 "CREATE TABLE cluster_mate_connections (c1 INTEGER, c2 INTEGER, N INTEGER, ids TEXT) ") | |
571 | |
572 def get_cluster_supercluster(self, cluster): | |
573 '''Get supercluster id for suplied cluster ''' | |
574 c = self.conn.cursor() | |
575 out = c.execute( | |
576 'SELECT supercluster FROM communities WHERE cluster="{0}" LIMIT 1'.format( | |
577 cluster)) | |
578 sc = out.fetchone()[0] | |
579 return sc | |
580 | |
581 def get_cluster_reads(self, cluster): | |
582 | |
583 if self._cluster_list: | |
584 return self._cluster_list[str(cluster)] | |
585 else: | |
586 # if queried first time | |
587 c = self.conn.cursor() | |
588 out = c.execute("select cluster, vertexname from vertex_cluster") | |
589 cluster_list = collections.defaultdict(list) | |
590 for clusterindex, vertexname in out: | |
591 cluster_list[str(clusterindex)].append(vertexname) | |
592 self._cluster_list = cluster_list | |
593 return self._cluster_list[str(cluster)] | |
594 | |
595 | |
596 def extract_cluster_blast(self, path, index, ids=None): | |
597 ''' Extract blast for cluster and save it to path | |
598 return number of blast lines ( i.e. number of graph edges E) | |
599 if ids is specified , only subset of blast is used''' | |
600 c = self.conn.cursor() | |
601 if ids: | |
602 vertexindex = ( | |
603 "select vertexindex from vertices " | |
604 "where vertexname in ({})").format('"' + '","'.join(ids) + '"') | |
605 | |
606 out = c.execute(("select * from graph where c1={0} and c2={0}" | |
607 " and v1 in ({1}) and v2 in ({1})").format( | |
608 index, vertexindex)) | |
609 else: | |
610 out = c.execute( | |
611 "select * from graph where c1={0} and c2={0}".format(index)) | |
612 E = 0 | |
613 N = len(self.get_cluster_reads(index)) | |
614 with open(path, 'w') as f: | |
615 for i in out: | |
616 print(self.vertexid2name[i[0]], | |
617 self.vertexid2name[ | |
618 i[1]], | |
619 i[2], | |
620 *i[4:13], | |
621 sep='\t', | |
622 file=f) | |
623 E += 1 | |
624 return E | |
625 | |
626 def export_clusters_files_multiple(self, | |
627 min_size, | |
628 directory, | |
629 sequences=None, | |
630 tRNA_database_path=None, | |
631 satellite_model_path=None): | |
632 def load_fun(N, E): | |
633 ''' estimate mem usage from graph size and density''' | |
634 NE = math.log(float(N) * float(E), 10) | |
635 if NE > 11.5: | |
636 return 1 | |
637 if NE > 11: | |
638 return 0.9 | |
639 if NE > 10: | |
640 return 0.4 | |
641 if NE > 9: | |
642 return 0.2 | |
643 if NE > 8: | |
644 return 0.07 | |
645 return 0.02 | |
646 | |
647 def estimate_sample_size(NV, NE, maxv, maxe): | |
648 ''' estimat suitable sampling based on the graph density | |
649 NV,NE is |V| and |E| of the graph | |
650 maxv, maxe are maximal |V| and |E|''' | |
651 | |
652 d = (2 * NE) / (NV * (NV - 1)) | |
653 eEst = (maxv * (maxv - 1) * d) / 2 | |
654 nEst = (d + math.sqrt(d**2 + 8 * d * maxe)) / (2 * d) | |
655 if eEst >= maxe: | |
656 N = int(nEst) | |
657 if nEst >= maxv: | |
658 N = int(maxv) | |
659 return N | |
660 | |
661 clusterindex = 1 | |
662 cluster_input_args = [] | |
663 ppn = [] | |
664 # is is comparative analysis? | |
665 if sequences.prefix_length: | |
666 self.conn.execute("CREATE TABLE comparative_counts (clusterindex INTEGER," | |
667 + ", ".join(["[{}] INTEGER".format(i) for i in sequences.prefix_codes.keys()]) + ")") | |
668 # do for comparative analysis | |
669 | |
670 for cl in range(self.number_of_clusters): | |
671 prefix_codes = dict((key, 0) for key in sequences.prefix_codes.keys()) | |
672 for i in self.get_cluster_reads(cl): | |
673 prefix_codes[i[0:sequences.prefix_length]] += 1 | |
674 header = ", ".join(["[" + str(i) + "]" for i in prefix_codes.keys()]) | |
675 values = ", ".join([str(i) for i in prefix_codes.values()]) | |
676 self.conn.execute( | |
677 "INSERT INTO comparative_counts (clusterindex, {}) VALUES ({}, {})".format( | |
678 header, cl, values)) | |
679 else: | |
680 prefix_codes = {} | |
681 | |
682 while True: | |
683 read_names = self.get_cluster_reads(clusterindex) | |
684 supercluster = self.get_cluster_supercluster(clusterindex) | |
685 N = len(read_names) | |
686 print("sequences.ids_kept -2 ") | |
687 print(sequences.ids_kept) | |
688 if sequences.ids_kept: | |
689 N_adjusted = round(len(set(sequences.ids_kept).intersection(read_names)) * | |
690 ((1 - config.FILTER_PROPORTION_OF_KEPT) / | |
691 config.FILTER_PROPORTION_OF_KEPT) + N) | |
692 else: | |
693 N_adjusted = N | |
694 if N < min_size: | |
695 break | |
696 else: | |
697 LOGGER.info("exporting cluster {}".format(clusterindex)) | |
698 blast_file = "{dir}/dir_CL{i:04}/hitsort_part.csv".format( | |
699 dir=directory, i=clusterindex) | |
700 cluster_dir = "{dir}/dir_CL{i:04}".format(dir=directory, | |
701 i=clusterindex) | |
702 fasta_file = "{dir}/reads_selection.fasta".format(dir=cluster_dir) | |
703 fasta_file_full = "{dir}/reads.fasta".format(dir=cluster_dir) | |
704 | |
705 os.makedirs(os.path.dirname(blast_file), exist_ok=True) | |
706 E = self.extract_cluster_blast(index=clusterindex, | |
707 path=blast_file) | |
708 # check if blast must be sampled | |
709 n_sample = estimate_sample_size(NV=N, | |
710 NE=E, | |
711 maxv=config.CLUSTER_VMAX, | |
712 maxe=config.CLUSTER_EMAX) | |
713 LOGGER.info("directories created..") | |
714 if n_sample < N: | |
715 LOGGER.info(("cluster is too large - sampling.." | |
716 "original size: {N}\n" | |
717 "sample size: {NS}\n" | |
718 "").format(N=N, NS=n_sample)) | |
719 random.seed(self.seed) | |
720 read_names_sample = random.sample(read_names, n_sample) | |
721 LOGGER.info("reads id sampled...") | |
722 blast_file_sample = "{dir}/dir_CL{i:04}/blast_sample.csv".format( | |
723 dir=directory, i=clusterindex) | |
724 E_sample = self.extract_cluster_blast( | |
725 index=clusterindex, | |
726 path=blast_file, | |
727 ids=read_names_sample) | |
728 LOGGER.info("numner of edges in sample: {}".format( | |
729 E_sample)) | |
730 sequences.save2fasta(fasta_file, subset=read_names_sample) | |
731 sequences.save2fasta(fasta_file_full, subset=read_names) | |
732 | |
733 else: | |
734 read_names_sample = None | |
735 E_sample = None | |
736 blast_file_sample = None | |
737 n_sample = None | |
738 sequences.save2fasta(fasta_file_full, subset=read_names) | |
739 ## TODO - use symlink instead of : | |
740 sequences.save2fasta(fasta_file, subset=read_names) | |
741 # export individual annotations tables: | |
742 # annotation is always for full cluster | |
743 LOGGER.info("exporting cluster annotation") | |
744 annotations = {} | |
745 annotations_custom = {} | |
746 for n in sequences.annotations: | |
747 print("sequences.annotations:", n) | |
748 if n.find("custom_db") == 0: | |
749 print("custom") | |
750 annotations_custom[n] = sequences.save_annotation( | |
751 annotation_name=n, | |
752 subset=read_names, | |
753 dir=cluster_dir) | |
754 else: | |
755 print("built in") | |
756 annotations[n] = sequences.save_annotation( | |
757 annotation_name=n, | |
758 subset=read_names, | |
759 dir=cluster_dir) | |
760 | |
761 cluster_input_args.append([ | |
762 n_sample, N,N_adjusted, blast_file, fasta_file, fasta_file_full, | |
763 clusterindex, supercluster, self.paired, | |
764 tRNA_database_path, satellite_model_path, sequences.prefix_codes, | |
765 prefix_codes, annotations, annotations_custom | |
766 ]) | |
767 clusterindex += 1 | |
768 ppn.append(load_fun(N, E)) | |
769 | |
770 | |
771 | |
772 self.conn.commit() | |
773 | |
774 # run in parallel: | |
775 # reorder jobs based on the ppn: | |
776 cluster_input_args = [ | |
777 x | |
778 for (y, x) in sorted( | |
779 zip(ppn, cluster_input_args), | |
780 key=lambda pair: pair[0], | |
781 reverse=True) | |
782 ] | |
783 ppn = sorted(ppn, reverse=True) | |
784 LOGGER.info("creating clusters in parallel") | |
785 clusters_info = parallel(Cluster, | |
786 *[list(i) for i in zip(*cluster_input_args)], | |
787 ppn=ppn) | |
788 # sort it back: | |
789 clusters_info = sorted(clusters_info, key=lambda cl: cl.index) | |
790 return clusters_info | |
791 | |
792 | |
793 class Cluster(): | |
794 ''' store and show information about cluster properties ''' | |
795 | |
796 def __init__(self, | |
797 size, | |
798 size_real, | |
799 size_adjusted, | |
800 blast_file, | |
801 fasta_file, | |
802 fasta_file_full, | |
803 index, | |
804 supercluster, | |
805 paired, | |
806 tRNA_database_path, | |
807 satellite_model_path, | |
808 all_prefix_codes, | |
809 prefix_codes, | |
810 annotations, | |
811 annotations_custom={}, | |
812 loop_index_threshold=0.7, | |
813 pair_completeness_threshold=0.40, | |
814 loop_index_unpaired_threshold=0.85): | |
815 if size: | |
816 # cluster was scaled down | |
817 self.size = size | |
818 self.size_real = size_real | |
819 else: | |
820 self.size = self.size_real = size_real | |
821 self.size_adjusted = size_adjusted | |
822 self.filtered = True if size_adjusted != size_real else False | |
823 self.all_prefix_codes = all_prefix_codes.keys | |
824 self.prefix_codes = prefix_codes | |
825 self.dir = FilePath(os.path.dirname(blast_file)) | |
826 self.blast_file = FilePath(blast_file) | |
827 self.fasta_file = FilePath(fasta_file) | |
828 self.fasta_file_full = FilePath(fasta_file_full) | |
829 self.index = index | |
830 self.assembly_files = {} | |
831 self.ltr_detection = None | |
832 self.supercluster = supercluster | |
833 self.annotations_files = annotations | |
834 self.annotations_files_custom = annotations_custom | |
835 self.annotations_summary, self.annotations_table = self._summarize_annotations( | |
836 annotations, size_real) | |
837 # add annotation | |
838 if len(annotations_custom): | |
839 self.annotations_summary_custom, self.annotations_custom_table = self._summarize_annotations( | |
840 annotations_custom, size_real) | |
841 else: | |
842 self.annotations_summary_custom, self.annotations_custom_table = "", "" | |
843 | |
844 self.paired = paired | |
845 self.graph_file = FilePath("{0}/graph_layout.GL".format(self.dir)) | |
846 self.directed_graph_file = FilePath( | |
847 "{0}/graph_layout_directed.RData".format(self.dir)) | |
848 self.fasta_oriented_file = FilePath("{0}/reads_selection_oriented.fasta".format( | |
849 self.dir)) | |
850 self.image_file = FilePath("{0}/graph_layout.png".format(self.dir)) | |
851 self.image_file_tmb = FilePath("{0}/graph_layout_tmb.png".format(self.dir)) | |
852 self.html_report_main = FilePath("{0}/index.html".format(self.dir)) | |
853 self.html_report_files = FilePath("{0}/html_files".format(self.dir)) | |
854 self.supercluster_best_hit = "NA" | |
855 TAREAN = r2py.R(config.RSOURCE_tarean) | |
856 LOGGER.info("creating graph no.{}".format(self.index)) | |
857 # if FileType muast be converted to str for rfunctions | |
858 graph_info = eval( | |
859 TAREAN.mgblast2graph( | |
860 self.blast_file, | |
861 seqfile=self.fasta_file, | |
862 seqfile_full=self.fasta_file_full, | |
863 graph_destination=self.graph_file, | |
864 directed_graph_destination=self.directed_graph_file, | |
865 oriented_sequences=self.fasta_oriented_file, | |
866 image_file=self.image_file, | |
867 image_file_tmb=self.image_file_tmb, | |
868 repex=True, | |
869 paired=self.paired, | |
870 satellite_model_path=satellite_model_path, | |
871 maxv=config.CLUSTER_VMAX, | |
872 maxe=config.CLUSTER_EMAX) | |
873 ) | |
874 print(graph_info) | |
875 self.ecount = graph_info['ecount'] | |
876 self.vcount = graph_info['vcount'] | |
877 self.loop_index = graph_info['loop_index'] | |
878 self.pair_completeness = graph_info['pair_completeness'] | |
879 self.orientation_score = graph_info['escore'] | |
880 self.satellite_probability = graph_info['satellite_probability'] | |
881 self.satellite = graph_info['satellite'] | |
882 # for paired reads: | |
883 cond1 = (self.paired and self.loop_index > loop_index_threshold and | |
884 self.pair_completeness > pair_completeness_threshold) | |
885 # no pairs | |
886 cond2 = ((not self.paired) and | |
887 self.loop_index > loop_index_unpaired_threshold) | |
888 if (cond1 or cond2) and config.ARGS.options.name != "oxford_nanopore": | |
889 self.putative_tandem = True | |
890 self.dir_tarean = FilePath("{}/tarean".format(self.dir)) | |
891 lock_file = self.dir + "../lock" | |
892 out = eval( | |
893 TAREAN.tarean(input_sequences=self.fasta_oriented_file, | |
894 output_dir=self.dir_tarean, | |
895 CPU=1, | |
896 reorient_reads=False, | |
897 tRNA_database_path=tRNA_database_path, | |
898 lock_file=lock_file) | |
899 ) | |
900 self.html_tarean = FilePath(out['htmlfile']) | |
901 self.tarean_contig_file = out['tarean_contig_file'] | |
902 self.TR_score = out['TR_score'] | |
903 self.TR_monomer_length = out['TR_monomer_length'] | |
904 self.TR_consensus = out['TR_consensus'] | |
905 self.pbs_score = out['pbs_score'] | |
906 self.max_ORF_length = out['orf_l'] | |
907 if (out['orf_l'] > config.ORF_THRESHOLD or | |
908 out['pbs_score'] > config.PBS_THRESHOLD): | |
909 self.tandem_rank = 3 | |
910 elif self.satellite: | |
911 self.tandem_rank = 1 | |
912 else: | |
913 self.tandem_rank = 2 | |
914 # some tandems could be rDNA genes - this must be check | |
915 # by annotation | |
916 if self.annotations_table: | |
917 rdna_score = 0 | |
918 contamination_score = 0 | |
919 for i in self.annotations_table: | |
920 if 'rDNA/' in i[0]: | |
921 rdna_score += i[1] | |
922 if 'contamination' in i[0]: | |
923 contamination_score += i[1] | |
924 if rdna_score > config.RDNA_THRESHOLD: | |
925 self.tandem_rank = 4 | |
926 if contamination_score > config.CONTAMINATION_THRESHOLD: | |
927 self.tandem_rank = 0 # other | |
928 | |
929 # by custom annotation - castom annotation has preference | |
930 if self.annotations_custom_table: | |
931 print("custom table searching") | |
932 rdna_score = 0 | |
933 contamination_score = 0 | |
934 print(self.annotations_custom_table) | |
935 for i in self.annotations_custom_table: | |
936 if 'rDNA' in i[0]: | |
937 rdna_score += i[1] | |
938 if 'contamination' in i[0]: | |
939 contamination_score += i[1] | |
940 if rdna_score > 0: | |
941 self.tandem_rank = 4 | |
942 if contamination_score > config.CONTAMINATION_THRESHOLD: | |
943 self.tandem_rank = 0 # other | |
944 | |
945 else: | |
946 self.putative_tandem = False | |
947 self.dir_tarean = None | |
948 self.html_tarean = None | |
949 self.TR_score = None | |
950 self.TR_monomer_length = None | |
951 self.TR_consensus = None | |
952 self.pbs_score = None | |
953 self.max_ORF_length = None | |
954 self.tandem_rank = 0 | |
955 self.tarean_contig_file = None | |
956 | |
957 def __str__(self): | |
958 out = [ | |
959 "cluster no {}:".format(self.index), | |
960 "Number of vertices : {}".format(self.size), | |
961 "Number of edges : {}".format(self.ecount), | |
962 "Loop index : {}".format(self.loop_index), | |
963 "Pair completeness : {}".format(self.pair_completeness), | |
964 "Orientation score : {}".format(self.orientation_score) | |
965 ] | |
966 return "\n".join(out) | |
967 | |
968 def listing(self, asdict=True): | |
969 ''' convert attributes to dictionary for printing purposes''' | |
970 out = {} | |
971 for i in dir(self): | |
972 # do not show private | |
973 if i[:2] != "__": | |
974 value = getattr(self, i) | |
975 if not callable(value): | |
976 # for dictionary | |
977 if isinstance(value, dict): | |
978 for k in value: | |
979 out[i + "_" + k] = value[k] | |
980 else: | |
981 out[i] = value | |
982 if asdict: | |
983 return out | |
984 else: | |
985 return {'keys': list(out.keys()), 'values': list(out.values())} | |
986 | |
987 def detect_ltr(self, trna_database): | |
988 '''detection of ltr in assembly files, output of analysis is stored in file''' | |
989 CREATE_ANNOTATION = r2py.R(config.RSOURCE_create_annotation, verbose=False) | |
990 if self.assembly_files['{}.{}.ace']: | |
991 ace_file = self.assembly_files['{}.{}.ace'] | |
992 print(ace_file, "running LTR detection") | |
993 fout = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['BASE']) | |
994 subprocess.check_call([ | |
995 config.LTR_DETECTION, | |
996 '-i', ace_file, | |
997 '-o', fout, | |
998 '-p', trna_database]) | |
999 # evaluate LTR presence | |
1000 fn = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['PBS_BLAST']) | |
1001 self.ltr_detection = CREATE_ANNOTATION.evaluate_LTR_detection(fn) | |
1002 | |
1003 | |
1004 @staticmethod | |
1005 def _summarize_annotations(annotations_files, size): | |
1006 ''' will tabulate annotation results ''' | |
1007 # TODO | |
1008 summaries = {} | |
1009 # weight is in percentage | |
1010 weight = 100 / size | |
1011 for i in annotations_files: | |
1012 with open(annotations_files[i]) as f: | |
1013 header = f.readline().split() | |
1014 id_index = [ | |
1015 i for i, item in enumerate(header) if item == "db_id" | |
1016 ][0] | |
1017 for line in f: | |
1018 classification = line.split()[id_index].split("#")[1] | |
1019 if classification in summaries: | |
1020 summaries[classification] += weight | |
1021 else: | |
1022 summaries[classification] = weight | |
1023 # format summaries for printing | |
1024 annotation_string = "" | |
1025 annotation_table = [] | |
1026 for i in sorted(summaries.items(), key=lambda x: x[1], reverse=True): | |
1027 ## hits with smaller proportion are not shown! | |
1028 if i[1] > 0.1: | |
1029 if i[1] > 1: | |
1030 annotation_string += "<b>{1:.2f}% {0}</b>\n".format(*i) | |
1031 else: | |
1032 annotation_string += "{1:.2f}% {0}\n".format(*i) | |
1033 annotation_table.append(i) | |
1034 return [annotation_string, annotation_table] | |
1035 | |
1036 @staticmethod | |
1037 def add_cluster_table_to_database(cluster_table, db_path): | |
1038 '''get column names from Cluster object and create | |
1039 correspopnding table in database values from all | |
1040 clusters are filled to database''' | |
1041 column_name_and_type = [] | |
1042 column_list = [] | |
1043 | |
1044 # get all atribute names -> they are column names | |
1045 # in sqlite table, detect proper sqlite type | |
1046 def identity(x): | |
1047 return (x) | |
1048 | |
1049 for i in cluster_table[1]: | |
1050 t = type(cluster_table[1][i]) | |
1051 if t == int: | |
1052 sqltype = "integer" | |
1053 convert = identity | |
1054 elif t == float: | |
1055 sqltype = "real" | |
1056 convert = identity | |
1057 elif t == bool: | |
1058 sqltype = "boolean" | |
1059 convert = bool | |
1060 else: | |
1061 sqltype = "text" | |
1062 convert = str | |
1063 column_name_and_type += ["[{}] {}".format(i, sqltype)] | |
1064 column_list += [tuple((i, convert))] | |
1065 header = ", ".join(column_name_and_type) | |
1066 db = sqlite3.connect(db_path) | |
1067 c = db.cursor() | |
1068 print("CREATE TABLE cluster_info ({})".format(header)) | |
1069 c.execute("CREATE TABLE cluster_info ({})".format(header)) | |
1070 # file data to cluster_table | |
1071 buffer = [] | |
1072 for i in cluster_table: | |
1073 buffer.append(tuple('{}'.format(fun(i[j])) for j, fun in | |
1074 column_list)) | |
1075 wildcards = ",".join(["?"] * len(column_list)) | |
1076 print(buffer) | |
1077 c.executemany("insert into cluster_info values ({})".format(wildcards), | |
1078 buffer) | |
1079 db.commit() |