view lib/graphtools.py @ 0:1d1b9e1b2e2f draft

Uploaded
author petr-novak
date Thu, 19 Dec 2019 10:24:45 -0500
parents
children
line wrap: on
line source

#!/usr/bin/env python3
'''
This module is mainly for large graph (e.i hitsort) storage, parsing and for clustering
'''
import os
import sys
import sqlite3
import time
import subprocess
import logging
from collections import defaultdict
import collections
import operator
import math
import random
import itertools
import config
from lib import r2py
from lib.utils import FilePath
from lib.parallel.parallel import parallel2 as parallel
REQUIRED_VERSION = (3, 4)
MAX_BUFFER_SIZE = 100000
if sys.version_info < REQUIRED_VERSION:
    raise Exception("\n\npython 3.4 or higher is required!\n")
LOGGER = logging.getLogger(__name__)


def dfs(start, graph):
    """
    helper function for cluster merging.
    Does depth-first search, returning a set of all nodes seen.
    Takes: a graph in node --> [neighbors] form.
    """
    visited, worklist = set(), [start]

    while worklist:
        node = worklist.pop()
        if node not in visited:
            visited.add(node)
            # Add all the neighbors to the worklist.
            worklist.extend(graph[node])

    return visited


def graph_components(edges):
    """
    Given a graph as a list of edges, divide the nodes into components.
    Takes a list of pairs of nodes, where the nodes are integers.
    """

    # Construct a graph (mapping node --> [neighbors]) from the edges.
    graph = defaultdict(list)
    nodes = set()

    for v1, v2 in edges:
        nodes.add(v1)
        nodes.add(v2)

        graph[v1].append(v2)
        graph[v2].append(v1)

    # Traverse the graph to find the components.
    components = []

    # We don't care what order we see the nodes in.
    while nodes:
        component = dfs(nodes.pop(), graph)
        components.append(component)

        # Remove this component from the nodes under consideration.
        nodes -= component

    return components


