Mercurial > repos > adrian.diaz > b2btools_single_sequence
comparison script.py @ 2:a9db23ac113f draft default tip
Uploaded new version formatted.
| author | adrian.diaz |
|---|---|
| date | Tue, 02 Aug 2022 09:44:33 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| 1:891ccfd22633 | 2:a9db23ac113f |
|---|---|
| 1 import optparse | |
| 2 import os.path | |
| 3 import unicodedata | |
| 4 import re | |
| 5 import numpy as np | |
| 6 import pandas as pd | |
| 7 from b2bTools import SingleSeq | |
| 8 import matplotlib.pyplot as plt | |
| 9 | |
| 10 | |
| 11 def slugify(value): | |
| 12 """ | |
| 13 Taken from | |
| 14 https://github.com/django/django/blob/master/django/utils/text.py | |
| 15 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated | |
| 16 dashes to single dashes. Remove characters that aren't alphanumerics, | |
| 17 underscores, or hyphens. Convert to lowercase. Also strip leading and | |
| 18 trailing whitespace, dashes, and underscores. | |
| 19 """ | |
| 20 value = str(value) | |
| 21 value = unicodedata.normalize( | |
| 22 'NFKD', value).encode( | |
| 23 'ascii', 'ignore').decode('ascii') | |
| 24 value = re.sub(r'[^\w\s-]', '', value.lower()) | |
| 25 return re.sub(r'[-\s]+', '-', value).strip('-_') | |
| 26 | |
| 27 | |
| 28 def check_min_max(predicted_values, former_min, former_max): | |
| 29 seq_max = max(predicted_values) | |
| 30 seq_min = min(predicted_values) | |
| 31 if seq_max + \ | |
| 32 0.1 > former_max and not np.isnan(seq_max) \ | |
| 33 and not np.isinf(seq_max): | |
| 34 former_max = seq_max + 0.1 | |
| 35 if seq_min - \ | |
| 36 0.1 < former_min and not np.isnan(seq_min) \ | |
| 37 and not np.isinf(seq_min): | |
| 38 former_min = seq_min - 0.1 | |
| 39 return former_min, former_max | |
| 40 | |
| 41 | |
| 42 def plot_prediction(prediction_name, highlighting_regions, | |
| 43 pred_vals, seq_name): | |
| 44 thresholds_dict = {'backbone': {'membrane spanning': [1., 1.5], | |
| 45 'rigid': [0.8, 1.], | |
| 46 'context-dependent': [0.69, 0.8], | |
| 47 'flexible': [-1.0, 0.69]}, | |
| 48 'earlyFolding': {'early folds': [0.169, 2.], | |
| 49 'late folds': [-1., 0.169]}, | |
| 50 'disoMine': {'ordered': [-1., 0.5], | |
| 51 'disordered': [0.5, 2.]}, | |
| 52 } | |
| 53 ordered_regions_dict = {'backbone': ['flexible', | |
| 54 'context-dependent', | |
| 55 'rigid', | |
| 56 'membrane spanning'], | |
| 57 'earlyFolding': ['late folds', 'early folds'], | |
| 58 'disoMine': ['ordered', 'disordered'], | |
| 59 } | |
| 60 colors = ['yellow', 'orange', 'pink', 'red'] | |
| 61 ranges_dict = { | |
| 62 'backbone': [-0.2, 1.2], | |
| 63 'sidechain': [-0.2, 1.2], | |
| 64 'ppII': [-0.2, 1.2], | |
| 65 'earlyFolding': [-0.2, 1.2], | |
| 66 'disoMine': [-0.2, 1.2], | |
| 67 'agmata': [-0.2, 1.2], | |
| 68 'helix': [-1., 1.], | |
| 69 'sheet': [-1., 1.], | |
| 70 'coil': [-1., 1.], | |
| 71 } | |
| 72 fig, ax = plt.subplots(1, 1) | |
| 73 fig.set_figwidth(10) | |
| 74 fig.set_figheight(5) | |
| 75 ax.set_title(prediction_name + ' ' + 'prediction') | |
| 76 min_value, max_value = ranges_dict[prediction_name] | |
| 77 if seq_name == 'all': | |
| 78 max_len = 0 | |
| 79 for seq in pred_vals.keys(): | |
| 80 predictions = pred_vals[seq] | |
| 81 min_value, max_value = check_min_max( | |
| 82 predictions, min_value, max_value) | |
| 83 ax.plot(range(len(predictions)), predictions, label=seq) | |
| 84 if len(predictions) > max_len: | |
| 85 max_len = len(predictions) | |
| 86 ax.set_xlim([0, max_len - 1]) | |
| 87 else: | |
| 88 predictions = pred_vals | |
| 89 min_value, max_value = check_min_max(predictions, min_value, max_value) | |
| 90 ax.plot(range(len(predictions)), predictions, label=seq_name) | |
| 91 ax.set_xlim([0, len(predictions) - 1]) | |
| 92 legend_lines = plt.legend( | |
| 93 bbox_to_anchor=( | |
| 94 1.04, | |
| 95 1), | |
| 96 loc="upper left", | |
| 97 fancybox=True, | |
| 98 shadow=True) | |
| 99 ax.add_artist(legend_lines) | |
| 100 # Define regions | |
| 101 if highlighting_regions: | |
| 102 if prediction_name in ordered_regions_dict.keys(): | |
| 103 for i, prediction in enumerate( | |
| 104 ordered_regions_dict[prediction_name]): | |
| 105 lower = thresholds_dict[prediction_name][prediction][0] | |
| 106 upper = thresholds_dict[prediction_name][prediction][1] | |
| 107 color = colors[i] | |
| 108 ax.axhspan( | |
| 109 lower, | |
| 110 upper, | |
| 111 alpha=0.3, | |
| 112 color=color, | |
| 113 label=prediction) | |
| 114 # to sort it "from up to low" | |
| 115 included_in_regions_legend = list(reversed( | |
| 116 [r_pred for r_pred in ordered_regions_dict[prediction_name]])) | |
| 117 # Get handles and labels | |
| 118 handles, labels = plt.gca().get_legend_handles_labels() | |
| 119 handles_dict = {label: handles[idx] | |
| 120 for idx, label in enumerate(labels)} | |
| 121 # Add legend for regions, if available | |
| 122 lgnd_labels = [handles_dict[r] for r in included_in_regions_legend] | |
| 123 lgnd_regions = [region for region in included_in_regions_legend] | |
| 124 region_legend = ax.legend(lgnd_labels, | |
| 125 lgnd_regions, | |
| 126 fancybox=True, | |
| 127 shadow=True, | |
| 128 loc='lower left', | |
| 129 bbox_to_anchor=(1.04, 0)) | |
| 130 ax.add_artist(region_legend) | |
| 131 ax.set_ylim([min_value, max_value]) | |
| 132 ax.set_xlabel('residue index') | |
| 133 ax.set_ylabel('prediction values') | |
| 134 ax.grid(axis='y') | |
| 135 plt.savefig( | |
| 136 os.path.join( | |
| 137 options.plot_output, | |
| 138 "{0}_{1}.png".format( | |
| 139 slugify(seq_name), | |
| 140 prediction_name)), | |
| 141 bbox_inches="tight") | |
| 142 plt.close() | |
| 143 | |
| 144 | |
| 145 def df_dict_to_dict_of_values(df_dict, predictor): | |
| 146 results_dict = {} | |
| 147 for seq in df_dict.keys(): | |
| 148 df = pd.read_csv(df_dict[seq], sep='\t') | |
| 149 results_dict[seq] = df[predictor] | |
| 150 return results_dict | |
| 151 | |
| 152 | |
| 153 def main(options): | |
| 154 single_seq = SingleSeq(options.input_fasta) | |
| 155 b2b_tools = [] | |
| 156 if options.dynamine: | |
| 157 b2b_tools.append('dynamine') | |
| 158 if options.disomine: | |
| 159 b2b_tools.append('disomine') | |
| 160 if options.efoldmine: | |
| 161 b2b_tools.append('efoldmine') | |
| 162 if options.agmata: | |
| 163 b2b_tools.append('agmata') | |
| 164 | |
| 165 single_seq.predict(b2b_tools) | |
| 166 predictions = single_seq.get_all_predictions() | |
| 167 results_json = single_seq.get_all_predictions_json('all') | |
| 168 with open(options.json_output, 'w') as f: | |
| 169 f.write(results_json) | |
| 170 first_sequence_key = next(iter(predictions)) | |
| 171 prediction_keys = predictions[first_sequence_key].keys() | |
| 172 df_dictionary = {} | |
| 173 for sequence_key, sequence_predictions in predictions.items(): | |
| 174 residues = sequence_predictions['seq'] | |
| 175 residues_count = len(residues) | |
| 176 sequence_df = pd.DataFrame( | |
| 177 columns=prediction_keys, | |
| 178 index=range(residues_count)) | |
| 179 sequence_df.index.name = 'residue_index' | |
| 180 for predictor in prediction_keys: | |
| 181 sequence_df[predictor] = sequence_predictions[predictor] | |
| 182 sequence_df = sequence_df.rename(columns={"seq": "residue"}) | |
| 183 sequence_df = sequence_df.round(decimals=2) | |
| 184 filename = f'{options.output}/{slugify(sequence_key)}.tsv' | |
| 185 df_dictionary[sequence_key] = filename | |
| 186 sequence_df.to_csv(filename, sep="\t") | |
| 187 # Plot each individual plot (compatible with plot all) | |
| 188 if options.plot: | |
| 189 for predictor in prediction_keys: | |
| 190 if predictor != 'seq': | |
| 191 plot_prediction(prediction_name=predictor, | |
| 192 highlighting_regions=True, | |
| 193 pred_vals=sequence_predictions[predictor], | |
| 194 seq_name=sequence_key) | |
| 195 # Plot all together (compatible with plot individual) | |
| 196 if options.plot_all: | |
| 197 for predictor in prediction_keys: | |
| 198 if predictor != 'seq': | |
| 199 results_dictionary = df_dict_to_dict_of_values( | |
| 200 df_dict=df_dictionary, predictor=predictor) | |
| 201 plot_prediction(prediction_name=predictor, | |
| 202 highlighting_regions=True, | |
| 203 pred_vals=results_dictionary, | |
| 204 seq_name='all') | |
| 205 | |
| 206 | |
| 207 if __name__ == "__main__": | |
| 208 parser = optparse.OptionParser() | |
| 209 parser.add_option("--dynamine", action="store_true", default=False) | |
| 210 parser.add_option("--disomine", action="store_true", default=False) | |
| 211 parser.add_option("--efoldmine", action="store_true", default=False) | |
| 212 parser.add_option("--agmata", action="store_true", default=False) | |
| 213 parser.add_option("--file", dest="input_fasta", default=False) | |
| 214 parser.add_option("--output", dest="output", default=False) | |
| 215 parser.add_option("--plot-output", dest="plot_output", default=False) | |
| 216 | |
| 217 parser.add_option("--json", dest="json_output", default=False) | |
| 218 parser.add_option("--plot", action="store_true", default=False) | |
| 219 parser.add_option("--plot_all", action="store_true", default=False) | |
| 220 parser.add_option("--highlight", action="store_true", default=False) | |
| 221 options, _args = parser.parse_args() | |
| 222 main(options) |
