view infer_rnaformer.xml @ 1:6c310acd2f87 draft default tip

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/rna_tools/rnaformer commit a0639f646e953fe490559b10bb7271a67d07cac4
author iuc
date Fri, 07 Mar 2025 15:22:24 +0000
parents 02b0ecc34d9a
children
line wrap: on
line source

<tool id="infer_rnaformer" name="@EXECUTABLE@" version="@TOOL_VERSION@+galaxy1" profile="22.05">
    <description>Predict the secondary structure of an RNA with RNAformer</description>
    <macros>
        <import>macros.xml</import>
    </macros>
    <expand macro="requirements">
        <requirement type="package" version="1.83">biopython</requirement>
        <requirement type="package" version="3.7.2">matplotlib</requirement>
        <requirement type="package" version="0.13.2">seaborn</requirement>
        <requirement type="package" version="2.32.3">requests</requirement>
    </expand>
    <command detect_errors="exit_code"><![CDATA[python '$script_file' > '$output']]></command>
<configfiles>
        <configfile name="script_file"><![CDATA[import RNAformer
import os
import argparse
import torch
import urllib.request
import logging
from collections import defaultdict
import torch.cuda
import loralib as lora
from RNAformer.model.RNAformer import RiboFormer
from RNAformer.utils.configuration import Config

from Bio import SeqIO

import logging
import sys

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
import shutil
import datetime
import requests

def is_valid_rna_sequence(sequence):
    """Check if the sequence contains only RNA bases."""
    valid_bases = {'A', 'C', 'G', 'U', 'N'}  # Include 'N' if unknown bases are allowed
    return all(base in valid_bases for base in sequence.upper())

def download_file(url, destination):
    response = requests.get(url, stream=True)
    response.raise_for_status() 
    with open(destination, 'wb') as file:
        for chunk in response.iter_content(chunk_size=8192):
            file.write(chunk)

model_url = "https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth"
config_url = "https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml"

model_file_path = "./model/RNAformer_32M_state_dict_intra_family_finetuned.pth"
config_file_path = "./model/RNAformer_32M_config_intra_family_finetuned.yml"

os.makedirs("./model", exist_ok=True)

model = '$model'
model_config = '$model_config'

if model and os.path.exists(model):
    model_file_path = model
else:
    download_file(model_url, model_file_path)

if model_config and os.path.exists(model_config):
    config_file_path = model_config
else:
    download_file(config_url, config_file_path)

config = Config(config_file=config_file_path)
config.RNAformer.cycling = 6
model = RiboFormer(config.RNAformer)
state_dict_file = model_file_path

matrix_color_type = '$matrix_color_type'
matrix_out_type = '$matrix_out_type'

#if str($input_type.input_type) == 'True'
fasta_path = '$input_type.fasta_input'
sequences = [str(record.seq) for record in SeqIO.parse(fasta_path, 'fasta')]
#else
sequence_string = "$input_type.rna_input_string"
sequences = [seq.strip() for seq in sequence_string.split(',')]
#end if

for seq in sequences:
    if not is_valid_rna_sequence(seq):
        print(f"Invalid RNA sequence detected: {seq}. Please ensure only RNA sequences are used as input.", file=sys.stderr)
        sys.exit(1)

lora_config = {
    "r": config.r,
    "lora_alpha": config.lora_alpha,
    "lora_dropout": config.lora_dropout,
}

with torch.no_grad():
    for name, module in model.named_modules():
        if any(replace_key in name for replace_key in config.replace_layer):
            parent = model.get_submodule(".".join(name.split(".")[:-1]))
            target_name = name.split(".")[-1]
            target = model.get_submodule(name)
            if isinstance(target, torch.nn.Linear) and "qkv" in name:
                new_module = lora.MergedLinear(target.in_features, target.out_features,
                                            bias=target.bias is not None,
                                            enable_lora=[True, True, True], **lora_config)
                new_module.weight.copy_(target.weight)
                if target.bias is not None:
                    new_module.bias.copy_(target.bias)
            elif isinstance(target, torch.nn.Linear):
                new_module = lora.Linear(target.in_features, target.out_features,
                                        bias=target.bias is not None, **lora_config)
                new_module.weight.copy_(target.weight)
                if target.bias is not None:
                    new_module.bias.copy_(target.bias)

            elif isinstance(target, torch.nn.Conv2d):
                kernel_size = target.kernel_size[0]
                new_module = lora.Conv2d(target.in_channels, target.out_channels, kernel_size,
                                        padding=(kernel_size - 1) // 2, bias=target.bias is not None,
                                        **lora_config)

                new_module.conv.weight.copy_(target.weight)
                if target.bias is not None:
                    new_module.conv.bias.copy_(target.bias)
            setattr(parent, target_name, new_module)

state_dict = torch.load(state_dict_file, map_location=torch.device('cpu'))

model.load_state_dict(state_dict, strict=True)
model_name = state_dict_file.split(".pth")[0]

if torch.cuda.is_available():
        model = model.cuda()

        # check GPU can do bf16
        if torch.cuda.is_bf16_supported():
            model = model.bfloat16()
        else:
            model = model.half()

model.eval()
predicted_structures = []

orig_seq = ""

job_name = '$job_name'

if job_name == "":
    job_name = f"RNAformer_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}"

count = 0
os.makedirs(f'./{job_name}', exist_ok=True)

total_output_buffer = []

for sequence in sequences:
    output_buffer = []
    seq_dir = f'./{job_name}/{job_name}_{count}'
    os.makedirs(seq_dir, exist_ok=True)
    with torch.no_grad():
        device = "cpu"

        seq_vocab = ['A', 'C', 'G', 'U', 'N']
        seq_stoi = dict(zip(seq_vocab, range(len(seq_vocab))))

        pdb_sample = 1

        length = len(sequence)
        src_seq = torch.LongTensor(list(map(seq_stoi.get, sequence)))

        orig_seq = sequence

        sample = {}
        sample['src_seq'] = src_seq.clone()
        sample['length'] = torch.LongTensor([length])[0]
        sample['pdb_sample'] = torch.LongTensor([pdb_sample])[0]

        sequence = sample['src_seq'].unsqueeze(0).to(device)
        src_len = torch.LongTensor([sequence.shape[-1]]).to(device)
        pdb_sample = torch.FloatTensor([[1]]).to(device)

        logits, pair_mask = model(sequence, src_len, pdb_sample)

        pred_mat = torch.sigmoid(logits[0, :, :, -1]) > 0.5

        pos_id = torch.where(pred_mat == True)
        pos1_id = pos_id[0].cpu().tolist()
        pos2_id = pos_id[1].cpu().tolist()
        predicted_structure = f"Pairing index 1: {pos1_id} \nPairing index 2: {pos2_id}"
        pairs = [[a, b] for a, b in zip(pos1_id, pos2_id)]

        seqlen = len(sample['src_seq'])
        dot_bracket =['.'] * seqlen
        pk_count = 0
        pk_list = []
        for i in range(len(pos1_id)):
            open_index = pos1_id[i]
            close_index = pos2_id[i]

            if 0 <= open_index < len(dot_bracket) and 0 <= close_index < len(dot_bracket):
                if dot_bracket[open_index] == '.' and dot_bracket[close_index] == '.':
                    dot_bracket[open_index] = '('
                    dot_bracket[close_index] = ')'
                else:
                    ## pseudoknots or multiplets present in structure- cannot represent with dot-bracket
                    pk_count += 1
                    pk_list.append(pairs[i])
        dot_bracket_str_pred = ''.join(dot_bracket)

        output_buffer.append(f"Job name: {job_name}_{count}\n")
        output_buffer.append(f"Sequence: {orig_seq}\n")
        output_buffer.append(f"Length: {len(orig_seq)}\n")
        output_buffer.append(f"Base pairs: {str(pairs)}\n")
        output_buffer.append(f"Predicted Structure: {dot_bracket_str_pred}\n")

        if pk_count > 0:
            output_buffer.append(f"NOTE: {pk_count} pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation: {pk_list}\n")

        heatmaps_dir = './heatmaps'
        os.makedirs(heatmaps_dir, exist_ok=True)
        plt.figure(figsize=(12, 10))

        if matrix_color_type == 'color':
            raw_pred_mat = torch.sigmoid(logits[0, :, :, -1])
            sns.heatmap(raw_pred_mat, cmap="inferno", vmin=0.0, vmax=1.0,
                        xticklabels=list(orig_seq),
                        yticklabels=list(orig_seq)
            )
            plt.title(f"RNAformer Base-pair Probability Matrix")
            plt.xticks(rotation=90)
            plt.tight_layout()
            plt.savefig(f'{seq_dir}/RNAformer_structure_adjacency_matrix_color_{count}.{matrix_out_type}', dpi=150)
            plt.close()
        else:
            sns.heatmap(pred_mat, cmap="gray", vmin=0.0, vmax=1.0,
                        xticklabels=list(orig_seq),
                        yticklabels=list(orig_seq)
            )
            plt.title(f"RNAformer Base-pair Binary Probability Matrix")
            plt.xticks(rotation=90)
            plt.tight_layout()
            plt.savefig(f'{seq_dir}/RNAformer_structure_adjacency_matrix_binary_{count}.{matrix_out_type}', dpi=150)
            plt.close()

        full_text = "".join(output_buffer)
        with open(f"{seq_dir}/RNAformer_output.txt", "w", encoding="utf-8") as txt_file:
            txt_file.write(full_text)
            txt_file.write("\n") 

        total_output_buffer.append(full_text)
        total_output_buffer.append("\n")

    full_job_output = "".join(total_output_buffer)
    with open(f"./RNAformer_job_output.txt", "w", encoding="utf-8") as txt_file:
            txt_file.write(full_job_output)

    count += 1

shutil.make_archive('output', "zip", job_name)

]]></configfile>
    </configfiles>
    <inputs>
        <conditional name="input_type">
            <param name="input_type" type="select" label="Input from FASTA file">
                <option value="False">Provide a single RNA sequence string as text</option>
                <option value="True">Provide a FASTA file</option>
            </param>
            <when value="False">
                <param name="rna_input_string" label="Sequence(s) to fold" type="text" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA" help="Enter RNA sequences. Separate multiple RNA sequences by commas.">
                    <sanitizer>
                        <valid>
                            <add value="ACGUacgu,"/>
                        </valid>
                    </sanitizer>
                </param>
            </when>
            <when value="True">
                <param format="fasta" name="fasta_input" type="data" label="Sequence to fold (FASTA file)"/>
            </when>
        </conditional>
            <param name="job_name" type="text" label="Job name" value="" help="Please edit job name for output files. Default will be RNAformer_{date}_{time}."/>
            <param name="model" type="data" format="binary" value="None" optional="true" label="Model" help="Manually download saved RNAformer model file to save time: https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth"/>
            <param name="model_config" type="data" format="yml" value="None" optional="true" label="Model configuration" help="Manually download saved RNAformer model configuration file to save time: https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml"/>
            <param name="matrix_color_type" type="select" label="Coloring of base probability matrix">
                <option value="color">Color: base pair probabilities are shown and colored</option>
                <option value="binary">Binary: probabilities are binarized to show only final predicted structure without probabilities</option>
            </param>
            <param name="matrix_out_type" type="select" label="Base probability adjacency matrix heatmap format">
                <option value="pdf">PDF</option>
                <option value="eps">EPS</option>
                <option value="png">PNG</option>
                <option value="svg">SVG</option>
            </param>
    </inputs>
    <outputs>
        <data name="output" format="txt" label="RNAformer Base Pair Predictions" from_work_dir="RNAformer_job_output.txt"/>
        <data name="output_files" format="zip" label="RNAformer Predicted Structures Output" from_work_dir="output.zip"/>
    </outputs>
    <tests>
        <!-- Test 1: Single sequence as text input, color matrix PDF -->
        <test>
            <param name="input_type" value="False"/>
            <param name="rna_input_string" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA"/>
            <param name="job_name" value="RNAformer_Prediction_Test_1"/>
            <param name="matrix_out_type" value="pdf"/>
            <param name="matrix_color_type" value="color"/>
            <output name="output" file="RNAformer_job_output.txt"/>
            <output name="output_files">
                <assert_contents>
                    <has_size min="1" /> 
                </assert_contents>
            </output>
        </test>
    </tests>
    <help><![CDATA[
    **RNAformer**
        The tool reads RNA sequences and predicts their secondary structure using RNAformer. Note that unlike conventional methods, RNAformer is capable of predicting all possible secondary structure motifs, including pseudoknots and multiplets. These currently will not be represented in dot-bracket notation and thus the output will be partial in these cases, excluding these which will be noted in the output file below the (partial) dot-bracket structure. However, the full structure will be represented in the adjacency matrix heatmap. Tip: To speed up inference time, you can manually download the model and model configuration files from the following URLs (select "Upload Data" -> "Paste/Fetch Data" and copy the following URLs): `https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml`
    
    **Input format**

    RNAformer requires one or more RNA sequences either as a single FASTA file or as plain text.

    **Outputs**

    - Predicted secondary structures as a text file for all sequences provided containing the following:
        - Job name (with index based on order of input sequences in either plain text or FASTA file)
        - RNA input sequence
        - Length of input sequence
        - Base pairs of predicted secondary structure
        - Predicted secondary structure in dot-bracket notation (excluding pseudoknots and multiplets)
        - Optional: pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation
    - A zip file containing:
        - Sub-directory for each input sequence with the name <job_name>_<index> containing:
            - Predicted secondary structure text file (same as above)
            - Heatmap of base pair probability matrix


    ]]></help>
    <expand macro="citations" />
</tool>