view validate_fasta.py @ 4:bc99b0a951ec draft

"planemo upload for repository https://github.com/usegalaxy-au/galaxy-local-tools commit 2d7c3d151feaafc3be33ebb0081ec640680fbb4d-dirty"
author galaxy-australia
date Thu, 10 Mar 2022 00:26:56 +0000
parents 6c92e000d684
children 04e95886cf24
line wrap: on
line source

"""Validate input FASTA sequence."""

import re
import argparse
from typing import List, TextIO


class Fasta:
    def __init__(self, header_str: str, seq_str: str):
        self.header = header_str
        self.aa_seq = seq_str


class FastaLoader:
    def __init__(self, fasta_path: str):
        """Initialize from FASTA file."""
        self.fastas = []
        self.load(fasta_path)
        print("Loaded FASTA sequences:")
        for f in self.fastas:
            print(f.header)
            print(f.aa_seq)

    def load(self, fasta_path: str):
        """Load bare or FASTA formatted sequence."""
        with open(fasta_path, 'r') as f:
            self.content = f.read()

        if "__cn__" in self.content:
            # Pasted content with escaped characters
            self.newline = '__cn__'
            self.caret = '__gt__'
        else:
            # Uploaded file with normal content
            self.newline = '\n'
            self.caret = '>'

        self.lines = self.content.split(self.newline)
        header, sequence = self.interpret_first_line()

        i = 0
        while i < len(self.lines):
            line = self.lines[i]
            if line.startswith(self.caret):
                self.update_fastas(header, sequence)
                header = '>' + self.strip_header(line)
                sequence = ''
            else:
                sequence += line.strip('\n ')
            i += 1

        # after reading whole file, header & sequence buffers might be full
        self.update_fastas(header, sequence)

    def interpret_first_line(self):
        line = self.lines[0]
        if line.startswith(self.caret):
            header = '>' + self.strip_header(line)
            return header, ''
        else:
            return '', line

    def strip_header(self, line):
        """Strip characters escaped with underscores from pasted text."""
        return re.sub(r'\_\_.{2}\_\_', '', line).strip('>')

    def update_fastas(self, header: str, sequence: str):
        # if we have a sequence
        if sequence:
            # create generic header if not exists
            if not header:
                fasta_count = len(self.fastas)
                header = f'>sequence_{fasta_count}'

            # Create new Fasta
            self.fastas.append(Fasta(header, sequence))


class FastaValidator:
    def __init__(self, fasta_list: List[Fasta]):
        self.fasta_list = fasta_list
        self.min_length = 30
        self.max_length = 2000
        self.iupac_characters = {
            'A', 'B', 'C', 'D', 'E', 'F', 'G',
            'H', 'I', 'K', 'L', 'M', 'N', 'P',
            'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X',
            'Y', 'Z', '-'
        }

    def validate(self):
        """performs fasta validation"""
        self.validate_num_seqs()
        self.validate_length()
        self.validate_alphabet()
        # not checking for 'X' nucleotides at the moment.
        # alphafold can throw an error if it doesn't like it.
        #self.validate_x()

    def validate_num_seqs(self) -> None:
        if len(self.fasta_list) > 1:
            raise Exception(f'Error encountered validating fasta: More than 1 sequence detected ({len(self.fasta_list)}). Please use single fasta sequence as input')
        elif len(self.fasta_list) == 0:
            raise Exception(f'Error encountered validating fasta: input file has no fasta sequences')

    def validate_length(self):
        """Confirms whether sequence length is valid. """
        fasta = self.fasta_list[0]
        if len(fasta.aa_seq) < self.min_length:
            raise Exception(f'Error encountered validating fasta: Sequence too short ({len(fasta.aa_seq)}aa). Must be > 30aa')
        if len(fasta.aa_seq) > self.max_length:
            raise Exception(f'Error encountered validating fasta: Sequence too long ({len(fasta.aa_seq)}aa). Must be < 2000aa')

    def validate_alphabet(self):
        """
        Confirms whether the sequence conforms to IUPAC codes.
        If not, reports the offending character and its position.
        """
        fasta = self.fasta_list[0]
        for i, char in enumerate(fasta.aa_seq.upper()):
            if char not in self.iupac_characters:
                raise Exception(f'Error encountered validating fasta: Invalid amino acid found at pos {i}: "{char}"')

    def validate_x(self):
        """checks if any bases are X. TODO check whether alphafold accepts X bases. """
        fasta = self.fasta_list[0]
        for i, char in enumerate(fasta.aa_seq.upper()):
            if char == 'X':
                raise Exception(f'Error encountered validating fasta: Unsupported aa code "X" found at pos {i}')


class FastaWriter:
    def __init__(self) -> None:
        self.outfile = 'alphafold.fasta'
        self.formatted_line_len = 60

    def write(self, fasta: Fasta):
        with open(self.outfile, 'w') as fp:
            header = fasta.header
            seq = self.format_sequence(fasta.aa_seq)
            fp.write(header + '\n')
            fp.write(seq + '\n')

    def format_sequence(self, aa_seq: str):
        formatted_seq = ''
        for i in range(0, len(aa_seq), self.formatted_line_len):
            formatted_seq += aa_seq[i: i + self.formatted_line_len] + '\n'
        return formatted_seq


def main():
    # load fasta file
    args = parse_args()
    fas = FastaLoader(args.input_fasta)

    # validate
    fv = FastaValidator(fas.fastas)
    fv.validate()

    # write cleaned version
    fw = FastaWriter()
    fw.write(fas.fastas[0])


def parse_args() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "input_fasta",
        help="input fasta file",
        type=str
    )
    return parser.parse_args()



if __name__ == '__main__':
    main()