Mercurial > repos > tduigou > save_to_db
comparison save_to_db.py @ 0:7ff266aecf01 draft default tip
planemo upload for repository https://github.com/brsynth/galaxytools/tree/main/tools commit 3401816c949b538bd9c67e61cbe92badff6a4007-dirty
| author | tduigou |
|---|---|
| date | Wed, 11 Jun 2025 09:42:24 +0000 |
| parents | |
| children |
comparison
equal
deleted
inserted
replaced
| -1:000000000000 | 0:7ff266aecf01 |
|---|---|
| 1 import subprocess | |
| 2 import argparse | |
| 3 import time | |
| 4 import os | |
| 5 import socket | |
| 6 import re | |
| 7 import json | |
| 8 from sqlalchemy import create_engine, inspect | |
| 9 from sqlalchemy.engine.url import make_url | |
| 10 from sqlalchemy.sql import text | |
| 11 from sqlalchemy.exc import OperationalError | |
| 12 | |
| 13 | |
| 14 def resolve_parameters(user_params: dict, json_params: dict, keys: list): | |
| 15 resolved = {} | |
| 16 for key in keys: | |
| 17 # Prefer user parameter if it's provided (not None or empty string) | |
| 18 if key in user_params and user_params[key]: | |
| 19 resolved[key] = user_params[key] | |
| 20 else: | |
| 21 resolved[key] = json_params.get(f"JSON_{key}") | |
| 22 return resolved | |
| 23 | |
| 24 | |
| 25 def fix_db_uri(uri): | |
| 26 """Replace __at__ with @ in the URI if needed.""" | |
| 27 return uri.replace("__at__", "@") | |
| 28 | |
| 29 | |
| 30 def is_port_in_use(uri): | |
| 31 """Check if a TCP port is already in use on host.""" | |
| 32 url = make_url(uri) | |
| 33 host = url.host | |
| 34 port = url.port | |
| 35 with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: | |
| 36 s.settimeout(2) | |
| 37 return s.connect_ex((host, port)) == 0 | |
| 38 | |
| 39 def extract_db_name(uri): | |
| 40 """Extract the database name from the SQLAlchemy URI.""" | |
| 41 url = make_url(uri) | |
| 42 return url.database | |
| 43 | |
| 44 | |
| 45 # this fuction is to activate the Docker id the DB is in container. BUT IT IS NOT USED IN MAIN() | |
| 46 def start_postgres_container(db_name): | |
| 47 """Start a PostgreSQL container with the given database name as the container name.""" | |
| 48 container_name = db_name | |
| 49 | |
| 50 # Check if container is already running | |
| 51 container_running = subprocess.run( | |
| 52 f"docker ps -q -f name={container_name}", shell=True, capture_output=True, text=True | |
| 53 ) | |
| 54 | |
| 55 if container_running.stdout.strip(): | |
| 56 print(f"Container '{container_name}' is already running.") | |
| 57 return | |
| 58 | |
| 59 # Check if container exists (stopped) | |
| 60 container_exists = subprocess.run( | |
| 61 f"docker ps -a -q -f name={container_name}", shell=True, capture_output=True, text=True | |
| 62 ) | |
| 63 | |
| 64 if container_exists.stdout.strip(): | |
| 65 print(f"Starting existing container '{container_name}'...") | |
| 66 subprocess.run(f"docker start {container_name}", shell=True) | |
| 67 print(f"PostgreSQL Docker container '{container_name}' activated.") | |
| 68 return | |
| 69 | |
| 70 # If container does not exist, create and start a new one | |
| 71 port = 5432 if not is_port_in_use(5432) else 5433 | |
| 72 postgres_password = os.getenv("POSTGRES_PASSWORD", "RK17") | |
| 73 | |
| 74 start_command = [ | |
| 75 "docker", "run", "--name", container_name, | |
| 76 "-e", f"POSTGRES_PASSWORD={postgres_password}", | |
| 77 "-p", f"{port}:5432", | |
| 78 "-d", "postgres" | |
| 79 ] | |
| 80 | |
| 81 try: | |
| 82 subprocess.run(start_command, check=True) | |
| 83 print(f"PostgreSQL Docker container '{container_name}' started on port {port}.") | |
| 84 except subprocess.CalledProcessError as e: | |
| 85 print(f"Failed to start Docker container: {e}") | |
| 86 | |
| 87 | |
| 88 def wait_for_db(uri, timeout=60): | |
| 89 """Try connecting to the DB until it works or timeout.""" | |
| 90 engine = create_engine(uri) | |
| 91 start_time = time.time() | |
| 92 while time.time() - start_time < timeout: | |
| 93 try: | |
| 94 with engine.connect(): | |
| 95 print("Connected to database.") | |
| 96 return | |
| 97 except OperationalError: | |
| 98 print("Database not ready, retrying...") | |
| 99 time.sleep(2) | |
| 100 raise Exception("Database connection failed after timeout.") | |
| 101 | |
| 102 | |
| 103 def push_gb_annotations(gb_files, sequence_column, annotation_column, db_uri, table_name, fragment_column_name, output, file_name_mapping): | |
| 104 """Push GenBank file content into the database if the fragment is not already present.""" | |
| 105 db_uri = fix_db_uri(db_uri) | |
| 106 engine = create_engine(db_uri) | |
| 107 inserted_fragments = [] | |
| 108 | |
| 109 try: | |
| 110 # Parse the file_name_mapping string into a dictionary {base_file_name: fragment_name} | |
| 111 file_name_mapping_dict = { | |
| 112 os.path.basename(path): os.path.splitext(fragment_name)[0] | |
| 113 for mapping in file_name_mapping.split(",") | |
| 114 for path, fragment_name in [mapping.split(":")] | |
| 115 } | |
| 116 | |
| 117 #print("File name mapping dictionary:") | |
| 118 #print(file_name_mapping_dict) # Debugging: Print the mapping dictionary | |
| 119 | |
| 120 with engine.begin() as connection: | |
| 121 inspector = inspect(engine) | |
| 122 columns = [col['name'] for col in inspector.get_columns(table_name)] | |
| 123 | |
| 124 if fragment_column_name not in columns: | |
| 125 raise ValueError(f"Fragment column '{fragment_column_name}' not found in table '{table_name}'.") | |
| 126 | |
| 127 # Get existing fragments | |
| 128 all_rows = connection.execute(text(f"SELECT {fragment_column_name} FROM {table_name}")).fetchall() | |
| 129 existing_fragments = {row[0] for row in all_rows} | |
| 130 | |
| 131 insert_rows = [] | |
| 132 | |
| 133 for gb_file in gb_files: | |
| 134 # Extract base file name (just the file name, not the full path) | |
| 135 real_file_name = os.path.basename(gb_file) | |
| 136 fragment_name = file_name_mapping_dict.get(real_file_name) | |
| 137 | |
| 138 print(f"Processing file: {real_file_name}({fragment_name})") # Debugging: Log the current file | |
| 139 | |
| 140 # Get the corresponding fragment name from the mapping | |
| 141 fragment_name = file_name_mapping_dict.get(real_file_name) | |
| 142 | |
| 143 if not fragment_name: | |
| 144 raise ValueError(f"Fragment name not found for file '{real_file_name}' in file_name_mapping.") | |
| 145 | |
| 146 # If the fragment is already in the DB, raise an error and stop the process | |
| 147 if fragment_name in existing_fragments: | |
| 148 raise RuntimeError(f"Fatal Error: Fragment '{fragment_name}' already exists in DB. Stopping the process.") | |
| 149 | |
| 150 with open(gb_file, "r") as f: | |
| 151 content = f.read() | |
| 152 | |
| 153 origin_match = re.search(r"^ORIGIN.*$", content, flags=re.MULTILINE) | |
| 154 if not origin_match: | |
| 155 raise ValueError(f"ORIGIN section not found in file: {gb_file}") | |
| 156 | |
| 157 origin_start = origin_match.start() | |
| 158 annotation_text = content[:origin_start].strip() | |
| 159 sequence_text = content[origin_start:].strip() | |
| 160 | |
| 161 values = {} | |
| 162 values[fragment_column_name] = fragment_name | |
| 163 values[annotation_column] = annotation_text | |
| 164 values[sequence_column] = sequence_text | |
| 165 | |
| 166 insert_rows.append(values) | |
| 167 inserted_fragments.append(fragment_name) | |
| 168 | |
| 169 # Insert the rows into the database | |
| 170 for values in insert_rows: | |
| 171 col_names = ", ".join(values.keys()) | |
| 172 placeholders = ", ".join([f":{key}" for key in values.keys()]) | |
| 173 insert_stmt = text(f"INSERT INTO {table_name} ({col_names}) VALUES ({placeholders})") | |
| 174 | |
| 175 # print(f"Inserting into DB: {values}") # Debugging print statement | |
| 176 connection.execute(insert_stmt, values) | |
| 177 | |
| 178 # print(f"Insert result: {result.rowcount if hasattr(result, 'rowcount') else 'N/A'}") # Debugging the row count | |
| 179 | |
| 180 print(f"Inserted {len(insert_rows)} fragments.") | |
| 181 | |
| 182 # Write inserted fragment names to a text file | |
| 183 with open(output, "w") as log_file: | |
| 184 for frag in inserted_fragments: | |
| 185 log_file.write(f"{frag}\n") | |
| 186 print(f"Fragment names written to '{output}'.") | |
| 187 | |
| 188 except Exception as e: | |
| 189 print(f"Error during GB file insertion: {e}") | |
| 190 raise | |
| 191 | |
| 192 | |
| 193 def main(): | |
| 194 parser = argparse.ArgumentParser(description="Fetch annotations from PostgreSQL database and save as JSON.") | |
| 195 parser.add_argument("--input", required=True, help="Input gb files") | |
| 196 parser.add_argument("--sequence_column", required=True, help="DB column contains sequence for ganbank file") | |
| 197 parser.add_argument("--annotation_column", required=True, help="DB column contains head for ganbank file") | |
| 198 parser.add_argument("--db_uri", required=True, help="Database URI connection string") | |
| 199 parser.add_argument("--table", required=True, help="Table name in the database") | |
| 200 parser.add_argument("--fragment_column", required=True, help="Fragment column name in the database") | |
| 201 parser.add_argument("--output", required=True, help="Text report") | |
| 202 parser.add_argument("--file_name_mapping", required=True, help="real fragments names") | |
| 203 parser.add_argument("--json_conf", required=False, help="JSON config file with DB parameters") | |
| 204 args = parser.parse_args() | |
| 205 | |
| 206 # Load JSON config if provided | |
| 207 json_config = {} | |
| 208 if args.json_conf != 'None' or '': | |
| 209 with open(args.json_conf, "r") as f: | |
| 210 json_config = json.load(f) | |
| 211 if "execution" in json_config and json_config["execution"] == "false": | |
| 212 print("Execution was blocked by config (execution = false)") | |
| 213 return | |
| 214 | |
| 215 # Prefer user input; fallback to JSON_ values if not provided | |
| 216 user_params = { | |
| 217 "table": args.table, | |
| 218 "sequence_column": args.sequence_column, | |
| 219 "annotation_column": args.annotation_column, | |
| 220 "fragment_column": args.fragment_column, | |
| 221 "db_uri": args.db_uri | |
| 222 } | |
| 223 | |
| 224 keys = ["table", "sequence_column", "annotation_column", "fragment_column", "db_uri"] | |
| 225 resolved = resolve_parameters(user_params, json_config, keys) | |
| 226 | |
| 227 # Unpack resolved parameters | |
| 228 table = resolved["table"] | |
| 229 sequence_column = resolved["sequence_column"] | |
| 230 annotation_column = resolved["annotation_column"] | |
| 231 fragment_column = resolved["fragment_column"] | |
| 232 db_uri = fix_db_uri(resolved["db_uri"]) | |
| 233 | |
| 234 # Prepare gb files | |
| 235 gb_file_list = [f.strip() for f in args.input.split(",") if f.strip()] | |
| 236 | |
| 237 # Start and wait for DB | |
| 238 # db_name = extract_db_name(db_uri) | |
| 239 # start_postgres_container(db_name) | |
| 240 MAX_RETRIES = 3 | |
| 241 for attempt in range(1, MAX_RETRIES + 1): | |
| 242 try: | |
| 243 wait_for_db(db_uri) | |
| 244 break # Success | |
| 245 except Exception as e: | |
| 246 if attempt == MAX_RETRIES: | |
| 247 print(f"Attempt {attempt} failed: Could not connect to database at {db_uri}.") | |
| 248 raise e | |
| 249 else: | |
| 250 time.sleep(2) | |
| 251 | |
| 252 # Push annotations | |
| 253 push_gb_annotations( | |
| 254 gb_file_list, | |
| 255 sequence_column, | |
| 256 annotation_column, | |
| 257 db_uri, | |
| 258 table, | |
| 259 fragment_column, | |
| 260 args.output, | |
| 261 args.file_name_mapping | |
| 262 ) | |
| 263 | |
| 264 | |
| 265 if __name__ == "__main__": | |
| 266 main() |
