diff gstf_preparation.py @ 9:f4acbfe8d6fe draft

planemo upload for repository https://github.com/TGAC/earlham-galaxytools/tree/master/tools/gstf_preparation commit 2f56285b1ef694d732c8b2637e3e924f8a626e55
author earlhaminst
date Wed, 17 Oct 2018 07:31:29 -0400
parents 92f3966d5bc3
children e8e75a79de59
line wrap: on
line diff
--- a/gstf_preparation.py	Wed May 16 20:03:57 2018 -0400
+++ b/gstf_preparation.py	Wed Oct 17 07:31:29 2018 -0400
@@ -264,10 +264,10 @@
 
     cur.execute('SELECT species, seq_region_name FROM transcript_species WHERE transcript_id=?',
                 (transcript_id, ))
-    results = cur.fetchone()
-    if not results:
-        return None
-    return results
+    row = cur.fetchone()
+    if not row:
+        return (None, None)
+    return row
 
 
 def fetch_gene_id_for_transcript(conn, transcript_id):
@@ -275,17 +275,18 @@
 
     cur.execute('SELECT gene_id FROM transcript WHERE transcript_id=?',
                 (transcript_id, ))
-    results = cur.fetchone()
-    if not results:
+    row = cur.fetchone()
+    if not row:
         return None
-    return results[0]
+    return row[0]
 
 
-def remove_id_version(s):
+def remove_id_version(s, force=False):
     """
-    Remove the optional '.VERSION' from an Ensembl id.
+    Remove the optional '.VERSION' from an id if it's an Ensembl id or if
+    `force` is True.
     """
-    if s.startswith('ENS'):
+    if force or s.startswith('ENS'):
         return s.split('.')[0]
     else:
         return s
@@ -358,7 +359,7 @@
                     print("Line %i in file '%s': %s" % (i, filename, e), file=sys.stderr)
 
         for unimplemented_feature, nlines in unimplemented_feature_nlines_dict.items():
-            print("Skipped %d lines in file '%s': '%s' is not an implemented feature type" % (nlines, filename, unimplemented_feature), file=sys.stderr)
+            print("Skipped %d lines in GFF3 file '%s': '%s' is not an implemented feature type" % (nlines, filename, unimplemented_feature), file=sys.stderr)
 
         join_dicts(gene_dict, transcript_dict, exon_parent_dict, cds_parent_dict, five_prime_utr_parent_dict, three_prime_utr_parent_dict)
         write_gene_dict_to_db(conn, gene_dict)
@@ -367,47 +368,68 @@
         with open(json_arg) as f:
             write_gene_dict_to_db(conn, json.load(f))
 
-    if options.longestCDS:
-        gene_transcripts_dict = dict()
-        for fasta_arg in options.fasta:
-            for entry in FASTAReader_gen(fasta_arg):
-                # Extract the transcript id by removing everything after the first space and then removing the version if it is an Ensembl id
-                transcript_id = remove_id_version(entry.header[1:].lstrip().split(' ')[0])
+    # Read the FASTA files a first time to:
+    # - determine for each file if we need to force the removal of the version
+    #   from the transcript id
+    # - fill gene_transcripts_dict when keeping only the longest CDS per gene
+    force_remove_id_version_file_list = []
+    gene_transcripts_dict = dict()
+    for fasta_arg in options.fasta:
+        force_remove_id_version = False
+        found_gene_transcript = False
+        for entry in FASTAReader_gen(fasta_arg):
+            # Extract the transcript id by removing everything after the first space and then removing the version if needed
+            transcript_id = remove_id_version(entry.header[1:].lstrip().split(' ')[0], force_remove_id_version)
 
-                if len(entry.sequence) % 3 != 0:
-                    print("Transcript '%s' in file '%s' has a coding sequence length which is not multiple of 3" % (transcript_id, fasta_arg), file=sys.stderr)
-                    continue
+            if len(entry.sequence) % 3 != 0:
+                continue
 
+            gene_id = fetch_gene_id_for_transcript(conn, transcript_id)
+            if not gene_id and not found_gene_transcript:
+                # We have not found a proper gene transcript in this file yet,
+                # try to force the removal of the version from the transcript id
+                transcript_id = remove_id_version(entry.header[1:].lstrip().split(' ')[0], True)
                 gene_id = fetch_gene_id_for_transcript(conn, transcript_id)
-                if not gene_id:
-                    print("Transcript '%s' in file '%s' not found in the gene feature information" % (transcript_id, fasta_arg), file=sys.stderr)
-                    continue
+                # Remember that we need to force the removal for this file
+                if gene_id:
+                    force_remove_id_version = True
+                    force_remove_id_version_file_list.append(fasta_arg)
+                    print("Forcing removal of id version in FASTA file '%s'" % fasta_arg, file=sys.stderr)
+            if not gene_id:
+                print("Transcript '%s' in FASTA file '%s' not found in the gene feature information" % (transcript_id, fasta_arg), file=sys.stderr)
+                continue
+            if options.longestCDS:
+                found_gene_transcript = True
+            else:
+                break
 
-                if gene_id in gene_transcripts_dict:
-                    gene_transcripts_dict[gene_id].append((transcript_id, len(entry.sequence)))
-                else:
-                    gene_transcripts_dict[gene_id] = [(transcript_id, len(entry.sequence))]
+            if gene_id in gene_transcripts_dict:
+                gene_transcripts_dict[gene_id].append((transcript_id, len(entry.sequence)))
+            else:
+                gene_transcripts_dict[gene_id] = [(transcript_id, len(entry.sequence))]
 
-        # For each gene, select the transcript with the longest sequence
-        # If more than one transcripts have the same longest sequence for a gene, the
-        # first one to appear in the FASTA file is selected
+    if options.longestCDS:
+        # For each gene, select the transcript with the longest sequence.
+        # If more than one transcripts have the same longest sequence for a
+        # gene, the first one to appear in the FASTA file is selected
         selected_transcript_ids = [max(transcript_id_lengths, key=lambda _: _[1])[0] for transcript_id_lengths in gene_transcripts_dict.values()]
 
     regions = [_.strip().lower() for _ in options.regions.split(",")]
     with open(options.of, 'w') as output_fasta_file, open(options.ff, 'w') as filtered_fasta_file:
         for fasta_arg in options.fasta:
+            force_remove_id_version = fasta_arg in force_remove_id_version_file_list
             for entry in FASTAReader_gen(fasta_arg):
-                transcript_id = remove_id_version(entry.header[1:].lstrip().split(' ')[0])
+                transcript_id = remove_id_version(entry.header[1:].lstrip().split(' ')[0], force_remove_id_version)
                 if options.longestCDS and transcript_id not in selected_transcript_ids:
                     continue
 
                 if len(entry.sequence) % 3 != 0:
-                    print("Transcript '%s' in file '%s' has a coding sequence length which is not multiple of 3" % (transcript_id, fasta_arg), file=sys.stderr)
+                    print("Transcript '%s' in FASTA file '%s' has a coding sequence length which is not multiple of 3" % (transcript_id, fasta_arg), file=sys.stderr)
                     continue
 
                 species_for_transcript, seq_region_for_transcript = fetch_species_and_seq_region_for_transcript(conn, transcript_id)
                 if not species_for_transcript:
-                    print("Transcript '%s' in file '%s' not found in the gene feature information" % (transcript_id, fasta_arg), file=sys.stderr)
+                    print("Transcript '%s' in FASTA file '%s' not found in the gene feature information" % (transcript_id, fasta_arg), file=sys.stderr)
                     continue
 
                 if options.headers: