diff uniprot.py @ 9:468c71dac78a draft

planemo upload for repository https://github.com/bgruening/galaxytools/tree/master/tools/uniprot_rest_interface commit da476148d1c609f5c26e880a3e593f0fa71ff2f6
author bgruening
date Wed, 22 May 2024 21:18:15 +0000
parents af5eccf83605
children 95fb5712344f
line wrap: on
line diff
--- a/uniprot.py	Mon Nov 21 22:02:41 2022 +0000
+++ b/uniprot.py	Wed May 22 21:18:15 2024 +0000
@@ -1,92 +1,266 @@
-#!/usr/bin/env python
-"""
-uniprot python interface
-to access the uniprot database
-
-Based on work from Jan Rudolph: https://github.com/jdrudolph/uniprot
-available services:
-    map
-    retrieve
-
-rewitten using inspiration form: https://findwork.dev/blog/advanced-usage-python-requests-timeouts-retries-hooks/
-"""
 import argparse
+import json
+import re
 import sys
+import time
+import zlib
+from urllib.parse import (
+    parse_qs,
+    urlencode,
+    urlparse,
+)
+from xml.etree import ElementTree
 
 import requests
-from requests.adapters import HTTPAdapter
-from requests.packages.urllib3.util.retry import Retry
-
-
-DEFAULT_TIMEOUT = 5  # seconds
-URL = 'https://legacy.uniprot.org/'
-
-retry_strategy = Retry(
-    total=5,
-    backoff_factor=2,
-    status_forcelist=[429, 500, 502, 503, 504],
-    allowed_methods=["HEAD", "GET", "OPTIONS", "POST"]
+from requests.adapters import (
+    HTTPAdapter,
+    Retry,
 )
 
 
-class TimeoutHTTPAdapter(HTTPAdapter):
-    def __init__(self, *args, **kwargs):
-        self.timeout = DEFAULT_TIMEOUT
-        if "timeout" in kwargs:
-            self.timeout = kwargs["timeout"]
-            del kwargs["timeout"]
-        super().__init__(*args, **kwargs)
+POLLING_INTERVAL = 3
+API_URL = "https://rest.uniprot.org"
+
+
+retries = Retry(total=5, backoff_factor=0.25, status_forcelist=[500, 502, 503, 504])
+session = requests.Session()
+session.mount("https://", HTTPAdapter(max_retries=retries))
+
+
+def check_response(response):
+    try:
+        response.raise_for_status()
+    except requests.HTTPError:
+        print(response.json())
+        raise
+
+
+def submit_id_mapping(from_db, to_db, ids):
+    print(f"{from_db} {to_db}")
+    request = requests.post(
+        f"{API_URL}/idmapping/run",
+        data={"from": from_db, "to": to_db, "ids": ",".join(ids)},
+    )
+    check_response(request)
+    return request.json()["jobId"]
+
+
+def get_next_link(headers):
+    re_next_link = re.compile(r'<(.+)>; rel="next"')
+    if "Link" in headers:
+        match = re_next_link.match(headers["Link"])
+        if match:
+            return match.group(1)
+
+
+def check_id_mapping_results_ready(job_id):
+    while True:
+        request = session.get(f"{API_URL}/idmapping/status/{job_id}")
+        check_response(request)
+        j = request.json()
+        if "jobStatus" in j:
+            if j["jobStatus"] == "RUNNING":
+                print(f"Retrying in {POLLING_INTERVAL}s")
+                time.sleep(POLLING_INTERVAL)
+            else:
+                raise Exception(j["jobStatus"])
+        else:
+            return bool(j["results"] or j["failedIds"])
+
+
+def get_batch(batch_response, file_format, compressed):
+    batch_url = get_next_link(batch_response.headers)
+    while batch_url:
+        batch_response = session.get(batch_url)
+        batch_response.raise_for_status()
+        yield decode_results(batch_response, file_format, compressed)
+        batch_url = get_next_link(batch_response.headers)
 
-    def send(self, request, **kwargs):
-        timeout = kwargs.get("timeout")
-        if timeout is None:
-            kwargs["timeout"] = self.timeout
-        return super().send(request, **kwargs)
+
+def combine_batches(all_results, batch_results, file_format):
+    if file_format == "json":
+        for key in ("results", "failedIds"):
+            if key in batch_results and batch_results[key]:
+                all_results[key] += batch_results[key]
+    elif file_format == "tsv":
+        return all_results + batch_results[1:]
+    else:
+        return all_results + batch_results
+    return all_results
+
+
+def get_id_mapping_results_link(job_id):
+    url = f"{API_URL}/idmapping/details/{job_id}"
+    request = session.get(url)
+    check_response(request)
+    return request.json()["redirectURL"]
+
+
+def decode_results(response, file_format, compressed):
+    if compressed:
+        decompressed = zlib.decompress(response.content, 16 + zlib.MAX_WBITS)
+        if file_format == "json":
+            j = json.loads(decompressed.decode("utf-8"))
+            return j
+        elif file_format == "tsv":
+            return [line for line in decompressed.decode("utf-8").split("\n") if line]
+        elif file_format == "xlsx":
+            return [decompressed]
+        elif file_format == "xml":
+            return [decompressed.decode("utf-8")]
+        else:
+            return decompressed.decode("utf-8")
+    elif file_format == "json":
+        return response.json()
+    elif file_format == "tsv":
+        return [line for line in response.text.split("\n") if line]
+    elif file_format == "xlsx":
+        return [response.content]
+    elif file_format == "xml":
+        return [response.text]
+    return response.text
+
+
+def get_xml_namespace(element):
+    m = re.match(r"\{(.*)\}", element.tag)
+    return m.groups()[0] if m else ""
+
+
+def merge_xml_results(xml_results):
+    merged_root = ElementTree.fromstring(xml_results[0])
+    for result in xml_results[1:]:
+        root = ElementTree.fromstring(result)
+        for child in root.findall("{http://uniprot.org/uniprot}entry"):
+            merged_root.insert(-1, child)
+    ElementTree.register_namespace("", get_xml_namespace(merged_root[0]))
+    return ElementTree.tostring(merged_root, encoding="utf-8", xml_declaration=True)
 
 
-def _map(query, f, t, format='tab', chunk_size=100):
-    """ _map is not meant for use with the python interface, use `map` instead
-    """
-    tool = 'uploadlists/'
-    data = {'format': format, 'from': f, 'to': t}
+def print_progress_batches(batch_index, size, total):
+    n_fetched = min((batch_index + 1) * size, total)
+    print(f"Fetched: {n_fetched} / {total}")
+
 
-    req = []
-    for i in range(0, len(query), chunk_size):
-        q = query[i:i + chunk_size]
-        req.append(dict([("url", URL + tool),
-                         ('data', data),
-                         ("files", {'file': ' '.join(q)})]))
-    return req
-    response = requests.post(URL + tool, data=data)
-    response.raise_for_status()
-    page = response.text
-    if "The service is temporarily unavailable" in page:
-        exit("The UNIPROT service is temporarily unavailable. Please try again later.")
-    return page
+def get_id_mapping_results_search(url):
+    parsed = urlparse(url)
+    query = parse_qs(parsed.query)
+    file_format = query["format"][0] if "format" in query else "json"
+    if "size" in query:
+        size = int(query["size"][0])
+    else:
+        size = 500
+        query["size"] = size
+    compressed = (
+        query["compressed"][0].lower() == "true" if "compressed" in query else False
+    )
+    parsed = parsed._replace(query=urlencode(query, doseq=True))
+    url = parsed.geturl()
+    request = session.get(url)
+    check_response(request)
+    results = decode_results(request, file_format, compressed)
+    total = int(request.headers["x-total-results"])
+    print_progress_batches(0, size, total)
+    for i, batch in enumerate(get_batch(request, file_format, compressed), 1):
+        results = combine_batches(results, batch, file_format)
+        print_progress_batches(i, size, total)
+    if file_format == "xml":
+        return merge_xml_results(results)
+    return results
 
 
-if __name__ == '__main__':
-    parser = argparse.ArgumentParser(description='retrieve uniprot mapping')
-    subparsers = parser.add_subparsers(dest='tool')
+# print(results)
+# {'results': [{'from': 'P05067', 'to': 'CHEMBL2487'}], 'failedIds': ['P12345']}
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description="retrieve uniprot mapping")
+    subparsers = parser.add_subparsers(dest="tool")
 
