# check dependencies
import sys

missing_deps = []
try:
    import pysam
except ImportError:
    missing_deps.append('pysam')

try:
    from matplotlib import pyplot as plt
except ImportError:
    missing_deps.append('matplotlib')
else:
    from matplotlib import gridspec
    from matplotlib.font_manager import FontProperties

if len(missing_deps):
    sys.exit('''\nCould not import the following modules. Please check if they
    (or their dependencies are installed correctly.
    \n\n{}\n'''.format('\n'.join(missing_deps)))


import os
import re
import config
import pysam
# import timeit
import shutil
import logging
import tempfile
import argparse
import subprocess

import core
# Default is production
CONFIG = config.ProductionConfig()


def get_start_stops(transcript_sequence, start_codons=None, stop_codons=None):
        """Return start and stop positions for all frames in the given
        transcript.

        """
        if not start_codons:
            start_codons = ['ATG']
        if not stop_codons:
            stop_codons = ['TAA', 'TAG', 'TGA']

        seq_frames = {1: {'starts': [], 'stops': []},
                      2: {'starts': [], 'stops': []},
                      3: {'starts': [], 'stops': []}}

        for codons, positions in ((start_codons, 'starts'),
                                  (stop_codons, 'stops')):
            if len(codons) > 1:
                pat = re.compile('|'.join(codons))
            else:
                pat = re.compile(codons[0])

            for m in re.finditer(pat, transcript_sequence):
                # Increment position by 1, Frame 1 starts at position 1 not 0
                start = m.start() + 1
                rem = start % 3
                if rem == 1:  # frame 1
                    seq_frames[1][positions].append(start)
                elif rem == 2:  # frame 2
                    seq_frames[2][positions].append(start)
                elif rem == 0:  # frame 3
                    seq_frames[3][positions].append(start)
        return seq_frames


def get_rna_counts(rna_file, transcript_name):
    """Get coverage for a given RNA BAM file, return read counts. """
    try:
        subprocess.check_output(['bedtools', '--version'])
    except OSError:
        logging.error('Could not find bedtools in PATH. bedtools is '
                      'required for generating RNA coverage plot. ')
        raise
    logging.debug('Get RNA coverage for transcript using bedtools')
    # check if the RNA file exists
    if not os.path.exists(rna_file):
        msg = 'RNA-Seq BAM file "{}" does not exist'.format(rna_file)
        logging.error(msg)
        raise OSError(msg)
    rna_counts = {}

    cov_file = tempfile.NamedTemporaryFile(delete=False)
    try:
        subprocess.check_call(
            ['bedtools', 'genomecov', '-ibam', rna_file,
             '-bg'], stdout=cov_file)
    except subprocess.CalledProcessError as e:
        # needs testing
        raise core.RNACountsError('Could not generate coverage for RNA BAM file. \n{}\n'.format(e))
    logging.debug('Finished generating RNA coverage')
    logging.debug('Processing coverage file')
    for line in open(cov_file.name):
        line = line.split()
        if line[0] == transcript_name:
            position, count = int(line[1]) + 1, int(line[3])
            rna_counts[position] = count
    cov_file.close()
    os.unlink(cov_file.name)
    logging.debug('Finished processing coverage file')
    return rna_counts


