6
|
1 #!/usr/bin/env python
|
|
2
|
|
3 import json
|
|
4 import logging
|
|
5 import os
|
|
6 from collections import defaultdict
|
|
7
|
|
8 import pandas as pd
|
|
9
|
|
10 LOG = logging.getLogger(__name__)
|
|
11
|
|
12 """
|
|
13 Serotype prediction for E. coli
|
|
14 """
|
|
15
|
|
16 def predict_serotype(blast_output_file, ectyper_dict_file, predictions_file, detailed=False):
|
|
17 """Make serotype prediction for all genomes based on blast output
|
|
18
|
|
19 Args:
|
|
20 blast_output_file(str):
|
|
21 blastn output with outfmt:
|
|
22 "6 qseqid qlen sseqid length pident sstart send sframe qcovhsp -word_size 11"
|
|
23 ectyper_dict_file(str):
|
|
24 mapping file used to associate allele id to allele informations
|
|
25 predictions_file(str):
|
|
26 csv file to store result
|
|
27 detailed(bool, optional):
|
|
28 whether to generate detailed output or not
|
|
29
|
|
30 Returns:
|
|
31 predictions_file
|
|
32 """
|
|
33 basename, extension = os.path.splitext(predictions_file)
|
|
34 parsed_output_file = ''.join([basename, '_raw', extension])
|
|
35
|
|
36 LOG.info("Predicting serotype from blast output")
|
|
37 output_df = blast_output_to_df(blast_output_file)
|
|
38 ectyper_df = ectyper_dict_to_df(ectyper_dict_file)
|
|
39 # Merge output_df and ectyper_df
|
|
40 output_df = output_df.merge(ectyper_df, left_on='qseqid', right_on='name', how='left')
|
|
41 predictions_dict = {}
|
|
42 # Select individual genomes
|
|
43 output_df['genome_name'] = output_df['sseqid'].str.split('|').str[1]
|
|
44 # Initialize constants
|
|
45 gene_pairs = {'wzx':'wzy', 'wzy':'wzx', 'wzm':'wzt', 'wzt':'wzm'}
|
|
46 predictions_columns = ['O_prediction', 'O_info', 'H_prediction', 'H_info']
|
|
47 gene_list = ['wzx', 'wzy', 'wzm', 'wzt', 'fliC', 'fllA', 'flkA', 'flmA', 'flnA']
|
|
48 if detailed:
|
|
49 # Add gene lists for detailed output report
|
|
50 for gene in gene_list:
|
|
51 predictions_columns.append(gene)
|
|
52 for genome_name, per_genome_df in output_df.groupby('genome_name'):
|
|
53 # Make prediction for each genome based on blast output
|
|
54 predictions_dict[genome_name] = get_prediction(
|
|
55 per_genome_df, predictions_columns, gene_pairs, detailed)
|
|
56 predictions_df = pd.DataFrame(predictions_dict).transpose()
|
|
57 if predictions_df.empty:
|
|
58 predictions_df = pd.DataFrame(columns=predictions_columns)
|
|
59 predictions_df = predictions_df[predictions_columns]
|
|
60 store_df(output_df, parsed_output_file)
|
|
61 store_df(predictions_df, predictions_file)
|
|
62 LOG.info("Serotype prediction completed")
|
|
63 return predictions_file
|
|
64
|
|
65 def get_prediction(per_genome_df, predictions_columns, gene_pairs, detailed, ):
|
|
66 """Make serotype prediction for single genomes based on blast output
|
|
67
|
|
68 Args:
|
|
69 per_genome_df(DataFrame):
|
|
70 blastn outputs for the given genome
|
|
71 predictions_columns(dict):
|
|
72 columns to be filled by prediction function
|
|
73 gene_pairs(dict):
|
|
74 dict of pair genes used for paired logic
|
|
75 detailed(bool):
|
|
76 whether to generate detailed output or not
|
|
77 Return:
|
|
78 Prediction dictionary for the given genome
|
|
79 """
|
|
80 # Extract potential predictors
|
|
81 useful_columns = [
|
|
82 'gene', 'serotype', 'score', 'name', 'desc', 'pident', 'qcovhsp', 'qseqid', 'sseqid'
|
|
83 ]
|
|
84 per_genome_df = per_genome_df.sort_values(['gene', 'serotype', 'score'], ascending=False)
|
|
85 per_genome_df = per_genome_df[~per_genome_df.duplicated(['gene', 'serotype'])]
|
|
86 predictors_df = per_genome_df[useful_columns]
|
|
87 predictors_df = predictors_df.sort_values('score', ascending=False)
|
|
88 predictions = {}
|
|
89 for column in predictions_columns:
|
|
90 predictions[column] = None
|
|
91 for predicting_antigen in ['O', 'H']:
|
|
92 genes_pool = defaultdict(list)
|
|
93 for index, row in predictors_df.iterrows():
|
|
94 gene = row['gene']
|
|
95 if detailed:
|
|
96 predictions[gene] = True
|
|
97 if not predictions[predicting_antigen+'_prediction']:
|
|
98 serotype = row['serotype']
|
|
99 if serotype[0] is not predicting_antigen:
|
|
100 continue
|
|
101 genes_pool[gene].append(serotype)
|
|
102 prediction = None
|
|
103 if len(serotype) < 1:
|
|
104 continue
|
|
105 antigen = serotype[0].upper()
|
|
106 if antigen != predicting_antigen:
|
|
107 continue
|
|
108 if gene in gene_pairs.keys():
|
|
109 predictions[antigen+'_info'] = 'Only unpaired alignments found'
|
|
110 # Pair gene logic
|
|
111 potential_pairs = genes_pool.get(gene_pairs.get(gene))
|
|
112 if potential_pairs is None:
|
|
113 continue
|
|
114 if serotype in potential_pairs:
|
|
115 prediction = serotype
|
|
116 else:
|
|
117 # Normal logic
|
|
118 prediction = serotype
|
|
119 if prediction is None:
|
|
120 continue
|
|
121 predictions[antigen+'_info'] = 'Alignment found'
|
|
122 predictions[predicting_antigen+'_prediction'] = prediction
|
|
123 # No alignment found
|
|
124 ## Make prediction based on non-paired gene
|
|
125 ## if only one non-paired gene is avaliable
|
|
126 if predictions.get(predicting_antigen+'_prediction') is None:
|
|
127 if len(genes_pool) == 1:
|
|
128 serotypes = list(genes_pool.values())[0]
|
|
129 if len(serotypes) == 1:
|
|
130 predictions[antigen+'_info'] = 'Lone unpaired alignment found'
|
|
131 predictions[predicting_antigen+'_prediction'] = serotypes[0]
|
|
132 return predictions
|
|
133
|
|
134 def blast_output_to_df(blast_output_file):
|
|
135 '''Convert raw blast output file to DataFrame
|
|
136 Args:
|
|
137 blast_output_file(str): location of blast output
|
|
138
|
|
139 Returns:
|
|
140 DataFrame:
|
|
141 DataFrame that contains all informations from blast output
|
|
142 '''
|
|
143 # Load blast output
|
|
144 output_data = []
|
|
145 with open(blast_output_file, 'r') as fh:
|
|
146 for line in fh:
|
|
147 fields = line.strip().split()
|
|
148 entry = {
|
|
149 'qseqid': fields[0],
|
|
150 'qlen': fields[1],
|
|
151 'sseqid': fields[2],
|
|
152 'length': fields[3],
|
|
153 'pident': fields[4],
|
|
154 'sstart': fields[5],
|
|
155 'send': fields[6],
|
|
156 'sframe': fields[7],
|
|
157 'qcovhsp': fields[8]
|
|
158 }
|
|
159 output_data.append(entry)
|
|
160 df = pd.DataFrame(output_data)
|
|
161 if not output_data:
|
|
162 LOG.info("No hit found for this blast query")
|
|
163 # Return empty dataframe with correct columns
|
|
164 return pd.DataFrame(
|
|
165 columns=[
|
|
166 'length', 'pident', 'qcovhsp',
|
|
167 'qlen', 'qseqid', 'send',
|
|
168 'sframe', 'sseqid', 'sstart'
|
|
169 ])
|
|
170 df['score'] = df['pident'].astype(float)*df['qcovhsp'].astype(float)/10000
|
|
171 return df
|
|
172
|
|
173 def ectyper_dict_to_df(ectyper_dict_file):
|
|
174 # Read ectyper dict
|
|
175 with open(ectyper_dict_file) as fh:
|
|
176 ectyper_dict = json.load(fh)
|
|
177 temp_list = []
|
|
178 for antigen, alleles in ectyper_dict.items():
|
|
179 for name, allele in alleles.items():
|
|
180 new_entry = {
|
|
181 'serotype': allele.get('allele'),
|
|
182 'name': name,
|
|
183 'gene': allele.get('gene'),
|
|
184 'desc': allele.get('desc')
|
|
185 }
|
|
186 temp_list.append(new_entry)
|
|
187 df = pd.DataFrame(temp_list)
|
|
188 return df
|
|
189
|
|
190 def store_df(src_df, dst_file):
|
|
191 """Append dataframe to a file if it exists, otherwise, make a new file
|
|
192 Args:
|
|
193 src_df(str): dataframe object to be stored
|
|
194 dst_file(str): dst_file to be modified/created
|
|
195 """
|
|
196 if os.path.isfile(dst_file):
|
|
197 with open(dst_file, 'a') as fh:
|
|
198 src_df.to_csv(fh, header=False)
|
|
199 else:
|
|
200 with open(dst_file, 'w') as fh:
|
|
201 src_df.to_csv(fh, header=True, index_label='genome')
|
|
202
|
|
203 def report_result(csv_file):
|
|
204 '''Report the content of dataframe in log
|
|
205
|
|
206 Args:
|
|
207 csv_file(str): location of the prediction file
|
|
208 '''
|
|
209 df = pd.read_csv(csv_file)
|
|
210 if df.empty:
|
|
211 LOG.info('No prediction was made becuase no alignment was found')
|
|
212 return
|
|
213 LOG.info('\n{0}'.format(df.to_string(index=False)))
|
|
214
|
|
215 def add_non_predicted(all_genomes_list, predictions_file):
|
|
216 '''Add genomes that do not show up in blast result to prediction file
|
|
217
|
|
218 Args:
|
|
219 all_genome_list(list):
|
|
220 list of genomes from user input
|
|
221 predictions_file(str):
|
|
222 location of the prediction file
|
|
223 Returns:
|
|
224 str: location of the prediction file
|
|
225 '''
|
|
226 df = pd.read_csv(predictions_file)
|
|
227 df = df.merge(pd.DataFrame(all_genomes_list, columns=['genome']), on='genome', how='outer')
|
|
228 df.fillna({'O_info':'No alignment found', 'H_info':'No alignment found'}, inplace=True)
|
|
229 df.fillna('-', inplace=True)
|
|
230 df.to_csv(predictions_file, index=False)
|
|
231 return predictions_file
|