diff vsnp_determine_ref_from_data.py @ 4:36bdf8b439ed draft

Uploaded
author greg
date Sun, 03 Jan 2021 16:13:22 +0000
parents ebc08e5ce646
children
line wrap: on
line diff
--- a/vsnp_determine_ref_from_data.py	Mon Nov 23 21:42:34 2020 +0000
+++ b/vsnp_determine_ref_from_data.py	Sun Jan 03 16:13:22 2021 +0000
@@ -2,30 +2,21 @@
 
 import argparse
 import gzip
-import multiprocessing
 import os
-import queue
+from collections import OrderedDict
+
 import yaml
 from Bio.SeqIO.QualityIO import FastqGeneralIterator
-from collections import OrderedDict
 
-INPUT_READS_DIR = 'input_reads'
 OUTPUT_DBKEY_DIR = 'output_dbkey'
 OUTPUT_METRICS_DIR = 'output_metrics'
 
 
-def get_base_file_name(file_path):
+def get_sample_name(file_path):
     base_file_name = os.path.basename(file_path)
     if base_file_name.find(".") > 0:
         # Eliminate the extension.
         return os.path.splitext(base_file_name)[0]
-    elif base_file_name.find("_") > 0:
-        # The dot extension was likely changed to
-        # the " character.
-        items = base_file_name.split("_")
-        no_ext = "_".join(items[0:-2])
-        if len(no_ext) > 0:
-            return no_ext
     return base_file_name
 
 
@@ -91,18 +82,6 @@
     return group, dbkey
 
 
-def get_group_and_dbkey_for_collection(task_queue, finished_queue, dnaprints_dict, timeout):
-    while True:
-        try:
-            tup = task_queue.get(block=True, timeout=timeout)
-        except queue.Empty:
-            break
-        fastq_file, count_list, brucella_string, brucella_sum, bovis_string, bovis_sum, para_string, para_sum = tup
-        group, dbkey = get_group_and_dbkey(dnaprints_dict, brucella_string, brucella_sum, bovis_string, bovis_sum, para_string, para_sum)
-        finished_queue.put((fastq_file, count_list, group, dbkey))
-        task_queue.task_done()
-
-
 def get_oligo_dict():
     oligo_dict = {}
     oligo_dict["01_ab1"] = "AATTGTCGGATAGCCTGGCGATAACGACGC"
@@ -138,7 +117,7 @@
 def get_seq_counts(value, fastq_list, gzipped):
     count = 0
     for fastq_file in fastq_list:
-        if gzipped == "true":
+        if gzipped:
             with gzip.open(fastq_file, 'rt') as fh:
                 for title, seq, qual in FastqGeneralIterator(fh):
                     count += seq.count(value)
@@ -166,17 +145,6 @@
     return count_summary, count_list, brucella_sum, bovis_sum, para_sum
 
 
-def get_species_counts_for_collection(task_queue, finished_queue, gzipped, timeout):
-    while True:
-        try:
-            fastq_file = task_queue.get(block=True, timeout=timeout)
-        except queue.Empty:
-            break
-        count_summary, count_list, brucella_sum, bovis_sum, para_sum = get_species_counts([fastq_file], gzipped)
-        finished_queue.put((fastq_file, count_summary, count_list, brucella_sum, bovis_sum, para_sum))
-        task_queue.task_done()
-
-
 def get_species_strings(count_summary):
     binary_dictionary = {}
     for k, v in count_summary.items():
@@ -197,56 +165,20 @@
     return brucella_string, bovis_string, para_string
 
 
-def get_species_strings_for_collection(task_queue, finished_queue, timeout):
-    while True:
-        try:
-            tup = task_queue.get(block=True, timeout=timeout)
-        except queue.Empty:
-            break
-        fastq_file, count_summary, count_list, brucella_sum, bovis_sum, para_sum = tup
-        brucella_string, bovis_string, para_string = get_species_strings(count_summary)
-        finished_queue.put((fastq_file, count_list, brucella_string, brucella_sum, bovis_string, bovis_sum, para_string, para_sum))
-        task_queue.task_done()
-
-
-def output_dbkey(file_name, dbkey, output_file=None):
+def output_dbkey(file_name, dbkey, output_file):
     # Output the dbkey.
