view get_db_info.py @ 2:11a3752feb0a draft default tip

planemo upload for repository https://github.com/brsynth/galaxytools/tree/main/tools commit 7f5d8b62d749a0c41110cd9c04e0254e4fd44893-dirty
author tduigou
date Wed, 15 Oct 2025 12:33:41 +0000
parents 7680420caf9f
children
line wrap: on
line source

import subprocess
import argparse
import time
import json
import os
import socket
import re
from Bio.Seq import Seq
import pandas as pd
from Bio.SeqRecord import SeqRecord
from sqlalchemy import create_engine, inspect
from sqlalchemy.engine.url import make_url
from sqlalchemy.sql import text
from sqlalchemy.exc import OperationalError


def fix_db_uri(uri):
    """Replace __at__ with @ in the URI if needed."""
    return uri.replace("__at__", "@")


def is_port_in_use(uri):
    """Check if a TCP port is already in use on host."""
    url = make_url(uri)
    host = url.host
    port = url.port
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.settimeout(2)
        return s.connect_ex((host, port)) == 0


def extract_db_name(uri):
    """Extract the database name from the SQLAlchemy URI."""
    url = make_url(uri)
    return url.database


# this fuction is to activate the Docker id the DB is in container. BUT IT IS NOT USED IN MAIN()
def start_postgres_container(db_name):
    """Start a PostgreSQL container with the given database name as the container name."""
    container_name = db_name

    # Check if container is already running
    container_running = subprocess.run(
        f"docker ps -q -f name={container_name}", shell=True, capture_output=True, text=True
    )

    if container_running.stdout.strip():
        print(f"Container '{container_name}' is already running.")
        return

    # Check if container exists (stopped)
    container_exists = subprocess.run(
        f"docker ps -a -q -f name={container_name}", shell=True, capture_output=True, text=True
    )

    if container_exists.stdout.strip():
        print(f"Starting existing container '{container_name}'...")
        subprocess.run(f"docker start {container_name}", shell=True)
        print(f"PostgreSQL Docker container '{container_name}' activated.")
        return

    # If container does not exist, create and start a new one
    port = 5432 if not is_port_in_use(5432) else 5433
    postgres_password = os.getenv("POSTGRES_PASSWORD", "RK17")

    start_command = [
        "docker", "run", "--name", container_name,
        "-e", f"POSTGRES_PASSWORD={postgres_password}",
        "-p", f"{port}:5432",
        "-d", "postgres"
    ]

    try:
        subprocess.run(start_command, check=True)
        print(f"PostgreSQL Docker container '{container_name}' started on port {port}.")
    except subprocess.CalledProcessError as e:
        print(f"Failed to start Docker container: {e}")


def wait_for_db(uri, timeout=60):
    """Try connecting to the DB until it works or timeout."""
    engine = create_engine(uri)
    start_time = time.time()
    while time.time() - start_time < timeout:
        try:
            with engine.connect():
                print("Connected to database.")
                return
        except OperationalError:
            print("Database not ready, retrying...")
            time.sleep(2)
    raise Exception("Database connection failed after timeout.")