-    mapping = subparsers.add_parser('map')
-    mapping.add_argument('f', help='from')
-    mapping.add_argument('t', help='to')
-    mapping.add_argument('inp', nargs='?', type=argparse.FileType('r'),
-                         default=sys.stdin, help='input file (default: stdin)')
-    mapping.add_argument('out', nargs='?', type=argparse.FileType('w'),
-                         default=sys.stdout, help='output file (default: stdout)')
-    mapping.add_argument('--format', default='tab', help='output format')
+    mapping = subparsers.add_parser("map")
+    mapping.add_argument("f", help="from")
+    mapping.add_argument("t", help="to")
+    mapping.add_argument(
+        "inp",
+        nargs="?",
+        type=argparse.FileType("r"),
+        default=sys.stdin,
+        help="input file (default: stdin)",
+    )
+    mapping.add_argument(
+        "out",
+        nargs="?",
+        type=argparse.FileType("w"),
+        default=sys.stdout,
+        help="output file (default: stdout)",
+    )
+    mapping.add_argument("--format", default="tab", help="output format")
 
-    retrieve = subparsers.add_parser('retrieve')
-    retrieve.add_argument('inp', metavar='in', nargs='?', type=argparse.FileType('r'),
-                          default=sys.stdin, help='input file (default: stdin)')
-    retrieve.add_argument('out', nargs='?', type=argparse.FileType('w'),
-                          default=sys.stdout, help='output file (default: stdout)')
-    retrieve.add_argument('-f', '--format', help='specify output format', default='txt')
+    retrieve = subparsers.add_parser("retrieve")
+    retrieve.add_argument(
+        "inp",
+        metavar="in",
+        nargs="?",
+        type=argparse.FileType("r"),
+        default=sys.stdin,
+        help="input file (default: stdin)",
+    )
+    retrieve.add_argument(
+        "out",
+        nargs="?",
+        type=argparse.FileType("w"),
+        default=sys.stdout,
+        help="output file (default: stdout)",
+    )
+    retrieve.add_argument("-f", "--format", help="specify output format", default="txt")
+    mapping = subparsers.add_parser("menu")
 
     args = parser.parse_args()
 