-    if output_file is None:
-        # We're producing a dataset collection.
-        output_file = os.path.join(OUTPUT_DBKEY_DIR, "%s.txt" % file_name)
     with open(output_file, "w") as fh:
         fh.write("%s" % dbkey)
 
 
-def output_files(fastq_file, count_list, group, dbkey, dbkey_file=None, metrics_file=None):
-    base_file_name = get_base_file_name(fastq_file)
-    if dbkey_file is not None:
-        # We're dealing with a single read or
-        # a set of paired reads.  If the latter,
-        # the following will hopefully produce a
-        # good sample string.
-        if base_file_name.find("_") > 0:
-            base_file_name = base_file_name.split("_")[0]
+def output_files(fastq_file, count_list, group, dbkey, dbkey_file, metrics_file):
+    base_file_name = get_sample_name(fastq_file)
     output_dbkey(base_file_name, dbkey, dbkey_file)
     output_metrics(base_file_name, count_list, group, dbkey, metrics_file)
 
 
-def output_files_for_collection(task_queue, timeout):
-    while True:
-        try:
-            tup = task_queue.get(block=True, timeout=timeout)
-        except queue.Empty:
-            break
-        fastq_file, count_list, group, dbkey = tup
-        output_files(fastq_file, count_list, group, dbkey)
-        task_queue.task_done()
-
-
-def output_metrics(file_name, count_list, group, dbkey, output_file=None):
+def output_metrics(file_name, count_list, group, dbkey, output_file):
     # Output the metrics.
-    if output_file is None:
-        # We're producing a dataset collection.
-        output_file = os.path.join(OUTPUT_METRICS_DIR, "%s.txt" % file_name)
     with open(output_file, "w") as fh:
         fh.write("Sample: %s\n" % file_name)
         fh.write("Brucella counts: ")
@@ -262,42 +194,21 @@
         fh.write("\ndbkey: %s\n" % dbkey)
 
 
-def set_num_cpus(num_files, processes):
-    num_cpus = int(multiprocessing.cpu_count())
-    if num_files < num_cpus and num_files < processes:
-        return num_files
-    if num_cpus < processes:
-        half_cpus = int(num_cpus / 2)
-        if num_files < half_cpus:
-            return num_files
-        return half_cpus
-    return processes
-
-
 if __name__ == '__main__':
     parser = argparse.ArgumentParser()
 
     parser.add_argument('--dnaprint_fields', action='append', dest='dnaprint_fields', nargs=2, help="List of dnaprints data table value, name and path fields")
-    parser.add_argument('--read1', action='store', dest='read1', required=False, default=None, help='Required: single read')
+    parser.add_argument('--read1', action='store', dest='read1', help='Required: single read')
     parser.add_argument('--read2', action='store', dest='read2', required=False, default=None, help='Optional: paired read')
-    parser.add_argument('--gzipped', action='store', dest='gzipped', help='Input files are gzipped')
-    parser.add_argument('--output_dbkey', action='store', dest='output_dbkey', required=False, default=None, help='Output reference file')
-    parser.add_argument('--output_metrics', action='store', dest='output_metrics', required=False, default=None, help='Output metrics file')
-    parser.add_argument('--processes', action='store', dest='processes', type=int, help='User-selected number of processes to use for job splitting')
+    parser.add_argument('--gzipped', action='store_true', dest='gzipped', help='Input files are gzipped')
+    parser.add_argument('--output_dbkey', action='store', dest='output_dbkey', help='Output reference file')
+    parser.add_argument('--output_metrics', action='store', dest='output_metrics', help='Output metrics file')
 
     args = parser.parse_args()
 
