Mercurial > repos > adrian.diaz > b2btools_single_sequence
comparison singleSeq/script.py @ 0:cacb90cde53e draft
First version of b2btools for single sequences in Galaxy
author | adrian.diaz |
---|---|
date | Wed, 06 Jul 2022 11:01:15 +0000 |
parents | |
children |
comparison
equal
deleted
inserted
replaced
-1:000000000000 | 0:cacb90cde53e |
---|---|
1 import optparse | |
2 import os.path | |
3 import unicodedata | |
4 import re | |
5 import pandas as pd | |
6 from b2bTools import SingleSeq | |
7 import matplotlib.pyplot as plt | |
8 | |
9 | |
10 def slugify(value): | |
11 """ | |
12 Taken from https://github.com/django/django/blob/master/django/utils/text.py | |
13 Convert to ASCII if 'allow_unicode'. Convert spaces or repeated | |
14 dashes to single dashes. Remove characters that aren't alphanumerics, | |
15 underscores, or hyphens. Convert to lowercase. Also strip leading and | |
16 trailing whitespace, dashes, and underscores. | |
17 """ | |
18 value = str(value) | |
19 value = unicodedata.normalize('NFKD', value).encode('ascii', 'ignore').decode('ascii') | |
20 value = re.sub(r'[^\w\s-]', '', value.lower()) | |
21 return re.sub(r'[-\s]+', '-', value).strip('-_') | |
22 | |
23 | |
24 def plot_prediction(prediction_name, highlighting_regions, predicted_values, seq_name): | |
25 thresholds_dict = {'backbone': {'membrane spanning': [1., 1.5], | |
26 'rigid': [0.8, 1.], | |
27 'context-dependent': [0.69, 0.8], | |
28 'flexible': [-1.0, 0.69]}, | |
29 'earlyFolding': {'early folds': [0.169, 2.], 'late folds': [-1., 0.169]}, | |
30 'disoMine': {'ordered': [-1., 0.5], 'disordered': [0.5, 2.]}, | |
31 } | |
32 ordered_regions_dict = {'backbone': ['flexible', 'context-dependent', 'rigid', 'membrane spanning'], | |
33 'earlyFolding': ['late folds', 'early folds'], | |
34 'disoMine': ['ordered', 'disordered'], | |
35 } | |
36 colors = ['yellow', 'orange', 'pink', 'red'] | |
37 ranges_dict = { | |
38 'backbone': [-0.2, 1.2], | |
39 'sidechain': [-0.2, 1.2], | |
40 'ppII': [-0.2, 1.2], | |
41 'earlyFolding': [-0.2, 1.2], | |
42 'disoMine': [-0.2, 1.2], | |
43 'agmata': [-0.2, 1.2], | |
44 'helix': [-1., 1.], | |
45 'sheet': [-1., 1.], | |
46 'coil': [-1., 1.], | |
47 } | |
48 fig, ax = plt.subplots(1, 1) | |
49 fig.set_figwidth(10) | |
50 fig.set_figheight(5) | |
51 ax.set_title(prediction_name + ' ' + 'prediction') | |
52 plt.tight_layout(rect=[0, 0, 0.75, 1]) | |
53 if seq_name == 'all': | |
54 max_len = 0 | |
55 for seq in predicted_values.keys(): | |
56 predictions = predicted_values[seq] | |
57 ax.plot(range(len(predictions)), predictions, label=seq) | |
58 if len(predictions)>max_len: | |
59 max_len = len(predictions) | |
60 ax.set_xlim([0, max_len - 1]) | |
61 else: | |
62 predictions = predicted_values | |
63 ax.plot(range(len(predictions)), predictions, label=seq_name) | |
64 ax.set_xlim([0, len(predictions) - 1]) | |
65 legend_lines = plt.legend(bbox_to_anchor=(1.04,1), loc="upper left", fancybox=True, shadow=True) | |
66 ax.add_artist(legend_lines) | |
67 # Define regions | |
68 if highlighting_regions: | |
69 if prediction_name in ordered_regions_dict.keys(): | |
70 for i, prediction in enumerate(ordered_regions_dict[prediction_name]): | |
71 lower = thresholds_dict[prediction_name][prediction][0] | |
72 upper = thresholds_dict[prediction_name][prediction][1] | |
73 color = colors[i] | |
74 ax.axhspan(lower, upper, alpha=0.3, color=color, label=prediction) | |
75 included_in_regions_legend = list(reversed( | |
76 [prediction for prediction in ordered_regions_dict[prediction_name]])) # to sort it "from up to low" | |
77 # Get handles and labels | |
78 handles, labels = plt.gca().get_legend_handles_labels() | |
79 handles_dict = {label: handles[idx] for idx, label in enumerate(labels)} | |
80 # Add legend for regions, if available | |
81 region_legend = ax.legend([handles_dict[region] for region in included_in_regions_legend], | |
82 [region for region in included_in_regions_legend], fancybox=True, shadow=True, | |
83 loc='lower left', bbox_to_anchor=(1.04,0)) | |
84 ax.add_artist(region_legend) | |
85 ax.set_ylim(ranges_dict[prediction_name]) | |
86 ax.set_xlabel('residue index') | |
87 ax.set_ylabel('prediction values') | |
88 ax.grid(axis='y') | |
89 plt.savefig(os.path.join(options.plot_output, "{0}_{1}.png".format(slugify(seq_name), prediction_name)), bbox_inches="tight") | |
90 plt.close() | |
91 | |
92 def df_dict_to_dict_of_values(df_dict, predictor): | |
93 results_dict = {} | |
94 for seq in df_dict.keys(): | |
95 df = pd.read_csv(df_dict[seq], sep='\t') | |
96 results_dict[seq] = df[predictor] | |
97 return results_dict | |
98 | |
99 def main(options): | |
100 single_seq = SingleSeq(options.input_fasta) | |
101 b2b_tools = [] | |
102 if options.dynamine: | |
103 b2b_tools.append('dynamine') | |
104 if options.disomine: | |
105 b2b_tools.append('disomine') | |
106 if options.efoldmine: | |
107 b2b_tools.append('efoldmine') | |
108 if options.agmata: | |
109 b2b_tools.append('agmata') | |
110 | |
111 single_seq.predict(b2b_tools) | |
112 predictions = single_seq.get_all_predictions() | |
113 results_json = single_seq.get_all_predictions_json('all') | |
114 with open(options.json_output, 'w') as f: | |
115 f.write(results_json) | |
116 first_sequence_key = next(iter(predictions)) | |
117 prediction_keys = predictions[first_sequence_key].keys() | |
118 df_dictionary = {} | |
119 for sequence_key, sequence_predictions in predictions.items(): | |
120 residues = sequence_predictions['seq'] | |
121 residues_count = len(residues) | |
122 sequence_df = pd.DataFrame(columns=prediction_keys, index=range(residues_count)) | |
123 sequence_df.index.name = 'residue_index' | |
124 for predictor in prediction_keys: | |
125 sequence_df[predictor] = sequence_predictions[predictor] | |
126 sequence_df = sequence_df.rename(columns={"seq": "residue"}) | |
127 sequence_df = sequence_df.round(decimals=2) | |
128 filename = f'{options.output}/{slugify(sequence_key)}.tsv' | |
129 df_dictionary[sequence_key] = filename | |
130 sequence_df.to_csv(filename, sep="\t") | |
131 # Plot each individual plot (compatible with plot all) | |
132 if options.plot: | |
133 for predictor in prediction_keys: | |
134 if predictor != 'seq': | |
135 plot_prediction(prediction_name=predictor, highlighting_regions=True, | |
136 predicted_values=sequence_predictions[predictor], seq_name=sequence_key) | |
137 # Plot all together (compatible with plot individual) | |
138 if options.plot_all: | |
139 for predictor in prediction_keys: | |
140 if predictor != 'seq': | |
141 results_dictionary = df_dict_to_dict_of_values(df_dict=df_dictionary, predictor=predictor) | |
142 plot_prediction(prediction_name=predictor, highlighting_regions=True, | |
143 predicted_values=results_dictionary, seq_name='all') | |
144 | |
145 if __name__ == "__main__": | |
146 parser = optparse.OptionParser() | |
147 parser.add_option("--dynamine", action="store_true", default=False) | |
148 parser.add_option("--disomine", action="store_true", default=False) | |
149 parser.add_option("--efoldmine", action="store_true", default=False) | |
150 parser.add_option("--agmata", action="store_true", default=False) | |
151 parser.add_option("--file", dest="input_fasta", default=False) | |
152 parser.add_option("--output", dest="output", default=False) | |
153 parser.add_option("--plot-output", dest="plot_output", default=False) | |
154 | |
155 parser.add_option("--json", dest="json_output", default=False) | |
156 parser.add_option("--plot", action="store_true", default=False) | |
157 parser.add_option("--plot_all", action="store_true", default=False) | |
158 parser.add_option("--highlight", action="store_true", default=False) | |
159 options, _args = parser.parse_args() | |
160 main(options) |