+    # code for auto generating the from - to conditional
+    if args.tool == "menu":
+        from lxml import etree
+
+        request = session.get("https://rest.uniprot.org/configure/idmapping/fields")
+        check_response(request)
+        fields = request.json()
+
+        tos = dict()
+        from_cond = etree.Element("conditional", name="from_cond")
+        from_select = etree.SubElement(
+            from_cond, "param", name="from", type="select", label="Source database:"
+        )
+
+        rules = dict()
+        for rule in fields["rules"]:
+            rules[rule["ruleId"]] = rule["tos"]
+
+        for group in fields["groups"]:
+            group_name = group["groupName"]
+            group_name = group_name.replace("databases", "DBs")
+            for item in group["items"]:
+                if item["to"]:
+                    tos[item["name"]] = f"{group_name} - {item['displayName']}"
+
+        for group in fields["groups"]:
+            group_name = group["groupName"]
+            group_name = group_name.replace("databases", "DBs")
+            for item in group["items"]:
+                if not item["from"]:
+                    continue
+                option = etree.SubElement(from_select, "option", value=item["name"])
+                option.text = f"{group_name} - {item['displayName']}"
+                when = etree.SubElement(from_cond, "when", value=item["name"])
+
+                to_select = etree.SubElement(
+                    when, "param", name="to", type="select", label="Target database:"
+                )
+                ruleId = item["ruleId"]
+                for to in rules[ruleId]:
+                    option = etree.SubElement(to_select, "option", value=to)
+                    option.text = tos[to]
+        etree.indent(from_cond, space="    ")
+        print(etree.tostring(from_cond, pretty_print=True, encoding="unicode"))
+        sys.exit(0)
+
     # get the IDs from the file as sorted list
     # (sorted is convenient for testing)
     query = set()
@@ -94,15 +268,19 @@
         query.add(line.strip())
     query = sorted(query)
 
-    if args.tool == 'map':
-        pload = _map(query, args.f, args.t, chunk_size=100)
-    elif args.tool == 'retrieve':
-        pload = _map(query, 'ACC+ID', 'ACC', args.format, chunk_size=100)
+    if args.tool == "map":
+        job_id = submit_id_mapping(from_db=args.f, to_db=args.t, ids=query)
+    elif args.tool == "retrieve":
+        job_id = submit_id_mapping(
+            from_db="UniProtKB_AC-ID", to_db="UniProtKB", ids=query
+        )
 
-    adapter = TimeoutHTTPAdapter(max_retries=retry_strategy)
-    http = requests.Session()
-    http.mount("https://", adapter)
-    for i, p in enumerate(pload):
-        response = http.post(**p)
-        args.out.write(response.text)
-    http.close()
+    if check_id_mapping_results_ready(job_id):
+        link = get_id_mapping_results_link(job_id)
+        link = f"{link}?format={args.format}"
+        print(link)
+        results = get_id_mapping_results_search(link)
+
+    if not isinstance(results, str):
+        results = "\n".join(results)
+    args.out.write(f"{results}\n")