-    collection = False
-    fastq_list = []
-    if args.read1 is not None:
-        fastq_list.append(args.read1)
-        if args.read2 is not None:
-            fastq_list.append(args.read2)
-    else:
-        collection = True
-        for file_name in sorted(os.listdir(INPUT_READS_DIR)):
-            file_path = os.path.abspath(os.path.join(INPUT_READS_DIR, file_name))
-            fastq_list.append(file_path)
+    fastq_list = [args.read1]
+    if args.read2 is not None:
+        fastq_list.append(args.read2)
 
     # The value of dnaprint_fields is a list of lists, where each list is
     # the [value, name, path] components of the vsnp_dnaprints data table.
@@ -306,62 +217,9 @@
     # table to ensure a proper mapping for discovering the dbkey.
     dnaprints_dict = get_dnaprints_dict(args.dnaprint_fields)
 
-    if collection:
-        # Here fastq_list consists of any number of
-        # reads, so each file will be processed and
-        # dataset collections will be produced as outputs.
-        multiprocessing.set_start_method('spawn')
-        queue1 = multiprocessing.JoinableQueue()
-        queue2 = multiprocessing.JoinableQueue()
-        num_files = len(fastq_list)
-        cpus = set_num_cpus(num_files, args.processes)
-        # Set a timeout for get()s in the queue.
-        timeout = 0.05
-
-        for fastq_file in fastq_list:
-            queue1.put(fastq_file)
-
-        # Complete the get_species_counts task.
-        processes = [multiprocessing.Process(target=get_species_counts_for_collection, args=(queue1, queue2, args.gzipped, timeout, )) for _ in range(cpus)]
-        for p in processes:
-            p.start()
-        for p in processes:
-            p.join()
-        queue1.join()
-
-        # Complete the get_species_strings task.
-        processes = [multiprocessing.Process(target=get_species_strings_for_collection, args=(queue2, queue1, timeout, )) for _ in range(cpus)]
-        for p in processes:
-            p.start()
-        for p in processes:
-            p.join()
-        queue2.join()
-
-        # Complete the get_group_and_dbkey task.
-        processes = [multiprocessing.Process(target=get_group_and_dbkey_for_collection, args=(queue1, queue2, dnaprints_dict, timeout, )) for _ in range(cpus)]
-        for p in processes:
-            p.start()
-        for p in processes:
-            p.join()
-        queue1.join()
-
-        # Complete the output_files task.
-        processes = [multiprocessing.Process(target=output_files_for_collection, args=(queue2, timeout, )) for _ in range(cpus)]
-        for p in processes:
-            p.start()
-        for p in processes:
-            p.join()
-        queue2.join()
-
-        if queue1.empty() and queue2.empty():
-            queue1.close()
-            queue1.join_thread()
-            queue2.close()
-            queue2.join_thread()
-    else:
-        # Here fastq_list consists of either a single read
-        # or a set of paired reads, producing single outputs.
-        count_summary, count_list, brucella_sum, bovis_sum, para_sum = get_species_counts(fastq_list, args.gzipped)
-        brucella_string, bovis_string, para_string = get_species_strings(count_summary)
-        group, dbkey = get_group_and_dbkey(dnaprints_dict, brucella_string, brucella_sum, bovis_string, bovis_sum, para_string, para_sum)
-        output_files(args.read1, count_list, group, dbkey, dbkey_file=args.output_dbkey, metrics_file=args.output_metrics)
+    # Here fastq_list consists of either a single read
+    # or a set of paired reads, producing single outputs.
+    count_summary, count_list, brucella_sum, bovis_sum, para_sum = get_species_counts(fastq_list, args.gzipped)
+    brucella_string, bovis_string, para_string = get_species_strings(count_summary)
+    group, dbkey = get_group_and_dbkey(dnaprints_dict, brucella_string, brucella_sum, bovis_string, bovis_sum, para_string, para_sum)
+    output_files(args.read1, count_list, group, dbkey, dbkey_file=args.output_dbkey, metrics_file=args.output_metrics)