view mismatch_frequencies.py @ 2:2974c382105c draft default tip

planemo upload for repository https://github.com/ARTbio/tools-artbio/tree/master/tools/mismatch_frequencies commit 10a7e3877c2568d9c23de53fc97dc1c902ff0524-dirty
author mvdbeek
date Sat, 22 Dec 2018 04:15:47 -0500
parents 77de5fc623f9
children
line wrap: on
line source

import re
import string
import pysam
import matplotlib
import pandas as pd
from collections import defaultdict
from collections import OrderedDict
import argparse
import itertools

matplotlib.use('pdf')
import matplotlib.pyplot as plt  # noqa: E402


class MismatchFrequencies:
    '''Iterate over a SAM/BAM alignment file, collecting reads with mismatches. One
    class instance per alignment file. The result_dict attribute will contain a
    nested dictionary with name, readlength and mismatch count.'''
    def __init__(self, result_dict={}, alignment_file=None, name="name", minimal_readlength=21,
                 maximal_readlength=21,
                 number_of_allowed_mismatches=1,
                 ignore_5p_nucleotides=0,
                 ignore_3p_nucleotides=0,
                 possible_mismatches=[
                        'AC', 'AG', 'AT',
                        'CA', 'CG', 'CT',
                        'GA', 'GC', 'GT',
                        'TA', 'TC', 'TG'
                ]):

        self.result_dict = result_dict
        self.name = name
        self.minimal_readlength = minimal_readlength
        self.maximal_readlength = maximal_readlength
        self.number_of_allowed_mismatches = number_of_allowed_mismatches
        self.ignore_5p_nucleotides = ignore_5p_nucleotides
        self.ignore_3p_nucleotides = ignore_3p_nucleotides
        self.possible_mismatches = possible_mismatches

        if alignment_file:
            self.pysam_alignment = pysam.Samfile(alignment_file)
            self.references = self.pysam_alignment.references  # names of fasta reference sequences
            result_dict[name] = self.get_mismatches(
                self.pysam_alignment,
                minimal_readlength,
                maximal_readlength,
                possible_mismatches
            )

    def get_mismatches(self, pysam_alignment, minimal_readlength,
                       maximal_readlength, possible_mismatches):
        rec_dd = lambda: defaultdict(rec_dd)
        len_dict = rec_dd()
        for alignedread in pysam_alignment:
            if self.read_is_valid(alignedread, minimal_readlength, maximal_readlength):
                chromosome = pysam_alignment.getrname(alignedread.rname)
                try:
                    len_dict[int(alignedread.rlen)][chromosome]['total valid reads'] += 1
                except TypeError:
                    len_dict[int(alignedread.rlen)][chromosome]['total valid reads'] = 1
                MD = alignedread.opt('MD')
                if self.read_has_mismatch(alignedread, self.number_of_allowed_mismatches):
                    (ref_base, mismatch_base) = self.read_to_reference_mismatch(MD, alignedread.seq, alignedread.is_reverse)
                    if not ref_base:
                            continue
                    else:
                        for i, base in enumerate(ref_base):
                            if not ref_base[i]+mismatch_base[i] in possible_mismatches:
                                continue
                            try:
                                len_dict[int(alignedread.rlen)][chromosome][ref_base[i]+mismatch_base[i]] += 1
                            except TypeError:
                                len_dict[int(alignedread.rlen)][chromosome][ref_base[i]+mismatch_base[i]] = 1
        return len_dict

    def read_is_valid(self, read, min_readlength, max_readlength):
        '''Filter out reads that are unmatched, too short or
        too long or that contian insertions'''
        if read.is_unmapped:
            return False
        if read.rlen < min_readlength:
            return False
        if read.rlen > max_readlength:
            return False
        else:
            return True

    def read_has_mismatch(self, read, number_of_allowed_mismatches=1):
        '''keep only reads with one mismatch. Could be simplified'''
        NM = read.opt('NM')
        if NM < 1:  # filter out reads with no mismatch
            return False
        if NM > number_of_allowed_mismatches:  # filter out reads with more than 1 mismtach
            return False
        else:
            return True

    def mismatch_in_allowed_region(self, readseq, mismatch_position):
        '''
        >>> M = MismatchFrequencies()
        >>> readseq = 'AAAAAA'
        >>> mismatch_position = 2
        >>> M.mismatch_in_allowed_region(readseq, mismatch_position)
        True
        >>> M = MismatchFrequencies(ignore_3p_nucleotides=2, ignore_5p_nucleotides=2)
        >>> readseq = 'AAAAAA'
        >>> mismatch_position = 1
        >>> M.mismatch_in_allowed_region(readseq, mismatch_position)
        False
        >>> readseq = 'AAAAAA'
        >>> mismatch_position = 4
        >>> M.mismatch_in_allowed_region(readseq, mismatch_position)
        False
        '''
        mismatch_position += 1  # To compensate for starting the count at 0
        five_p = self.ignore_5p_nucleotides
        three_p = self.ignore_3p_nucleotides
        if any([five_p > 0, three_p > 0]):
            if any([mismatch_position <= five_p,
                    mismatch_position >= (len(readseq) + 1 - three_p)]):  # Again compensate for starting the count at 0
                return False
            else:
                return True
        else:
            return True

    def read_to_reference_mismatch(self, MD, readseq, is_reverse):
        '''
        This is where the magic happens. The MD tag contains SNP and indel information,
        without looking to the genome sequence. This is a typical MD tag: 3C0G2A6.
        3 bases of the read align to the reference, followed by a mismatch, where the
        reference base is C, followed by 10 bases aligned to the reference.
        suppose a reference 'CTTCGATAATCCTT'
                             |||  || ||||||
                 and a read 'CTTATATTATCCTT'.
        This situation is represented by the above MD tag.
        Given MD tag and read sequence this function returns the reference base C, G and A,
        and the mismatched base A, T, T.
        >>> M = MismatchFrequencies()
        >>> MD='3C0G2A7'
        >>> seq='CTTATATTATCCTT'
        >>> result=M.read_to_reference_mismatch(MD, seq, is_reverse=False)
        >>> result[0]=="CGA"
        True
        >>> result[1]=="ATT"
        True
        >>>
        '''
        search = re.finditer('[ATGC]', MD)
        if '^' in MD:
            print 'WARNING insertion detected, mismatch calling skipped for this read!!!'
            return (None, None)
        start_index = 0  # refers to the leading integer of the MD string before an edited base
        current_position = 0  # position of the mismatched nucleotide in the MD tag string
        mismatch_position = 0  # position of edited base in current read
        reference_base = ""
        mismatched_base = ""
        for result in search:
            current_position = result.start()
            mismatch_position = mismatch_position + 1 + int(MD[start_index:current_position])  # converts the leading characters before an edited base into integers
            start_index = result.end()
            reference_base += MD[result.end() - 1]
            mismatched_base += readseq[mismatch_position - 1]
        if is_reverse:
            reference_base = reverseComplement(reference_base)
            mismatched_base = reverseComplement(mismatched_base)
            mismatch_position = len(readseq)-mismatch_position-1
        if mismatched_base == 'N':
            return (None, None)
        if self.mismatch_in_allowed_region(readseq, mismatch_position):
            return (reference_base, mismatched_base)
        else:
            return (None, None)


