#!/home/gianmarco/galaxy-python/python

import Bio
from Bio import SeqIO
from Bio.Data import CodonTable
import re
import sys
import os
import pandas as pd

def read_input(data = "example.fna"):

    seqs = ""
    with open(data, "rU") as handle:
        for record in SeqIO.parse(handle, "fasta"):
            seqs = seqs + str(record.seq)

    return seqs

def codon_usage(seqs, codonTable):

    codon_usage = {}
    tmp = [x for x in re.split(r'(\w{3})', seqs) if x != ""]

    b_cod_table = CodonTable.unambiguous_dna_by_name[codonTable].forward_table


    for cod in CodonTable.unambiguous_dna_by_name[codonTable].stop_codons:
        b_cod_table[cod] = "_Stop"

    for cod in CodonTable.unambiguous_dna_by_name[codonTable].start_codons:
            b_cod_table[cod + " Start"] = b_cod_table[cod]
            b_cod_table.pop(cod)

    aas = set(b_cod_table.values())


    for aa in aas:
        codon_usage[aa] = {}
        for codon in b_cod_table.keys():
            if b_cod_table[codon] == aa:
                codon_usage[aa][codon] = tmp.count(codon.split(" ")[0])


    tups = {(outerKey, innerKey): values for outerKey, innerDict in codon_usage.iteritems() for innerKey, values in innerDict.iteritems()}

    #aas_ = set(tups.keys())

    #stops_ = {el for el in aas_ if el[0] == "Stop"}
    #aas_ = list(aas_.difference(stops_))
    #stops_ = list(stops_)
    #aas_.sort()
    #stops_.sort()

    codon_usage_ = pd.DataFrame(pd.Series(tups), columns = ["Count"])
    codon_usage_.index = codon_usage_.index.set_names(["AA", "Codon"])
    #codon_usage_.index.reindex(pd.MultiIndex.from_tuples([aas_, stops_], names=('AA', 'Codon')), level=[0,1])


    codon_usage_['Proportion'] = codon_usage_.groupby(level=0).transform(lambda x: (x / x.sum()).round(2))

    return {"Dictionary": codon_usage, "Tuples": tups, "Table": codon_usage_}



if __name__ == '__main__':

    
    seqs = read_input(data=sys.argv[1])
    out = codon_usage(seqs,"Bacterial")


    with open(sys.argv[2], "w") as outf:
        out["Table"].to_csv(outf, sep="\t")
    #sys.stdout.write(out['Table'])