class Graph():
    '''
    create Graph object stored in sqlite database, either in memory or on disk
    structure of table is:
    V1 V2 weigth12
    V2 V3 weight23
    V4 V5 weight45
    ...
    ...
    !! this is undirected simple graph - duplicated edges must
    be removed on graph creation

    '''
    # seed for random number generator - this is necessary for reproducibility between runs
    seed = '123'

    def __init__(self,
                 source=None,
                 filename=None,
                 new=False,
                 paired=True,
                 seqids=None):
        '''
        filename : fite where to store database, if not defined it is stored in memory
        source : ncol file from which describe graph
        new : if false and source is not define graph can be loaded from database (filename)

        vertices_name must be in correcti order!!!
        '''

        self.filename = filename
        self.source = source
        self.paired = paired
        # path to indexed graph - will be set later
        self.indexed_file = None
        self._cluster_list = None
        # these two attributes are set after clustering
        # communities before merging
        self.graph_2_community0 = None
        # communities after merging
        self.graph_2_community = None
        self.number_of_clusters = None
        self.binary_file = None
        self.cluster_sizes = None
        self.graph_tree = None
        self.graph_tree_log = None
        self.weights_file = None

        if filename:
            if os.path.isfile(filename) and (new or source):
                os.remove(filename)
            self.conn = sqlite3.connect(filename)
        else:
            self.conn = sqlite3.connect(":memory:")
        c = self.conn.cursor()

        c.execute("PRAGMA page_size=8192")
        c.execute("PRAGMA cache_size = 2000000 ")  # this helps

        try:
            c.execute((
                "create table graph (v1 integer, v2 integer, weight integer, "
                "pair integer, v1length integer, v1start  integer, v1end  integer, "
                "v2length  integer, v2start  integer, v2end  integer, pid  integer,"
                "evalue real, strand text )"))
        except sqlite3.OperationalError:
            pass  # table already exist
        else:
            c.execute(
                "create table vertices (vertexname text primary key, vertexindex integer)")
        tables = sorted(c.execute(
            "SELECT name FROM sqlite_master WHERE type='table'").fetchall())

        if not [('graph', ), ('vertices', )] == tables:
            raise Exception("tables for sqlite for graph are not correct")

        if source:
            self._read_from_hitsort()

        if paired and seqids:
            # vertices must be defined - create graph of paired reads:
            # last character must disinguish pair
            c.execute((
                "create table pairs (basename, vertexname1, vertexname2,"
                "v1 integer, v2 integer, cluster1 integer, cluster2 integer)"))
            buffer = []
            for i, k in zip(seqids[0::2], seqids[1::2]):
                assert i[:-1] == k[:-1], "problem with pair reads ids"
                # some vertices are not in graph - singletons
                try:
                    index1 = self.vertices[i]
                except KeyError:
                    index1 = -1

                try:
                    index2 = self.vertices[k]
                except KeyError:
                    index2 = -1

                buffer.append((i[:-1], i, k, index1, index2))

            self.conn.executemany(
                "insert into pairs (basename, vertexname1, vertexname2, v1, v2) values (?,?,?,?,?)",
                buffer)
            self.conn.commit()

    def _read_from_hitsort(self):

        c = self.conn.cursor()
        c.execute("delete from graph")
        buffer = []
        vertices = {}
        counter = 0
        v_count = 0
        with open(self.source, 'r') as f:
            for i in f:
                edge_index = {}
                items = i.split()
                # get or insert vertex index
                for vn in items[0:2]:
                    if vn not in vertices:
                        vertices[vn] = v_count
                        edge_index[vn] = v_count
                        v_count += 1
                    else:
                        edge_index[vn] = vertices[vn]
                if self.paired:
                    pair = int(items[0][:-1] == items[1][:-1])
                else:
                    pair = 0
                buffer.append(((edge_index[items[0]], edge_index[items[1]],
                                items[2], pair) + tuple(items[3:])))
                if len(buffer) == MAX_BUFFER_SIZE:
                    counter += 1
                    self.conn.executemany(
                        "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)",
                        buffer)
                    buffer = []
        if buffer:
            self.conn.executemany(
                "insert or ignore into graph values (?,?,?,?,?,?,?,?,?,?,?,?,?)",
                buffer)

        self.conn.commit()
        self.vertices = vertices
        self.vertexid2name = {
            vertex: index
            for index, vertex in vertices.items()
        }
        self.vcount = len(vertices)
        c = self.conn.cursor()
        c.execute("select count(*) from graph")
        self.ecount = c.fetchone()[0]
        # fill table of vertices
        self.conn.executemany("insert into vertices values (?,?)",
                              vertices.items())
        self.conn.commit()

    def save_indexed_graph(self, file=None):
        if not file:
            self.indexed_file = "{}.int".format(self.source)
        else:
            self.indexed_file = file
        c = self.conn.cursor()
        with open(self.indexed_file, 'w') as f:
            out = c.execute('select v1,v2,weight from graph')
            for v1, v2, weight in out:
                f.write('{}\t{}\t{}\n'.format(v1, v2, weight))

    def get_subgraph(self, vertices):
        pass

    def _levels(self):
        with open(self.graph_tree_log, 'r') as f:
            levels = -1
            for i in f:
                if i[:5] == 'level':
                    levels += 1
        return levels

    def _reindex_community(self, id2com):
        '''
        reindex community and superclusters so that biggest cluster is no.1
        '''
        self.conn.commit()
        _, community, supercluster = zip(*id2com)
        (cluster_index, frq, self.cluster_sizes,
         self.number_of_clusters) = self._get_index_and_frequency(community)

        supercluster_index, sc_frq, _, _ = self._get_index_and_frequency(
            supercluster)
        id2com_reindexed = []

        for i, _ in enumerate(id2com):
            id2com_reindexed.append((id2com[i][0], id2com[i][1], frq[
                i], cluster_index[i], supercluster_index[i], sc_frq[i]))
        return id2com_reindexed

    @staticmethod
    def _get_index_and_frequency(membership):
        frequency_table = collections.Counter(membership)
        frequency_table_sorted = sorted(frequency_table.items(),
                                        key=operator.itemgetter(1),
                                        reverse=True)
        frq = []
        for i in membership:
            frq.append(frequency_table[i])
        rank = {}
        index = 0
        for comm, _ in frequency_table_sorted:
            index += 1
            rank[comm] = index
        cluster_index = [rank[i] for i in membership]
        cluster_sizes = [i[1] for i in frequency_table_sorted]
        number_of_clusters = len(frequency_table)
        return [cluster_index, frq, cluster_sizes, number_of_clusters]

    def louvain_clustering(self, merge_threshold=0, cleanup=False):
        '''
        input - graph
        output - list of clusters
        executables path ??
        '''
        LOGGER.info("converting hitsort to binary format")
        self.binary_file = "{}.bin".format(self.indexed_file)
        self.weights_file = "{}.weight".format(self.indexed_file)
        self.graph_tree = "{}.graph_tree".format(self.indexed_file)
        self.graph_tree_log = "{}.graph_tree_log".format(self.indexed_file)
        self.graph_2_community0 = "{}.graph_2_community0".format(
            self.indexed_file)
        self._cluster_list = None
        self.graph_2_community = "{}.graph_2_community".format(
            self.indexed_file)
        print(["louvain_convert", "-i", self.indexed_file, "-o",
               self.binary_file, "-w", self.weights_file])
        subprocess.check_call(
            ["louvain_convert", "-i", self.indexed_file, "-o",
             self.binary_file, "-w", self.weights_file],
            timeout=None)

        gt = open(self.graph_tree, 'w')
        gtl = open(self.graph_tree_log, 'w')
        LOGGER.info("running louvain clustering...")
        subprocess.check_call(
            ["louvain_community", self.binary_file, "-l", "-1", "-w",
             self.weights_file, "-v ", "-s", self.seed],
            stdout=gt,
            stderr=gtl,
            timeout=None)
        gt.close()
        gtl.close()

        LOGGER.info("creating list of cummunities")
        gt2c = open(self.graph_2_community0, 'w')
        subprocess.check_call(
            ['louvain_hierarchy', self.graph_tree, "-l", str(self._levels())],
            stdout=gt2c)
        gt2c.close()
        if merge_threshold and self.paired:
            com2newcom = self.find_superclusters(merge_threshold)
        elif self.paired:
            com2newcom = self.find_superclusters(config.SUPERCLUSTER_THRESHOLD)
        else:
            com2newcom = {}
        # merging of clusters, creatting superclusters
        LOGGER.info("mergings clusters based on mate-pairs ")
        # modify self.graph_2_community file
        # rewrite graph2community
        with open(self.graph_2_community0, 'r') as fin:
            with open(self.graph_2_community, 'w') as fout:
                for i in fin:
                    # write  graph 2 community file in format:
                    # id communityid supeclusterid
                    # if merging - community and superclustwers are identical
                    vi, com = i.split()
                    if merge_threshold:
                        ## mergin
                        if int(com) in com2newcom:
                            fout.write("{} {} {}\n".format(vi, com2newcom[int(
                                com)], com2newcom[int(com)]))
                        else:
                            fout.write("{} {} {}\n".format(vi, com, com))
                    else:
                        ## superclusters
                        if int(com) in com2newcom:
                            fout.write("{} {} {}\n".format(vi, com, com2newcom[
                                int(com)]))
                        else:
                            fout.write("{} {} {}\n".format(vi, com, com))

        LOGGER.info("loading communities into database")
        c = self.conn.cursor()
        c.execute(("create table communities (vertexindex integer primary key,"
                   "community integer, size integer, cluster integer, "
                   "supercluster integer, supercluster_size integer)"))
        id2com = []
        with open(self.graph_2_community, 'r') as f:
            for i in f:
                name, com, supercluster = i.split()
                id2com.append((name, com, supercluster))
        id2com_reindexed = self._reindex_community(id2com)
        c.executemany("insert into communities values (?,?,?,?,?,?)",
                      id2com_reindexed)
        #create table of superclusters - clusters
        c.execute(("create table superclusters as "
                   "select distinct supercluster, supercluster_size, "
                   "cluster, size from communities;"))
        # create view id-index-cluster
        c.execute(
            ("CREATE VIEW vertex_cluster AS SELECT vertices.vertexname,"
             "vertices.vertexindex, communities.cluster, communities.size"
             " FROM vertices JOIN communities USING (vertexindex)"))
        self.conn.commit()

        # add clustering infor to graph
        LOGGER.info("updating graph table")
        t0 = time.time()

        c.execute("alter table graph add c1 integer")
        c.execute("alter table graph add c2 integer")
        c.execute(("update graph set c1 = (select cluster  FROM communities "
                   "where communities.vertexindex=graph.v1)"))
        c.execute(
            ("update graph set c2 = (select cluster  FROM communities where "
             "communities.vertexindex=graph.v2)"))
        self.conn.commit()
        t1 = time.time()
        LOGGER.info("updating graph table - done in {} seconds".format(t1 -
                                                                       t0))

        # identify similarity  connections between clusters
        c.execute(
            "create table cluster_connections as SELECT c1,c2 , count(*) FROM (SELECT c1, c2  FROM graph WHERE c1>c2 UNION ALL SELECT c2 as c1, c1  as c2 FROM graph WHERE c2>c1) GROUP BY c1, c2")
        # TODO - remove directionality - summarize -

        # add cluster identity to pairs table

        if self.paired:
            LOGGER.info("analyzing pairs ")
            t0 = time.time()
            c.execute(
                "UPDATE pairs SET cluster1=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v1)")
            t1 = time.time()
            LOGGER.info(
                "updating pairs table - cluster1 - done in {} seconds".format(
                    t1 - t0))

            t0 = time.time()
            c.execute(
                "UPDATE pairs SET cluster2=(SELECT cluster FROM communities WHERE communities.vertexindex=pairs.v2)")
            t1 = time.time()
            LOGGER.info(
                "updating pairs table - cluster2 - done in {} seconds".format(
                    t1 - t0))
            # reorder records

            t0 = time.time()
            c.execute(
                "UPDATE pairs SET cluster1=cluster2, cluster2=cluster1, vertexname1=vertexname2,vertexname2=vertexname1 where cluster1<cluster2")
            t1 = time.time()
            LOGGER.info("sorting - done in {} seconds".format(t1 - t0))

            t0 = time.time()
            c.execute(
                "create table cluster_mate_connections as select cluster1 as c1, cluster2 as c2, count(*) as N, group_concat(basename) as ids from pairs where cluster1!=cluster2 group by cluster1, cluster2;")
            t1 = time.time()
            LOGGER.info(
                "creating cluster_mate_connections table - done in {} seconds".format(
                    t1 - t0))
            # summarize
            t0 = time.time()
            self._calculate_pair_bond()
            t1 = time.time()
            LOGGER.info(
                "calculating cluster pair bond - done in {} seconds".format(
                    t1 - t0))
            t0 = time.time()
        else:
            # not paired - create empty tables
            self._add_empty_tables()

        self.conn.commit()
        t1 = time.time()
        LOGGER.info("commiting changes - done in {} seconds".format(t1 - t0))

        if cleanup:
            LOGGER.info("cleaning clustering temp files")
            os.unlink(self.binary_file)
            os.unlink(self.weights_file)
            os.unlink(self.graph_tree)
            os.unlink(self.graph_tree_log)
            os.unlink(self.graph_2_community0)
            os.unlink(self.graph_2_community)
            os.unlink(self.indexed_file)
            self.binary_file = None
            self.weights_file = None
            self.graph_tree = None
            self.graph_tree_log = None
            self.graph_2_community0 = None
            self.graph_2_community = None
            self.indexed_file = None

        # calcultate k

    def find_superclusters(self, merge_threshold):
        '''Find superclusters from clustering based on paired reads '''
        clsdict = {}
        with open(self.graph_2_community0, 'r') as f:
            for i in f:
                vi, com = i.split()
                if com in clsdict:
                    clsdict[com] += [self.vertexid2name[int(vi)][0:-1]]
                else:
                    clsdict[com] = [self.vertexid2name[int(vi)][0:-1]]
        # remove all small clusters - these will not be merged:
        small_cls = []
        for i in clsdict:
            if len(clsdict[i]) < config.MINIMUM_NUMBER_OF_READS_FOR_MERGING:
                small_cls.append(i)
        for i in small_cls:
            del clsdict[i]
        pairs = []
        for i, j in itertools.combinations(clsdict, 2):
            s1 = set(clsdict[i])
            s2 = set(clsdict[j])
            wgh = len(s1 & s2)
            if wgh < config.MINIMUM_NUMBER_OF_SHARED_PAIRS_FOR_MERGING:
                continue
            else:
                n1 = len(s1) * 2 - len(clsdict[i])
                n2 = len(s2) * 2 - len(clsdict[j])
                k = 2 * wgh / (n1 + n2)
            if k > merge_threshold:
                pairs.append((int(i), int(j)))
        # find connected commponents - will be merged
        cls2merge = graph_components(pairs)
        com2newcom = {}
        for i in cls2merge:
            newcom = min(i)
            for j in i:
                com2newcom[j] = newcom
        return com2newcom

    def adjust_cluster_size(self, proportion_kept, ids_kept):
        LOGGER.info("adjusting cluster sizes")
        c = self.conn.cursor()
        c.execute("ALTER TABLE superclusters ADD COLUMN size_uncorrected INTEGER")
        c.execute("UPDATE superclusters SET size_uncorrected=size")
        if ids_kept:
            ids_kept_set = set(ids_kept)
            ratio = (1 - proportion_kept)/proportion_kept
            for cl, size in c.execute("SELECT cluster,size FROM superclusters"):
                ids = self.get_cluster_reads(cl)
                ovl_size = len(ids_kept_set.intersection(ids))
                size_adjusted = int(len(ids) + ovl_size * ratio)
                if size_adjusted > size:
                    c.execute("UPDATE superclusters SET size=? WHERE cluster=?",
                              (size_adjusted, cl))
        self.conn.commit()
        LOGGER.info("adjusting cluster sizes - done")

    def export_cls(self, path):
        with open(path, 'w') as f:
            for i in range(1, self.number_of_clusters + 1):
                ids = self.get_cluster_reads(i)
                f.write(">CL{}\t{}\n".format(i, len(ids)))
                f.write("\t".join(ids))
                f.write("\n")

    def _calculate_pair_bond(self):
        c = self.conn.cursor()
        out = c.execute("select c1, c2, ids from cluster_mate_connections")
        buffer = []
        for c1, c2, ids in out:
            w = len(set(ids.split(",")))
            n1 = len(set([i[:-1] for i in self.get_cluster_reads(c1)
                          ])) * 2 - len(self.get_cluster_reads(c1))
            n2 = len(set([i[:-1] for i in self.get_cluster_reads(c2)
                          ])) * 2 - len(self.get_cluster_reads(c2))
            buffer.append((c1, c2, n1, n2, w, 2 * w / (n1 + n2)))
        c.execute(
            "CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)")
        c.executemany(" INSERT INTO cluster_mate_bond values (?,?,?,?,?,?)",
                      buffer)

    def _add_empty_tables(self):
        '''This is used with reads that are not paired
        - it creates empty mate tables, this is necessary for
        subsequent reporting to work corectly '''
        c = self.conn.cursor()
        c.execute(("CREATE TABLE cluster_mate_bond (c1 INTEGER, c2 INTEGER, "
                   "n1 INTEGER, n2 INTEGER, w INTEGER, k FLOAT)"))
        c.execute(
            "CREATE TABLE cluster_mate_connections (c1 INTEGER, c2 INTEGER, N INTEGER, ids TEXT) ")

    def get_cluster_supercluster(self, cluster):
        '''Get supercluster id for suplied cluster '''
        c = self.conn.cursor()
        out = c.execute(
            'SELECT supercluster FROM communities WHERE cluster="{0}" LIMIT 1'.format(
                cluster))
        sc = out.fetchone()[0]
        return sc

    def get_cluster_reads(self, cluster):

        if self._cluster_list:
            return self._cluster_list[str(cluster)]
        else:
            # if queried first time
            c = self.conn.cursor()
            out = c.execute("select cluster, vertexname from vertex_cluster")
            cluster_list = collections.defaultdict(list)
            for clusterindex, vertexname in out:
                cluster_list[str(clusterindex)].append(vertexname)
            self._cluster_list = cluster_list
            return self._cluster_list[str(cluster)]

        
    def extract_cluster_blast(self, path, index, ids=None):
        ''' Extract blast for cluster and save it to path
        return number of blast lines ( i.e. number of graph edges E)
        if ids is specified , only subset of blast is used'''
        c = self.conn.cursor()
        if ids:
            vertexindex = (
                "select vertexindex from vertices "
                "where vertexname in ({})").format('"' + '","'.join(ids) + '"')

            out = c.execute(("select * from graph where c1={0} and c2={0}"
                             " and v1 in ({1}) and v2 in ({1})").format(
                                 index, vertexindex))
        else:
            out = c.execute(
                "select * from graph where c1={0} and c2={0}".format(index))
        E = 0
        N = len(self.get_cluster_reads(index))
        with open(path, 'w') as f:
            for i in out:
                print(self.vertexid2name[i[0]],
                      self.vertexid2name[
                          i[1]],
                      i[2],
                      *i[4:13],
                      sep='\t',
                      file=f)
                E += 1
        return E

    def export_clusters_files_multiple(self,
                                       min_size,
                                       directory,
                                       sequences=None,
                                       tRNA_database_path=None,
                                       satellite_model_path=None):
        def load_fun(N, E):
            ''' estimate mem usage from graph size and density'''
            NE = math.log(float(N) * float(E), 10)
            if NE > 11.5:
                return 1
            if NE > 11:
                return 0.9
            if NE > 10:
                return 0.4
            if NE > 9:
                return 0.2
            if NE > 8:
                return 0.07
            return 0.02

        def estimate_sample_size(NV, NE, maxv, maxe):
            ''' estimat suitable sampling based on the graph density
            NV,NE is |V| and |E| of the graph
            maxv, maxe are maximal |V| and |E|'''

            d = (2 * NE) / (NV * (NV - 1))
            eEst = (maxv * (maxv - 1) * d) / 2
            nEst = (d + math.sqrt(d**2 + 8 * d * maxe)) / (2 * d)
            if eEst >= maxe:
                N = int(nEst)
            if nEst >= maxv:
                N = int(maxv)
            return N

        clusterindex = 1
        cluster_input_args = []
        ppn = []
        # is is comparative analysis?
        if sequences.prefix_length:
            self.conn.execute("CREATE TABLE comparative_counts (clusterindex INTEGER,"
                    + ", ".join(["[{}] INTEGER".format(i) for i in sequences.prefix_codes.keys()]) + ")")
                    # do for comparative analysis

            for cl in range(self.number_of_clusters):
                prefix_codes = dict((key, 0) for key in sequences.prefix_codes.keys())
                for i in self.get_cluster_reads(cl):
                    prefix_codes[i[0:sequences.prefix_length]] += 1
                header = ", ".join(["[" + str(i) + "]" for i in prefix_codes.keys()])
                values = ", ".join([str(i) for i in prefix_codes.values()])
                self.conn.execute(
                    "INSERT INTO comparative_counts (clusterindex, {}) VALUES ({}, {})".format(
                        header, cl, values))
        else:
            prefix_codes = {}

        while True:
            read_names = self.get_cluster_reads(clusterindex)
            supercluster = self.get_cluster_supercluster(clusterindex)
            N = len(read_names)
            print("sequences.ids_kept -2 ")
            print(sequences.ids_kept)
            if sequences.ids_kept:
                N_adjusted = round(len(set(sequences.ids_kept).intersection(read_names)) *
                                   ((1 - config.FILTER_PROPORTION_OF_KEPT) /
                                    config.FILTER_PROPORTION_OF_KEPT) + N)
            else:
                N_adjusted = N
            if N < min_size:
                break
            else:
                LOGGER.info("exporting cluster {}".format(clusterindex))
                blast_file = "{dir}/dir_CL{i:04}/hitsort_part.csv".format(
                    dir=directory, i=clusterindex)
                cluster_dir = "{dir}/dir_CL{i:04}".format(dir=directory,
                                                          i=clusterindex)
                fasta_file = "{dir}/reads_selection.fasta".format(dir=cluster_dir)
                fasta_file_full = "{dir}/reads.fasta".format(dir=cluster_dir)

                os.makedirs(os.path.dirname(blast_file), exist_ok=True)
                E = self.extract_cluster_blast(index=clusterindex,
                                               path=blast_file)
                # check if blast must be sampled
                n_sample = estimate_sample_size(NV=N,
                                                NE=E,
                                                maxv=config.CLUSTER_VMAX,
                                                maxe=config.CLUSTER_EMAX)
                LOGGER.info("directories created..")
                if n_sample < N:
                    LOGGER.info(("cluster is too large - sampling.."
                                 "original size: {N}\n"
                                 "sample size: {NS}\n"
                                 "").format(N=N, NS=n_sample))
                    random.seed(self.seed)
                    read_names_sample = random.sample(read_names, n_sample)
                    LOGGER.info("reads id sampled...")
                    blast_file_sample = "{dir}/dir_CL{i:04}/blast_sample.csv".format(
                        dir=directory, i=clusterindex)
                    E_sample = self.extract_cluster_blast(
                        index=clusterindex,
                        path=blast_file,
                        ids=read_names_sample)
                    LOGGER.info("numner of edges in sample: {}".format(
                        E_sample))
                    sequences.save2fasta(fasta_file, subset=read_names_sample)
                    sequences.save2fasta(fasta_file_full, subset=read_names)

                else:
                    read_names_sample = None
                    E_sample = None
                    blast_file_sample = None
                    n_sample = None
                    sequences.save2fasta(fasta_file_full, subset=read_names)
                    ## TODO - use symlink instead of :
                    sequences.save2fasta(fasta_file, subset=read_names)
                # export individual annotations tables:
                # annotation is always for full cluster
                LOGGER.info("exporting cluster annotation")
                annotations = {}
                annotations_custom = {}
                for n in sequences.annotations:
                    print("sequences.annotations:", n)
                    if n.find("custom_db") == 0:
                        print("custom")
                        annotations_custom[n] = sequences.save_annotation(
                            annotation_name=n,
                            subset=read_names,
                            dir=cluster_dir)
                    else:
                        print("built in")
                        annotations[n] = sequences.save_annotation(
                            annotation_name=n,
                            subset=read_names,
                            dir=cluster_dir)

                cluster_input_args.append([
                    n_sample, N,N_adjusted, blast_file, fasta_file, fasta_file_full,
                    clusterindex, supercluster, self.paired,
                    tRNA_database_path, satellite_model_path, sequences.prefix_codes,
                    prefix_codes, annotations, annotations_custom
                ])
                clusterindex += 1
                ppn.append(load_fun(N, E))

                    

        self.conn.commit()

        # run in parallel:
        # reorder jobs based on the ppn:
        cluster_input_args = [
            x
            for (y, x) in sorted(
                zip(ppn, cluster_input_args),
                key=lambda pair: pair[0],
                reverse=True)
        ]
        ppn = sorted(ppn, reverse=True)
        LOGGER.info("creating clusters in parallel")
        clusters_info = parallel(Cluster,
                                 *[list(i) for i in zip(*cluster_input_args)],
                                 ppn=ppn)
        # sort it back:
        clusters_info = sorted(clusters_info, key=lambda cl: cl.index)
        return clusters_info


