# -*- coding: utf-8 -*-
import sys

# check dependencies
try:
    import pysam
except ImportError as e:
    sys.exit('Could not import the "pysam" module\n\nImporting failed with'
             '{}\n\n'.format(e))

import os
# import timeit
import shutil
import zipfile
import logging
import argparse
from datetime import datetime

import ribocore
import config

# Default is production
CONFIG = config.ProductionConfig()


def create_parser():
    """Argument parser. """
    parser = argparse.ArgumentParser(
        prog='ribocount.py', description='Output read counts for all transcripts')

    # required arguments
    required = parser.add_argument_group('required arguments')
    required.add_argument('-b', '--ribo_file', help='Ribo-Seq alignment file in BAM format', required=True)
    required.add_argument('-f', '--transcriptome_fasta', help='FASTA format file of the transcriptome', required=True)

    # optional arguments
    parser.add_argument('-l', '--read_length', help='Read length to consider (default: %(default)s)',
                        metavar='INTEGER', type=int)
    parser.add_argument('-s', '--read_offset', help='Read offset (default: %(default)s)',
                        metavar='INTEGER', type=int, default=0)
    parser.add_argument('-m', '--html_file', help='Output file for results (HTML)', default='output/index.html')
    parser.add_argument('-o', '--output_path', help='Files are saved in this directory', default='output')
    parser.add_argument('-d', '--debug', help='Flag. Produce debug output', action='store_true')

    return parser


if __name__ == '__main__':
    parsed = create_parser()
    args = parsed.parse_args()

    if args.debug:
        level = logging.DEBUG
    else:
        level = logging.INFO

    logging.basicConfig(format='%(levelname)s: %(message)s',
                        level=level, stream=sys.stdout, datefmt='%d/%m/%Y %I:%M:%S %p')

    start = datetime.now()
    logging.debug('Supplied arguments\n{}'.format('\n'.join(['{:<20}: {}'.format(k, v) for k, v in vars(args).items()])))

    (ribo_file, transcriptome_fasta, read_length, read_offset, output_path) = (
        args.ribo_file, args.transcriptome_fasta, args.read_length, args.read_offset, args.output_path)

    logging.info('Checking if required arguments are valid...')
    ribocore.check_required_arguments(ribo_file=ribo_file, transcriptome_fasta=transcriptome_fasta)
    logging.info('Done')

    logging.info('Checking if optional arguments are valid...')
    ribocore.check_optional_arguments(ribo_file=ribo_file, read_length=read_length, read_offset=read_offset)
    logging.info('Done')

    if not os.path.exists(output_path):
        os.mkdir(output_path)

    # zip_dir contents will be written here and a zip archive will be created
    # from this directory
    zip_dir = os.path.join(output_path, 'ribocount_output')
    if not os.path.exists(zip_dir):
        os.mkdir(zip_dir)

    csv_dir = os.path.join(zip_dir, 'csv')
    if not os.path.exists(csv_dir):
        os.mkdir(csv_dir)

    logging.info('Get RiboSeq read counts for all transcripts in FASTA')
    count = 0
    table_content = ''
    bam_fileobj = pysam.AlignmentFile(ribo_file, 'rb')
    fasta_file = pysam.FastaFile(transcriptome_fasta)
    for transcript in fasta_file.references:
        rp_counts, rp_reads = ribocore.get_ribo_counts(bam_fileobj, transcript, read_length)
        if not rp_reads:  # no reads for this transcript. skip.
            continue

        logging.debug('Writing counts for {}'.format(transcript))
        count += 1
        csv_file = 'RiboCounts{}.csv'.format(count)
        with open(os.path.join(csv_dir, csv_file), 'w') as f:
            f.write('"Position","Frame 1","Frame 2","Frame 3"\n')

            for pos in range(1, len(fasta_file[transcript]) + 1):
                if pos in rp_counts:
                    f.write('{0},{1},{2},{3}\n'.format(
                        pos, rp_counts[pos][1], rp_counts[pos][2], rp_counts[pos][3]))
                else:
                    f.write('{0},{1},{2},{3}\n'.format(pos, 0, 0, 0))
        if count % 2 == 0:
            table_content += '<tr>'
        else:
            table_content += '<tr class="odd">'
        table_content += '<td>{0}</td><td>{1}</td><td>{2}</td><td><a href="csv/{3}">{3}</a></td></tr>'.format(
            count, transcript, rp_reads, csv_file)
    fasta_file.close()
    bam_fileobj.close()

    # only for display in HTML
    if not read_length:
        read_length = 'All'

    duration = str(datetime.now() - start).split('.')[0]
    logging.info('Done')

    if not count:
        if read_length:
            logging.info('No transcripts found for read length {}'.format(read_length))
        else:
            logging.info('No transcripts found')
    else:
        logging.info('Time taken for generating counts for {0} transcripts: {1}, footprint '
                     'length: {2}'.format(count, duration, read_length))

        with open(os.path.join(CONFIG.DATA_DIR, 'ribocount.html')) as g,\
                open(os.path.join(zip_dir, 'index.html'), 'w') as h:
            h.write(g.read().format(count=count, length=read_length, table_content=table_content, duration=duration))

        for asset in ('css', 'js'):
            asset_dir = os.path.join(zip_dir, asset)
            if not os.path.exists(asset_dir):
                os.mkdir(asset_dir)
            asset_data_dir = os.path.join(CONFIG.DATA_DIR, asset)
            for fname in os.listdir(asset_data_dir):
                shutil.copy(os.path.join(asset_data_dir, fname),
                            os.path.join(zip_dir, asset, fname))

        logging.info('Creating zip file')
        os.chdir(output_path)
        with zipfile.ZipFile('ribocount_output.zip', 'w') as zipf:
            for root, d, f in os.walk('ribocount_output'):
                for name in f:
                    zipf.write(os.path.join(root, name))
        shutil.rmtree('ribocount_output')
        os.chdir('../')
        logging.debug('Writing HTML report')

        with open(os.path.join(CONFIG.DATA_DIR, 'ribocount_index.html')) as j, open(args.html_file, 'w') as k:
            k.write(j.read().format(count=count, read_length=read_length))
    logging.info('Finished')