def reverseComplement(sequence):
    '''do a reverse complement of DNA base.
    >>> reverseComplement('ATGC')=='GCAT'
    True
    >>>
    '''
    sequence = sequence.upper()
    complement = string.maketrans('ATCGN', 'TAGCN')
    return sequence.upper().translate(complement)[::-1]


def barplot(df, library, axes):
    df.plot(kind='bar', ax=axes, subplots=False,
            stacked=False, legend='test',
            title='Mismatch frequencies for {0}'.format(library))


def df_to_tab(df, output):
    df.to_csv(output, sep='\t')


def reduce_result(df, possible_mismatches):
    '''takes a pandas dataframe with full mismatch details and
    summarises the results for plotting.'''
    alignments = df['Alignment_file'].unique()
    readlengths = df['Readlength'].unique()
    combinations = itertools.product(*[alignments, readlengths])  # generate all possible combinations of readlength and alignment files
    reduced_dict = {}
    last_column = 3 + len(possible_mismatches)
    for combination in combinations:
        library_subset = df[df['Alignment_file'] == combination[0]]
        library_readlength_subset = library_subset[library_subset['Readlength'] == combination[1]]
        sum_of_library_and_readlength = library_readlength_subset.iloc[:, 3:last_column+1].sum()
        if combination[0] not in reduced_dict:
            reduced_dict[combination[0]] = {}
        reduced_dict[combination[0]][combination[1]] = sum_of_library_and_readlength.to_dict()
    return reduced_dict


def plot_result(reduced_dict, args):
    names = reduced_dict.keys()
    nrows = len(names) / 2 + 1
    fig = plt.figure(figsize=(16, 32))
    for i, library in enumerate(names):
        axes = fig.add_subplot(nrows, 2, i+1)
        library_dict = reduced_dict[library]
        df = pd.DataFrame(library_dict)
        df.drop(['total aligned reads'], inplace=True)
        barplot(df, library, axes),
        axes.set_ylabel('Mismatch count / all valid reads * readlength')
    fig.savefig(args.output_pdf, format='pdf')


