
#!/usr/bin/env python

import sys
import string
import os

class convertitori:
	def __init__(self,input,inouttype,type,output):
		self.input = input
		self.inouttype = inouttype
		self.type = type
		self.output = output

	def fp(self):
		count = 0
		cseq = 0
		fasta = []
		for riga in self.input:
			count += 1
		        if ">" in riga:
		                f = ""
		                p = self.input.index(riga,count-1)
		                c = 1
                                y = riga[1:].replace(" ","_")
                                if y >= 10:
                                        f = f + y[:10] + '\t'
                                else:
                                        f = f + y + "_"*(10-len(y)) + '\t'
				try:
			                while ">" not in self.input[p+c]:
			                        f = f + (self.input[p+c].strip())
			                        c += 1
				except:
					pass
				fasta.append(f)
		num = str(len(fasta))
		lun = str(len(fasta[0].split("\t")[1]))
		for sequence in fasta:
			if str(len(sequence.split("\t")[1])) != lun:
				sys.exit("The input file does not contains a multiple alignment in fasta format. Please ensure that all the sequences have the same length")
		self.output.write(num + '\t' + lun + '\n')
		for seq in fasta:
		        self.output.write(seq + '\n')

	def fn(self):
		count = 0
		fasta = []
		for riga in self.input:
			count += 1
		        if ">" in riga:
		                f = ""
		                p = self.input.index(riga,count-1)
		                c = 1
				y = riga[1:].replace(" ","_")
				if y >= 10:
			                f = f + y[:10] + '\t'
				else:
					f = f + y + "_"*(10-len(y)) + '\t'
				try:
                                        while ">" not in self.input[p+c]:
                                                f = f + (self.input[p+c].strip())
                                                c += 1
				except:
					pass
				fasta.append(f)
		num = str(len(fasta))
		lun = str(len(fasta[0].split("\t")[1]))
		for sequence in fasta:
			if str(len(sequence.split("\t")[1])) != lun:
				sys.exit("The input file does not contains a multiple alignment in fasta format. Please ensure that all the sequences have the same length")
		self.output.write("#NEXUS\n\nBEGIN DATA;\nDIMENSIONS NTAX=%s NCHAR=%s;\nFORMAT DATATYPE=DNA INTERLEAVE MISSING=-;\n\nMATRIX\n"%(num,lun))
		porzioni = int(lun)/100
		for volte in range(porzioni):
			for seq in fasta:
				part = ""
				self.output.write(seq.split("\t")[0] + '\t')
				cont = 0
				for chara in seq.split("\t")[1][volte*100:(volte+1)*100]:
					cont += 1
					part = part + chara
					if cont%20.0 == 0:
						part = part + " "
				part = part + "\n"
				self.output.write(part)
			self.output.write("\n\n\n")
		for seq in fasta:
			part = ""
			cont = 0
			self.output.write(seq.split("\t")[0] + '\t')
			for chara in seq.split("\t")[1][(volte+1)*100:]:
				cont += 1
				part = part + chara
				if cont%20.0 == 0:
					part = part + " "
			part = part + "\n"
			self.output.write(part)

	def pn(self):
		num = int(self.input[0].split()[0])
		lun = float(self.input[0].split()[1])
		lunf = float(len(self.input))
		self.output.write("#NEXUS\n\nBEGIN DATA;\nDIMENSIONS NTAX=%s NCHAR=%s;\nFORMAT DATATYPE=DNA INTERLEAVE MISSING=-;\n\nMATRIX\n"%(int(num),lun))
		spia = 0
		porzioni = int(lun)/100
		if (lunf-1)/num == 1.0:
			spia = 1
		if spia == 1:
			for volte in range(porzioni):
				for seq in self.input[1:]:
	                                part = ""
	                                self.output.write(seq.split("\t")[0] + '\t')
	                                cont = 0
	                                for chara in seq.split("\t")[1][volte*100:(volte+1)*100]:
	                                        cont += 1
	                                        part = part + chara
	                                        if cont%20.0 == 0:
	                                                part = part + " "
	                                part = part + "\n"
	                                self.output.write(part)
	                        self.output.write("\n\n\n")
                	for seq in self.input[1:]:
	                        part = ""
	                        cont = 0
	                        self.output.write(seq.split("\t")[0] + '\t')
	                        for chara in seq.split("\t")[1][(volte+1)*100:]:
	                                cont += 1
	                                part = part + chara
	                                if cont%20.0 == 0:
	                                        part = part + " "
	                        part = part + "\n"
	                        self.output.write(part)
		else:
			if len(self.input[1])<=11:
				for volte in range(porzioni):
					interm = 0 
					for seq in self.input[1:]:
	                                        if seq == "\n":
	                                                interm += 1
						if (self.input.index(seq)+interm)%2 == 0 and seq != "\n":
			                                part = ""
			                                cont = 0
			                                for chara in seq[volte*100:(volte+1)*100]:
			                                        cont += 1
			                                        part = part + chara
			                                        if cont%20.0 == 0:
			                                                part = part + " "
			                                part = part + "\n"
			                                self.output.write(part)
	                                        elif (self.input.index(seq)+interm)%2 != 0 and seq != "\n":
        	                                        self.output.write(seq[:10] + "\t")
					self.output.write("\n\n\n")
				interm = 0
		                for seq in self.input[1:]:
	                        	if seq == "\n":
	                                        interm += 1
					if (self.input.index(seq)+interm)%2 == 0 and seq != "\n":
				                part = ""
				                cont = 0
				                for chara in seq[(volte+1)*100:]:
				 	                cont += 1
				                        part = part + chara
				                        if cont%20.0 == 0:
			        	                	part = part + " "
			                	part = part + "\n"
			                        self.output.write(part)
					elif (self.input.index(seq)+interm)%2 != 0 and seq != "\n":
        	                                self.output.write(seq[:10] + "\t")
			else:
				try:
					diz = {}
					volta = 0
					for riga in self.input[1:]:
						if self.input.index(riga) in range(num+1):
							numriga = self.input.index(riga)
							diz[self.input.index(riga)] = [self.input[numriga][:10],self.input[numriga][10:].strip().replace(" ","")]
						else:
							if riga == "\n":
								volta += 1
							else:
								numriga = self.input.index(riga)
								prima = diz[numriga - ((num+1)*volta)][1] + self.input[self.input.index(riga)].strip().replace(" ","")
								diz[numriga - ((num+1)*volta)] = [diz[numriga - ((num+1)*volta)][0],prima]
	                                for volte in range(porzioni):
	                                        for seq in diz.keys():
	                                        	cont = 0
							self.output.write(diz[seq][0] + "\t")
							for chara in diz[seq][1][volte*100:(volte+1)*100]:
		                                                self.output.write(chara)
	                                                        cont += 1
                                                                if cont%20.0 == 0:
                                                                        self.output.write(" ")
	                                                self.output.write("\n")
	                                        self.output.write("\n\n\n")
                                        for seq in diz.keys():
                                                cont = 0
                                                self.output.write(diz[seq][0] + "\t")
                                                for chara in diz[seq][1][(volte+1)*100:]:
	                                                self.output.write(chara)
                                                        cont += 1
                                                        if cont%20.0 == 0:
                                                                self.output.write(" ")
                                                self.output.write("\n")
				except:
					sys.exit("The input file is not in the proper format. Please check that your file is in Phylip standard interleaved (or sequential) format ")
					
	def pf(self):
		num = int(self.input[0].split()[0])
		lun = float(len(self.input))
		spia = 0
		if (lun-1)/num == 1.0:
			spia = 1
		if spia == 1:
			for riga in self.input[1:]:
				for ele in range(int(lun-1)):
					cont = 0
					self.output.write(">" + self.input[ele+1][:10] + "\n")
					for char in self.input[ele+1][10:].strip().replace(" ",""):
						self.output.write(char)
						cont += 1
						if cont%80.0 == 0:
							self.output.write('\n')
					self.output.write('\n')
		else:
			if len(self.input[1])<=11:
				interm = 0
				for riga in self.input[1:]:
					if riga == "\n":
						interm += 1
					if (self.input.index(riga)+interm)%2 == 0 and riga != "\n":
						cont = 0
						for char in riga.strip().replace(" ",""):
	                                	        self.output.write(char)
	                                                cont += 1
	                                                if cont%80.0 == 0:
	                                        	        self.output.write('\n')
	                                        self.output.write('\n')
					elif (self.input.index(riga)+interm)%2 != 0 and riga != "\n":
	                                        self.output.write(">" + riga[:10] + "\n")
			else:
				try:
					diz = {}
					volta = 0
					for riga in self.input[1:]:
						if self.input.index(riga) in range(num+1):
							numriga = self.input.index(riga)
							diz[self.input.index(riga)] = [self.input[numriga][:10],self.input[numriga][10:].strip().replace(" ","")]
						else:
							if riga == "\n":
								volta += 1
							else:
								numriga = self.input.index(riga)
								prima = diz[numriga - ((num+1)*volta)][1] + self.input[self.input.index(riga)].strip().replace(" ","")
								diz[numriga - ((num+1)*volta)] = [diz[numriga - ((num+1)*volta)][0],prima]
					for elemento in diz.keys():
						self.output.write(">" + diz[elemento][0] + '\n')
						con = 0
						for char in diz[elemento][1]:
							self.output.write(char)
							con += 1
							if con%80 == 0:
								self.output.write('\n')
						self.output.write('\n')
				except:
					sys.exit("The input file is not in the proper format. Please check that your file is in Phylip standard interleaved (or sequential) format ")

	def nf(self):
		try:
			diz = {}
			spia = 0
			for riga in self.input:
				if "MATRIX" in riga:
					spia = 1
				if spia == 1 and "MATRIX" not in riga and riga != "\n":
					if riga.split()[0] not in diz.keys():
						diz[riga.split()[0]] = ""
					else:
						for ele in riga.split()[1:]:
							diz[riga.split()[0]] = diz[riga.split()[0]] + ele.strip()
			for elemento in diz.keys():
				self.output.write(">" + elemento + '\n')
				con = 0
				for char in diz[elemento]:
					self.output.write(char)
					con += 1
					if con%80 == 0:
						self.output.write('\n')
				self.output.write('\n')
		except:
			sys.exit("The input file is not in Nexus format. ")

	def np(self):
		try:
			diz = {}
			spia = 0
			for riga in self.input:
				if "MATRIX" in riga:
					spia = 1
				if spia == 1 and "MATRIX" not in riga and riga != "\n":
					if riga.split()[0] not in diz.keys():
						diz[riga.split()[0]] = ""
					else:
						for ele in riga.split()[1:]:
							diz[riga.split()[0]] = diz[riga.split()[0]] + ele.strip()
	                num = str(len(diz.keys()))
	                lun = str(len(diz.values()[0]))
	                self.output.write(num + '\t' + lun + '\n')
			for elemento in diz.keys():
	                        if elemento >= 10:
	                        	nome = elemento[:10] + '\t'
	                        else:
	                                nome = elemento + "_"*(10-len(elemento)) + '\t'
	                        self.output.write(nome + diz[elemento] + '\n')
		except:
			sys.exit("The input file is not in Nexus format. ")

					
	def fg(self):
		count = 0
		fasta = []
		for riga in self.input:
			count += 1
		        if ">" in riga:
		                f = ""
		                p = self.input.index(riga,count-1)
		                c = 1
				y = riga[1:-1].replace(" ","_")
			        f = f + y + '\t'
				try:
                                        while ">" not in self.input[p+c]:
                                                f = f + (self.input[p+c].strip())
                                                c += 1
				except:
					pass
				fasta.append(f)
		for seq in fasta:
			lun = str(len(seq.split("\t")[1]))
			self.output.write("LOCUS\t%s\t%s bp\nORIGIN\n"%(seq.split("\t")[0],lun))
			porzioni = int(lun)/60
			cont = 0
			for volte in range(porzioni):
				part = ""
				self.output.write(str(cont+1) + "\t")
				for chara in seq.split("\t")[1][volte*60:(volte+1)*60]:
					cont += 1
					part = part + chara
					if cont%10.0 == 0:
						part = part + " "
				self.output.write(part)
                        	self.output.write("\n")
			self.output.write(str(cont+1) + "\t")
			part = ""
			for chara in seq.split("\t")[1][(volte+1)*60:]:
				cont += 1
				part = part + chara
				if cont%10.0 == 0:
					part = part + " "
			self.output.write(part)
			self.output.write("\n")
			self.output.write("//\n\n")

	def gf(self):
		for riga in self.input:
			if "LOCUS" in riga:
				nome = ""
				spia = 0
				len = ""
				seq = ""
				part = riga.split()
				for ele in part:
					if "bp" in ele:
						len = str(riga.index(ele)-1)
				nome = part[1] + '\t'
			if "DEFINITION" in riga:
				part = riga.split()
				for ele in part[1:]:
					nome = nome + ele + ' ' 
			if "ORIGIN" in riga:
				spia = 1
			if spia == 1 and "ORIGIN" not in riga:
				part = riga.split()
				for ele in part[1:]:
					seq = seq + ele.strip()
			if "//" in riga:
				self.output.write(">" + nome + '\n')
				con = 0
				for char in seq:
					self.output.write(char)
					con += 1
					if con%80 == 0:
						self.output.write('\n')
				self.output.write('\n')
				spia = 0

