2
|
1 import optparse
|
|
2 import os.path
|
|
3 import unicodedata
|
|
4 import re
|
|
5 import numpy as np
|
|
6 import pandas as pd
|
|
7 from b2bTools import SingleSeq
|
|
8 import matplotlib.pyplot as plt
|
|
9
|
|
10
|
|
11 def slugify(value):
|
|
12 """
|
|
13 Taken from
|
|
14 https://github.com/django/django/blob/master/django/utils/text.py
|
|
15 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated
|
|
16 dashes to single dashes. Remove characters that aren't alphanumerics,
|
|
17 underscores, or hyphens. Convert to lowercase. Also strip leading and
|
|
18 trailing whitespace, dashes, and underscores.
|
|
19 """
|
|
20 value = str(value)
|
|
21 value = unicodedata.normalize(
|
|
22 'NFKD', value).encode(
|
|
23 'ascii', 'ignore').decode('ascii')
|
|
24 value = re.sub(r'[^\w\s-]', '', value.lower())
|
|
25 return re.sub(r'[-\s]+', '-', value).strip('-_')
|
|
26
|
|
27
|
|
28 def check_min_max(predicted_values, former_min, former_max):
|
|
29 seq_max = max(predicted_values)
|
|
30 seq_min = min(predicted_values)
|
|
31 if seq_max + \
|
|
32 0.1 > former_max and not np.isnan(seq_max) \
|
|
33 and not np.isinf(seq_max):
|
|
34 former_max = seq_max + 0.1
|
|
35 if seq_min - \
|
|
36 0.1 < former_min and not np.isnan(seq_min) \
|
|
37 and not np.isinf(seq_min):
|
|
38 former_min = seq_min - 0.1
|
|
39 return former_min, former_max
|
|
40
|
|
41
|
|
42 def plot_prediction(prediction_name, highlighting_regions,
|
|
43 pred_vals, seq_name):
|
|
44 thresholds_dict = {'backbone': {'membrane spanning': [1., 1.5],
|
|
45 'rigid': [0.8, 1.],
|
|
46 'context-dependent': [0.69, 0.8],
|
|
47 'flexible': [-1.0, 0.69]},
|
|
48 'earlyFolding': {'early folds': [0.169, 2.],
|
|
49 'late folds': [-1., 0.169]},
|
|
50 'disoMine': {'ordered': [-1., 0.5],
|
|
51 'disordered': [0.5, 2.]},
|
|
52 }
|
|
53 ordered_regions_dict = {'backbone': ['flexible',
|
|
54 'context-dependent',
|
|
55 'rigid',
|
|
56 'membrane spanning'],
|
|
57 'earlyFolding': ['late folds', 'early folds'],
|
|
58 'disoMine': ['ordered', 'disordered'],
|
|
59 }
|
|
60 colors = ['yellow', 'orange', 'pink', 'red']
|
|
61 ranges_dict = {
|
|
62 'backbone': [-0.2, 1.2],
|
|
63 'sidechain': [-0.2, 1.2],
|
|
64 'ppII': [-0.2, 1.2],
|
|
65 'earlyFolding': [-0.2, 1.2],
|
|
66 'disoMine': [-0.2, 1.2],
|
|
67 'agmata': [-0.2, 1.2],
|
|
68 'helix': [-1., 1.],
|
|
69 'sheet': [-1., 1.],
|
|
70 'coil': [-1., 1.],
|
|
71 }
|
|
72 fig, ax = plt.subplots(1, 1)
|
|
73 fig.set_figwidth(10)
|
|
74 fig.set_figheight(5)
|
|
75 ax.set_title(prediction_name + ' ' + 'prediction')
|
|
76 min_value, max_value = ranges_dict[prediction_name]
|
|
77 if seq_name == 'all':
|
|
78 max_len = 0
|
|
79 for seq in pred_vals.keys():
|
|
80 predictions = pred_vals[seq]
|
|
81 min_value, max_value = check_min_max(
|
|
82 predictions, min_value, max_value)
|
|
83 ax.plot(range(len(predictions)), predictions, label=seq)
|
|
84 if len(predictions) > max_len:
|
|
85 max_len = len(predictions)
|
|
86 ax.set_xlim([0, max_len - 1])
|
|
87 else:
|
|
88 predictions = pred_vals
|
|
89 min_value, max_value = check_min_max(predictions, min_value, max_value)
|
|
90 ax.plot(range(len(predictions)), predictions, label=seq_name)
|
|
91 ax.set_xlim([0, len(predictions) - 1])
|
|
92 legend_lines = plt.legend(
|
|
93 bbox_to_anchor=(
|
|
94 1.04,
|
|
95 1),
|
|
96 loc="upper left",
|
|
97 fancybox=True,
|
|
98 shadow=True)
|
|
99 ax.add_artist(legend_lines)
|
|
100 # Define regions
|
|
101 if highlighting_regions:
|
|
102 if prediction_name in ordered_regions_dict.keys():
|
|
103 for i, prediction in enumerate(
|
|
104 ordered_regions_dict[prediction_name]):
|
|
105 lower = thresholds_dict[prediction_name][prediction][0]
|
|
106 upper = thresholds_dict[prediction_name][prediction][1]
|
|
107 color = colors[i]
|
|
108 ax.axhspan(
|
|
109 lower,
|
|
110 upper,
|
|
111 alpha=0.3,
|
|
112 color=color,
|
|
113 label=prediction)
|
|
114 # to sort it "from up to low"
|
|
115 included_in_regions_legend = list(reversed(
|
|
116 [r_pred for r_pred in ordered_regions_dict[prediction_name]]))
|
|
117 # Get handles and labels
|
|
118 handles, labels = plt.gca().get_legend_handles_labels()
|
|
119 handles_dict = {label: handles[idx]
|
|
120 for idx, label in enumerate(labels)}
|
|
121 # Add legend for regions, if available
|
|
122 lgnd_labels = [handles_dict[r] for r in included_in_regions_legend]
|
|
123 lgnd_regions = [region for region in included_in_regions_legend]
|
|
124 region_legend = ax.legend(lgnd_labels,
|
|
125 lgnd_regions,
|
|
126 fancybox=True,
|
|
127 shadow=True,
|
|
128 loc='lower left',
|
|
129 bbox_to_anchor=(1.04, 0))
|
|
130 ax.add_artist(region_legend)
|
|
131 ax.set_ylim([min_value, max_value])
|
|
132 ax.set_xlabel('residue index')
|
|
133 ax.set_ylabel('prediction values')
|
|
134 ax.grid(axis='y')
|
|
135 plt.savefig(
|
|
136 os.path.join(
|
|
137 options.plot_output,
|
|
138 "{0}_{1}.png".format(
|
|
139 slugify(seq_name),
|
|
140 prediction_name)),
|
|
141 bbox_inches="tight")
|
|
142 plt.close()
|
|
143
|
|
144
|
|
145 def df_dict_to_dict_of_values(df_dict, predictor):
|
|
146 results_dict = {}
|
|
147 for seq in df_dict.keys():
|
|
148 df = pd.read_csv(df_dict[seq], sep='\t')
|
|
149 results_dict[seq] = df[predictor]
|
|
150 return results_dict
|
|
151
|
|
152
|
|
153 def main(options):
|
|
154 single_seq = SingleSeq(options.input_fasta)
|
|
155 b2b_tools = []
|
|
156 if options.dynamine:
|
|
157 b2b_tools.append('dynamine')
|
|
158 if options.disomine:
|
|
159 b2b_tools.append('disomine')
|
|
160 if options.efoldmine:
|
|
161 b2b_tools.append('efoldmine')
|
|
162 if options.agmata:
|
|
163 b2b_tools.append('agmata')
|
|
164
|
|
165 single_seq.predict(b2b_tools)
|
|
166 predictions = single_seq.get_all_predictions()
|
|
167 results_json = single_seq.get_all_predictions_json('all')
|
|
168 with open(options.json_output, 'w') as f:
|
|
169 f.write(results_json)
|
|
170 first_sequence_key = next(iter(predictions))
|
|
171 prediction_keys = predictions[first_sequence_key].keys()
|
|
172 df_dictionary = {}
|
|
173 for sequence_key, sequence_predictions in predictions.items():
|
|
174 residues = sequence_predictions['seq']
|
|
175 residues_count = len(residues)
|
|
176 sequence_df = pd.DataFrame(
|
|
177 columns=prediction_keys,
|
|
178 index=range(residues_count))
|
|
179 sequence_df.index.name = 'residue_index'
|
|
180 for predictor in prediction_keys:
|
|
181 sequence_df[predictor] = sequence_predictions[predictor]
|
|
182 sequence_df = sequence_df.rename(columns={"seq": "residue"})
|
|
183 sequence_df = sequence_df.round(decimals=2)
|
|
184 filename = f'{options.output}/{slugify(sequence_key)}.tsv'
|
|
185 df_dictionary[sequence_key] = filename
|
|
186 sequence_df.to_csv(filename, sep="\t")
|
|
187 # Plot each individual plot (compatible with plot all)
|
|
188 if options.plot:
|
|
189 for predictor in prediction_keys:
|
|
190 if predictor != 'seq':
|
|
191 plot_prediction(prediction_name=predictor,
|
|
192 highlighting_regions=True,
|
|
193 pred_vals=sequence_predictions[predictor],
|
|
194 seq_name=sequence_key)
|
|
195 # Plot all together (compatible with plot individual)
|
|
196 if options.plot_all:
|
|
197 for predictor in prediction_keys:
|
|
198 if predictor != 'seq':
|
|
199 results_dictionary = df_dict_to_dict_of_values(
|
|
200 df_dict=df_dictionary, predictor=predictor)
|
|
201 plot_prediction(prediction_name=predictor,
|
|
202 highlighting_regions=True,
|
|
203 pred_vals=results_dictionary,
|
|
204 seq_name='all')
|
|
205
|
|
206
|
|
207 if __name__ == "__main__":
|
|
208 parser = optparse.OptionParser()
|
|
209 parser.add_option("--dynamine", action="store_true", default=False)
|
|
210 parser.add_option("--disomine", action="store_true", default=False)
|
|
211 parser.add_option("--efoldmine", action="store_true", default=False)
|
|
212 parser.add_option("--agmata", action="store_true", default=False)
|
|
213 parser.add_option("--file", dest="input_fasta", default=False)
|
|
214 parser.add_option("--output", dest="output", default=False)
|
|
215 parser.add_option("--plot-output", dest="plot_output", default=False)
|
|
216
|
|
217 parser.add_option("--json", dest="json_output", default=False)
|
|
218 parser.add_option("--plot", action="store_true", default=False)
|
|
219 parser.add_option("--plot_all", action="store_true", default=False)
|
|
220 parser.add_option("--highlight", action="store_true", default=False)
|
|
221 options, _args = parser.parse_args()
|
|
222 main(options)
|