'''Takes tab-delimited SNP tables from user input and merges them into one.'''

import sys
files = []
filenames = []

try:
    output = open(sys.argv[1], "w")
    output.write('Position\tReference')
except:
    exit("No output file given or unable to open output file.")
for name in sys.argv[2:]:
    try:
        files.append(open(name, "rU"))
    except:
        continue

# Fetch headers and print them to output file;
headers = [File.readline()[:-1].split('\t')[1:] for File in files]
columns = [len(strains[1:]) for strains in headers]
output.write('\t'.join(['Position']+[headers[0][0]]+[a for b in headers for a in b[1:]]))
##headers = [header.readline()[:-1].split('\t')[2:] for header in files]
##columns = [len(strains) for strains in headers]
##for strain in [a for b in headers for a in b]:
##    output.write('\t'+strain)
##    output.flush()

file_active = [True]*len(files)
snps = [row.readline()[:-1].split('\t') for row in files]

while True in file_active:
    for h in range(0,len(snps)):
        if file_active[h]:
            cur_pos = [h]
            lowest = int(snps[h][0])
            break
    i = 1

    # Determine lowest position
    while i < len(snps):
        if int(snps[i][0]) < lowest and file_active[i]:
            lowest = int(snps[i][0])
            cur_pos = [i]
        elif int(snps[i][0]) == lowest:
            cur_pos.append(i)
        i+=1

    # Check if all SNPs sharing a position have the same reference base, exit with message otherwise;
    if len(cur_pos) > 1:
        ref_base = snps[cur_pos[0]][1].lower()
        for j in cur_pos[1:]:
            if snps[j][1].lower() != ref_base:
                error = '\nError: Reference bases not the same for position %s, present in following files:' % lowest
                for k in cur_pos:
                    error += ' '+filenames[k]
                exit(error+'.')

    # Write line to output file containing position, ref base and snps/empty cells;
    output.write('\n'+snps[cur_pos[0]][0]+'\t'+snps[cur_pos[0]][1].lower())
    for l,row in enumerate(snps):
        if l in cur_pos:
            for snp in row[2:]:
                output.write('\t'+snp)
        else:
            output.write('\t'*columns[l])

    # Read new line in files that had snp at current position;
    for m in cur_pos:
        line = files[m].readline()
        if line == '': file_active[m] = False
        else:
            snps[m] = line.split('\t')
            snps[m][-1] = snps[m][-1].rstrip()# Remove newline character at end of line;

for it in files: it.close()
output.close()