def plot_profile(ribo_counts, transcript_name, transcript_length,
                 start_stops, read_length=None, read_offset=None, rna_counts=None,
                 html_file='index.html', output_path='output'):
    """Plot read counts (in all 3 frames) and RNA coverage if provided for a
    single transcript.

    """
    gs = gridspec.GridSpec(4, 1, height_ratios=[7, 1, 1, 1])
    font_xsmall = {'family': 'sans-serif', 'color': '#555555',
                   'weight': 'normal', 'size': 'x-small'}

    ax1 = plt.subplot(gs[0])

    ax_rna = None
    if rna_counts:
        ax_rna = ax1.twinx()
        ax_rna.set_ylabel('RNA-Seq count', fontdict=font_xsmall, labelpad=10)
        leg1 = ax_rna.bar(rna_counts.keys(), rna_counts.values(), edgecolor='#e6e6e6', label='RNA')
        ax_rna.set_zorder(1)

    plt.title('Transcript - {}'.format(transcript_name),
              fontdict={'family': 'sans-serif', 'color': '#333333',
                        'weight': 'normal', 'size': 'x-small'}, y=1.06)
    if read_length:
        ax1.set_ylabel('Ribo-Seq count ({}-mer)'.format(read_length),
                       fontdict=font_xsmall, labelpad=10)
    else:
        ax1.set_ylabel('Ribo-Seq count', fontdict=font_xsmall, labelpad=10)

    frame_counts = {1: {}, 2: {}, 3: {}}
    for k, v in ribo_counts.iteritems():
        for fr in (1, 2, 3):
            if v[fr] > 0:
                frame_counts[fr][k] = v[fr]
                break

    cnts = []
    [cnts.extend(item.values()) for item in frame_counts.values()]
    y_max = int(round(max(cnts) * 1.2))
    ax1.set_ylim(0.0, y_max)
    ax1.set_zorder(2)
    ax1.patch.set_facecolor('none')

    for frame, color in ((1, 'tomato'), (2, 'limegreen'), (3, 'deepskyblue')):
        if read_offset:
            x_vals = [pos + read_offset for pos in frame_counts[frame].keys()]
        else:
            x_vals = frame_counts[frame].keys()
        ax1.bar(x_vals, frame_counts[frame].values(), edgecolor=color)

    ax2 = plt.subplot(gs[1], sharex=ax1, axisbg='#c6c6c6')
    ax3 = plt.subplot(gs[2], sharex=ax1, axisbg='#c6c6c6')
    ax4 = plt.subplot(gs[3], sharex=ax1, axisbg='#c6c6c6')

    axes = [ax1]
    if ax_rna:
        axes.append(ax_rna)

    fp = FontProperties(size='5')
    for axis in axes:
        axis.spines['top'].set_color('#f7f7f7')
        axis.spines['right'].set_color('#f7f7f7')
        axis.spines['bottom'].set_color('#f7f7f7')
        axis.spines['left'].set_color('#f7f7f7')
        for item in (axis.get_xticklabels() + axis.get_yticklabels()):
            item.set_fontproperties(fp)
            item.set_color('#555555')

    for axis, frame in ((ax2, 1), (ax3, 2), (ax4, 3)):
        axis.spines['top'].set_color('#C6C6C6')
        axis.spines['right'].set_color('#C6C6C6')
        axis.spines['bottom'].set_color('#C6C6C6')
        axis.spines['left'].set_color('#C6C6C6')
        for item in (axis.get_xticklabels()):
            item.set_fontproperties(fp)
            item.set_color('#555555')
        axis.set_ylim(0, 0.2)
        axis.set_xlim(0, transcript_length)
        starts = [(item, 1) for item in start_stops[frame]['starts']]
        stops = [(item, 1) for item in start_stops[frame]['stops']]
        start_colors = ['#ffffff' for item in starts]
        axis.broken_barh(starts, (0.11, 0.2),
                         facecolors=start_colors, edgecolors=start_colors, label='start', zorder=5)
        stop_colors = ['#666666' for item in stops]
        axis.broken_barh(stops, (0, 0.2), facecolors=stop_colors,
                         edgecolors=stop_colors, label='stop', zorder=5)
        axis.set_ylabel('{}'.format(frame),
                        fontdict={'family': 'sans-serif', 'color': '#555555',
                                  'weight': 'normal', 'size': '6'},
                        rotation='horizontal', labelpad=10, verticalalignment='center')
        axis.tick_params(top=False, left=False, right=False, labeltop=False,
                         labelleft=False, labelright=False, direction='out')
    plt.xlabel('Transcript length ({} nt)'.format(transcript_length),
               fontdict=font_xsmall, labelpad=10)

    if not os.path.exists(output_path):
        os.mkdir(output_path)
    plt.savefig(os.path.join(output_path, 'ribograph.svg'))
    plt.savefig(os.path.join(output_path, 'ribograph.png'), dpi=300)

    with open(os.path.join(CONFIG.DATA_DIR, 'riboplot.html')) as g, open(os.path.join(output_path, html_file), 'w') as h:
        h.write(g.read().format(transcript_name=transcript_name))

    css_dir = os.path.join(output_path, 'css')
    if not os.path.exists(css_dir):
        os.mkdir(css_dir)

    css_data_dir = os.path.join(CONFIG.DATA_DIR, 'css')
    for fname in os.listdir(css_data_dir):
        shutil.copy(os.path.join(css_data_dir, fname), os.path.join(output_path, 'css', fname))


