Mercurial > repos > dcouvin > resfinder4
comparison resfinder/cge/standardize_results.py @ 0:55051a9bc58d draft default tip
Uploaded
| author | dcouvin |
|---|---|
| date | Mon, 10 Jan 2022 20:06:07 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:55051a9bc58d |
|---|---|
| 1 #!/usr/bin/env python3 | |
| 2 import random | |
| 3 import string | |
| 4 | |
| 5 from .phenotype2genotype.feature import ResGene, ResMutation | |
| 6 from .phenotype2genotype.res_profile import PhenoDB | |
| 7 from .out.util.generator import Generator | |
| 8 | |
| 9 import json | |
| 10 | |
| 11 | |
| 12 class SeqVariationResult(dict): | |
| 13 def __init__(self, res_collection, mismatch, region_results, db_name): | |
| 14 self.res_collection = res_collection | |
| 15 self.load_var_type(mismatch[0]) | |
| 16 self["ref_start_pos"] = mismatch[1] | |
| 17 self["ref_end_pos"] = mismatch[2] | |
| 18 mut_string = mismatch[4] | |
| 19 self["ref_codon"] = mismatch[5].lower() | |
| 20 self["var_codon"] = mismatch[6].lower() | |
| 21 | |
| 22 if(len(mismatch) > 7): | |
| 23 self["ref_aa"] = mismatch[7].lower() | |
| 24 self["var_aa"] = mismatch[8].lower() | |
| 25 region_name = region_results[0]["ref_id"] | |
| 26 region_name = PhenoDB.if_promoter_rename(region_name) | |
| 27 | |
| 28 self["type"] = "seq_variation" | |
| 29 if(len(mismatch) > 7): | |
| 30 self["ref_id"] = ("{id}{deli}{pos}{deli}{var}" | |
| 31 .format(id=region_name, | |
| 32 pos=self["ref_start_pos"], | |
| 33 var=self["var_aa"], deli="_")) | |
| 34 else: | |
| 35 self["ref_id"] = ("{id}{deli}{pos}{deli}{var}" | |
| 36 .format(id=region_name, | |
| 37 pos=self["ref_start_pos"], | |
| 38 var=self["var_codon"], deli="_")) | |
| 39 self["key"] = self._get_unique_key() | |
| 40 self["seq_var"] = mut_string | |
| 41 | |
| 42 if(len(self["ref_codon"]) == 3): | |
| 43 self["codon_change"] = ("{}>{}".format(self["ref_codon"], | |
| 44 self["var_codon"])) | |
| 45 | |
| 46 db_key = DatabaseHandler.get_key(res_collection, db_name) | |
| 47 self["ref_database"] = db_key | |
| 48 | |
| 49 region_keys = [] | |
| 50 for result in region_results: | |
| 51 region_keys.append(result["key"]) | |
| 52 self["genes"] = region_keys | |
| 53 | |
| 54 def load_var_type(self, type): | |
| 55 self["substitution"] = False | |
| 56 self["deletion"] = False | |
| 57 self["insertion"] = False | |
| 58 if(type == "sub"): | |
| 59 self["substitution"] = True | |
| 60 elif(type == "ins"): | |
| 61 self["insertion"] = True | |
| 62 elif(type == "del"): | |
| 63 self["deletion"] = True | |
| 64 | |
| 65 def _get_unique_key(self, delimiter=";;"): | |
| 66 minimum_key = self["ref_id"] | |
| 67 unique_key = minimum_key | |
| 68 while(unique_key in self.res_collection["seq_variations"]): | |
| 69 rnd_str = GeneResult.randomString() | |
| 70 unique_key = ("{key}{deli}{rnd}" | |
| 71 .format(key=minimum_key, deli=delimiter, | |
| 72 rnd=rnd_str)) | |
| 73 | |
| 74 return unique_key | |
| 75 | |
| 76 | |
| 77 class GeneResult(dict): | |
| 78 def __init__(self, res_collection, res, db_name): | |
| 79 self.db_name = db_name | |
| 80 self["type"] = "gene" | |
| 81 | |
| 82 self["ref_id"] = res["sbjct_header"] | |
| 83 self["ref_id"] = PhenoDB.if_promoter_rename(self["ref_id"]) | |
| 84 | |
| 85 if(db_name == "ResFinder"): | |
| 86 self["name"], self.variant, self["ref_acc"] = ( | |
| 87 GeneResult._split_sbjct_header(self["ref_id"])) | |
| 88 elif(db_name == "PointFinder"): | |
| 89 self["name"] = self["ref_id"] | |
| 90 | |
| 91 self["ref_start_pos"] = res["sbjct_start"] | |
| 92 self["ref_end_pos"] = res["sbjct_end"] | |
| 93 self["identity"] = res["perc_ident"] | |
| 94 self["alignment_length"] = res["HSP_length"] | |
| 95 self["ref_gene_lenght"] = res["sbjct_length"] | |
| 96 self["query_id"] = res["contig_name"] | |
| 97 self["query_start_pos"] = res["query_start"] | |
| 98 self["query_end_pos"] = res["query_end"] | |
| 99 self["key"] = self._get_unique_gene_key(res_collection) | |
| 100 | |
| 101 # BLAST coverage formatted results | |
| 102 coverage = res.get("coverage", None) | |
| 103 if(coverage is None): | |
| 104 # KMA coverage formatted results | |
| 105 coverage = res["perc_coverage"] | |
| 106 else: | |
| 107 coverage = float(coverage) * 100 | |
| 108 self["coverage"] = coverage | |
| 109 | |
| 110 depth = res.get("depth", None) | |
| 111 if(depth is not None): | |
| 112 self["depth"] = depth | |
| 113 | |
| 114 db_key = DatabaseHandler.get_key(res_collection, db_name) | |
| 115 self["ref_database"] = db_key | |
| 116 self.remove_NAs() | |
| 117 | |
| 118 @staticmethod | |
| 119 def _split_sbjct_header(header): | |
| 120 sbjct = header.split("_") | |
| 121 template = sbjct[0] | |
| 122 | |
| 123 if(len(sbjct) > 1): | |
| 124 variant = sbjct[1] | |
| 125 acc = "_".join(sbjct[2:]) | |
| 126 else: | |
| 127 variant = None | |
| 128 acc = None | |
| 129 | |
| 130 return (template, variant, acc) | |
| 131 | |
| 132 def remove_NAs(self): | |
| 133 na_keys = [] | |
| 134 for key, val in self.items(): | |
| 135 if(val == "NA" or val is None): | |
| 136 na_keys.append(key) | |
| 137 for key in na_keys: | |
| 138 del self[key] | |
| 139 | |
| 140 def _get_unique_gene_key(self, res_collection, delimiter=";;"): | |
| 141 if(self.db_name == "ResFinder"): | |
| 142 gene_key = ("{name}{deli}{var}{deli}{ref_acc}" | |
| 143 .format(deli=delimiter, var=self.variant, **self)) | |
| 144 if(self.db_name == "PointFinder"): | |
| 145 gene_key = self["name"] | |
| 146 # Attach random string if key already exists | |
| 147 minimum_gene_key = gene_key | |
| 148 if gene_key in res_collection["genes"]: | |
| 149 if(self["query_id"] == "NA"): | |
| 150 gene_key = self.get_rnd_unique_gene_key( | |
| 151 gene_key, res_collection, minimum_gene_key, delimiter) | |
| 152 elif (self["query_id"] | |
| 153 != res_collection["genes"][gene_key]["query_id"] | |
| 154 or self["query_start_pos"] | |
| 155 != res_collection["genes"][gene_key]["query_start_pos"] | |
| 156 or self["query_end_pos"] | |
| 157 != res_collection["genes"][gene_key]["query_end_pos"]): | |
| 158 gene_key = self.get_rnd_unique_gene_key( | |
| 159 gene_key, res_collection, minimum_gene_key, delimiter) | |
| 160 | |
| 161 return gene_key | |
| 162 | |
| 163 def get_rnd_unique_gene_key(self, gene_key, res_collection, | |
| 164 minimum_gene_key, delimiter): | |
| 165 while(gene_key in res_collection["genes"]): | |
| 166 rnd_str = GeneResult.randomString() | |
| 167 gene_key = ("{key}{deli}{rnd}" | |
| 168 .format(key=minimum_gene_key, deli=delimiter, | |
| 169 rnd=rnd_str)) | |
| 170 return gene_key | |
| 171 | |
| 172 @staticmethod | |
| 173 def randomString(stringLength=4): | |
| 174 letters = string.ascii_lowercase | |
| 175 return ''.join(random.choice(letters) for i in range(stringLength)) | |
| 176 | |
| 177 | |
| 178 class PhenotypeResult(dict): | |
| 179 def __init__(self, antibiotic): | |
| 180 self["type"] = "phenotype" | |
| 181 self["category"] = "amr" | |
| 182 self["key"] = antibiotic.name | |
| 183 self["amr_classes"] = antibiotic.classes | |
| 184 self["resistance"] = antibiotic.name | |
| 185 self["resistant"] = False | |
| 186 | |
| 187 def set_resistant(self, res): | |
| 188 self["resistant"] = res | |
| 189 | |
| 190 def add_feature(self, res_collection, isolate, feature): | |
| 191 # Get all keys in the result that matches the feature in question. | |
| 192 # Most of the time this will be a one to one relationship. | |
| 193 # However if several identical features has been found in a sample, | |
| 194 # they will all have different keys, but identical ref ids. | |
| 195 | |
| 196 ref_id, type = PhenotypeResult.get_ref_id_and_type(feature, isolate) | |
| 197 feature_keys = PhenotypeResult.get_keys_matching_ref_id( | |
| 198 ref_id, res_collection[type]) | |
| 199 # Add keys to phenotype results | |
| 200 pheno_feat_keys = self.get(type, []) | |
| 201 pheno_feat_keys = pheno_feat_keys + feature_keys | |
| 202 self[type] = pheno_feat_keys | |
| 203 | |
| 204 # Add phenotype keys to feature results | |
| 205 features = res_collection[type] | |
| 206 for feat_key in feature_keys: | |
| 207 feat_result = features[feat_key] | |
| 208 pheno_keys = feat_result.get("phenotypes", []) | |
| 209 pheno_keys.append(self["key"]) | |
| 210 feat_result["phenotypes"] = pheno_keys | |
| 211 if(type == "genes"): | |
| 212 db_key = DatabaseHandler.get_key(res_collection, "ResFinder") | |
| 213 elif(type == "seq_variations"): | |
| 214 db_key = DatabaseHandler.get_key(res_collection, "PointFinder") | |
| 215 self["ref_database"] = db_key | |
| 216 | |
| 217 @staticmethod | |
| 218 def get_ref_id_and_type(feature, isolate): | |
| 219 type = None | |
| 220 ref_id = None | |
| 221 if(isinstance(feature, ResGene)): | |
| 222 type = "genes" | |
| 223 ref_id = isolate.resprofile.phenodb.id_to_idwithvar[ | |
| 224 feature.unique_id] | |
| 225 elif(isinstance(feature, ResMutation)): | |
| 226 type = "seq_variations" | |
| 227 ref_id = feature.unique_id | |
| 228 return (ref_id, type) | |
| 229 | |
| 230 @staticmethod | |
| 231 def get_keys_matching_ref_id(ref_id, res_collection): | |
| 232 out_keys = [] | |
| 233 for key, results in res_collection.items(): | |
| 234 if(ref_id == results["ref_id"]): | |
| 235 out_keys.append(key) | |
| 236 | |
| 237 return out_keys | |
| 238 | |
| 239 | |
| 240 class ResFinderResultHandler(): | |
| 241 | |
| 242 @staticmethod | |
| 243 def load_res_profile(res_collection, isolate): | |
| 244 # For each antibiotic class | |
| 245 for ab_class in isolate.resprofile.phenodb.antibiotics.keys(): | |
| 246 # For each antibiotic in current class | |
| 247 for phenodb_ab in isolate.resprofile.phenodb.antibiotics[ab_class]: | |
| 248 | |
| 249 phenotype = PhenotypeResult(phenodb_ab) | |
| 250 # Isolate is resistant towards the antibiotic | |
| 251 if(phenodb_ab in isolate.resprofile.resistance): | |
| 252 phenotype.set_resistant(True) | |
| 253 | |
| 254 isolate_ab = isolate.resprofile.resistance[phenodb_ab] | |
| 255 for unique_id, feature in isolate_ab.features.items(): | |
| 256 if(isinstance(feature, ResGene)): | |
| 257 phenotype.add_feature(res_collection, isolate, | |
| 258 feature) | |
| 259 res_collection.add_class(cl="phenotypes", **phenotype) | |
| 260 | |
| 261 @staticmethod | |
| 262 def standardize_results(res_collection, res, ref_db_name): | |
| 263 for db_name, db in res.items(): | |
| 264 if(db_name == "excluded"): | |
| 265 continue | |
| 266 | |
| 267 if(db == "No hit found"): | |
| 268 continue | |
| 269 | |
| 270 for unique_id, hit_db in db.items(): | |
| 271 if(unique_id in res["excluded"]): | |
| 272 continue | |
| 273 gene_result = GeneResult(res_collection, hit_db, ref_db_name) | |
| 274 if gene_result["key"] in res_collection["genes"]: | |
| 275 res_collection.modify_class(cl="genes", **gene_result) | |
| 276 else: | |
| 277 res_collection.add_class(cl="genes", **gene_result) | |
| 278 | |
| 279 | |
| 280 class DatabaseHandler(): | |
| 281 | |
| 282 @staticmethod | |
| 283 def load_database_metadata(name, res_collection, db_dir): | |
| 284 database_metadata = {} | |
| 285 database_metadata["type"] = "database" | |
| 286 database_metadata["database_name"] = name | |
| 287 | |
| 288 version, commit = Generator.get_version_commit(db_dir) | |
| 289 database_metadata["database_version"] = version | |
| 290 database_metadata["key"] = "{}-{}".format(name, version) | |
| 291 database_metadata["database_commit"] = commit | |
| 292 | |
| 293 res_collection.add_class(cl="databases", **database_metadata) | |
| 294 | |
| 295 @staticmethod | |
| 296 def get_key(res_collection, name): | |
| 297 for key, val in res_collection["databases"].items(): | |
| 298 if(val["database_name"] == name): | |
| 299 return key | |
| 300 | |
| 301 | |
| 302 class PointFinderResultHandler(): | |
| 303 | |
| 304 @staticmethod | |
| 305 def load_res_profile(res_collection, isolate): | |
| 306 # For each antibiotic class | |
| 307 for ab_class in isolate.resprofile.phenodb.antibiotics.keys(): | |
| 308 # For each antibiotic in current class | |
| 309 for phenodb_ab in isolate.resprofile.phenodb.antibiotics[ab_class]: | |
| 310 | |
| 311 phenotype = PhenotypeResult(phenodb_ab) | |
| 312 # Isolate is resistant towards the antibiotic | |
| 313 if(phenodb_ab in isolate.resprofile.resistance): | |
| 314 phenotype.set_resistant(True) | |
| 315 | |
| 316 isolate_ab = isolate.resprofile.resistance[phenodb_ab] | |
| 317 for unique_id, feature in isolate_ab.features.items(): | |
| 318 if(isinstance(feature, ResMutation)): | |
| 319 phenotype.add_feature(res_collection, isolate, | |
| 320 feature) | |
| 321 res_collection.add_class(cl="phenotypes", **phenotype) | |
| 322 | |
| 323 @staticmethod | |
| 324 def standardize_results(res_collection, res, ref_db_name): | |
| 325 for gene_name, db in res.items(): | |
| 326 if(gene_name == "excluded"): | |
| 327 continue | |
| 328 | |
| 329 if(db == "No hit found"): | |
| 330 continue | |
| 331 | |
| 332 ###Added to solve current PointFinder | |
| 333 if gene_name in res["excluded"]: | |
| 334 continue | |
| 335 if(isinstance(db, str)): | |
| 336 if db.startswith("Gene found with coverage"): | |
| 337 continue | |
| 338 ##### ##### | |
| 339 | |
| 340 gene_results = [] | |
| 341 | |
| 342 # For BLAST results | |
| 343 db_hits = db.get("hits", {}) | |
| 344 | |
| 345 # For KMA results | |
| 346 if(not db_hits): | |
| 347 id = db["sbjct_header"] | |
| 348 db_hits[id] = db | |
| 349 | |
| 350 for unique_id, hit_db in db_hits.items(): | |
| 351 if(unique_id in res["excluded"]): | |
| 352 continue | |
| 353 | |
| 354 gene_result = GeneResult(res_collection, hit_db, ref_db_name) | |
| 355 res_collection.add_class(cl="genes", **gene_result) | |
| 356 gene_results.append(gene_result) | |
| 357 | |
| 358 mismatches = db["mis_matches"] | |
| 359 | |
| 360 #DEBUG | |
| 361 for mismatch in mismatches: | |
| 362 seq_var_result = SeqVariationResult( | |
| 363 res_collection, mismatch, gene_results, ref_db_name) | |
| 364 res_collection.add_class(cl="seq_variations", **seq_var_result) |
