diff script.py @ 2:a9db23ac113f draft default tip

Uploaded new version formatted.
author adrian.diaz
date Tue, 02 Aug 2022 09:44:33 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/script.py	Tue Aug 02 09:44:33 2022 +0000
@@ -0,0 +1,222 @@
+import optparse
+import os.path
+import unicodedata
+import re
+import numpy as np
+import pandas as pd
+from b2bTools import SingleSeq
+import matplotlib.pyplot as plt
+
+
+def slugify(value):
+    """
+    Taken from
+    https://github.com/django/django/blob/master/django/utils/text.py
+    Convert to ASCII if 'allow_unicode'. Convert spaces or repeated
+    dashes to single dashes. Remove characters that aren't alphanumerics,
+    underscores, or hyphens. Convert to lowercase. Also strip leading and
+    trailing whitespace, dashes, and underscores.
+    """
+    value = str(value)
+    value = unicodedata.normalize(
+        'NFKD', value).encode(
+        'ascii', 'ignore').decode('ascii')
+    value = re.sub(r'[^\w\s-]', '', value.lower())
+    return re.sub(r'[-\s]+', '-', value).strip('-_')
+
+
+def check_min_max(predicted_values, former_min, former_max):
+    seq_max = max(predicted_values)
+    seq_min = min(predicted_values)
+    if seq_max + \
+            0.1 > former_max and not np.isnan(seq_max) \
+            and not np.isinf(seq_max):
+        former_max = seq_max + 0.1
+    if seq_min - \
+            0.1 < former_min and not np.isnan(seq_min) \
+            and not np.isinf(seq_min):
+        former_min = seq_min - 0.1
+    return former_min, former_max
+
+
+def plot_prediction(prediction_name, highlighting_regions,
+                    pred_vals, seq_name):
+    thresholds_dict = {'backbone': {'membrane spanning': [1., 1.5],
+                                    'rigid': [0.8, 1.],
+                                    'context-dependent': [0.69, 0.8],
+                                    'flexible': [-1.0, 0.69]},
+                       'earlyFolding': {'early folds': [0.169, 2.],
+                                        'late folds': [-1., 0.169]},
+                       'disoMine': {'ordered': [-1., 0.5],
+                                    'disordered': [0.5, 2.]},
+                       }
+    ordered_regions_dict = {'backbone': ['flexible',
+                                         'context-dependent',
+                                         'rigid',
+                                         'membrane spanning'],
+                            'earlyFolding': ['late folds', 'early folds'],
+                            'disoMine': ['ordered', 'disordered'],
+                            }
+    colors = ['yellow', 'orange', 'pink', 'red']
+    ranges_dict = {
+        'backbone': [-0.2, 1.2],
+        'sidechain': [-0.2, 1.2],
+        'ppII': [-0.2, 1.2],
+        'earlyFolding': [-0.2, 1.2],
+        'disoMine': [-0.2, 1.2],
+        'agmata': [-0.2, 1.2],
+        'helix': [-1., 1.],
+        'sheet': [-1., 1.],
+        'coil': [-1., 1.],
+    }
+    fig, ax = plt.subplots(1, 1)
+    fig.set_figwidth(10)
+    fig.set_figheight(5)
+    ax.set_title(prediction_name + ' ' + 'prediction')
+    min_value, max_value = ranges_dict[prediction_name]
+    if seq_name == 'all':
+        max_len = 0
+        for seq in pred_vals.keys():
+            predictions = pred_vals[seq]
+            min_value, max_value = check_min_max(
+                predictions, min_value, max_value)
+            ax.plot(range(len(predictions)), predictions, label=seq)
+            if len(predictions) > max_len:
+                max_len = len(predictions)
+            ax.set_xlim([0, max_len - 1])
+    else:
+        predictions = pred_vals
+        min_value, max_value = check_min_max(predictions, min_value, max_value)
+        ax.plot(range(len(predictions)), predictions, label=seq_name)
+        ax.set_xlim([0, len(predictions) - 1])
+    legend_lines = plt.legend(
+        bbox_to_anchor=(
+            1.04,
+            1),
+        loc="upper left",
+        fancybox=True,
+        shadow=True)
+    ax.add_artist(legend_lines)
+    # Define regions
+    if highlighting_regions:
+        if prediction_name in ordered_regions_dict.keys():
+            for i, prediction in enumerate(
+                    ordered_regions_dict[prediction_name]):
+                lower = thresholds_dict[prediction_name][prediction][0]
+                upper = thresholds_dict[prediction_name][prediction][1]
+                color = colors[i]
+                ax.axhspan(
+                    lower,
+                    upper,
+                    alpha=0.3,
+                    color=color,
+                    label=prediction)
+            # to sort it "from up to low"
+            included_in_regions_legend = list(reversed(
+                [r_pred for r_pred in ordered_regions_dict[prediction_name]]))
+            # Get handles and labels
+            handles, labels = plt.gca().get_legend_handles_labels()
+            handles_dict = {label: handles[idx]
+                            for idx, label in enumerate(labels)}
+            # Add legend for regions, if available
+            lgnd_labels = [handles_dict[r] for r in included_in_regions_legend]
+            lgnd_regions = [region for region in included_in_regions_legend]
+            region_legend = ax.legend(lgnd_labels,
+                                      lgnd_regions,
+                                      fancybox=True,
+                                      shadow=True,
+                                      loc='lower left',
+                                      bbox_to_anchor=(1.04, 0))
+            ax.add_artist(region_legend)
+    ax.set_ylim([min_value, max_value])
+    ax.set_xlabel('residue index')
+    ax.set_ylabel('prediction values')
+    ax.grid(axis='y')
+    plt.savefig(
+        os.path.join(
+            options.plot_output,
+            "{0}_{1}.png".format(
+                slugify(seq_name),
+                prediction_name)),
+        bbox_inches="tight")
+    plt.close()
+
+
+def df_dict_to_dict_of_values(df_dict, predictor):
+    results_dict = {}
+    for seq in df_dict.keys():
+        df = pd.read_csv(df_dict[seq], sep='\t')
+        results_dict[seq] = df[predictor]
+    return results_dict
+
+
+def main(options):
+    single_seq = SingleSeq(options.input_fasta)
+    b2b_tools = []
+    if options.dynamine:
+        b2b_tools.append('dynamine')
+    if options.disomine:
+        b2b_tools.append('disomine')
+    if options.efoldmine:
+        b2b_tools.append('efoldmine')
+    if options.agmata:
+        b2b_tools.append('agmata')
+
+    single_seq.predict(b2b_tools)
+    predictions = single_seq.get_all_predictions()
+    results_json = single_seq.get_all_predictions_json('all')
+    with open(options.json_output, 'w') as f:
+        f.write(results_json)
+    first_sequence_key = next(iter(predictions))
+    prediction_keys = predictions[first_sequence_key].keys()
+    df_dictionary = {}
+    for sequence_key, sequence_predictions in predictions.items():
+        residues = sequence_predictions['seq']
+        residues_count = len(residues)
+        sequence_df = pd.DataFrame(
+            columns=prediction_keys,
+            index=range(residues_count))
+        sequence_df.index.name = 'residue_index'
+        for predictor in prediction_keys:
+            sequence_df[predictor] = sequence_predictions[predictor]
+        sequence_df = sequence_df.rename(columns={"seq": "residue"})
+        sequence_df = sequence_df.round(decimals=2)
+        filename = f'{options.output}/{slugify(sequence_key)}.tsv'
+        df_dictionary[sequence_key] = filename
+        sequence_df.to_csv(filename, sep="\t")
+        # Plot each individual plot (compatible with plot all)
+        if options.plot:
+            for predictor in prediction_keys:
+                if predictor != 'seq':
+                    plot_prediction(prediction_name=predictor,
+                                    highlighting_regions=True,
+                                    pred_vals=sequence_predictions[predictor],
+                                    seq_name=sequence_key)
+    # Plot all together (compatible with plot individual)
+    if options.plot_all:
+        for predictor in prediction_keys:
+            if predictor != 'seq':
+                results_dictionary = df_dict_to_dict_of_values(
+                    df_dict=df_dictionary, predictor=predictor)
+                plot_prediction(prediction_name=predictor,
+                                highlighting_regions=True,
+                                pred_vals=results_dictionary,
+                                seq_name='all')
+
+
+if __name__ == "__main__":
+    parser = optparse.OptionParser()
+    parser.add_option("--dynamine", action="store_true", default=False)
+    parser.add_option("--disomine", action="store_true", default=False)
+    parser.add_option("--efoldmine", action="store_true", default=False)
+    parser.add_option("--agmata", action="store_true", default=False)
+    parser.add_option("--file", dest="input_fasta", default=False)
+    parser.add_option("--output", dest="output", default=False)
+    parser.add_option("--plot-output", dest="plot_output", default=False)
+
+    parser.add_option("--json", dest="json_output", default=False)
+    parser.add_option("--plot", action="store_true", default=False)
+    parser.add_option("--plot_all", action="store_true", default=False)
+    parser.add_option("--highlight", action="store_true", default=False)
+    options, _args = parser.parse_args()
+    main(options)