def create_parser():
    """Argument parser. """
    parser = argparse.ArgumentParser(
        prog='ribograph.plot', description='Plot and output read counts for a single transcript')

    required = parser.add_argument_group('required arguments')
    required.add_argument('--ribo_file', help='Ribo-Seq alignment file in BAM format', required=True)
    required.add_argument('--transcriptome_fasta', help='FASTA format file of the transcriptome', required=True)
    required.add_argument('--transcript_name', help='Transcript name', metavar='TEXT', required=True)

    # plot function - optional arguments
    parser.add_argument('--rna_file', help='RNA-Seq alignment file (BAM)')
    parser.add_argument('--read_length', help='Read length to consider (default: %(default)s)',
                        metavar='INTEGER', type=int)
    parser.add_argument('--read_offset', help='Read offset (default: %(default)s)',
                        metavar='INTEGER', type=int, default=0)
    parser.add_argument('--html_file', help='Output file for results (HTML)', default='index.html')
    parser.add_argument('--output_path', help='Files are saved in this directory', default='output')
    parser.add_argument('--debug', help='Flag. Produce debug output', action='store_true')

    return parser


def main():
    """Start program. """
    parsed = create_parser()
    args = parsed.parse_args()

    if args.debug:
        level = logging.DEBUG
    else:
        level = logging.INFO

    logging.basicConfig(format='%(asctime)s: %(levelname)s %(message)s',
                        level=level, stream=sys.stdout, datefmt='%d/%m/%Y %I:%M:%S %p')

    logging.debug('Start')
    logging.debug('Supplied Arguments')
    logging.debug('\n{}'.format('\n'.join(['{:<20}: {}'.format(k, v) for k, v in vars(args).items()])))

    (ribo_file, rna_file, transcript_name, transcriptome_fasta, read_length,
     read_offset, output_path) = (args.ribo_file, args.rna_file, args.transcript_name,
                     args.transcriptome_fasta, args.read_length, args.read_offset, args.output_path)

    logging.debug('Checking if BAM file is indexed...')
    core.check_bam_file(ribo_file)

    logging.debug('Get transcript information...')
    transcripts = pysam.FastaFile(transcriptome_fasta)
    sequence = transcripts.fetch(transcript_name)
    length = len(sequence)

    if not length:
        logging.error('Transcript "{}" does not exist in transcriptome '
                      'FASTA file "{}"'.format(transcript_name, os.path.basename(transcriptome_fasta)))
        sys.exit(1)

    logging.debug('Get start/stop positions in transcript (3 frames)...')
    codon_positions = get_start_stops(sequence)

    mrna_counts = {}
    if rna_file:
        try:
            mrna_counts = get_rna_counts(rna_file, transcript_name)
        except OSError as e:
            sys.exit(e)
    else:
        logging.debug('No RNA-Seq data provided. Not generating coverage')

    logging.debug('Get ribo-seq read counts and total reads in Ribo-Seq...')
    bam_fileobj = pysam.AlignmentFile(ribo_file, 'rb')
    ribo_counts, total_reads = core.get_ribo_counts(bam_fileobj, transcript_name, read_length, read_offset)
    bam_fileobj.close()

    if not os.path.exists(output_path):
        os.mkdir(output_path)

    logging.debug('Writing counts for {}'.format(transcript_name))
    with open(os.path.join(output_path, 'RiboCounts.csv'), 'w') as f:
        f.write('"Position","Frame 1","Frame 2","Frame 3"\n')

        for pos in range(1, length + 1):
            if pos in ribo_counts:
                f.write('{0},{1},{2},{3}\n'.format(
                    pos, ribo_counts[pos][1], ribo_counts[pos][2], ribo_counts[pos][3]))
            else:
                f.write('{0},{1},{2},{3}\n'.format(pos, 0, 0, 0))

    logging.debug('Generating plot...')
    plot_profile(ribo_counts, transcript_name, length,
                 codon_positions, read_length, read_offset, mrna_counts,
                 html_file=os.path.join(args.output_path, args.html_file),
                 output_path=args.output_path)
    logging.debug('Finished')


if __name__ == '__main__':
    # print timeit.timeit('main()', number=1, setup='from __main__ import main')
    main()
