comparison singleSeq/script.py @ 0:cacb90cde53e draft

First version of b2btools for single sequences in Galaxy
author adrian.diaz
date Wed, 06 Jul 2022 11:01:15 +0000
parents
children
comparison
equal deleted inserted replaced
-1:000000000000 0:cacb90cde53e
1 import optparse
2 import os.path
3 import unicodedata
4 import re
5 import pandas as pd
6 from b2bTools import SingleSeq
7 import matplotlib.pyplot as plt
8
9
10 def slugify(value):
11 """
12 Taken from https://github.com/django/django/blob/master/django/utils/text.py
13 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated
14 dashes to single dashes. Remove characters that aren't alphanumerics,
15 underscores, or hyphens. Convert to lowercase. Also strip leading and
16 trailing whitespace, dashes, and underscores.
17 """
18 value = str(value)
19 value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii')
20 value = re.sub(r'[^\w\s-]', '', value.lower())
21 return re.sub(r'[-\s]+', '-', value).strip('-_')
22
23
24 def plot_prediction(prediction_name, highlighting_regions, predicted_values, seq_name):
25 thresholds_dict = {'backbone': {'membrane spanning': [1., 1.5],
26 'rigid': [0.8, 1.],
27 'context-dependent': [0.69, 0.8],
28 'flexible': [-1.0, 0.69]},
29 'earlyFolding': {'early folds': [0.169, 2.], 'late folds': [-1., 0.169]},
30 'disoMine': {'ordered': [-1., 0.5], 'disordered': [0.5, 2.]},
31 }
32 ordered_regions_dict = {'backbone': ['flexible', 'context-dependent', 'rigid', 'membrane spanning'],
33 'earlyFolding': ['late folds', 'early folds'],
34 'disoMine': ['ordered', 'disordered'],
35 }
36 colors = ['yellow', 'orange', 'pink', 'red']
37 ranges_dict = {
38 'backbone': [-0.2, 1.2],
39 'sidechain': [-0.2, 1.2],
40 'ppII': [-0.2, 1.2],
41 'earlyFolding': [-0.2, 1.2],
42 'disoMine': [-0.2, 1.2],
43 'agmata': [-0.2, 1.2],
44 'helix': [-1., 1.],
45 'sheet': [-1., 1.],
46 'coil': [-1., 1.],
47 }
48 fig, ax = plt.subplots(1, 1)
49 fig.set_figwidth(10)
50 fig.set_figheight(5)
51 ax.set_title(prediction_name + ' ' + 'prediction')
52 plt.tight_layout(rect=[0, 0, 0.75, 1])
53 if seq_name == 'all':
54 max_len = 0
55 for seq in predicted_values.keys():
56 predictions = predicted_values[seq]
57 ax.plot(range(len(predictions)), predictions, label=seq)
58 if len(predictions)>max_len:
59 max_len = len(predictions)
60 ax.set_xlim([0, max_len - 1])
61 else:
62 predictions = predicted_values
63 ax.plot(range(len(predictions)), predictions, label=seq_name)
64 ax.set_xlim([0, len(predictions) - 1])
65 legend_lines = plt.legend(bbox_to_anchor=(1.04,1), loc="upper left", fancybox=True, shadow=True)
66 ax.add_artist(legend_lines)
67 # Define regions
68 if highlighting_regions:
69 if prediction_name in ordered_regions_dict.keys():
70 for i, prediction in enumerate(ordered_regions_dict[prediction_name]):
71 lower = thresholds_dict[prediction_name][prediction][0]
72 upper = thresholds_dict[prediction_name][prediction][1]
73 color = colors[i]
74 ax.axhspan(lower, upper, alpha=0.3, color=color, label=prediction)
75 included_in_regions_legend = list(reversed(
76 [prediction for prediction in ordered_regions_dict[prediction_name]])) # to sort it "from up to low"
77 # Get handles and labels
78 handles, labels = plt.gca().get_legend_handles_labels()
79 handles_dict = {label: handles[idx] for idx, label in enumerate(labels)}
80 # Add legend for regions, if available
81 region_legend = ax.legend([handles_dict[region] for region in included_in_regions_legend],
82 [region for region in included_in_regions_legend], fancybox=True, shadow=True,
83 loc='lower left', bbox_to_anchor=(1.04,0))
84 ax.add_artist(region_legend)
85 ax.set_ylim(ranges_dict[prediction_name])
86 ax.set_xlabel('residue index')
87 ax.set_ylabel('prediction values')
88 ax.grid(axis='y')
89 plt.savefig(os.path.join(options.plot_output, "{0}_{1}.png".format(slugify(seq_name), prediction_name)), bbox_inches="tight")
90 plt.close()
91
92 def df_dict_to_dict_of_values(df_dict, predictor):
93 results_dict = {}
94 for seq in df_dict.keys():
95 df = pd.read_csv(df_dict[seq], sep='\t')
96 results_dict[seq] = df[predictor]
97 return results_dict
98
99 def main(options):
100 single_seq = SingleSeq(options.input_fasta)
101 b2b_tools = []
102 if options.dynamine:
103 b2b_tools.append('dynamine')
104 if options.disomine:
105 b2b_tools.append('disomine')
106 if options.efoldmine:
107 b2b_tools.append('efoldmine')
108 if options.agmata:
109 b2b_tools.append('agmata')
110
111 single_seq.predict(b2b_tools)
112 predictions = single_seq.get_all_predictions()
113 results_json = single_seq.get_all_predictions_json('all')
114 with open(options.json_output, 'w') as f:
115 f.write(results_json)
116 first_sequence_key = next(iter(predictions))
117 prediction_keys = predictions[first_sequence_key].keys()
118 df_dictionary = {}
119 for sequence_key, sequence_predictions in predictions.items():
120 residues = sequence_predictions['seq']
121 residues_count = len(residues)
122 sequence_df = pd.DataFrame(columns=prediction_keys, index=range(residues_count))
123 sequence_df.index.name = 'residue_index'
124 for predictor in prediction_keys:
125 sequence_df[predictor] = sequence_predictions[predictor]
126 sequence_df = sequence_df.rename(columns={"seq": "residue"})
127 sequence_df = sequence_df.round(decimals=2)
128 filename = f'{options.output}/{slugify(sequence_key)}.tsv'
129 df_dictionary[sequence_key] = filename
130 sequence_df.to_csv(filename, sep="\t")
131 # Plot each individual plot (compatible with plot all)
132 if options.plot:
133 for predictor in prediction_keys:
134 if predictor != 'seq':
135 plot_prediction(prediction_name=predictor, highlighting_regions=True,
136 predicted_values=sequence_predictions[predictor], seq_name=sequence_key)
137 # Plot all together (compatible with plot individual)
138 if options.plot_all:
139 for predictor in prediction_keys:
140 if predictor != 'seq':
141 results_dictionary = df_dict_to_dict_of_values(df_dict=df_dictionary, predictor=predictor)
142 plot_prediction(prediction_name=predictor, highlighting_regions=True,
143 predicted_values=results_dictionary, seq_name='all')
144
145 if __name__ == "__main__":
146 parser = optparse.OptionParser()
147 parser.add_option("--dynamine", action="store_true", default=False)
148 parser.add_option("--disomine", action="store_true", default=False)
149 parser.add_option("--efoldmine", action="store_true", default=False)
150 parser.add_option("--agmata", action="store_true", default=False)
151 parser.add_option("--file", dest="input_fasta", default=False)
152 parser.add_option("--output", dest="output", default=False)
153 parser.add_option("--plot-output", dest="plot_output", default=False)
154
155 parser.add_option("--json", dest="json_output", default=False)
156 parser.add_option("--plot", action="store_true", default=False)
157 parser.add_option("--plot_all", action="store_true", default=False)
158 parser.add_option("--highlight", action="store_true", default=False)
159 options, _args = parser.parse_args()
160 main(options)