1 #!/usr/bin/env python
3 import json
4 import logging
5 import os
6 from collections import defaultdict
8 import pandas as pd
10 LOG = logging.getLogger(__name__)
12 """
13 Serotype prediction for E. coli
14 """
16 def predict_serotype(blast_output_file, ectyper_dict_file, predictions_file, detailed=False):
17 """Make serotype prediction for all genomes based on blast output
19 Args:
20 blast_output_file(str):
21 blastn output with outfmt:
22 "6 qseqid qlen sseqid length pident sstart send sframe qcovhsp -word_size 11"
23 ectyper_dict_file(str):
24 mapping file used to associate allele id to allele informations
25 predictions_file(str):
26 csv file to store result
27 detailed(bool, optional):
28 whether to generate detailed output or not
30 Returns:
31 predictions_file
32 """
33 basename, extension = os.path.splitext(predictions_file)
34 parsed_output_file = ''.join([basename, '_raw', extension])
36 LOG.info("Predicting serotype from blast output")
37 output_df = blast_output_to_df(blast_output_file)
38 ectyper_df = ectyper_dict_to_df(ectyper_dict_file)
39 # Merge output_df and ectyper_df
40 output_df = output_df.merge(ectyper_df, left_on='qseqid', right_on='name', how='left')
41 predictions_dict = {}
42 # Select individual genomes
43 output_df['genome_name'] = output_df['sseqid'].str.split('|').str[1]
44 # Initialize constants
45 gene_pairs = {'wzx':'wzy', 'wzy':'wzx', 'wzm':'wzt', 'wzt':'wzm'}
46 predictions_columns = ['O_prediction', 'O_info', 'H_prediction', 'H_info']
47 gene_list = ['wzx', 'wzy', 'wzm', 'wzt', 'fliC', 'fllA', 'flkA', 'flmA', 'flnA']
48 if detailed:
49 # Add gene lists for detailed output report
50 for gene in gene_list:
51 predictions_columns.append(gene)
52 for genome_name, per_genome_df in output_df.groupby('genome_name'):
53 # Make prediction for each genome based on blast output
54 predictions_dict[genome_name] = get_prediction(
55 per_genome_df, predictions_columns, gene_pairs, detailed)
56 predictions_df = pd.DataFrame(predictions_dict).transpose()
57 if predictions_df.empty:
58 predictions_df = pd.DataFrame(columns=predictions_columns)
59 predictions_df = predictions_df[predictions_columns]
60 store_df(output_df, parsed_output_file)
61 store_df(predictions_df, predictions_file)
62 LOG.info("Serotype prediction completed")
63 return predictions_file
65 def get_prediction(per_genome_df, predictions_columns, gene_pairs, detailed, ):
66 """Make serotype prediction for single genomes based on blast output
68 Args:
69 per_genome_df(DataFrame):
70 blastn outputs for the given genome
71 predictions_columns(dict):
72 columns to be filled by prediction function
73 gene_pairs(dict):
74 dict of pair genes used for paired logic
75 detailed(bool):
76 whether to generate detailed output or not
77 Return:
78 Prediction dictionary for the given genome
79 """
80 # Extract potential predictors
81 useful_columns = [
82 'gene', 'serotype', 'score', 'name', 'desc', 'pident', 'qcovhsp', 'qseqid', 'sseqid'
83 ]
84 per_genome_df = per_genome_df.sort_values(['gene', 'serotype', 'score'], ascending=False)
85 per_genome_df = per_genome_df[~per_genome_df.duplicated(['gene', 'serotype'])]
86 predictors_df = per_genome_df[useful_columns]
87 predictors_df = predictors_df.sort_values('score', ascending=False)
88 predictions = {}
89 for column in predictions_columns:
90 predictions[column] = None
91 for predicting_antigen in ['O', 'H']:
92 genes_pool = defaultdict(list)
93 for index, row in predictors_df.iterrows():
94 gene = row['gene']
95 if detailed:
96 predictions[gene] = True
97 if not predictions[predicting_antigen+'_prediction']:
98 serotype = row['serotype']
99 if serotype[0] is not predicting_antigen:
100 continue
101 genes_pool[gene].append(serotype)
102 prediction = None
103 if len(serotype) < 1:
104 continue
105 antigen = serotype[0].upper()
106 if antigen != predicting_antigen:
107 continue
108 if gene in gene_pairs.keys():
109 predictions[antigen+'_info'] = 'Only unpaired alignments found'
110 # Pair gene logic
111 potential_pairs = genes_pool.get(gene_pairs.get(gene))
112 if potential_pairs is None:
113 continue
114 if serotype in potential_pairs:
115 prediction = serotype
116 else:
117 # Normal logic
118 prediction = serotype
119 if prediction is None:
120 continue
121 predictions[antigen+'_info'] = 'Alignment found'
122 predictions[predicting_antigen+'_prediction'] = prediction
123 # No alignment found
124 ## Make prediction based on non-paired gene
125 ## if only one non-paired gene is avaliable
126 if predictions.get(predicting_antigen+'_prediction') is None:
127 if len(genes_pool) == 1:
128 serotypes = list(genes_pool.values())[0]
129 if len(serotypes) == 1:
130 predictions[antigen+'_info'] = 'Lone unpaired alignment found'
131 predictions[predicting_antigen+'_prediction'] = serotypes[0]
132 return predictions
134 def blast_output_to_df(blast_output_file):
135 '''Convert raw blast output file to DataFrame
136 Args:
137 blast_output_file(str): location of blast output
139 Returns:
140 DataFrame:
141 DataFrame that contains all informations from blast output
142 '''
143 # Load blast output
144 output_data = []
145 with open(blast_output_file, 'r') as fh:
146 for line in fh:
147 fields = line.strip().split()
148 entry = {
149 'qseqid': fields[0],
150 'qlen': fields[1],
151 'sseqid': fields[2],
152 'length': fields[3],
153 'pident': fields[4],
154 'sstart': fields[5],
155 'send': fields[6],
156 'sframe': fields[7],
157 'qcovhsp': fields[8]
158 }
159 output_data.append(entry)
160 df = pd.DataFrame(output_data)
161 if not output_data:
162 LOG.info("No hit found for this blast query")
163 # Return empty dataframe with correct columns
164 return pd.DataFrame(
165 columns=[
166 'length', 'pident', 'qcovhsp',
167 'qlen', 'qseqid', 'send',
168 'sframe', 'sseqid', 'sstart'
169 ])
170 df['score'] = df['pident'].astype(float)*df['qcovhsp'].astype(float)/10000
171 return df
173 def ectyper_dict_to_df(ectyper_dict_file):
174 # Read ectyper dict
175 with open(ectyper_dict_file) as fh:
176 ectyper_dict = json.load(fh)
177 temp_list = []
178 for antigen, alleles in ectyper_dict.items():
179 for name, allele in alleles.items():
180 new_entry = {
181 'serotype': allele.get('allele'),
182 'name': name,
183 'gene': allele.get('gene'),
184 'desc': allele.get('desc')
185 }
186 temp_list.append(new_entry)
187 df = pd.DataFrame(temp_list)
188 return df
190 def store_df(src_df, dst_file):
191 """Append dataframe to a file if it exists, otherwise, make a new file
192 Args:
193 src_df(str): dataframe object to be stored
194 dst_file(str): dst_file to be modified/created
195 """
196 if os.path.isfile(dst_file):
197 with open(dst_file, 'a') as fh:
198 src_df.to_csv(fh, header=False)
199 else:
200 with open(dst_file, 'w') as fh:
201 src_df.to_csv(fh, header=True, index_label='genome')
203 def report_result(csv_file):
204 '''Report the content of dataframe in log
206 Args:
207 csv_file(str): location of the prediction file
208 '''
209 df = pd.read_csv(csv_file)
210 if df.empty:
211 LOG.info('No prediction was made becuase no alignment was found')
212 return
213 LOG.info('\n{0}'.format(df.to_string(index=False)))
215 def add_non_predicted(all_genomes_list, predictions_file):
216 '''Add genomes that do not show up in blast result to prediction file
218 Args:
219 all_genome_list(list):
220 list of genomes from user input
221 predictions_file(str):
222 location of the prediction file
223 Returns:
224 str: location of the prediction file
225 '''
226 df = pd.read_csv(predictions_file)
227 df = df.merge(pd.DataFrame(all_genomes_list, columns=['genome']), on='genome', how='outer')
228 df.fillna({'O_info':'No alignment found', 'H_info':'No alignment found'}, inplace=True)
229 df.fillna('-', inplace=True)
230 df.to_csv(predictions_file, index=False)
231 return predictions_file