Mercurial > repos > iuc > b2btools_single_sequence
comparison script.py @ 0:b694a77ca1e8 draft default tip
planemo upload commit 599e1135baba020195b3f7576449d595bca9af75
author | iuc |
---|---|
date | Tue, 09 Aug 2022 12:30:52 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:b694a77ca1e8 |
---|---|
1 import json | |
2 import optparse | |
3 import os.path | |
4 import re | |
5 import unicodedata | |
6 | |
7 import matplotlib.pyplot as plt | |
8 import numpy as np | |
9 import pandas as pd | |
10 from b2bTools import SingleSeq | |
11 | |
12 | |
13 def slugify(value): | |
14 """ | |
15 From https://github.com/django/django/blob/master/django/utils/text.py | |
16 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated | |
17 dashes to single dashes. Remove characters that aren't alphanumerics, | |
18 underscores, or hyphens. Convert to lowercase. Also strip leading and | |
19 trailing whitespace, dashes, and underscores. | |
20 """ | |
21 value = str(value) | |
22 value = ( | |
23 unicodedata.normalize("NFKD", value) | |
24 .encode("ascii", "ignore") | |
25 .decode("ascii") | |
26 ) | |
27 value = re.sub(r"[^\w\s-]", "", value.lower()) | |
28 return re.sub(r"[-\s]+", "-", value).strip("-_") | |
29 | |
30 | |
31 def check_min_max(predicted_values, former_min, former_max): | |
32 seq_max = max(predicted_values) | |
33 seq_min = min(predicted_values) | |
34 if ( | |
35 seq_max + 0.1 > former_max | |
36 and not np.isnan(seq_max) | |
37 and not np.isinf(seq_max) | |
38 ): | |
39 former_max = seq_max + 0.1 | |
40 if ( | |
41 seq_min - 0.1 < former_min | |
42 and not np.isnan(seq_min) | |
43 and not np.isinf(seq_min) | |
44 ): | |
45 former_min = seq_min - 0.1 | |
46 return former_min, former_max | |
47 | |
48 | |
49 def plot_prediction(pred_name, hlighting_regions, predicted_values, seq_name): | |
50 thresholds_dict = { | |
51 "backbone": { | |
52 "membrane spanning": [1.0, 1.5], | |
53 "rigid": [0.8, 1.0], | |
54 "context-dependent": [0.69, 0.8], | |
55 "flexible": [-1.0, 0.69], | |
56 }, | |
57 "earlyFolding": { | |
58 "early folds": [0.169, 2.0], | |
59 "late folds": [-1.0, 0.169], | |
60 }, | |
61 "disoMine": {"ordered": [-1.0, 0.5], "disordered": [0.5, 2.0]}, | |
62 } | |
63 ordered_regions_dict = { | |
64 "backbone": [ | |
65 "flexible", | |
66 "context-dependent", | |
67 "rigid", | |
68 "membrane spanning", | |
69 ], | |
70 "earlyFolding": ["late folds", "early folds"], | |
71 "disoMine": ["ordered", "disordered"], | |
72 } | |
73 colors = ["yellow", "orange", "pink", "red"] | |
74 ranges_dict = { | |
75 "backbone": [-0.2, 1.2], | |
76 "sidechain": [-0.2, 1.2], | |
77 "ppII": [-0.2, 1.2], | |
78 "earlyFolding": [-0.2, 1.2], | |
79 "disoMine": [-0.2, 1.2], | |
80 "agmata": [-0.2, 1.2], | |
81 "helix": [-1.0, 1.0], | |
82 "sheet": [-1.0, 1.0], | |
83 "coil": [-1.0, 1.0], | |
84 } | |
85 fig, ax = plt.subplots(1, 1) | |
86 fig.set_figwidth(10) | |
87 fig.set_figheight(5) | |
88 ax.set_title(pred_name + " " + "prediction") | |
89 min_value, max_value = ranges_dict[pred_name] | |
90 if seq_name == "all": | |
91 max_len = 0 | |
92 for seq in predicted_values.keys(): | |
93 predictions = predicted_values[seq] | |
94 min_value, max_value = check_min_max( | |
95 predictions, min_value, max_value | |
96 ) | |
97 ax.plot(range(len(predictions)), predictions, label=seq) | |
98 if len(predictions) > max_len: | |
99 max_len = len(predictions) | |
100 ax.set_xlim([0, max_len - 1]) | |
101 else: | |
102 predictions = predicted_values | |
103 min_value, max_value = check_min_max(predictions, min_value, max_value) | |
104 ax.plot(range(len(predictions)), predictions, label=seq_name) | |
105 ax.set_xlim([0, len(predictions) - 1]) | |
106 legend_lines = plt.legend( | |
107 bbox_to_anchor=(1.04, 1), loc="upper left", fancybox=True, shadow=True | |
108 ) | |
109 ax.add_artist(legend_lines) | |
110 # Define regions | |
111 if hlighting_regions: | |
112 if pred_name in ordered_regions_dict.keys(): | |
113 for i, prediction in enumerate(ordered_regions_dict[pred_name]): | |
114 lower = thresholds_dict[pred_name][prediction][0] | |
115 upper = thresholds_dict[pred_name][prediction][1] | |
116 color = colors[i] | |
117 ax.axhspan( | |
118 lower, upper, alpha=0.3, color=color, label=prediction | |
119 ) | |
120 included_in_regions_legend = list( | |
121 reversed( | |
122 [ | |
123 prediction | |
124 for prediction in ordered_regions_dict[pred_name] | |
125 ] | |
126 ) | |
127 ) # to sort it "from up to low" | |
128 # Get handles and labels | |
129 handles, labels = plt.gca().get_legend_handles_labels() | |
130 handles_dict = { | |
131 label: handles[idx] for idx, label in enumerate(labels) | |
132 } | |
133 # Add legend for regions, if available | |
134 region_legend = ax.legend( | |
135 [ | |
136 handles_dict[region] | |
137 for region in included_in_regions_legend | |
138 ], | |
139 [region for region in included_in_regions_legend], | |
140 fancybox=True, | |
141 shadow=True, | |
142 loc="lower left", | |
143 bbox_to_anchor=(1.04, 0), | |
144 ) | |
145 ax.add_artist(region_legend) | |
146 ax.set_ylim([min_value, max_value]) | |
147 ax.set_xlabel("residue index") | |
148 ax.set_ylabel("prediction values") | |
149 ax.grid(axis="y") | |
150 plt.savefig( | |
151 os.path.join( | |
152 options.plot_output, | |
153 "{0}_{1}.png".format(slugify(seq_name), pred_name), | |
154 ), | |
155 bbox_inches="tight", | |
156 ) | |
157 plt.close() | |
158 | |
159 | |
160 def df_dict_to_dict_of_values(df_dict, predictor): | |
161 results_dict = {} | |
162 for seq in df_dict.keys(): | |
163 df = pd.read_csv(df_dict[seq], sep="\t") | |
164 results_dict[seq] = df[predictor] | |
165 return results_dict | |
166 | |
167 | |
168 def main(options): | |
169 single_seq = SingleSeq(options.input_fasta) | |
170 b2b_tools = [] | |
171 if options.dynamine: | |
172 b2b_tools.append("dynamine") | |
173 if options.disomine: | |
174 b2b_tools.append("disomine") | |
175 if options.efoldmine: | |
176 b2b_tools.append("efoldmine") | |
177 if options.agmata: | |
178 b2b_tools.append("agmata") | |
179 single_seq.predict(b2b_tools) | |
180 predictions = single_seq.get_all_predictions() | |
181 | |
182 def rounder_function(value): | |
183 return round(float(value), 3) | |
184 | |
185 rounded_predictions = json.loads( | |
186 json.dumps(predictions), parse_float=rounder_function | |
187 ) | |
188 results_json = json.dumps(rounded_predictions, indent=2, sort_keys=True) | |
189 with open(options.json_output, "w") as f: | |
190 f.write(results_json) | |
191 first_sequence_key = next(iter(predictions)) | |
192 prediction_keys = predictions[first_sequence_key].keys() | |
193 # Sort column names | |
194 tsv_column_names = list(prediction_keys) | |
195 tsv_column_names.remove("seq") | |
196 tsv_column_names = ['residue', *sorted(tsv_column_names)] | |
197 | |
198 df_dictionary = {} | |
199 for sequence_key, seq_preds in predictions.items(): | |
200 residues = seq_preds["seq"] | |
201 residues_count = len(residues) | |
202 sequence_df = pd.DataFrame( | |
203 columns=prediction_keys, index=range(residues_count) | |
204 ) | |
205 sequence_df.index.name = "residue_index" | |
206 for predictor in prediction_keys: | |
207 sequence_df[predictor] = seq_preds[predictor] | |
208 sequence_df = sequence_df.rename(columns={"seq": "residue"}) | |
209 sequence_df = sequence_df.round(decimals=3) | |
210 filename = f"{options.output}/{slugify(sequence_key)}.tsv" | |
211 df_dictionary[sequence_key] = filename | |
212 sequence_df.to_csv( | |
213 filename, | |
214 header=True, | |
215 columns=tsv_column_names, | |
216 sep="\t" | |
217 ) | |
218 # Plot each individual plot (compatible with plot all) | |
219 if options.plot: | |
220 for predictor in prediction_keys: | |
221 if predictor != "seq": | |
222 plot_prediction( | |
223 pred_name=predictor, | |
224 hlighting_regions=True, | |
225 predicted_values=seq_preds[predictor], | |
226 seq_name=sequence_key, | |
227 ) | |
228 # Plot all together (compatible with plot individual) | |
229 if options.plot_all: | |
230 for predictor in prediction_keys: | |
231 if predictor != "seq": | |
232 results_dictionary = df_dict_to_dict_of_values( | |
233 df_dictionary, predictor | |
234 ) | |
235 plot_prediction( | |
236 pred_name=predictor, | |
237 hlighting_regions=True, | |
238 predicted_values=results_dictionary, | |
239 seq_name="all", | |
240 ) | |
241 | |
242 | |
243 if __name__ == "__main__": | |
244 parser = optparse.OptionParser() | |
245 parser.add_option( | |
246 "--dynamine", | |
247 action="store_true" | |
248 ) | |
249 parser.add_option( | |
250 "--disomine", | |
251 action="store_true" | |
252 ) | |
253 parser.add_option( | |
254 "--efoldmine", | |
255 action="store_true" | |
256 ) | |
257 parser.add_option( | |
258 "--agmata", | |
259 action="store_true" | |
260 ) | |
261 parser.add_option( | |
262 "--file", | |
263 dest="input_fasta", | |
264 type="string" | |
265 ) | |
266 parser.add_option( | |
267 "--output", | |
268 dest="output", | |
269 type="string" | |
270 ) | |
271 parser.add_option( | |
272 "--plot-output", | |
273 type="string", | |
274 dest="plot_output" | |
275 ) | |
276 parser.add_option( | |
277 "--json", | |
278 dest="json_output", | |
279 type="string" | |
280 ) | |
281 parser.add_option( | |
282 "--plot", | |
283 action="store_true" | |
284 ) | |
285 parser.add_option( | |
286 "--plot_all", | |
287 action="store_true" | |
288 ) | |
289 parser.add_option( | |
290 "--highlight", | |
291 action="store_true" | |
292 ) | |
293 try: | |
294 options, args = parser.parse_args() | |
295 if not (options.dynamine or options.disomine or options.efoldmine or options.agmata): | |
296 parser.error('At least one predictor is required') | |
297 if not options.input_fasta: | |
298 parser.error('Input file not given (--file)') | |
299 if not options.output: | |
300 parser.error('Output directory not given (--output)') | |
301 if (options.plot or options.plot_all) and not options.plot_output: | |
302 parser.error('Plot output directory not given (--plot-output)') | |
303 if not options.json_output: | |
304 parser.error('Json output file not given (--json)') | |
305 main(options) | |
306 except optparse.OptionError as exc: | |
307 raise RuntimeError(f"Invalid arguments: {args}") from exc |