"""
@summary: Calculate the frequency of eQTLs and genes per bin
@author: nanette.coetzer@gmail.com
@version 5

"""
import optparse, sys
import subprocess
import tempfile
import os, re

def stop_err( msg ):
    sys.stderr.write( "%s\n" % msg )
    sys.exit()
 
def __main__():
    #Parse Command Line
    parser = optparse.OptionParser()
    parser.add_option("-i", "--input1", default=None, dest="input1", 
                      help="Lookup table file")
    parser.add_option("-j", "--input2", default=None, dest="input2", 
                      help="Gene positions file")
    parser.add_option("-k", "--input3", default=None, dest="input3", 
                      help="eQTL all classification file")
    parser.add_option("-l", "--input4", default=None, dest="input4", 
                      help="eQTL cis classification file")
    parser.add_option("-m", "--input5", default=None, dest="input5", 
                      help="eQTL trans classification file")
    parser.add_option("-n", "--num_intervals", default=2, dest="num_intervals", 
                      help="Select the number intervals to include per sliding window")
    parser.add_option("-g", "--num_permutations", default=1000, dest="num_permutations", 
                      help="Number of permutations used to calculate if eQTLs per cM is significantly higher than expected")
    parser.add_option("-o", "--output1", default=None, dest="output1", 
                      help="Frequency per cM file")
    parser.add_option("-p", "--output2", default=None, dest="output2", 
                      help="Frequency summary file")
    
    (options, args) = parser.parse_args()
    
    try:
        open(options.input1, "r").close()
    except TypeError, e:
        stop_err("You need to supply the Lookup table file:\n" + str(e))
    except IOError, e:
        stop_err("Can not open the Lookup table file:\n" + str(e))
        
    try:
        open(options.input2, "r").close()
    except TypeError, e:
        stop_err("You need to supply the Gene positions file:\n" + str(e))
    except IOError, e:
        stop_err("Can not open the Gene positions file:\n" + str(e))
        
    try:
        open(options.input3, "r").close()
    except TypeError, e:
        stop_err("You need to supply the eQTL all classification  file:\n" + str(e))
    except IOError, e:
        stop_err("Can not open the eQTL all classification  file:\n" + str(e))
    
    try:
        open(options.input4, "r").close()
    except TypeError, e:
        stop_err("You need to supply the eQTL cis classification file:\n" + str(e))
    except IOError, e:
        stop_err("Can not open the eQTL cis classification  file:\n" + str(e))
        
    try:
        open(options.input5, "r").close()
    except TypeError, e:
        stop_err("You need to supply the eQTL trans classification file:\n" + str(e))
    except IOError, e:
        stop_err("Can not open the eQTL trans classification  file:\n" + str(e))
        
    try:
        options.num_intervals = int(options.num_intervals)
    except TypeError, e:
        stop_err("Not an integer, 1, 2 or 3 intervals can be included per sliding window:\n" + str(e))
        
    try:
        options.num_permutations = int(options.num_permutations)
    except TypeError, e:
        stop_err("Not an integer, must be an integer between 1 and 5000:\n" + str(e))

    ##############################################
    ##############################################

    # Aim 1: add interval column to eqtl.txt => [interval id]
    
    # Aim 2: sort the temp output file (eqtl_intervals.txt) => [interval id]
    def sortfunc(temp_intervals, sorted_intervals):
        cmd = 'sort -k1 %s > %s'%(temp_intervals, sorted_intervals)
        os.system(cmd)
        
    # Aim 3: calculate peak eQTL frequency per bin => [interval id, freq eQTL peaks]
    def freq_per_bin(sorted_intervals, freq_per_bin_out):
        sorted_intervals = open(sorted_intervals,'r')
        freq_per_bin_out = open(freq_per_bin_out,'w')
        count = 1
        bin_prev = 0
        tot = 0
        for line in sorted_intervals:
            bin = line.strip()
            if bin == bin_prev:
                count += 1
                bin_prev = bin      
            else:
                if bin_prev != 0:
                    #print str(bin_prev)+"\t"+str(count)
                    freq_per_bin_out.write(str(bin_prev)+"\t"+str(count)+"\n")
                    tot += int(count)
                count = 1
                bin_prev = bin
        #print str(bin_prev)+"\t"+str(count)
        freq_per_bin_out.write(str(bin_prev)+"\t"+str(count)+"\n")
        tot += int(count)
        sorted_intervals.close()
        freq_per_bin_out.close()
        return tot

    #############################################
    #############################################
    
    # For eQTL all classification:
    infile1 = open(options.input1,'r')  # Lookup.txt
    infile2 = open(options.input3,'r')  # all eQTLs.txt
    
    # create temp output file
    eqtl_temp = tempfile.mktemp()
    outfile_all_eqtl = open(eqtl_temp,'w')
    
    # Aim is to write the peak eQTL bin (interval) number to a file --> later sort --> count adjacent frequencies
    chr_int_dict = {}
    chrdict = {}
    i = 0
    for line in infile1: # lookup
        if i > 0:
            l = line.strip().split("\t")
            #chr_marker_int = l[1]+"\t"+l[2]+"\t"+str(round(float(l[3]),2))
            chr_marker_int = l[1]+"\t"+str(round(float(l[3]),4))
            chr_int_dict[chr_marker_int] = l[0]
            chrdict[l[0]] = l[1]
        i += 1
    
    for line2 in infile2:   # eqtl classification
        l2 = line2.strip().split("\t")
        if not l2[3].startswith("s"):
            chr_int_eqtl = l2[2]+"\t"+str(round(float(l2[8]),4))
            try:
                ans = chr_int_dict[chr_int_eqtl]
            except:
                ans = ""
            if ans != "":
                outfile_all_eqtl.write(str(ans)+"\n")
            
    infile1.close()
    infile2.close()
    outfile_all_eqtl.close()
    
    #############################################
    
    # create temp output file
    sorted_intervals = tempfile.mktemp()
    sortfunc(outfile_all_eqtl.name, sorted_intervals)    # use outfile.name because the file already exists!
    
    # create temp output file
    freq_per_bin_out_eqtl_all = tempfile.mktemp()
    tot_eqtls = freq_per_bin(sorted_intervals, freq_per_bin_out_eqtl_all)
    
    #############################################
    #############################################

    # For eQTL cis classification:
    infile1 = open(options.input1,'r')  # Lookup.txt
    infile2 = open(options.input4,'r')  # cis eQTLs.txt
    
    # create temp output file
    eqtl_temp = tempfile.mktemp()
    outfile_cis_eqtl = open(eqtl_temp,'w')
    
    # Aim is to write the peak eQTL bin (interval) number to a file --> later sort --> count adjacent frequencies
    chr_int_dict = {}
    chrdict = {}
    i = 0
    for line in infile1: # lookup
        if i > 0:
            l = line.strip().split("\t")
            chr_marker_int = l[1]+"\t"+str(round(float(l[3]),4))
            chr_int_dict[chr_marker_int] = l[0]
            chrdict[l[0]] = l[1]
        i += 1
    for line2 in infile2:   # eqtl classification
        l2 = line2.strip().split("\t")
        if not l2[3].startswith("s"):
            chr_int_eqtl = l2[2]+"\t"+str(round(float(l2[8]),4))
            try:
                ans = chr_int_dict[chr_int_eqtl]
            except:
                ans = ""
            if ans != "":
                outfile_cis_eqtl.write(str(ans)+"\n")

    infile1.close()
    infile2.close()
    outfile_cis_eqtl.close()
    
    #############################################
    
    # create temp output file
    sorted_intervals = tempfile.mktemp()
    sortfunc(outfile_cis_eqtl.name, sorted_intervals)    # use outfile.name because the file already exists!
    
    # create temp output file
    freq_per_bin_out_eqtl_cis = tempfile.mktemp()
    tot_cis = freq_per_bin(sorted_intervals, freq_per_bin_out_eqtl_cis)
    
    #############################################
    #############################################
    
     # For eQTL trans classification:
    infile1 = open(options.input1,'r')  # Lookup.txt
    infile2 = open(options.input5,'r')  # trans eQTLs.txt
    
    # create temp output file
    eqtl_temp = tempfile.mktemp()
    outfile_trans_eqtl = open(eqtl_temp,'w')
    
    # Aim is to write the peak eQTL bin (interval) number to a file --> later sort --> count adjacent frequencies
    chr_int_dict = {}
    chrdict = {}
    i = 0
    for line in infile1: # lookup
        if i > 0:
            l = line.strip().split("\t")
            chr_marker_int = l[1]+"\t"+str(round(float(l[3]),4))
            chr_int_dict[chr_marker_int] = l[0]
            chrdict[l[0]] = l[1]
        i += 1
    for line2 in infile2:   # eqtl classification
        l2 = line2.strip().split("\t")
        if not l2[3].startswith("s"):
            chr_int_eqtl = l2[2]+"\t"+str(round(float(l2[8]),4))
            try:
                ans = chr_int_dict[chr_int_eqtl]
            except:
                ans = ""
            if ans != "":
                outfile_trans_eqtl.write(str(ans)+"\n")
    
    infile1.close()
    infile2.close()
    outfile_trans_eqtl.close()
        
    #############################################
    
    # create temp output file
    sorted_intervals = tempfile.mktemp()
    sortfunc(outfile_trans_eqtl.name, sorted_intervals)    # use outfile.name because the file already exists!
    
    # create temp output file
    freq_per_bin_out_eqtl_trans = tempfile.mktemp()
    tot_trans = freq_per_bin(sorted_intervals, freq_per_bin_out_eqtl_trans)
    

    #############################################
    #############################################
    
    # For genes:

    infile1 = open(options.input1,'r')  # Lookup.txt
    infile2 = open(options.input2,'r')  # gene_positions.txt

    temp_file = tempfile.mktemp()
    tmp_out = open(temp_file,'w') # temp_file.txt
    
    tot_cM = 0
    header = 0
    intlist = []
    prev_line = ""
    for line in infile1:
        l = line.strip().split("\t")
        if l[5].startswith("bp"):
            header = 1
        else:
            header = 0
        if header == 0:
            intlist.append(l[0])
            if prev_line != "":
                pl = prev_line.split("\t")
                tmp_out.write(pl[0]+"\t"+pl[1]+"\t"+pl[5]+"\t"+l[5]+"\n")
                if l[4] == "0.0":
                    tot_cM += float(pl[4])
            prev_line = line.strip()
    
    pl = prev_line.split("\t")
    tot_cM += float(pl[4])   
    tmp_out.close()
    infile1.close()
    
    temp_intervals_genes = tempfile.mktemp()
    outfile_genes = open(temp_intervals_genes,'w') # temp_intervals_genes.txt
    i = 0
    header = 0
    for line2 in infile2:
        l2 = line2.strip().split("\t")
        if l2[1].startswith("c"):
            header = 1
        else:
            header = 0
        if header == 0:
            chr = l2[1]     # keep a string
            mid_pos = (int(l2[2]) + int(l2[3]))/float(2)    # float
            tmpfile = open(temp_file,'r')
            for line in tmpfile:
                l = line.strip().split("\t")
                if (l[1] == chr) and (mid_pos >= int(float(l[2]))) and (mid_pos < int(float(l[3]))):
                    outfile_genes.write(str(l[0])+"\n")
                    i += 1
            tmpfile.close()
    infile2.close()
    outfile_genes.close()
    
    
    #############################################
    # For genes:
    
    sorted_intervals = tempfile.mktemp()
    sortfunc(outfile_genes.name, sorted_intervals)    # use outfile.name because the file already exists!
    
    freq_per_bin_out_gene = tempfile.mktemp()
    tot_genes = freq_per_bin(sorted_intervals, freq_per_bin_out_gene)
    
    #############################################
    #############################################
    
    # For genes and eQTLs:
    summary = open(options.output2,'w')    # freq_summary.txt
    summary.write("Total number of eQTLs (all)\t"+str(tot_eqtls)+"\nTotal number of cis-eQTLs\t"+str(tot_cis)+"\nTotal number of trans-eQTLs\t"+str(tot_trans)+\
                  "\nTotal number of genes\t"+str(tot_genes)+"\nTotal number of cM\t"+str(tot_cM)+\
                  "\nExpected number of eQTL per cM (all)\t"+str(round(tot_eqtls/float(tot_cM),2))+"\nExpected number of cis-eQTL per cM\t"+str(round(tot_cis/float(tot_cM),2))+\
                  "\nExpected number of trans-eQTL per cM\t"+str(round(tot_trans/float(tot_cM),2))+"\nExpected number of genes per cM\t"+\
                  str(round(tot_genes/float(tot_cM),2))+"\nUser specified number of permutations\t"+str(options.num_permutations)+"\nNumber of intervals per sliding window\t"+str(options.num_intervals)+"\n")
    # take out in summary table generated above:
    # +"\nExpected number of genes per cM per 10 genes\t"+str(round((tot_genes/float(tot_cM))/(tot_genes/float(10)),3))+
    summary.close()
    
    ############## NEW - add actual frequences per bin ##############
    gene_freq_per_bin = open(freq_per_bin_out_gene,'r')
    g_bin_d = {}
    for line in gene_freq_per_bin:
        l_bin = line.strip().split("\t")
        g_bin_d[l_bin[0]] = l_bin[1]
    gene_freq_per_bin.close()
    
    eqtl_freq_per_bin_all = open(freq_per_bin_out_eqtl_all,'r')
    e_bin_d_all = {}
    for line in eqtl_freq_per_bin_all:
        el_bin = line.strip().split("\t")
        e_bin_d_all[el_bin[0]] = el_bin[1]
    eqtl_freq_per_bin_all.close()
    
    eqtl_freq_per_bin_cis = open(freq_per_bin_out_eqtl_cis,'r')
    e_bin_d_cis = {}
    for line in eqtl_freq_per_bin_cis:
        el_bin = line.strip().split("\t")
        if el_bin[0] != "":
            e_bin_d_cis[el_bin[0]] = el_bin[1]
    eqtl_freq_per_bin_cis.close()
    
    eqtl_freq_per_bin_trans = open(freq_per_bin_out_eqtl_trans,'r')
    e_bin_d_trans = {}
    for line in eqtl_freq_per_bin_trans:
        el_bin = line.strip().split("\t")
        if el_bin[0] != "":
            e_bin_d_trans[el_bin[0]] = el_bin[1]
    eqtl_freq_per_bin_trans.close()
    
    lookup = open(options.input1,'r')  # Lookup.txt
    lookup_dict = {}
    for line in lookup:
        l = line.strip().split("\t")
        lookup_dict[l[0]] = l
    lookup.close()
    
    ############## NEW ##############
        
    eqtl_gene_freq = open(options.output1,'w')    # freq.txt
    eqtl_gene_freq.write("int.id\tchr\tmarker\tinterval\tcM\tbp\tlength_cM\tnum.eQTL.all\tnum.eQTL.cis\tnum.eQTL.trans\tnum.genes\n")
    for i in range(len(intlist)):
        #print i
        try:
            num_gene_per_bin = round(float(g_bin_d[intlist[i]]),2)
        except:
            num_gene_per_bin = 0
        try:
            num_eqtl_per_bin_all = round(float(e_bin_d_all[intlist[i]]),2)
        except:
            num_eqtl_per_bin_all = 0
        try:
            num_eqtl_per_bin_cis = round(float(e_bin_d_cis[intlist[i]]),2)
        except:
            num_eqtl_per_bin_cis = 0
        try:
            num_eqtl_per_bin_trans = round(float(e_bin_d_trans[intlist[i]]),2)
        except:
            num_eqtl_per_bin_trans = 0
        #num_eqtl = round(float(ed[intlist[i]]),2)
        #num_gene = round(float(gd[intlist[i]]),2)
        lookup = "\t".join(lookup_dict[intlist[i]])
        #eqtl_gene_freq.write(intlist[i]+"\t"+chrdict[intlist[i]]+"\t"+str(lookup_dict[intlist[i]][6])+"\t"+str(num_eqtl)+"\t"+str(num_gene)+"\t"+str(num_eqtl_per_bin)+"\t"+str(num_gene_per_bin)+"\n")
        eqtl_gene_freq.write(lookup+"\t"+str(num_eqtl_per_bin_all)+"\t"+str(num_eqtl_per_bin_cis)+"\t"+str(num_eqtl_per_bin_trans)+"\t"+str(num_gene_per_bin)+"\n")
    eqtl_gene_freq.close()
    
    ##############################################
    
if __name__=="__main__": 
    __main__()



