comparison script.py @ 0:b694a77ca1e8 draft default tip

planemo upload commit 599e1135baba020195b3f7576449d595bca9af75
author iuc
date Tue, 09 Aug 2022 12:30:52 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:b694a77ca1e8
1 import json
2 import optparse
3 import os.path
4 import re
5 import unicodedata
6
7 import matplotlib.pyplot as plt
8 import numpy as np
9 import pandas as pd
10 from b2bTools import SingleSeq
11
12
13 def slugify(value):
14 """
15 From https://github.com/django/django/blob/master/django/utils/text.py
16 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated
17 dashes to single dashes. Remove characters that aren't alphanumerics,
18 underscores, or hyphens. Convert to lowercase. Also strip leading and
19 trailing whitespace, dashes, and underscores.
20 """
21 value = str(value)
22 value = (
23 unicodedata.normalize("NFKD", value)
24 .encode("ascii", "ignore")
25 .decode("ascii")
26 )
27 value = re.sub(r"[^\w\s-]", "", value.lower())
28 return re.sub(r"[-\s]+", "-", value).strip("-_")
29
30
31 def check_min_max(predicted_values, former_min, former_max):
32 seq_max = max(predicted_values)
33 seq_min = min(predicted_values)
34 if (
35 seq_max + 0.1 > former_max
36 and not np.isnan(seq_max)
37 and not np.isinf(seq_max)
38 ):
39 former_max = seq_max + 0.1
40 if (
41 seq_min - 0.1 < former_min
42 and not np.isnan(seq_min)
43 and not np.isinf(seq_min)
44 ):
45 former_min = seq_min - 0.1
46 return former_min, former_max
47
48
49 def plot_prediction(pred_name, hlighting_regions, predicted_values, seq_name):
50 thresholds_dict = {
51 "backbone": {
52 "membrane spanning": [1.0, 1.5],
53 "rigid": [0.8, 1.0],
54 "context-dependent": [0.69, 0.8],
55 "flexible": [-1.0, 0.69],
56 },
57 "earlyFolding": {
58 "early folds": [0.169, 2.0],
59 "late folds": [-1.0, 0.169],
60 },
61 "disoMine": {"ordered": [-1.0, 0.5], "disordered": [0.5, 2.0]},
62 }
63 ordered_regions_dict = {
64 "backbone": [
65 "flexible",
66 "context-dependent",
67 "rigid",
68 "membrane spanning",
69 ],
70 "earlyFolding": ["late folds", "early folds"],
71 "disoMine": ["ordered", "disordered"],
72 }
73 colors = ["yellow", "orange", "pink", "red"]
74 ranges_dict = {
75 "backbone": [-0.2, 1.2],
76 "sidechain": [-0.2, 1.2],
77 "ppII": [-0.2, 1.2],
78 "earlyFolding": [-0.2, 1.2],
79 "disoMine": [-0.2, 1.2],
80 "agmata": [-0.2, 1.2],
81 "helix": [-1.0, 1.0],
82 "sheet": [-1.0, 1.0],
83 "coil": [-1.0, 1.0],
84 }
85 fig, ax = plt.subplots(1, 1)
86 fig.set_figwidth(10)
87 fig.set_figheight(5)
88 ax.set_title(pred_name + " " + "prediction")
89 min_value, max_value = ranges_dict[pred_name]
90 if seq_name == "all":
91 max_len = 0
92 for seq in predicted_values.keys():
93 predictions = predicted_values[seq]
94 min_value, max_value = check_min_max(
95 predictions, min_value, max_value
96 )
97 ax.plot(range(len(predictions)), predictions, label=seq)
98 if len(predictions) > max_len:
99 max_len = len(predictions)
100 ax.set_xlim([0, max_len - 1])
101 else:
102 predictions = predicted_values
103 min_value, max_value = check_min_max(predictions, min_value, max_value)
104 ax.plot(range(len(predictions)), predictions, label=seq_name)
105 ax.set_xlim([0, len(predictions) - 1])
106 legend_lines = plt.legend(
107 bbox_to_anchor=(1.04, 1), loc="upper left", fancybox=True, shadow=True
108 )
109 ax.add_artist(legend_lines)
110 # Define regions
111 if hlighting_regions:
112 if pred_name in ordered_regions_dict.keys():
113 for i, prediction in enumerate(ordered_regions_dict[pred_name]):
114 lower = thresholds_dict[pred_name][prediction][0]
115 upper = thresholds_dict[pred_name][prediction][1]
116 color = colors[i]
117 ax.axhspan(
118 lower, upper, alpha=0.3, color=color, label=prediction
119 )
120 included_in_regions_legend = list(
121 reversed(
122 [
123 prediction
124 for prediction in ordered_regions_dict[pred_name]
125 ]
126 )
127 ) # to sort it "from up to low"
128 # Get handles and labels
129 handles, labels = plt.gca().get_legend_handles_labels()
130 handles_dict = {
131 label: handles[idx] for idx, label in enumerate(labels)
132 }
133 # Add legend for regions, if available
134 region_legend = ax.legend(
135 [
136 handles_dict[region]
137 for region in included_in_regions_legend
138 ],
139 [region for region in included_in_regions_legend],
140 fancybox=True,
141 shadow=True,
142 loc="lower left",
143 bbox_to_anchor=(1.04, 0),
144 )
145 ax.add_artist(region_legend)
146 ax.set_ylim([min_value, max_value])
147 ax.set_xlabel("residue index")
148 ax.set_ylabel("prediction values")
149 ax.grid(axis="y")
150 plt.savefig(
151 os.path.join(
152 options.plot_output,
153 "{0}_{1}.png".format(slugify(seq_name), pred_name),
154 ),
155 bbox_inches="tight",
156 )
157 plt.close()
158
159
160 def df_dict_to_dict_of_values(df_dict, predictor):
161 results_dict = {}
162 for seq in df_dict.keys():
163 df = pd.read_csv(df_dict[seq], sep="\t")
164 results_dict[seq] = df[predictor]
165 return results_dict
166
167
168 def main(options):
169 single_seq = SingleSeq(options.input_fasta)
170 b2b_tools = []
171 if options.dynamine:
172 b2b_tools.append("dynamine")
173 if options.disomine:
174 b2b_tools.append("disomine")
175 if options.efoldmine:
176 b2b_tools.append("efoldmine")
177 if options.agmata:
178 b2b_tools.append("agmata")
179 single_seq.predict(b2b_tools)
180 predictions = single_seq.get_all_predictions()
181
182 def rounder_function(value):
183 return round(float(value), 3)
184
185 rounded_predictions = json.loads(
186 json.dumps(predictions), parse_float=rounder_function
187 )
188 results_json = json.dumps(rounded_predictions, indent=2, sort_keys=True)
189 with open(options.json_output, "w") as f:
190 f.write(results_json)
191 first_sequence_key = next(iter(predictions))
192 prediction_keys = predictions[first_sequence_key].keys()
193 # Sort column names
194 tsv_column_names = list(prediction_keys)
195 tsv_column_names.remove("seq")
196 tsv_column_names = ['residue', *sorted(tsv_column_names)]
197
198 df_dictionary = {}
199 for sequence_key, seq_preds in predictions.items():
200 residues = seq_preds["seq"]
201 residues_count = len(residues)
202 sequence_df = pd.DataFrame(
203 columns=prediction_keys, index=range(residues_count)
204 )
205 sequence_df.index.name = "residue_index"
206 for predictor in prediction_keys:
207 sequence_df[predictor] = seq_preds[predictor]
208 sequence_df = sequence_df.rename(columns={"seq": "residue"})
209 sequence_df = sequence_df.round(decimals=3)
210 filename = f"{options.output}/{slugify(sequence_key)}.tsv"
211 df_dictionary[sequence_key] = filename
212 sequence_df.to_csv(
213 filename,
214 header=True,
215 columns=tsv_column_names,
216 sep="\t"
217 )
218 # Plot each individual plot (compatible with plot all)
219 if options.plot:
220 for predictor in prediction_keys:
221 if predictor != "seq":
222 plot_prediction(
223 pred_name=predictor,
224 hlighting_regions=True,
225 predicted_values=seq_preds[predictor],
226 seq_name=sequence_key,
227 )
228 # Plot all together (compatible with plot individual)
229 if options.plot_all:
230 for predictor in prediction_keys:
231 if predictor != "seq":
232 results_dictionary = df_dict_to_dict_of_values(
233 df_dictionary, predictor
234 )
235 plot_prediction(
236 pred_name=predictor,
237 hlighting_regions=True,
238 predicted_values=results_dictionary,
239 seq_name="all",
240 )
241
242
243 if __name__ == "__main__":
244 parser = optparse.OptionParser()
245 parser.add_option(
246 "--dynamine",
247 action="store_true"
248 )
249 parser.add_option(
250 "--disomine",
251 action="store_true"
252 )
253 parser.add_option(
254 "--efoldmine",
255 action="store_true"
256 )
257 parser.add_option(
258 "--agmata",
259 action="store_true"
260 )
261 parser.add_option(
262 "--file",
263 dest="input_fasta",
264 type="string"
265 )
266 parser.add_option(
267 "--output",
268 dest="output",
269 type="string"
270 )
271 parser.add_option(
272 "--plot-output",
273 type="string",
274 dest="plot_output"
275 )
276 parser.add_option(
277 "--json",
278 dest="json_output",
279 type="string"
280 )
281 parser.add_option(
282 "--plot",
283 action="store_true"
284 )
285 parser.add_option(
286 "--plot_all",
287 action="store_true"
288 )
289 parser.add_option(
290 "--highlight",
291 action="store_true"
292 )
293 try:
294 options, args = parser.parse_args()
295 if not (options.dynamine or options.disomine or options.efoldmine or options.agmata):
296 parser.error('At least one predictor is required')
297 if not options.input_fasta:
298 parser.error('Input file not given (--file)')
299 if not options.output:
300 parser.error('Output directory not given (--output)')
301 if (options.plot or options.plot_all) and not options.plot_output:
302 parser.error('Plot output directory not given (--plot-output)')
303 if not options.json_output:
304 parser.error('Json output file not given (--json)')
305 main(options)
306 except optparse.OptionError as exc:
307 raise RuntimeError(f"Invalid arguments: {args}") from exc