def format_result_dict(result_dict, chromosomes, possible_mismatches):
    '''Turn nested dictionary into preformatted tab seperated lines'''
    header = "Reference sequence\tAlignment_file\tReadlength\t" + "\t".join(
        possible_mismatches) + "\ttotal aligned reads"
    libraries = result_dict.keys()
    readlengths = result_dict[libraries[0]].keys()
    result = []
    for chromosome in chromosomes:
        for library in libraries:
            for readlength in readlengths:
                line = []
                line.extend([chromosome, library, readlength])
                try:
                    line.extend([result_dict[library][readlength][chromosome].get(mismatch, 0) for mismatch in possible_mismatches])
                    line.extend([result_dict[library][readlength][chromosome].get(u'total valid reads', 0)])
                except KeyError:
                    line.extend([0 for mismatch in possible_mismatches])
                    line.extend([0])
                result.append(line)
    df = pd.DataFrame(result, columns=header.split('\t'))
    last_column = 3 + len(possible_mismatches)
    df['mismatches/per aligned nucleotides'] = df.iloc[:, 3:last_column].sum(1)/(df.iloc[:, last_column] * df['Readlength'])
    return df


def setup_MismatchFrequencies(args):
    resultDict = OrderedDict()
    kw_list = [{'result_dict': resultDict,
                'alignment_file': alignment_file,
                'name': name,
                'minimal_readlength': args.min,
                'maximal_readlength': args.max,
                'number_of_allowed_mismatches': args.n_mm,
                'ignore_5p_nucleotides': args.five_p,
                'ignore_3p_nucleotides': args.three_p,
                'possible_mismatches': args.possible_mismatches}
               for alignment_file, name in zip(args.input, args.name)]
    return (kw_list, resultDict)


def nested_dict_to_df(dictionary):
    dictionary = {(outerKey, innerKey): values for outerKey, innerDict in dictionary.iteritems() for innerKey, values in innerDict.iteritems()}
    df = pd.DataFrame.from_dict(dictionary).transpose()
    df.index.names = ['Library', 'Readlength']
    return df


def run_MismatchFrequencies(args):
    kw_list, resultDict = setup_MismatchFrequencies(args)
    references = [MismatchFrequencies(**kw_dict).references for kw_dict in kw_list]
    return (resultDict, references[0])


def main():
    result_dict, references = run_MismatchFrequencies(args)
    df = format_result_dict(result_dict, references, args.possible_mismatches)
    reduced_dict = reduce_result(df, args.possible_mismatches)
    plot_result(reduced_dict, args)
    reduced_df = nested_dict_to_df(reduced_dict)
    df_to_tab(reduced_df, args.output_tab)
    if args.expanded_output_tab:
        df_to_tab(df, args.expanded_output_tab)
    return reduced_dict

if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Produce mismatch statistics for BAM/SAM alignment files.')
    parser.add_argument('--input', nargs='*', help='Input files in SAM/BAM format')
    parser.add_argument('--name', nargs='*', help='Name for input file to display in output file. Should have same length as the number of inputs')
    parser.add_argument('--output_pdf', help='Output filename for graph')
    parser.add_argument('--output_tab', help='Output filename for table')
    parser.add_argument('--expanded_output_tab', default=None, help='Output filename for table')
    parser.add_argument('--possible_mismatches', default=[
            'AC', 'AG', 'AT', 'CA', 'CG', 'CT', 'GA', 'GC', 'GT', 'TA', 'TC', 'TG'
        ], nargs='+', help='specify mismatches that should be counted for the mismatch frequency. The format is Reference base -> observed base, eg AG for A to G mismatches.')
    parser.add_argument('--min', '--minimal_readlength', type=int, help='minimum readlength')
    parser.add_argument('--max', '--maximal_readlength', type=int, help='maximum readlength')
    parser.add_argument('--n_mm', '--number_allowed_mismatches', type=int, default=1, help='discard reads with more than n mismatches')
    parser.add_argument('--five_p', '--ignore_5p_nucleotides', type=int, default=0, help='when calculating nucleotide mismatch frequencies ignore the first N nucleotides of the read')
    parser.add_argument('--three_p', '--ignore_3p_nucleotides', type=int, default=1, help='when calculating nucleotide mismatch frequencies ignore the last N nucleotides of the read')
    # args = parser.parse_args(['--input', '3mismatches_ago2ip_s2.bam', '3mismatches_ago2ip_ovary.bam','--possible_mismatches','AC','AG', 'CG', 'TG', 'CT','--name', 'Siomi1', 'Siomi2' , '--five_p', '3','--three_p','3','--output_pdf', 'out.pdf', '--output_tab', 'out.tab', '--expanded_output_tab', 'expanded.tab', '--min', '20', '--max', '22'])
    args = parser.parse_args()
    reduced_dict = main()