comparison env/lib/python3.7/site-packages/cwltool/subgraph.py @ 5:9b1c78e6ba9c draft default tip

"planemo upload commit 6c0a8142489327ece472c84e558c47da711a9142"
author shellac
date Mon, 01 Jun 2020 08:59:25 -0400
parents 79f47841a781
children
comparison
equal deleted inserted replaced
4:79f47841a781 5:9b1c78e6ba9c
1 import copy
2 from .utils import aslist, json_dumps
3 from collections import namedtuple
4 from typing import (Dict, MutableMapping, MutableSequence, Set, Any, Text, Optional, Tuple)
5 from .process import shortname
6 from six import itervalues
7 from six.moves import urllib
8 from .workflow import Workflow
9 from ruamel.yaml.comments import CommentedMap
10
11 Node = namedtuple('Node', ('up', 'down', 'type'))
12 UP = "up"
13 DOWN = "down"
14 INPUT = "input"
15 OUTPUT = "output"
16 STEP = "step"
17
18 def subgraph_visit(current, # type: Text
19 nodes, # type: MutableMapping[Text, Node]
20 visited, # type: Set[Text]
21 direction # type: Text
22 ): # type: (...) -> None
23
24 if current in visited:
25 return
26 visited.add(current)
27
28 if direction == DOWN:
29 d = nodes[current].down
30 if direction == UP:
31 d = nodes[current].up
32 for c in d:
33 subgraph_visit(c, nodes, visited, direction)
34
35 def declare_node(nodes, nodeid, tp):
36 # type: (Dict[Text, Node], Text, Optional[Text]) -> Node
37 if nodeid in nodes:
38 n = nodes[nodeid]
39 if n.type is None:
40 nodes[nodeid] = Node(n.up, n.down, tp)
41 else:
42 nodes[nodeid] = Node([], [], tp)
43 return nodes[nodeid]
44
45 def get_subgraph(roots, # type: MutableSequence[Text]
46 tool # type: Workflow
47 ): # type: (...) -> Optional[CommentedMap]
48 if tool.tool["class"] != "Workflow":
49 raise Exception("Can only extract subgraph from workflow")
50
51 nodes = {} # type: Dict[Text, Node]
52
53 for inp in tool.tool["inputs"]:
54 declare_node(nodes, inp["id"], INPUT)
55
56 for out in tool.tool["outputs"]:
57 declare_node(nodes, out["id"], OUTPUT)
58 for i in aslist(out.get("outputSource", [])):
59 # source is upstream from output (dependency)
60 nodes[out["id"]].up.append(i)
61 # output is downstream from source
62 declare_node(nodes, i, None)
63 nodes[i].down.append(out["id"])
64
65 for st in tool.tool["steps"]:
66 step = declare_node(nodes, st["id"], STEP)
67 for i in st["in"]:
68 if "source" not in i:
69 continue
70 for src in aslist(i["source"]):
71 # source is upstream from step (dependency)
72 step.up.append(src)
73 # step is downstream from source
74 declare_node(nodes, src, None)
75 nodes[src].down.append(st["id"])
76 for out in st["out"]:
77 # output is downstream from step
78 step.down.append(out)
79 # step is upstream from output
80 declare_node(nodes, out, None)
81 nodes[out].up.append(st["id"])
82
83
84 # Find all the downstream nodes from the starting points
85 visited_down = set() # type: Set[Text]
86 for r in roots:
87 if nodes[r].type == OUTPUT:
88 subgraph_visit(r, nodes, visited_down, UP)
89 else:
90 subgraph_visit(r, nodes, visited_down, DOWN)
91
92 def find_step(stepid): # type: (Text) -> Optional[MutableMapping[Text, Any]]
93 for st in tool.steps:
94 if st.tool["id"] == stepid:
95 return st.tool
96 return None
97
98 # Now make sure all the nodes are connected to upstream inputs
99 visited = set() # type: Set[Text]
100 rewire = {} # type: Dict[Text, Tuple[Text, Text]]
101 for v in visited_down:
102 visited.add(v)
103 if nodes[v].type in (STEP, OUTPUT):
104 for u in nodes[v].up:
105 if u in visited_down:
106 continue
107 if nodes[u].type == INPUT:
108 visited.add(u)
109 else:
110 # rewire
111 df = urllib.parse.urldefrag(u)
112 rn = df[0] + "#" + df[1].replace("/", "_")
113 if nodes[v].type == STEP:
114 wfstep = find_step(v)
115 if wfstep is not None:
116 for inp in wfstep["inputs"]:
117 if u in inp["source"]:
118 rewire[u] = (rn, inp["type"])
119 break
120 else:
121 raise Exception("Could not find step %s" % v)
122
123
124 extracted = CommentedMap()
125 for f in tool.tool:
126 if f in ("steps", "inputs", "outputs"):
127 extracted[f] = []
128 for i in tool.tool[f]:
129 if i["id"] in visited:
130 if f == "steps":
131 for inport in i["in"]:
132 if "source" not in inport:
133 continue
134 if isinstance(inport["source"], MutableSequence):
135 inport["source"] = [rewire[s][0] for s in inport["source"] if s in rewire]
136 elif inport["source"] in rewire:
137 inport["source"] = rewire[inport["source"]][0]
138 extracted[f].append(i)
139 else:
140 extracted[f] = tool.tool[f]
141
142 for rv in itervalues(rewire):
143 extracted["inputs"].append({
144 "id": rv[0],
145 "type": rv[1]
146 })
147
148 return extracted