diff venn_diagram.py @ 0:57f01ca855cd draft default tip

"planemo upload commit 47d779aa1de5153673ac8bb1e37c9730210cbb5d"
author galaxyp
date Sat, 12 Jun 2021 18:06:28 +0000
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/venn_diagram.py	Sat Jun 12 18:06:28 2021 +0000
@@ -0,0 +1,204 @@
+#!/usr/bin/env python
+
+import argparse
+import csv
+import json
+import os
+import re
+from itertools import combinations
+
+
+CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
+
+########################################################################
+# FUNCTIONS
+########################################################################
+
+
+def isnumber(format, n):
+    """
+    Check if an element is integer or float
+    """
+    float_format = re.compile(r"^[-]?[1-9][0-9]*.?[0-9]+$")
+    int_format = re.compile(r"^[-]?[1-9][0-9]*$")
+    test = ""
+    if format == "int":
+        test = re.match(int_format, n)
+    elif format == "float":
+        test = re.match(float_format, n)
+    if test:
+        return True
+    else:
+        return False
+
+
+def input_to_dict(inputs):
+    """
+    Parse input and return a dictionary of name and data of each lists/files
+    """
+    comp_dict = {}
+    title_dict = {}
+    c = ["A", "B", "C", "D", "E", "F"]
+    for i in range(len(inputs)):
+        input_file = inputs[i][0]
+        name = inputs[i][1]
+        input_type = inputs[i][2]
+        title = c[i]
+        title_dict[title] = name
+        ids = set()
+        if input_type == "file":
+            header = inputs[i][3]
+            ncol = inputs[i][4]
+            with open(input_file, "r") as handle:
+                file_content = csv.reader(handle, delimiter="\t")
+                file_content = list(file_content)   # csv object to list
+
+                # Check if column number is in right form
+                if isnumber("int", ncol.replace("c", "")):
+                    if header == "true":
+                        # gets ids from defined column
+                        file_content = [x for x in [line[int(ncol.replace("c", ""))-1].split(";") for line in file_content[1:]]]  # noqa 501
+
+                    else:
+                        file_content = [x for x in [line[int(ncol.replace("c", ""))-1].split(";") for line in file_content]]  # noqa 501
+                else:
+                    raise ValueError("Please fill in the right format of column number")  # noqa 501
+        else:
+            ids = set()
+            file_content = inputs[i][0].split()
+            file_content = [x.split(";") for x in file_content]
+
+        # flat list of list of lists, remove empty items
+        file_content = [item.strip() for sublist in file_content for item in sublist if item != '']   # noqa 501
+        ids.update(file_content)
+        if 'NA' in ids:
+            ids.remove('NA')
+        comp_dict[title] = ids
+
+    return comp_dict, title_dict
+
+
+def intersect(comp_dict):
+    """
+    Calculate the intersections of input
+    """
+    names = set(comp_dict)
+    for i in range(1, len(comp_dict) + 1):
+        for group in combinations(sorted(comp_dict), i):
+            others = set()
+            [others.add(name) for name in names if name not in group]
+            difference = []
+            intersected = set.intersection(*(comp_dict[k] for k in group))
+            if len(others) > 0:
+                difference = intersected.difference(set.union(*(comp_dict[k] for k in others))) # noqa 501
+            yield group, list(intersected), list(difference)
+
+
+def diagram(comp_dict, title_dict):
+    """
+    Create json string for jvenn diagram plot
+    """
+    result = {}
+    result["name"] = {}
+    for k in comp_dict.keys():
+        result["name"][k] = title_dict[k]
+
+    result["data"] = {}
+    result["values"] = {}
+    for group, intersected, difference in intersect(comp_dict):
+        if len(group) == 1:
+            result["data"]["".join(group)] = sorted(difference)
+            result["values"]["".join(group)] = len(difference)
+        elif len(group) > 1 and len(group) < len(comp_dict):
+            result["data"]["".join(group)] = sorted(difference)
+            result["values"]["".join(group)] = len(difference)
+        elif len(group) == len(comp_dict):
+            result["data"]["".join(group)] = sorted(intersected)
+            result["values"]["".join(group)] = len(intersected)
+
+    return result
+
+# Write intersections of input to text output file
+
+
+def write_text_venn(json_result):
+    lines = []
+    result = dict((k, v) for k, v in json_result["data"].items() if v != [])  # noqa 501
+    for key in result:
+        if 'NA' in result[key]:
+            result[key].remove("NA")
+
+    list_names = dict((k, v) for k, v in json_result["name"].items() if v != [])  # noqa 501
+    nb_lines_max = max(len(v) for v in result.values())
+
+    # get list names associated to each column
+    column_dict = {}
+    for key in result:
+        if key in list_names:
+            column_dict[key] = list_names[key]
+        else:
+            keys = list(key)
+            column_dict[key] = "_".join([list_names[k] for k in keys])
+
+    # construct tsv
+    for key in result:
+        line = result[key]
+        if len(line) < nb_lines_max:
+            line.extend([''] * (nb_lines_max - len(line)))
+        line = [column_dict[key]] + line     # add header
+        lines.append(line)
+    # transpose tsv
+    lines = zip(*lines)
+
+    with open("venn_diagram_text_output.tsv", "w") as output:
+        tsv_output = csv.writer(output, delimiter='\t')
+        tsv_output.writerows(lines)
+
+
+def write_summary(summary_file, inputs):
+    """
+    Paste json string into template file
+    """
+    a, b = input_to_dict(inputs)
+    data = diagram(a, b)
+    write_text_venn(data)
+
+    to_replace = {
+        "series": [data],
+        "displayStat": "true",
+        "displaySwitch": "true",
+        "shortNumber": "true",
+    }
+
+    FH_summary_tpl = open(os.path.join(CURRENT_DIR, "jvenn_template.html"))
+    FH_summary_out = open(summary_file, "w")
+    for line in FH_summary_tpl:
+        if "###JVENN_DATA###" in line:
+            line = line.replace("###JVENN_DATA###", json.dumps(to_replace))
+        FH_summary_out.write(line)
+
+    FH_summary_out.close()
+    FH_summary_tpl.close()
+
+
+def process(args):
+    write_summary(args.summary, args.input)
+
+
+#####################################################################
+# MAIN
+#####################################################################
+if __name__ == '__main__':
+    # Parse parameters
+    parser = argparse.ArgumentParser(description='Filters an abundance file')
+    group_input = parser.add_argument_group('Inputs')
+    group_input.add_argument('--input', nargs="+", action="append",
+                             required=True, help="The input tabular file.")
+    group_output = parser.add_argument_group('Outputs')
+    group_output.add_argument('--summary', default="summary.html",
+                              help="The HTML file containing the graphs. \
+                                   [Default: %(default)s]")
+    args = parser.parse_args()
+
+    # Process
+    process(args)