Mercurial > repos > jpetteng > ectyper
comparison ecoli_serotyping/predictionFunctions.py @ 6:fe3ceb5c4214 draft
Uploaded
author | jpetteng |
---|---|
date | Fri, 05 Jan 2018 15:43:14 -0500 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
5:a202cc394af8 | 6:fe3ceb5c4214 |
---|---|
1 #!/usr/bin/env python | |
2 | |
3 import json | |
4 import logging | |
5 import os | |
6 from collections import defaultdict | |
7 | |
8 import pandas as pd | |
9 | |
10 LOG = logging.getLogger(__name__) | |
11 | |
12 """ | |
13 Serotype prediction for E. coli | |
14 """ | |
15 | |
16 def predict_serotype(blast_output_file, ectyper_dict_file, predictions_file, detailed=False): | |
17 """Make serotype prediction for all genomes based on blast output | |
18 | |
19 Args: | |
20 blast_output_file(str): | |
21 blastn output with outfmt: | |
22 "6 qseqid qlen sseqid length pident sstart send sframe qcovhsp -word_size 11" | |
23 ectyper_dict_file(str): | |
24 mapping file used to associate allele id to allele informations | |
25 predictions_file(str): | |
26 csv file to store result | |
27 detailed(bool, optional): | |
28 whether to generate detailed output or not | |
29 | |
30 Returns: | |
31 predictions_file | |
32 """ | |
33 basename, extension = os.path.splitext(predictions_file) | |
34 parsed_output_file = ''.join([basename, '_raw', extension]) | |
35 | |
36 LOG.info("Predicting serotype from blast output") | |
37 output_df = blast_output_to_df(blast_output_file) | |
38 ectyper_df = ectyper_dict_to_df(ectyper_dict_file) | |
39 # Merge output_df and ectyper_df | |
40 output_df = output_df.merge(ectyper_df, left_on='qseqid', right_on='name', how='left') | |
41 predictions_dict = {} | |
42 # Select individual genomes | |
43 output_df['genome_name'] = output_df['sseqid'].str.split('|').str[1] | |
44 # Initialize constants | |
45 gene_pairs = {'wzx':'wzy', 'wzy':'wzx', 'wzm':'wzt', 'wzt':'wzm'} | |
46 predictions_columns = ['O_prediction', 'O_info', 'H_prediction', 'H_info'] | |
47 gene_list = ['wzx', 'wzy', 'wzm', 'wzt', 'fliC', 'fllA', 'flkA', 'flmA', 'flnA'] | |
48 if detailed: | |
49 # Add gene lists for detailed output report | |
50 for gene in gene_list: | |
51 predictions_columns.append(gene) | |
52 for genome_name, per_genome_df in output_df.groupby('genome_name'): | |
53 # Make prediction for each genome based on blast output | |
54 predictions_dict[genome_name] = get_prediction( | |
55 per_genome_df, predictions_columns, gene_pairs, detailed) | |
56 predictions_df = pd.DataFrame(predictions_dict).transpose() | |
57 if predictions_df.empty: | |
58 predictions_df = pd.DataFrame(columns=predictions_columns) | |
59 predictions_df = predictions_df[predictions_columns] | |
60 store_df(output_df, parsed_output_file) | |
61 store_df(predictions_df, predictions_file) | |
62 LOG.info("Serotype prediction completed") | |
63 return predictions_file | |
64 | |
65 def get_prediction(per_genome_df, predictions_columns, gene_pairs, detailed, ): | |
66 """Make serotype prediction for single genomes based on blast output | |
67 | |
68 Args: | |
69 per_genome_df(DataFrame): | |
70 blastn outputs for the given genome | |
71 predictions_columns(dict): | |
72 columns to be filled by prediction function | |
73 gene_pairs(dict): | |
74 dict of pair genes used for paired logic | |
75 detailed(bool): | |
76 whether to generate detailed output or not | |
77 Return: | |
78 Prediction dictionary for the given genome | |
79 """ | |
80 # Extract potential predictors | |
81 useful_columns = [ | |
82 'gene', 'serotype', 'score', 'name', 'desc', 'pident', 'qcovhsp', 'qseqid', 'sseqid' | |
83 ] | |
84 per_genome_df = per_genome_df.sort_values(['gene', 'serotype', 'score'], ascending=False) | |
85 per_genome_df = per_genome_df[~per_genome_df.duplicated(['gene', 'serotype'])] | |
86 predictors_df = per_genome_df[useful_columns] | |
87 predictors_df = predictors_df.sort_values('score', ascending=False) | |
88 predictions = {} | |
89 for column in predictions_columns: | |
90 predictions[column] = None | |
91 for predicting_antigen in ['O', 'H']: | |
92 genes_pool = defaultdict(list) | |
93 for index, row in predictors_df.iterrows(): | |
94 gene = row['gene'] | |
95 if detailed: | |
96 predictions[gene] = True | |
97 if not predictions[predicting_antigen+'_prediction']: | |
98 serotype = row['serotype'] | |
99 if serotype[0] is not predicting_antigen: | |
100 continue | |
101 genes_pool[gene].append(serotype) | |
102 prediction = None | |
103 if len(serotype) < 1: | |
104 continue | |
105 antigen = serotype[0].upper() | |
106 if antigen != predicting_antigen: | |
107 continue | |
108 if gene in gene_pairs.keys(): | |
109 predictions[antigen+'_info'] = 'Only unpaired alignments found' | |
110 # Pair gene logic | |
111 potential_pairs = genes_pool.get(gene_pairs.get(gene)) | |
112 if potential_pairs is None: | |
113 continue | |
114 if serotype in potential_pairs: | |
115 prediction = serotype | |
116 else: | |
117 # Normal logic | |
118 prediction = serotype | |
119 if prediction is None: | |
120 continue | |
121 predictions[antigen+'_info'] = 'Alignment found' | |
122 predictions[predicting_antigen+'_prediction'] = prediction | |
123 # No alignment found | |
124 ## Make prediction based on non-paired gene | |
125 ## if only one non-paired gene is avaliable | |
126 if predictions.get(predicting_antigen+'_prediction') is None: | |
127 if len(genes_pool) == 1: | |
128 serotypes = list(genes_pool.values())[0] | |
129 if len(serotypes) == 1: | |
130 predictions[antigen+'_info'] = 'Lone unpaired alignment found' | |
131 predictions[predicting_antigen+'_prediction'] = serotypes[0] | |
132 return predictions | |
133 | |
134 def blast_output_to_df(blast_output_file): | |
135 '''Convert raw blast output file to DataFrame | |
136 Args: | |
137 blast_output_file(str): location of blast output | |
138 | |
139 Returns: | |
140 DataFrame: | |
141 DataFrame that contains all informations from blast output | |
142 ''' | |
143 # Load blast output | |
144 output_data = [] | |
145 with open(blast_output_file, 'r') as fh: | |
146 for line in fh: | |
147 fields = line.strip().split() | |
148 entry = { | |
149 'qseqid': fields[0], | |
150 'qlen': fields[1], | |
151 'sseqid': fields[2], | |
152 'length': fields[3], | |
153 'pident': fields[4], | |
154 'sstart': fields[5], | |
155 'send': fields[6], | |
156 'sframe': fields[7], | |
157 'qcovhsp': fields[8] | |
158 } | |
159 output_data.append(entry) | |
160 df = pd.DataFrame(output_data) | |
161 if not output_data: | |
162 LOG.info("No hit found for this blast query") | |
163 # Return empty dataframe with correct columns | |
164 return pd.DataFrame( | |
165 columns=[ | |
166 'length', 'pident', 'qcovhsp', | |
167 'qlen', 'qseqid', 'send', | |
168 'sframe', 'sseqid', 'sstart' | |
169 ]) | |
170 df['score'] = df['pident'].astype(float)*df['qcovhsp'].astype(float)/10000 | |
171 return df | |
172 | |
173 def ectyper_dict_to_df(ectyper_dict_file): | |
174 # Read ectyper dict | |
175 with open(ectyper_dict_file) as fh: | |
176 ectyper_dict = json.load(fh) | |
177 temp_list = [] | |
178 for antigen, alleles in ectyper_dict.items(): | |
179 for name, allele in alleles.items(): | |
180 new_entry = { | |
181 'serotype': allele.get('allele'), | |
182 'name': name, | |
183 'gene': allele.get('gene'), | |
184 'desc': allele.get('desc') | |
185 } | |
186 temp_list.append(new_entry) | |
187 df = pd.DataFrame(temp_list) | |
188 return df | |
189 | |
190 def store_df(src_df, dst_file): | |
191 """Append dataframe to a file if it exists, otherwise, make a new file | |
192 Args: | |
193 src_df(str): dataframe object to be stored | |
194 dst_file(str): dst_file to be modified/created | |
195 """ | |
196 if os.path.isfile(dst_file): | |
197 with open(dst_file, 'a') as fh: | |
198 src_df.to_csv(fh, header=False) | |
199 else: | |
200 with open(dst_file, 'w') as fh: | |
201 src_df.to_csv(fh, header=True, index_label='genome') | |
202 | |
203 def report_result(csv_file): | |
204 '''Report the content of dataframe in log | |
205 | |
206 Args: | |
207 csv_file(str): location of the prediction file | |
208 ''' | |
209 df = pd.read_csv(csv_file) | |
210 if df.empty: | |
211 LOG.info('No prediction was made becuase no alignment was found') | |
212 return | |
213 LOG.info('\n{0}'.format(df.to_string(index=False))) | |
214 | |
215 def add_non_predicted(all_genomes_list, predictions_file): | |
216 '''Add genomes that do not show up in blast result to prediction file | |
217 | |
218 Args: | |
219 all_genome_list(list): | |
220 list of genomes from user input | |
221 predictions_file(str): | |
222 location of the prediction file | |
223 Returns: | |
224 str: location of the prediction file | |
225 ''' | |
226 df = pd.read_csv(predictions_file) | |
227 df = df.merge(pd.DataFrame(all_genomes_list, columns=['genome']), on='genome', how='outer') | |
228 df.fillna({'O_info':'No alignment found', 'H_info':'No alignment found'}, inplace=True) | |
229 df.fillna('-', inplace=True) | |
230 df.to_csv(predictions_file, index=False) | |
231 return predictions_file |