diff vsnp_get_snps.py @ 9:0fe292b20b9d draft

"planemo upload for repository https://github.com/gregvonkuster/galaxy_tools/tree/master/tools/sequence_analysis/vsnp/vsnp_get_snps commit 3b7fef2d17fec96647345e89c774d4af417d23d7"
author greg
date Thu, 29 Jul 2021 13:16:03 +0000
parents 5e4595b9f63c
children be5875f29ea4
line wrap: on
line diff
--- a/vsnp_get_snps.py	Thu Jul 29 12:50:01 2021 +0000
+++ b/vsnp_get_snps.py	Thu Jul 29 13:16:03 2021 +0000
@@ -4,7 +4,9 @@
 # and output alignment files in fasta format.
 
 import argparse
+import multiprocessing
 import os
+import queue
 import shutil
 import sys
 import time
@@ -19,6 +21,18 @@
     return datetime.fromtimestamp(time.time()).strftime('%Y-%m-%d %H-%M-%S')
 
 
+def set_num_cpus(num_files, processes):
+    num_cpus = int(multiprocessing.cpu_count())
+    if num_files < num_cpus and num_files < processes:
+        return num_files
+    if num_cpus < processes:
+        half_cpus = int(num_cpus / 2)
+        if num_files < half_cpus:
+            return num_files
+        return half_cpus
+    return processes
+
+
 def setup_all_vcfs(vcf_files, vcf_dirs):
     # Create the all_vcfs directory and link
     # all input vcf files into it for processing.
@@ -326,37 +340,43 @@
         except ValueError:
             return []
 
-    def get_snps(self, group_dir):
-        # Parse all vcf files to accumulate
-        # the SNPs into a data frame.
-        positions_dict = {}
-        group_files = []
-        for file_name in os.listdir(os.path.abspath(group_dir)):
-            file_path = os.path.abspath(os.path.join(group_dir, file_name))
-            group_files.append(file_path)
-        for file_name in group_files:
-            found_positions, found_positions_mix = self.find_initial_positions(file_name)
-            positions_dict.update(found_positions)
-        # Order before adding to file to match
-        # with ordering of individual samples.
-        # all_positions is abs_pos:REF
-        self.all_positions = OrderedDict(sorted(positions_dict.items()))
-        ref_positions_df = pandas.DataFrame(self.all_positions, index=['root'])
-        all_map_qualities = {}
-        df_list = []
-        for file_name in group_files:
-            sample_df, file_name_base, sample_map_qualities = self.decide_snps(file_name)
-            df_list.append(sample_df)
-            all_map_qualities.update({file_name_base: sample_map_qualities})
-        all_sample_df = pandas.concat(df_list)
-        # All positions have now been selected for each sample,
-        # so select parisomony informative SNPs.  This removes
-        # columns where all fields are the same.
-        # Add reference to top row.
-        prefilter_df = pandas.concat([ref_positions_df, all_sample_df], join='inner')
-        all_mq_df = pandas.DataFrame.from_dict(all_map_qualities)
-        mq_averages = all_mq_df.mean(axis=1).astype(int)
-        self.gather_and_filter(prefilter_df, mq_averages, group_dir)
+    def get_snps(self, task_queue, timeout):
+        while True:
+            try:
+                group_dir = task_queue.get(block=True, timeout=timeout)
+            except queue.Empty:
+                break
+            # Parse all vcf files to accumulate
+            # the SNPs into a data frame.
+            positions_dict = {}
+            group_files = []
+            for file_name in os.listdir(os.path.abspath(group_dir)):
+                file_path = os.path.abspath(os.path.join(group_dir, file_name))
+                group_files.append(file_path)
+            for file_name in group_files:
+                found_positions, found_positions_mix = self.find_initial_positions(file_name)
+                positions_dict.update(found_positions)
+            # Order before adding to file to match
+            # with ordering of individual samples.
+            # all_positions is abs_pos:REF
+            self.all_positions = OrderedDict(sorted(positions_dict.items()))
+            ref_positions_df = pandas.DataFrame(self.all_positions, index=['root'])
+            all_map_qualities = {}
+            df_list = []
+            for file_name in group_files:
+                sample_df, file_name_base, sample_map_qualities = self.decide_snps(file_name)
+                df_list.append(sample_df)
+                all_map_qualities.update({file_name_base: sample_map_qualities})
+            all_sample_df = pandas.concat(df_list)
+            # All positions have now been selected for each sample,
+            # so select parisomony informative SNPs.  This removes
+            # columns where all fields are the same.
+            # Add reference to top row.
+            prefilter_df = pandas.concat([ref_positions_df, all_sample_df], join='inner')
+            all_mq_df = pandas.DataFrame.from_dict(all_map_qualities)
+            mq_averages = all_mq_df.mean(axis=1).astype(int)
+            self.gather_and_filter(prefilter_df, mq_averages, group_dir)
+            task_queue.task_done()
 
     def group_vcfs(self, vcf_files):
         # Parse an excel file to produce a
@@ -441,7 +461,13 @@
     for file_name in os.listdir(args.input_vcf_dir):
         file_path = os.path.abspath(os.path.join(args.input_vcf_dir, file_name))
         vcf_files.append(file_path)
+
+    multiprocessing.set_start_method('spawn')
+    queue1 = multiprocessing.JoinableQueue()
     num_files = len(vcf_files)
+    cpus = set_num_cpus(num_files, args.processes)
+    # Set a timeout for get()s in the queue.
+    timeout = 0.05
 
     # Initialize the snp_finder object.
     snp_finder = SnpFinder(num_files, args.dbkey, args.input_excel, args.all_isolates, args.ac, args.min_mq, args.quality_score_n_threshold, args.min_quality_score, args.input_vcf_dir, args.output_json_avg_mq_dir, args.output_json_snps_dir, args.output_snps_dir, args.output_summary)
@@ -464,8 +490,17 @@
         group_dirs = [d for d in os.listdir(os.getcwd()) if os.path.isdir(d) and d in snp_finder.groups]
         vcf_dirs.extend(group_dirs)
 
+    # Populate the queue for job splitting.
     for vcf_dir in vcf_dirs:
-        snp_finder.get_snps(vcf_dir)
+        queue1.put(vcf_dir)
+
+    # Complete the get_snps task.
+    processes = [multiprocessing.Process(target=snp_finder.get_snps, args=(queue1, timeout, )) for _ in range(cpus)]
+    for p in processes:
+        p.start()
+    for p in processes:
+        p.join()
+    queue1.join()
 
     # Finish summary log.
     snp_finder.append_to_summary("<br/><b>Time finished:</b> %s<br/>\n" % get_time_stamp())