Mercurial > repos > rnateam > infer_rnaformer
view infer_rnaformer.xml @ 1:6c310acd2f87 draft default tip
planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/rna_tools/rnaformer commit a0639f646e953fe490559b10bb7271a67d07cac4
author | iuc |
---|---|
date | Fri, 07 Mar 2025 15:22:24 +0000 |
parents | 02b0ecc34d9a |
children |
line wrap: on
line source
<tool id="infer_rnaformer" name="@EXECUTABLE@" version="@TOOL_VERSION@+galaxy1" profile="22.05"> <description>Predict the secondary structure of an RNA with RNAformer</description> <macros> <import>macros.xml</import> </macros> <expand macro="requirements"> <requirement type="package" version="1.83">biopython</requirement> <requirement type="package" version="3.7.2">matplotlib</requirement> <requirement type="package" version="0.13.2">seaborn</requirement> <requirement type="package" version="2.32.3">requests</requirement> </expand> <command detect_errors="exit_code"><![CDATA[python '$script_file' > '$output']]></command> <configfiles> <configfile name="script_file"><![CDATA[import RNAformer import os import argparse import torch import urllib.request import logging from collections import defaultdict import torch.cuda import loralib as lora from RNAformer.model.RNAformer import RiboFormer from RNAformer.utils.configuration import Config from Bio import SeqIO import logging import sys import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt import seaborn as sns import shutil import datetime import requests def is_valid_rna_sequence(sequence): """Check if the sequence contains only RNA bases.""" valid_bases = {'A', 'C', 'G', 'U', 'N'} # Include 'N' if unknown bases are allowed return all(base in valid_bases for base in sequence.upper()) def download_file(url, destination): response = requests.get(url, stream=True) response.raise_for_status() with open(destination, 'wb') as file: for chunk in response.iter_content(chunk_size=8192): file.write(chunk) model_url = "https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth" config_url = "https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml" model_file_path = "./model/RNAformer_32M_state_dict_intra_family_finetuned.pth" config_file_path = "./model/RNAformer_32M_config_intra_family_finetuned.yml" os.makedirs("./model", exist_ok=True) model = '$model' model_config = '$model_config' if model and os.path.exists(model): model_file_path = model else: download_file(model_url, model_file_path) if model_config and os.path.exists(model_config): config_file_path = model_config else: download_file(config_url, config_file_path) config = Config(config_file=config_file_path) config.RNAformer.cycling = 6 model = RiboFormer(config.RNAformer) state_dict_file = model_file_path matrix_color_type = '$matrix_color_type' matrix_out_type = '$matrix_out_type' #if str($input_type.input_type) == 'True' fasta_path = '$input_type.fasta_input' sequences = [str(record.seq) for record in SeqIO.parse(fasta_path, 'fasta')] #else sequence_string = "$input_type.rna_input_string" sequences = [seq.strip() for seq in sequence_string.split(',')] #end if for seq in sequences: if not is_valid_rna_sequence(seq): print(f"Invalid RNA sequence detected: {seq}. Please ensure only RNA sequences are used as input.", file=sys.stderr) sys.exit(1) lora_config = { "r": config.r, "lora_alpha": config.lora_alpha, "lora_dropout": config.lora_dropout, } with torch.no_grad(): for name, module in model.named_modules(): if any(replace_key in name for replace_key in config.replace_layer): parent = model.get_submodule(".".join(name.split(".")[:-1])) target_name = name.split(".")[-1] target = model.get_submodule(name) if isinstance(target, torch.nn.Linear) and "qkv" in name: new_module = lora.MergedLinear(target.in_features, target.out_features, bias=target.bias is not None, enable_lora=[True, True, True], **lora_config) new_module.weight.copy_(target.weight) if target.bias is not None: new_module.bias.copy_(target.bias) elif isinstance(target, torch.nn.Linear): new_module = lora.Linear(target.in_features, target.out_features, bias=target.bias is not None, **lora_config) new_module.weight.copy_(target.weight) if target.bias is not None: new_module.bias.copy_(target.bias) elif isinstance(target, torch.nn.Conv2d): kernel_size = target.kernel_size[0] new_module = lora.Conv2d(target.in_channels, target.out_channels, kernel_size, padding=(kernel_size - 1) // 2, bias=target.bias is not None, **lora_config) new_module.conv.weight.copy_(target.weight) if target.bias is not None: new_module.conv.bias.copy_(target.bias) setattr(parent, target_name, new_module) state_dict = torch.load(state_dict_file, map_location=torch.device('cpu')) model.load_state_dict(state_dict, strict=True) model_name = state_dict_file.split(".pth")[0] if torch.cuda.is_available(): model = model.cuda() # check GPU can do bf16 if torch.cuda.is_bf16_supported(): model = model.bfloat16() else: model = model.half() model.eval() predicted_structures = [] orig_seq = "" job_name = '$job_name' if job_name == "": job_name = f"RNAformer_{datetime.datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}" count = 0 os.makedirs(f'./{job_name}', exist_ok=True) total_output_buffer = [] for sequence in sequences: output_buffer = [] seq_dir = f'./{job_name}/{job_name}_{count}' os.makedirs(seq_dir, exist_ok=True) with torch.no_grad(): device = "cpu" seq_vocab = ['A', 'C', 'G', 'U', 'N'] seq_stoi = dict(zip(seq_vocab, range(len(seq_vocab)))) pdb_sample = 1 length = len(sequence) src_seq = torch.LongTensor(list(map(seq_stoi.get, sequence))) orig_seq = sequence sample = {} sample['src_seq'] = src_seq.clone() sample['length'] = torch.LongTensor([length])[0] sample['pdb_sample'] = torch.LongTensor([pdb_sample])[0] sequence = sample['src_seq'].unsqueeze(0).to(device) src_len = torch.LongTensor([sequence.shape[-1]]).to(device) pdb_sample = torch.FloatTensor([[1]]).to(device) logits, pair_mask = model(sequence, src_len, pdb_sample) pred_mat = torch.sigmoid(logits[0, :, :, -1]) > 0.5 pos_id = torch.where(pred_mat == True) pos1_id = pos_id[0].cpu().tolist() pos2_id = pos_id[1].cpu().tolist() predicted_structure = f"Pairing index 1: {pos1_id} \nPairing index 2: {pos2_id}" pairs = [[a, b] for a, b in zip(pos1_id, pos2_id)] seqlen = len(sample['src_seq']) dot_bracket =['.'] * seqlen pk_count = 0 pk_list = [] for i in range(len(pos1_id)): open_index = pos1_id[i] close_index = pos2_id[i] if 0 <= open_index < len(dot_bracket) and 0 <= close_index < len(dot_bracket): if dot_bracket[open_index] == '.' and dot_bracket[close_index] == '.': dot_bracket[open_index] = '(' dot_bracket[close_index] = ')' else: ## pseudoknots or multiplets present in structure- cannot represent with dot-bracket pk_count += 1 pk_list.append(pairs[i]) dot_bracket_str_pred = ''.join(dot_bracket) output_buffer.append(f"Job name: {job_name}_{count}\n") output_buffer.append(f"Sequence: {orig_seq}\n") output_buffer.append(f"Length: {len(orig_seq)}\n") output_buffer.append(f"Base pairs: {str(pairs)}\n") output_buffer.append(f"Predicted Structure: {dot_bracket_str_pred}\n") if pk_count > 0: output_buffer.append(f"NOTE: {pk_count} pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation: {pk_list}\n") heatmaps_dir = './heatmaps' os.makedirs(heatmaps_dir, exist_ok=True) plt.figure(figsize=(12, 10)) if matrix_color_type == 'color': raw_pred_mat = torch.sigmoid(logits[0, :, :, -1]) sns.heatmap(raw_pred_mat, cmap="inferno", vmin=0.0, vmax=1.0, xticklabels=list(orig_seq), yticklabels=list(orig_seq) ) plt.title(f"RNAformer Base-pair Probability Matrix") plt.xticks(rotation=90) plt.tight_layout() plt.savefig(f'{seq_dir}/RNAformer_structure_adjacency_matrix_color_{count}.{matrix_out_type}', dpi=150) plt.close() else: sns.heatmap(pred_mat, cmap="gray", vmin=0.0, vmax=1.0, xticklabels=list(orig_seq), yticklabels=list(orig_seq) ) plt.title(f"RNAformer Base-pair Binary Probability Matrix") plt.xticks(rotation=90) plt.tight_layout() plt.savefig(f'{seq_dir}/RNAformer_structure_adjacency_matrix_binary_{count}.{matrix_out_type}', dpi=150) plt.close() full_text = "".join(output_buffer) with open(f"{seq_dir}/RNAformer_output.txt", "w", encoding="utf-8") as txt_file: txt_file.write(full_text) txt_file.write("\n") total_output_buffer.append(full_text) total_output_buffer.append("\n") full_job_output = "".join(total_output_buffer) with open(f"./RNAformer_job_output.txt", "w", encoding="utf-8") as txt_file: txt_file.write(full_job_output) count += 1 shutil.make_archive('output', "zip", job_name) ]]></configfile> </configfiles> <inputs> <conditional name="input_type"> <param name="input_type" type="select" label="Input from FASTA file"> <option value="False">Provide a single RNA sequence string as text</option> <option value="True">Provide a FASTA file</option> </param> <when value="False"> <param name="rna_input_string" label="Sequence(s) to fold" type="text" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA" help="Enter RNA sequences. Separate multiple RNA sequences by commas."> <sanitizer> <valid> <add value="ACGUacgu,"/> </valid> </sanitizer> </param> </when> <when value="True"> <param format="fasta" name="fasta_input" type="data" label="Sequence to fold (FASTA file)"/> </when> </conditional> <param name="job_name" type="text" label="Job name" value="" help="Please edit job name for output files. Default will be RNAformer_{date}_{time}."/> <param name="model" type="data" format="binary" value="None" optional="true" label="Model" help="Manually download saved RNAformer model file to save time: https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth"/> <param name="model_config" type="data" format="yml" value="None" optional="true" label="Model configuration" help="Manually download saved RNAformer model configuration file to save time: https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml"/> <param name="matrix_color_type" type="select" label="Coloring of base probability matrix"> <option value="color">Color: base pair probabilities are shown and colored</option> <option value="binary">Binary: probabilities are binarized to show only final predicted structure without probabilities</option> </param> <param name="matrix_out_type" type="select" label="Base probability adjacency matrix heatmap format"> <option value="pdf">PDF</option> <option value="eps">EPS</option> <option value="png">PNG</option> <option value="svg">SVG</option> </param> </inputs> <outputs> <data name="output" format="txt" label="RNAformer Base Pair Predictions" from_work_dir="RNAformer_job_output.txt"/> <data name="output_files" format="zip" label="RNAformer Predicted Structures Output" from_work_dir="output.zip"/> </outputs> <tests> <!-- Test 1: Single sequence as text input, color matrix PDF --> <test> <param name="input_type" value="False"/> <param name="rna_input_string" value="GCCCGCAUGGUGAAAUCGGUAAACACAUCGCACUAAUGCGCCGCCUCUGGCUUGCCGGUUCAAGUCCGGCUGCGGGCACCA"/> <param name="job_name" value="RNAformer_Prediction_Test_1"/> <param name="matrix_out_type" value="pdf"/> <param name="matrix_color_type" value="color"/> <output name="output" file="RNAformer_job_output.txt"/> <output name="output_files"> <assert_contents> <has_size min="1" /> </assert_contents> </output> </test> </tests> <help><![CDATA[ **RNAformer** The tool reads RNA sequences and predicts their secondary structure using RNAformer. Note that unlike conventional methods, RNAformer is capable of predicting all possible secondary structure motifs, including pseudoknots and multiplets. These currently will not be represented in dot-bracket notation and thus the output will be partial in these cases, excluding these which will be noted in the output file below the (partial) dot-bracket structure. However, the full structure will be represented in the adjacency matrix heatmap. Tip: To speed up inference time, you can manually download the model and model configuration files from the following URLs (select "Upload Data" -> "Paste/Fetch Data" and copy the following URLs): `https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_state_dict_intra_family_finetuned.pth https://ml.informatik.uni-freiburg.de/research-artifacts/RNAformer/models/RNAformer_32M_config_intra_family_finetuned.yml` **Input format** RNAformer requires one or more RNA sequences either as a single FASTA file or as plain text. **Outputs** - Predicted secondary structures as a text file for all sequences provided containing the following: - Job name (with index based on order of input sequences in either plain text or FASTA file) - RNA input sequence - Length of input sequence - Base pairs of predicted secondary structure - Predicted secondary structure in dot-bracket notation (excluding pseudoknots and multiplets) - Optional: pseudoknots and/or multiplets present in predicted structure excluded from dot-bracket notation - A zip file containing: - Sub-directory for each input sequence with the name <job_name>_<index> containing: - Predicted secondary structure text file (same as above) - Heatmap of base pair probability matrix ]]></help> <expand macro="citations" /> </tool>