class Cluster():
    ''' store and show information about cluster properties '''

    def __init__(self,
                 size,
                 size_real,
                 size_adjusted,
                 blast_file,
                 fasta_file,
                 fasta_file_full,
                 index,
                 supercluster,
                 paired,
                 tRNA_database_path,
                 satellite_model_path,
                 all_prefix_codes,
                 prefix_codes,
                 annotations,
                 annotations_custom={},
                 loop_index_threshold=0.7,
                 pair_completeness_threshold=0.40,
                 loop_index_unpaired_threshold=0.85):
        if size:
            # cluster was scaled down
            self.size = size
            self.size_real = size_real
        else:
            self.size = self.size_real = size_real
        self.size_adjusted = size_adjusted
        self.filtered = True if size_adjusted != size_real else False
        self.all_prefix_codes = all_prefix_codes.keys
        self.prefix_codes = prefix_codes
        self.dir = FilePath(os.path.dirname(blast_file))
        self.blast_file = FilePath(blast_file)
        self.fasta_file = FilePath(fasta_file)
        self.fasta_file_full = FilePath(fasta_file_full)
        self.index = index
        self.assembly_files = {}
        self.ltr_detection = None
        self.supercluster = supercluster
        self.annotations_files = annotations
        self.annotations_files_custom = annotations_custom
        self.annotations_summary, self.annotations_table = self._summarize_annotations(
            annotations, size_real)
        # add annotation
        if len(annotations_custom):
            self.annotations_summary_custom, self.annotations_custom_table = self._summarize_annotations(
                annotations_custom, size_real)
        else:
            self.annotations_summary_custom, self.annotations_custom_table = "", ""

        self.paired = paired
        self.graph_file = FilePath("{0}/graph_layout.GL".format(self.dir))
        self.directed_graph_file = FilePath(
            "{0}/graph_layout_directed.RData".format(self.dir))
        self.fasta_oriented_file = FilePath("{0}/reads_selection_oriented.fasta".format(
            self.dir))
        self.image_file = FilePath("{0}/graph_layout.png".format(self.dir))
        self.image_file_tmb = FilePath("{0}/graph_layout_tmb.png".format(self.dir))
        self.html_report_main = FilePath("{0}/index.html".format(self.dir))
        self.html_report_files = FilePath("{0}/html_files".format(self.dir))
        self.supercluster_best_hit = "NA"
        TAREAN = r2py.R(config.RSOURCE_tarean)
        LOGGER.info("creating graph no.{}".format(self.index))
        # if FileType muast be converted to str for rfunctions
        graph_info = eval(
            TAREAN.mgblast2graph(
                self.blast_file,
                seqfile=self.fasta_file,
                seqfile_full=self.fasta_file_full,
                graph_destination=self.graph_file,
                directed_graph_destination=self.directed_graph_file,
                oriented_sequences=self.fasta_oriented_file,
                image_file=self.image_file,
                image_file_tmb=self.image_file_tmb,
                repex=True,
                paired=self.paired,
                satellite_model_path=satellite_model_path,
                maxv=config.CLUSTER_VMAX,
                maxe=config.CLUSTER_EMAX)
        )
        print(graph_info)
        self.ecount = graph_info['ecount']
        self.vcount = graph_info['vcount']
        self.loop_index = graph_info['loop_index']
        self.pair_completeness = graph_info['pair_completeness']
        self.orientation_score = graph_info['escore']
        self.satellite_probability = graph_info['satellite_probability']
        self.satellite = graph_info['satellite']
        # for paired reads:
        cond1 = (self.paired and self.loop_index > loop_index_threshold and
                 self.pair_completeness > pair_completeness_threshold)
        # no pairs
        cond2 = ((not self.paired) and
                 self.loop_index > loop_index_unpaired_threshold)
        if (cond1 or cond2) and config.ARGS.options.name != "oxford_nanopore":
            self.putative_tandem = True
            self.dir_tarean = FilePath("{}/tarean".format(self.dir))
            lock_file = self.dir + "../lock"
            out = eval(
                TAREAN.tarean(input_sequences=self.fasta_oriented_file,
                              output_dir=self.dir_tarean,
                              CPU=1,
                              reorient_reads=False,
                              tRNA_database_path=tRNA_database_path,
                              lock_file=lock_file)
            )
            self.html_tarean = FilePath(out['htmlfile'])
            self.tarean_contig_file = out['tarean_contig_file']
            self.TR_score = out['TR_score']
            self.TR_monomer_length = out['TR_monomer_length']
            self.TR_consensus = out['TR_consensus']
            self.pbs_score = out['pbs_score']
            self.max_ORF_length = out['orf_l']
            if (out['orf_l'] > config.ORF_THRESHOLD or
                    out['pbs_score'] > config.PBS_THRESHOLD):
                self.tandem_rank = 3
            elif self.satellite:
                self.tandem_rank = 1
            else:
                self.tandem_rank = 2
            # some tandems could be rDNA genes - this must be check
            # by annotation
            if self.annotations_table:
                rdna_score = 0
                contamination_score = 0
                for i in self.annotations_table:
                    if 'rDNA/' in i[0]:
                        rdna_score += i[1]
                    if 'contamination' in i[0]:
                        contamination_score += i[1]
                if rdna_score > config.RDNA_THRESHOLD:
                    self.tandem_rank = 4
                if contamination_score > config.CONTAMINATION_THRESHOLD:
                    self.tandem_rank = 0  # other

            # by custom annotation - castom annotation has preference
            if self.annotations_custom_table:
                print("custom table searching")
                rdna_score = 0
                contamination_score = 0
                print(self.annotations_custom_table)
                for i in self.annotations_custom_table:
                    if 'rDNA' in i[0]:
                        rdna_score += i[1]
                    if 'contamination' in i[0]:
                        contamination_score += i[1]
                if rdna_score > 0:
                    self.tandem_rank = 4
                if contamination_score > config.CONTAMINATION_THRESHOLD:
                    self.tandem_rank = 0  # other

        else:
            self.putative_tandem = False
            self.dir_tarean = None
            self.html_tarean = None
            self.TR_score = None
            self.TR_monomer_length = None
            self.TR_consensus = None
            self.pbs_score = None
            self.max_ORF_length = None
            self.tandem_rank = 0
            self.tarean_contig_file = None

    def __str__(self):
        out = [
            "cluster no {}:".format(self.index),
            "Number of vertices  : {}".format(self.size),
            "Number of edges     : {}".format(self.ecount),
            "Loop index          : {}".format(self.loop_index),
            "Pair completeness   : {}".format(self.pair_completeness),
            "Orientation score   : {}".format(self.orientation_score)
        ]
        return "\n".join(out)

    def listing(self, asdict=True):
        ''' convert attributes to dictionary for printing purposes'''
        out = {}
        for i in dir(self):
            # do not show private
            if i[:2] != "__":
                value = getattr(self, i)
                if not callable(value):
                    # for dictionary
                    if isinstance(value, dict):
                        for k in value:
                            out[i + "_" + k] = value[k]
                    else:
                        out[i] = value
        if asdict:
            return out
        else:
            return {'keys': list(out.keys()), 'values': list(out.values())}

    def detect_ltr(self, trna_database):
        '''detection of ltr in assembly files, output of analysis is stored in file'''
        CREATE_ANNOTATION = r2py.R(config.RSOURCE_create_annotation, verbose=False)
        if self.assembly_files['{}.{}.ace']:
            ace_file = self.assembly_files['{}.{}.ace']
            print(ace_file, "running LTR detection")
            fout = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['BASE'])
            subprocess.check_call([
                config.LTR_DETECTION,
                '-i', ace_file,
                '-o', fout,
                '-p', trna_database])
            # evaluate LTR presence
            fn = "{}/{}".format(self.dir, config.LTR_DETECTION_FILES['PBS_BLAST'])
            self.ltr_detection = CREATE_ANNOTATION.evaluate_LTR_detection(fn)


    @staticmethod
    def _summarize_annotations(annotations_files, size):
        ''' will tabulate annotation results '''
        # TODO
        summaries = {}
        # weight is in percentage
        weight = 100 / size
        for i in annotations_files:
            with open(annotations_files[i]) as f:
                header = f.readline().split()
                id_index = [
                    i for i, item in enumerate(header) if item == "db_id"
                ][0]
                for line in f:
                    classification = line.split()[id_index].split("#")[1]
                    if classification in summaries:
                        summaries[classification] += weight
                    else:
                        summaries[classification] = weight
        # format summaries for printing
        annotation_string = ""
        annotation_table = []
        for i in sorted(summaries.items(), key=lambda x: x[1], reverse=True):
            ## hits with smaller proportion are not shown!
            if i[1] > 0.1:
                if i[1] > 1:
                    annotation_string += "<b>{1:.2f}% {0}</b>\n".format(*i)
                else:
                    annotation_string += "{1:.2f}% {0}\n".format(*i)
            annotation_table.append(i)
        return [annotation_string, annotation_table]

    @staticmethod
    def add_cluster_table_to_database(cluster_table, db_path):
        '''get column names from Cluster object and create
        correspopnding table in database values from all
        clusters are filled to database'''
        column_name_and_type = []
        column_list = []

        # get all atribute names -> they are column names
        # in sqlite table, detect proper sqlite type
        def identity(x):
            return (x)

        for i in cluster_table[1]:
            t = type(cluster_table[1][i])
            if t == int:
                sqltype = "integer"
                convert = identity
            elif t == float:
                sqltype = "real"
                convert = identity
            elif t == bool:
                sqltype = "boolean"
                convert = bool
            else:
                sqltype = "text"
                convert = str
            column_name_and_type += ["[{}] {}".format(i, sqltype)]
            column_list += [tuple((i, convert))]
        header = ", ".join(column_name_and_type)
        db = sqlite3.connect(db_path)
        c = db.cursor()
        print("CREATE TABLE cluster_info ({})".format(header))
        c.execute("CREATE TABLE cluster_info ({})".format(header))
        # file data to cluster_table
        buffer = []
        for i in cluster_table:
            buffer.append(tuple('{}'.format(fun(i[j])) for j, fun in
                                column_list))
        wildcards = ",".join(["?"] * len(column_list))
        print(buffer)
        c.executemany("insert into cluster_info values ({})".format(wildcards),
                      buffer)
        db.commit()