diff ecoli_serotyping/predictionFunctions.py @ 6:fe3ceb5c4214 draft

Uploaded
author jpetteng
date Fri, 05 Jan 2018 15:43:14 -0500
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/ecoli_serotyping/predictionFunctions.py	Fri Jan 05 15:43:14 2018 -0500
@@ -0,0 +1,231 @@
+#!/usr/bin/env python
+
+import json
+import logging
+import os
+from collections import defaultdict
+
+import pandas as pd
+
+LOG = logging.getLogger(__name__)
+
+"""
+    Serotype prediction for E. coli
+"""
+
+def predict_serotype(blast_output_file, ectyper_dict_file, predictions_file, detailed=False):
+    """Make serotype prediction for all genomes based on blast output
+    
+    Args:
+        blast_output_file(str):
+            blastn output with outfmt:
+                "6 qseqid qlen sseqid length pident sstart send sframe qcovhsp -word_size 11"
+    ectyper_dict_file(str):
+        mapping file used to associate allele id to allele informations
+    predictions_file(str):
+        csv file to store result
+    detailed(bool, optional):
+        whether to generate detailed output or not
+    
+    Returns:
+        predictions_file
+    """
+    basename, extension = os.path.splitext(predictions_file)
+    parsed_output_file = ''.join([basename, '_raw', extension])
+
+    LOG.info("Predicting serotype from blast output")
+    output_df = blast_output_to_df(blast_output_file)
+    ectyper_df = ectyper_dict_to_df(ectyper_dict_file)
+    # Merge output_df and ectyper_df
+    output_df = output_df.merge(ectyper_df, left_on='qseqid', right_on='name', how='left')
+    predictions_dict = {}
+    # Select individual genomes
+    output_df['genome_name'] = output_df['sseqid'].str.split('|').str[1]
+    # Initialize constants
+    gene_pairs = {'wzx':'wzy', 'wzy':'wzx', 'wzm':'wzt', 'wzt':'wzm'}
+    predictions_columns = ['O_prediction', 'O_info', 'H_prediction', 'H_info']
+    gene_list = ['wzx', 'wzy', 'wzm', 'wzt', 'fliC', 'fllA', 'flkA', 'flmA', 'flnA']
+    if detailed:
+        # Add gene lists for detailed output report
+        for gene in gene_list:
+            predictions_columns.append(gene)
+    for genome_name, per_genome_df in output_df.groupby('genome_name'):
+        # Make prediction for each genome based on blast output
+        predictions_dict[genome_name] = get_prediction(
+            per_genome_df, predictions_columns, gene_pairs, detailed)
+    predictions_df = pd.DataFrame(predictions_dict).transpose()
+    if predictions_df.empty:
+        predictions_df = pd.DataFrame(columns=predictions_columns)
+    predictions_df = predictions_df[predictions_columns]
+    store_df(output_df, parsed_output_file)
+    store_df(predictions_df, predictions_file)
+    LOG.info("Serotype prediction completed")
+    return predictions_file
+
+def get_prediction(per_genome_df, predictions_columns, gene_pairs, detailed, ):
+    """Make serotype prediction for single genomes based on blast output
+    
+    Args:
+        per_genome_df(DataFrame):
+            blastn outputs for the given genome
+        predictions_columns(dict):
+            columns to be filled by prediction function
+        gene_pairs(dict):
+            dict of pair genes used for paired logic
+        detailed(bool):
+            whether to generate detailed output or not
+    Return:
+        Prediction dictionary for the given genome
+    """
+    # Extract potential predictors
+    useful_columns = [
+        'gene', 'serotype', 'score', 'name', 'desc', 'pident', 'qcovhsp', 'qseqid', 'sseqid'
+    ]
+    per_genome_df = per_genome_df.sort_values(['gene', 'serotype', 'score'], ascending=False)
+    per_genome_df = per_genome_df[~per_genome_df.duplicated(['gene', 'serotype'])]
+    predictors_df = per_genome_df[useful_columns]
+    predictors_df = predictors_df.sort_values('score', ascending=False)
+    predictions = {}
+    for column in predictions_columns:
+        predictions[column] = None
+    for predicting_antigen in ['O', 'H']:
+        genes_pool = defaultdict(list)
+        for index, row in predictors_df.iterrows():
+            gene = row['gene']
+            if detailed:
+                predictions[gene] = True
+            if not predictions[predicting_antigen+'_prediction']:
+                serotype = row['serotype']
+                if serotype[0] is not predicting_antigen:
+                    continue
+                genes_pool[gene].append(serotype)
+                prediction = None
+                if len(serotype) < 1:
+                    continue
+                antigen = serotype[0].upper()
+                if antigen != predicting_antigen:
+                    continue
+                if gene in gene_pairs.keys():
+                    predictions[antigen+'_info'] = 'Only unpaired alignments found'
+                    # Pair gene logic
+                    potential_pairs = genes_pool.get(gene_pairs.get(gene))
+                    if potential_pairs is None:
+                        continue
+                    if serotype in potential_pairs:
+                        prediction = serotype
+                else:
+                    # Normal logic
+                    prediction = serotype
+                if prediction is None:
+                    continue
+                predictions[antigen+'_info'] = 'Alignment found'
+                predictions[predicting_antigen+'_prediction'] = prediction
+        # No alignment found
+        ## Make prediction based on non-paired gene
+        ##   if only one non-paired gene is avaliable
+        if predictions.get(predicting_antigen+'_prediction') is None:
+            if len(genes_pool) == 1:
+                serotypes = list(genes_pool.values())[0]
+                if len(serotypes) == 1:
+                    predictions[antigen+'_info'] = 'Lone unpaired alignment found'
+                    predictions[predicting_antigen+'_prediction'] = serotypes[0]
+    return predictions
+
+def blast_output_to_df(blast_output_file):
+    '''Convert raw blast output file to DataFrame
+    Args:
+        blast_output_file(str): location of blast output
+    
+    Returns:
+        DataFrame:
+            DataFrame that contains all informations from blast output
+    '''
+    # Load blast output
+    output_data = []
+    with open(blast_output_file, 'r') as fh:
+        for line in fh:
+            fields = line.strip().split()
+            entry = {
+                'qseqid': fields[0],
+                'qlen': fields[1],
+                'sseqid': fields[2],
+                'length': fields[3],
+                'pident': fields[4],
+                'sstart': fields[5],
+                'send': fields[6],
+                'sframe': fields[7],
+                'qcovhsp': fields[8]
+            }
+            output_data.append(entry)
+    df = pd.DataFrame(output_data)
+    if not output_data:
+        LOG.info("No hit found for this blast query")
+        # Return empty dataframe with correct columns
+        return pd.DataFrame(
+            columns=[
+                'length', 'pident', 'qcovhsp',
+                'qlen', 'qseqid', 'send',
+                'sframe', 'sseqid', 'sstart'
+            ])
+    df['score'] = df['pident'].astype(float)*df['qcovhsp'].astype(float)/10000
+    return df
+
+def ectyper_dict_to_df(ectyper_dict_file):
+    # Read ectyper dict
+    with open(ectyper_dict_file) as fh:
+        ectyper_dict = json.load(fh)
+        temp_list = []
+        for antigen, alleles in ectyper_dict.items():
+            for name, allele in alleles.items():
+                new_entry = {
+                    'serotype': allele.get('allele'),
+                    'name': name,
+                    'gene': allele.get('gene'),
+                    'desc': allele.get('desc')
+                }
+                temp_list.append(new_entry)
+        df = pd.DataFrame(temp_list)
+        return df
+
+def store_df(src_df, dst_file):
+    """Append dataframe to a file if it exists, otherwise, make a new file
+    Args:
+        src_df(str): dataframe object to be stored
+        dst_file(str): dst_file to be modified/created
+    """
+    if os.path.isfile(dst_file):
+        with open(dst_file, 'a') as fh:
+            src_df.to_csv(fh, header=False)
+    else:
+        with open(dst_file, 'w') as fh:
+            src_df.to_csv(fh, header=True, index_label='genome')
+
+def report_result(csv_file):
+    '''Report the content of dataframe in log
+    
+    Args:
+        csv_file(str): location of the prediction file
+    '''
+    df = pd.read_csv(csv_file)
+    if df.empty:
+        LOG.info('No prediction was made becuase no alignment was found')
+        return
+    LOG.info('\n{0}'.format(df.to_string(index=False)))
+
+def add_non_predicted(all_genomes_list, predictions_file):
+    '''Add genomes that do not show up in blast result to prediction file
+    
+    Args:
+        all_genome_list(list):
+            list of genomes from user input
+        predictions_file(str):
+            location of the prediction file
+    Returns:
+        str: location of the prediction file
+    '''
+    df = pd.read_csv(predictions_file)
+    df = df.merge(pd.DataFrame(all_genomes_list, columns=['genome']), on='genome', how='outer')
+    df.fillna({'O_info':'No alignment found', 'H_info':'No alignment found'}, inplace=True)
+    df.fillna('-', inplace=True)
+    df.to_csv(predictions_file, index=False)
+    return predictions_file