Mercurial > repos > adrian.diaz > b2btools_single_sequence
comparison script.py @ 2:a9db23ac113f draft default tip
Uploaded new version formatted.
author | adrian.diaz |
---|---|
date | Tue, 02 Aug 2022 09:44:33 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
1:891ccfd22633 | 2:a9db23ac113f |
---|---|
1 import optparse | |
2 import os.path | |
3 import unicodedata | |
4 import re | |
5 import numpy as np | |
6 import pandas as pd | |
7 from b2bTools import SingleSeq | |
8 import matplotlib.pyplot as plt | |
9 | |
10 | |
11 def slugify(value): | |
12 """ | |
13 Taken from | |
14 https://github.com/django/django/blob/master/django/utils/text.py | |
15 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated | |
16 dashes to single dashes. Remove characters that aren't alphanumerics, | |
17 underscores, or hyphens. Convert to lowercase. Also strip leading and | |
18 trailing whitespace, dashes, and underscores. | |
19 """ | |
20 value = str(value) | |
21 value = unicodedata.normalize( | |
22 'NFKD', value).encode( | |
23 'ascii', 'ignore').decode('ascii') | |
24 value = re.sub(r'[^\w\s-]', '', value.lower()) | |
25 return re.sub(r'[-\s]+', '-', value).strip('-_') | |
26 | |
27 | |
28 def check_min_max(predicted_values, former_min, former_max): | |
29 seq_max = max(predicted_values) | |
30 seq_min = min(predicted_values) | |
31 if seq_max + \ | |
32 0.1 > former_max and not np.isnan(seq_max) \ | |
33 and not np.isinf(seq_max): | |
34 former_max = seq_max + 0.1 | |
35 if seq_min - \ | |
36 0.1 < former_min and not np.isnan(seq_min) \ | |
37 and not np.isinf(seq_min): | |
38 former_min = seq_min - 0.1 | |
39 return former_min, former_max | |
40 | |
41 | |
42 def plot_prediction(prediction_name, highlighting_regions, | |
43 pred_vals, seq_name): | |
44 thresholds_dict = {'backbone': {'membrane spanning': [1., 1.5], | |
45 'rigid': [0.8, 1.], | |
46 'context-dependent': [0.69, 0.8], | |
47 'flexible': [-1.0, 0.69]}, | |
48 'earlyFolding': {'early folds': [0.169, 2.], | |
49 'late folds': [-1., 0.169]}, | |
50 'disoMine': {'ordered': [-1., 0.5], | |
51 'disordered': [0.5, 2.]}, | |
52 } | |
53 ordered_regions_dict = {'backbone': ['flexible', | |
54 'context-dependent', | |
55 'rigid', | |
56 'membrane spanning'], | |
57 'earlyFolding': ['late folds', 'early folds'], | |
58 'disoMine': ['ordered', 'disordered'], | |
59 } | |
60 colors = ['yellow', 'orange', 'pink', 'red'] | |
61 ranges_dict = { | |
62 'backbone': [-0.2, 1.2], | |
63 'sidechain': [-0.2, 1.2], | |
64 'ppII': [-0.2, 1.2], | |
65 'earlyFolding': [-0.2, 1.2], | |
66 'disoMine': [-0.2, 1.2], | |
67 'agmata': [-0.2, 1.2], | |
68 'helix': [-1., 1.], | |
69 'sheet': [-1., 1.], | |
70 'coil': [-1., 1.], | |
71 } | |
72 fig, ax = plt.subplots(1, 1) | |
73 fig.set_figwidth(10) | |
74 fig.set_figheight(5) | |
75 ax.set_title(prediction_name + ' ' + 'prediction') | |
76 min_value, max_value = ranges_dict[prediction_name] | |
77 if seq_name == 'all': | |
78 max_len = 0 | |
79 for seq in pred_vals.keys(): | |
80 predictions = pred_vals[seq] | |
81 min_value, max_value = check_min_max( | |
82 predictions, min_value, max_value) | |
83 ax.plot(range(len(predictions)), predictions, label=seq) | |
84 if len(predictions) > max_len: | |
85 max_len = len(predictions) | |
86 ax.set_xlim([0, max_len - 1]) | |
87 else: | |
88 predictions = pred_vals | |
89 min_value, max_value = check_min_max(predictions, min_value, max_value) | |
90 ax.plot(range(len(predictions)), predictions, label=seq_name) | |
91 ax.set_xlim([0, len(predictions) - 1]) | |
92 legend_lines = plt.legend( | |
93 bbox_to_anchor=( | |
94 1.04, | |
95 1), | |
96 loc="upper left", | |
97 fancybox=True, | |
98 shadow=True) | |
99 ax.add_artist(legend_lines) | |
100 # Define regions | |
101 if highlighting_regions: | |
102 if prediction_name in ordered_regions_dict.keys(): | |
103 for i, prediction in enumerate( | |
104 ordered_regions_dict[prediction_name]): | |
105 lower = thresholds_dict[prediction_name][prediction][0] | |
106 upper = thresholds_dict[prediction_name][prediction][1] | |
107 color = colors[i] | |
108 ax.axhspan( | |
109 lower, | |
110 upper, | |
111 alpha=0.3, | |
112 color=color, | |
113 label=prediction) | |
114 # to sort it "from up to low" | |
115 included_in_regions_legend = list(reversed( | |
116 [r_pred for r_pred in ordered_regions_dict[prediction_name]])) | |
117 # Get handles and labels | |
118 handles, labels = plt.gca().get_legend_handles_labels() | |
119 handles_dict = {label: handles[idx] | |
120 for idx, label in enumerate(labels)} | |
121 # Add legend for regions, if available | |
122 lgnd_labels = [handles_dict[r] for r in included_in_regions_legend] | |
123 lgnd_regions = [region for region in included_in_regions_legend] | |
124 region_legend = ax.legend(lgnd_labels, | |
125 lgnd_regions, | |
126 fancybox=True, | |
127 shadow=True, | |
128 loc='lower left', | |
129 bbox_to_anchor=(1.04, 0)) | |
130 ax.add_artist(region_legend) | |
131 ax.set_ylim([min_value, max_value]) | |
132 ax.set_xlabel('residue index') | |
133 ax.set_ylabel('prediction values') | |
134 ax.grid(axis='y') | |
135 plt.savefig( | |
136 os.path.join( | |
137 options.plot_output, | |
138 "{0}_{1}.png".format( | |
139 slugify(seq_name), | |
140 prediction_name)), | |
141 bbox_inches="tight") | |
142 plt.close() | |
143 | |
144 | |
145 def df_dict_to_dict_of_values(df_dict, predictor): | |
146 results_dict = {} | |
147 for seq in df_dict.keys(): | |
148 df = pd.read_csv(df_dict[seq], sep='\t') | |
149 results_dict[seq] = df[predictor] | |
150 return results_dict | |
151 | |
152 | |
153 def main(options): | |
154 single_seq = SingleSeq(options.input_fasta) | |
155 b2b_tools = [] | |
156 if options.dynamine: | |
157 b2b_tools.append('dynamine') | |
158 if options.disomine: | |
159 b2b_tools.append('disomine') | |
160 if options.efoldmine: | |
161 b2b_tools.append('efoldmine') | |
162 if options.agmata: | |
163 b2b_tools.append('agmata') | |
164 | |
165 single_seq.predict(b2b_tools) | |
166 predictions = single_seq.get_all_predictions() | |
167 results_json = single_seq.get_all_predictions_json('all') | |
168 with open(options.json_output, 'w') as f: | |
169 f.write(results_json) | |
170 first_sequence_key = next(iter(predictions)) | |
171 prediction_keys = predictions[first_sequence_key].keys() | |
172 df_dictionary = {} | |
173 for sequence_key, sequence_predictions in predictions.items(): | |
174 residues = sequence_predictions['seq'] | |
175 residues_count = len(residues) | |
176 sequence_df = pd.DataFrame( | |
177 columns=prediction_keys, | |
178 index=range(residues_count)) | |
179 sequence_df.index.name = 'residue_index' | |
180 for predictor in prediction_keys: | |
181 sequence_df[predictor] = sequence_predictions[predictor] | |
182 sequence_df = sequence_df.rename(columns={"seq": "residue"}) | |
183 sequence_df = sequence_df.round(decimals=2) | |
184 filename = f'{options.output}/{slugify(sequence_key)}.tsv' | |
185 df_dictionary[sequence_key] = filename | |
186 sequence_df.to_csv(filename, sep="\t") | |
187 # Plot each individual plot (compatible with plot all) | |
188 if options.plot: | |
189 for predictor in prediction_keys: | |
190 if predictor != 'seq': | |
191 plot_prediction(prediction_name=predictor, | |
192 highlighting_regions=True, | |
193 pred_vals=sequence_predictions[predictor], | |
194 seq_name=sequence_key) | |
195 # Plot all together (compatible with plot individual) | |
196 if options.plot_all: | |
197 for predictor in prediction_keys: | |
198 if predictor != 'seq': | |
199 results_dictionary = df_dict_to_dict_of_values( | |
200 df_dict=df_dictionary, predictor=predictor) | |
201 plot_prediction(prediction_name=predictor, | |
202 highlighting_regions=True, | |
203 pred_vals=results_dictionary, | |
204 seq_name='all') | |
205 | |
206 | |
207 if __name__ == "__main__": | |
208 parser = optparse.OptionParser() | |
209 parser.add_option("--dynamine", action="store_true", default=False) | |
210 parser.add_option("--disomine", action="store_true", default=False) | |
211 parser.add_option("--efoldmine", action="store_true", default=False) | |
212 parser.add_option("--agmata", action="store_true", default=False) | |
213 parser.add_option("--file", dest="input_fasta", default=False) | |
214 parser.add_option("--output", dest="output", default=False) | |
215 parser.add_option("--plot-output", dest="plot_output", default=False) | |
216 | |
217 parser.add_option("--json", dest="json_output", default=False) | |
218 parser.add_option("--plot", action="store_true", default=False) | |
219 parser.add_option("--plot_all", action="store_true", default=False) | |
220 parser.add_option("--highlight", action="store_true", default=False) | |
221 options, _args = parser.parse_args() | |
222 main(options) |