comparison script.py @ 2:a9db23ac113f draft default tip

Uploaded new version formatted.
author adrian.diaz
date Tue, 02 Aug 2022 09:44:33 +0000
parents
children
comparison
equal deleted inserted replaced
1:891ccfd22633 2:a9db23ac113f
1 import optparse
2 import os.path
3 import unicodedata
4 import re
5 import numpy as np
6 import pandas as pd
7 from b2bTools import SingleSeq
8 import matplotlib.pyplot as plt
9
10
11 def slugify(value):
12 """
13 Taken from
14 https://github.com/django/django/blob/master/django/utils/text.py
15 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated
16 dashes to single dashes. Remove characters that aren't alphanumerics,
17 underscores, or hyphens. Convert to lowercase. Also strip leading and
18 trailing whitespace, dashes, and underscores.
19 """
20 value = str(value)
21 value = unicodedata.normalize(
22 'NFKD', value).encode(
23 'ascii', 'ignore').decode('ascii')
24 value = re.sub(r'[^\w\s-]', '', value.lower())
25 return re.sub(r'[-\s]+', '-', value).strip('-_')
26
27
28 def check_min_max(predicted_values, former_min, former_max):
29 seq_max = max(predicted_values)
30 seq_min = min(predicted_values)
31 if seq_max + \
32 0.1 > former_max and not np.isnan(seq_max) \
33 and not np.isinf(seq_max):
34 former_max = seq_max + 0.1
35 if seq_min - \
36 0.1 < former_min and not np.isnan(seq_min) \
37 and not np.isinf(seq_min):
38 former_min = seq_min - 0.1
39 return former_min, former_max
40
41
42 def plot_prediction(prediction_name, highlighting_regions,
43 pred_vals, seq_name):
44 thresholds_dict = {'backbone': {'membrane spanning': [1., 1.5],
45 'rigid': [0.8, 1.],
46 'context-dependent': [0.69, 0.8],
47 'flexible': [-1.0, 0.69]},
48 'earlyFolding': {'early folds': [0.169, 2.],
49 'late folds': [-1., 0.169]},
50 'disoMine': {'ordered': [-1., 0.5],
51 'disordered': [0.5, 2.]},
52 }
53 ordered_regions_dict = {'backbone': ['flexible',
54 'context-dependent',
55 'rigid',
56 'membrane spanning'],
57 'earlyFolding': ['late folds', 'early folds'],
58 'disoMine': ['ordered', 'disordered'],
59 }
60 colors = ['yellow', 'orange', 'pink', 'red']
61 ranges_dict = {
62 'backbone': [-0.2, 1.2],
63 'sidechain': [-0.2, 1.2],
64 'ppII': [-0.2, 1.2],
65 'earlyFolding': [-0.2, 1.2],
66 'disoMine': [-0.2, 1.2],
67 'agmata': [-0.2, 1.2],
68 'helix': [-1., 1.],
69 'sheet': [-1., 1.],
70 'coil': [-1., 1.],
71 }
72 fig, ax = plt.subplots(1, 1)
73 fig.set_figwidth(10)
74 fig.set_figheight(5)
75 ax.set_title(prediction_name + ' ' + 'prediction')
76 min_value, max_value = ranges_dict[prediction_name]
77 if seq_name == 'all':
78 max_len = 0
79 for seq in pred_vals.keys():
80 predictions = pred_vals[seq]
81 min_value, max_value = check_min_max(
82 predictions, min_value, max_value)
83 ax.plot(range(len(predictions)), predictions, label=seq)
84 if len(predictions) > max_len:
85 max_len = len(predictions)
86 ax.set_xlim([0, max_len - 1])
87 else:
88 predictions = pred_vals
89 min_value, max_value = check_min_max(predictions, min_value, max_value)
90 ax.plot(range(len(predictions)), predictions, label=seq_name)
91 ax.set_xlim([0, len(predictions) - 1])
92 legend_lines = plt.legend(
93 bbox_to_anchor=(
94 1.04,
95 1),
96 loc="upper left",
97 fancybox=True,
98 shadow=True)
99 ax.add_artist(legend_lines)
100 # Define regions
101 if highlighting_regions:
102 if prediction_name in ordered_regions_dict.keys():
103 for i, prediction in enumerate(
104 ordered_regions_dict[prediction_name]):
105 lower = thresholds_dict[prediction_name][prediction][0]
106 upper = thresholds_dict[prediction_name][prediction][1]
107 color = colors[i]
108 ax.axhspan(
109 lower,
110 upper,
111 alpha=0.3,
112 color=color,
113 label=prediction)
114 # to sort it "from up to low"
115 included_in_regions_legend = list(reversed(
116 [r_pred for r_pred in ordered_regions_dict[prediction_name]]))
117 # Get handles and labels
118 handles, labels = plt.gca().get_legend_handles_labels()
119 handles_dict = {label: handles[idx]
120 for idx, label in enumerate(labels)}
121 # Add legend for regions, if available
122 lgnd_labels = [handles_dict[r] for r in included_in_regions_legend]
123 lgnd_regions = [region for region in included_in_regions_legend]
124 region_legend = ax.legend(lgnd_labels,
125 lgnd_regions,
126 fancybox=True,
127 shadow=True,
128 loc='lower left',
129 bbox_to_anchor=(1.04, 0))
130 ax.add_artist(region_legend)
131 ax.set_ylim([min_value, max_value])
132 ax.set_xlabel('residue index')
133 ax.set_ylabel('prediction values')
134 ax.grid(axis='y')
135 plt.savefig(
136 os.path.join(
137 options.plot_output,
138 "{0}_{1}.png".format(
139 slugify(seq_name),
140 prediction_name)),
141 bbox_inches="tight")
142 plt.close()
143
144
145 def df_dict_to_dict_of_values(df_dict, predictor):
146 results_dict = {}
147 for seq in df_dict.keys():
148 df = pd.read_csv(df_dict[seq], sep='\t')
149 results_dict[seq] = df[predictor]
150 return results_dict
151
152
153 def main(options):
154 single_seq = SingleSeq(options.input_fasta)
155 b2b_tools = []
156 if options.dynamine:
157 b2b_tools.append('dynamine')
158 if options.disomine:
159 b2b_tools.append('disomine')
160 if options.efoldmine:
161 b2b_tools.append('efoldmine')
162 if options.agmata:
163 b2b_tools.append('agmata')
164
165 single_seq.predict(b2b_tools)
166 predictions = single_seq.get_all_predictions()
167 results_json = single_seq.get_all_predictions_json('all')
168 with open(options.json_output, 'w') as f:
169 f.write(results_json)
170 first_sequence_key = next(iter(predictions))
171 prediction_keys = predictions[first_sequence_key].keys()
172 df_dictionary = {}
173 for sequence_key, sequence_predictions in predictions.items():
174 residues = sequence_predictions['seq']
175 residues_count = len(residues)
176 sequence_df = pd.DataFrame(
177 columns=prediction_keys,
178 index=range(residues_count))
179 sequence_df.index.name = 'residue_index'
180 for predictor in prediction_keys:
181 sequence_df[predictor] = sequence_predictions[predictor]
182 sequence_df = sequence_df.rename(columns={"seq": "residue"})
183 sequence_df = sequence_df.round(decimals=2)
184 filename = f'{options.output}/{slugify(sequence_key)}.tsv'
185 df_dictionary[sequence_key] = filename
186 sequence_df.to_csv(filename, sep="\t")
187 # Plot each individual plot (compatible with plot all)
188 if options.plot:
189 for predictor in prediction_keys:
190 if predictor != 'seq':
191 plot_prediction(prediction_name=predictor,
192 highlighting_regions=True,
193 pred_vals=sequence_predictions[predictor],
194 seq_name=sequence_key)
195 # Plot all together (compatible with plot individual)
196 if options.plot_all:
197 for predictor in prediction_keys:
198 if predictor != 'seq':
199 results_dictionary = df_dict_to_dict_of_values(
200 df_dict=df_dictionary, predictor=predictor)
201 plot_prediction(prediction_name=predictor,
202 highlighting_regions=True,
203 pred_vals=results_dictionary,
204 seq_name='all')
205
206
207 if __name__ == "__main__":
208 parser = optparse.OptionParser()
209 parser.add_option("--dynamine", action="store_true", default=False)
210 parser.add_option("--disomine", action="store_true", default=False)
211 parser.add_option("--efoldmine", action="store_true", default=False)
212 parser.add_option("--agmata", action="store_true", default=False)
213 parser.add_option("--file", dest="input_fasta", default=False)
214 parser.add_option("--output", dest="output", default=False)
215 parser.add_option("--plot-output", dest="plot_output", default=False)
216
217 parser.add_option("--json", dest="json_output", default=False)
218 parser.add_option("--plot", action="store_true", default=False)
219 parser.add_option("--plot_all", action="store_true", default=False)
220 parser.add_option("--highlight", action="store_true", default=False)
221 options, _args = parser.parse_args()
222 main(options)