def fetch_annotations(csv_file, sequence_column, annotation_columns, db_uri, table_name, fragment_column_name, output, output_report):
    """Fetch annotations from the database and save the result as GenBank files."""
    db_uri = fix_db_uri(db_uri)
    df = pd.read_csv(csv_file, sep=',', header=None)

    engine = create_engine(db_uri)
    connection = engine.connect()

    annotated_data = []

    try:
        with connection:
            inspector = inspect(engine)
            columns = [column['name'] for column in inspector.get_columns(table_name)]

            # Fetch all fragments from the table once
            if fragment_column_name not in columns:
                raise ValueError(f"Fragment column '{fragment_column_name}' not found in table '{table_name}'.")

            fragment_column_index = columns.index(fragment_column_name)
            all_rows = connection.execute(text(f"SELECT * FROM {table_name}")).fetchall()
            fragment_map = {row[fragment_column_index]: row for row in all_rows}

            # Compare fragments between CSV and DB
            csv_fragments = set()
            all_ids = set(df[0].dropna().astype(str))
            for _, row in df.iterrows():
                for col in df.columns:
                    if col != 0:
                        fragment = row[col]
                        if pd.notna(fragment):
                            fragment_str = str(fragment)
                            if fragment_str not in all_ids:
                                csv_fragments.add(fragment_str)

            db_fragments = set(fragment_map.keys())
            missing_fragments = sorted(list(csv_fragments - db_fragments))

            # Write report file
            with open(output_report, "w") as report_file:
                if missing_fragments:
                    for frag in missing_fragments:
                        report_file.write(f"{frag}\n")
                else:
                    report_file.write("")

            # === CONTINUE WITH GB FILE CREATION ===
            for _, row in df.iterrows():
                annotated_row = {"Backbone": row[0], "Fragments": []}
                for col in df.columns:
                    if col != 0:
                        fragment = row[col]
                        if fragment not in csv_fragments:
                            continue
                        db_row = fragment_map.get(fragment)

                        if db_row:
                            fragment_data = {"id": fragment}
                            for i, column_name in enumerate(columns[1:]):  # skip ID column
                                fragment_data[column_name] = db_row[i + 1]
                        else:
                            fragment_data = {"id": fragment, "metadata": "No data found"}

                        annotated_row["Fragments"].append(fragment_data)

                annotated_data.append(annotated_row)

    except Exception as e:
        print(f"Error occurred during annotation: {e}")
        raise  # Ensures the error exits the script

    # GenBank file generation per fragment
    try:
        for annotated_row in annotated_data:
            backbone_id = annotated_row["Backbone"]

            for fragment in annotated_row["Fragments"]:
                fragment_id = fragment["id"]

                # Skip generation for missing fragments
                if fragment_id in missing_fragments:
                    continue

                sequence = fragment.get(sequence_column, "")
                annotation = fragment.get(annotation_columns, "")

                # Create the SeqRecord
                record = SeqRecord(
                    Seq(sequence),
                    id=fragment_id,
                    name=fragment_id,
                    description=f"Fragment {fragment_id} from Backbone {backbone_id}"
                )

                # Add annotations to GenBank header
                record.annotations = {
                    k: str(fragment[k]) for k in annotation_columns if k in fragment
                }

                # LOCUS line extraction from annotation
                locus_line_match = re.search(r"LOCUS\s+.+", annotation)
                if locus_line_match:
                    locus_line = locus_line_match.group()
                else:
                    print(f"LOCUS info missing for fragment {fragment_id}")
                    locus_line = f"LOCUS       {fragment_id: <20} {len(sequence)} bp    DNA     linear   UNK 01-JAN-2025"

                # Format sequence
                if "ORIGIN" in sequence:
                    origin_block = sequence.strip()
                else:
                    formatted_sequence = "ORIGIN\n"
                    seq_str = str(record.seq)
                    for i in range(0, len(seq_str), 60):
                        line_seq = seq_str[i:i + 60]
                        formatted_sequence += f"{str(i + 1).rjust(9)} { ' '.join([line_seq[j:j+10] for j in range(0, len(line_seq), 10)]) }\n"
                    origin_block = formatted_sequence.strip()

                # Extract FEATURES section
                features_section = ""
                features_start = annotation.find("FEATURES")
                if features_start != -1:
                    features_section = annotation[features_start:]

                # Write GenBank file
                if not os.path.exists(output):
                    os.makedirs(output)

                gb_filename = os.path.join(output, f"{fragment_id}.gb")
                with open(gb_filename, "w") as f:
                    f.write(locus_line + "\n")
                    f.write(f"DEFINITION  {record.description}\n")
                    f.write(f"ACCESSION   {record.id}\n")
                    f.write(f"VERSION     DB\n")
                    f.write(f"KEYWORDS    .\n")
                    f.write(f"SOURCE      .\n")
                    f.write(features_section)
                    f.write(origin_block + "\n")
                    f.write("//\n")

    except Exception as e:
        print(f"Error saving GenBank files: {e}")
        return


def main():
    parser = argparse.ArgumentParser(description="Fetch annotations from PostgreSQL database and save as JSON.")
    parser.add_argument("--input", required=True, help="Input CSV file")
    parser.add_argument("--use_json_paramers", required=True, help="Use parameters from JSON: true/false")
    parser.add_argument("--sequence_column", required=False, help="DB column contains sequence for ganbank file")
    parser.add_argument("--annotation_columns", required=False, help="DB column contains head for ganbank file")
    parser.add_argument("--db_uri", required=False, help="Database URI connection string")
    parser.add_argument("--table", required=False, help="Table name in the database")
    parser.add_argument("--fragment_column", required=False, help="Fragment column name in the database")
    parser.add_argument("--output", required=True, help="Output dir for gb files")
    parser.add_argument("--json_conf", required=False, help="JSON config file with DB parameters")
    parser.add_argument("--report", required=True, help="Output report for fragments checking in DB")
    args = parser.parse_args()
    
    # get param and chek for json
    config_params = {}
    use_json = args.use_json_paramers == 'true'
    if use_json:
        if not args.json_conf:
            raise ValueError("You must provide --json_conf when --use_json_paramers is 'true'")
        with open(args.json_conf, "r") as f:
            config_params = json.load(f)
    else:
        config_params = {
            "table": args.table,
            "sequence_column": args.sequence_column,
            "annotation_column": args.annotation_columns,
            "fragment_column": args.fragment_column,
            "db_uri": args.db_uri,
        }

    # Extract final resolved parameters
    table = config_params["table"]
    sequence_column = config_params["sequence_column"]
    annotation_column = config_params["annotation_column"]
    fragment_column = config_params["fragment_column"]
    db_uri = fix_db_uri(config_params["db_uri"])

    # Wait until the database is ready
    db_uri = fix_db_uri(db_uri)
    # db_name = extract_db_name(db_uri)
    # start_postgres_container(db_name)
    MAX_RETRIES = 3
    for attempt in range(1, MAX_RETRIES + 1):
        try:
            wait_for_db(db_uri)
            break  # Success
        except Exception as e:
            if attempt == MAX_RETRIES:
                print(f"Attempt {attempt} failed: Could not connect to database at {db_uri}.")
                raise e
            else:
                time.sleep(2)

    # Fetch annotations from the database and save as gb
    fetch_annotations(args.input, sequence_column, annotation_column, db_uri, table, fragment_column, args.output, args.report)

if __name__ == "__main__":
    main()