class check_fileformat:
        def __init__(self,inouttype,input):
                self.intype = inouttype[0]
		self.infile = input
	def single(self):
		if self.intype == "f":
			count = 0
			for riga in self.infile:
				if riga[0] == ">":
					count += 1
			if count == 1:
				if len(self.infile) < 2:
					sys.exit("The input file is not in fasta format. Please check that the first row starts with > and that the sequence starts from the second line")
				else:
					return "ok"
			else:
				if count >1:
					sys.exit("The input file is a multi-fasta file. Please resubmit the job using the 'multi sequence' option")
				if count == 0:
					sys.exit("The input file is not in fasta format. Please check that the first row starts with > and that the sequence starts from the second line")
		if self.intype == "g":
			locus = 0
			origin = 0
			end = 0
			lun = 1
			for riga in self.infile:
				if "LOCUS" in riga:
					locus = 1
				if "ORIGIN" in riga:
					origin = 1
				elif origin == 1 and len(riga.split()) >= 7:
					lun = 0
				if "//" in riga:
					end = 1
			if locus == 0 or origin == 0 or end == 0 or lun == 1:
				sys.exit("The input file is not in GenBank format. Please make sure that the file contains at least the LOCUS and ORIGIN fields. The file must also ends with //")
			else:
				return "ok" 
	def multi(self):
		if self.intype == "p":
			if len(self.infile[0].split()) == 2 or len(self.infile[0].split()) == 3:
				if int(self.infile[0].split()[0]) > 1:
					return "ok"
				else:
					sys.exit("There is only one sequence in the file")
			else:
				sys.exit("the input file is not in phylip format.")
		if self.intype == "n":
			begin = 0
			matrix = 0
			ntax = 0
			if "#NEXUS" in self.infile[0]:
				for riga in self.infile:
					if "begin data;" in riga.lower():
						begin = 1
					if "matrix" in riga.lower():
						matrix = 1
					if "ntax" in riga.lower():
						r = riga.split()
						ntax = int(r[1][5:])
				if begin==1 and matrix == 1:
					return "ok"
				else:
					sys.exit("the input file is not in nexus format.")
				if ntax <= 1:
					 sys.exit("There is only one sequence in the file")				
			else:
				sys.exit("the input file is not in nexus format.")
		if self.intype == "f":
			count = 0
			for riga in self.infile:
				if riga[0] == ">":
					count += 1
			if count > 1:
				if len(self.infile) < 4:
					sys.exit("The input file is not in fasta format. Please check that the first row starts with > and that the sequence starts from the second line")
				else:
					return "ok"
			else:
				if count == 1:
					sys.exit("The input file is a single-fasta file. Please resubmit the job using the 'single sequence' option")
				if count == 0:
					sys.exit("The input file is not in fasta format. Please check that the first row starts with > and that the sequence starts from the second line")
		if self.intype == "g":
			locus = 0
			origin = 0
			end = 0
			lun = 1
			for riga in self.infile:
				if "LOCUS" in riga:
					locus = 1
				if "ORIGIN" in riga:
					origin = 1
				if origin == 1 and len(riga.split()) >= 7:
					lun = 0
				if "//" in riga:
					end = 1
			if locus == 0 or origin == 0 or end == 0 or lun == 1:
				sys.exit("The input file is not in GenBank format. Please make sure that the file contains at least the LOCUS and ORIGIN fields. The file must also ends with //")
			else:
				return "ok"


def main(input,output,inouttype,type):
	check = check_fileformat(inouttype,input)
	if type == "single":
		c = check.single()
		if c == "ok":
			conv = convertitori(input,inouttype,type,output)
	                if inouttype == "f-g":
	                        conv.fg()
	                if inouttype == "g-f":
	                        conv.gf()
	if type == "multi":
		c = check.multi()
		if c == "ok":
			conv = convertitori(input,inouttype,type,output)
			if inouttype == "f-g":
				conv.fg()
			if inouttype == "g-f":
				conv.gf()
			if inouttype == "f-p":
				conv.fp()
			if inouttype == "f-n":
				conv.fn()
			if inouttype == "p-f":
				conv.pf()
			if inouttype == "p-n":
				conv.pn()
			if inouttype == "n-p":
				conv.np()
			if inouttype == "n-f":
				conv.nf()
	output.close()

if __name__ == "__main__" : 
	input = open(sys.argv[1],"r").readlines()
	output = open(sys.argv[2],"a")
	inouttype = sys.argv[3]
        type = sys.argv[4]
	main(input,output,inouttype,type)
