comparison env/lib/python3.7/site-packages/cwltool/workflow.py @ 2:6af9afd405e9 draft

"planemo upload commit 0a63dd5f4d38a1f6944587f52a8cd79874177fc1"
author shellac
date Thu, 14 May 2020 14:56:58 -0400 (2020-05-14)
parents 26e78fe6e8c4
children
comparison
equal deleted inserted replaced
1:75ca89e9b81c 2:6af9afd405e9
1 from __future__ import absolute_import
2
3 import copy
4 import datetime
5 import functools
6 import logging
7 import random
8 import tempfile
9 from collections import namedtuple
10 from typing import (Any, Callable, Dict, Generator, Iterable, List,
11 Mapping, MutableMapping, MutableSequence,
12 Optional, Sequence, Tuple, Union, cast)
13 from uuid import UUID # pylint: disable=unused-import
14
15 import threading
16 from ruamel.yaml.comments import CommentedMap
17 from schema_salad import validate
18 from schema_salad.sourceline import SourceLine, indent
19 from six import string_types, iteritems
20 from six.moves import range
21 from future.utils import raise_from
22 from typing_extensions import Text # pylint: disable=unused-import
23 # move to a regular typing import when Python 3.3-3.6 is no longer supported
24
25 from . import command_line_tool, context, expression, procgenerator
26 from .command_line_tool import CallbackJob, ExpressionTool
27 from .job import JobBase
28 from .builder import content_limit_respected_read
29 from .checker import can_assign_src_to_sink, static_checker
30 from .context import LoadingContext # pylint: disable=unused-import
31 from .context import RuntimeContext, getdefault
32 from .errors import WorkflowException
33 from .load_tool import load_tool
34 from .loghandler import _logger
35 from .mutation import MutationManager # pylint: disable=unused-import
36 from .pathmapper import adjustDirObjs, get_listing
37 from .process import Process, get_overrides, shortname, uniquename
38 from .provenance import ProvenanceProfile
39 from .software_requirements import ( # pylint: disable=unused-import
40 DependenciesConfiguration)
41 from .stdfsaccess import StdFsAccess
42 from .utils import DEFAULT_TMP_PREFIX, aslist, json_dumps
43
44 WorkflowStateItem = namedtuple('WorkflowStateItem', ['parameter', 'value', 'success'])
45
46
47 def default_make_tool(toolpath_object, # type: MutableMapping[Text, Any]
48 loadingContext # type: LoadingContext
49 ): # type: (...) -> Process
50 if not isinstance(toolpath_object, MutableMapping):
51 raise WorkflowException(u"Not a dict: '%s'" % toolpath_object)
52 if "class" in toolpath_object:
53 if toolpath_object["class"] == "CommandLineTool":
54 return command_line_tool.CommandLineTool(toolpath_object, loadingContext)
55 if toolpath_object["class"] == "ExpressionTool":
56 return command_line_tool.ExpressionTool(toolpath_object, loadingContext)
57 if toolpath_object["class"] == "Workflow":
58 return Workflow(toolpath_object, loadingContext)
59 if toolpath_object["class"] == "ProcessGenerator":
60 return procgenerator.ProcessGenerator(toolpath_object, loadingContext)
61
62 raise WorkflowException(
63 u"Missing or invalid 'class' field in "
64 "%s, expecting one of: CommandLineTool, ExpressionTool, Workflow" %
65 toolpath_object["id"])
66
67
68 context.default_make_tool = default_make_tool
69
70 def findfiles(wo, fn=None): # type: (Any, Optional[List[MutableMapping[Text, Any]]]) -> List[MutableMapping[Text, Any]]
71 if fn is None:
72 fn = []
73 if isinstance(wo, MutableMapping):
74 if wo.get("class") == "File":
75 fn.append(wo)
76 findfiles(wo.get("secondaryFiles", None), fn)
77 else:
78 for w in wo.values():
79 findfiles(w, fn)
80 elif isinstance(wo, MutableSequence):
81 for w in wo:
82 findfiles(w, fn)
83 return fn
84
85
86 def match_types(sinktype, # type: Union[List[Text], Text]
87 src, # type: WorkflowStateItem
88 iid, # type: Text
89 inputobj, # type: Dict[Text, Any]
90 linkMerge, # type: Text
91 valueFrom # type: Optional[Text]
92 ): # type: (...) -> bool
93 if isinstance(sinktype, MutableSequence):
94 # Sink is union type
95 for st in sinktype:
96 if match_types(st, src, iid, inputobj, linkMerge, valueFrom):
97 return True
98 elif isinstance(src.parameter["type"], MutableSequence):
99 # Source is union type
100 # Check that at least one source type is compatible with the sink.
101 original_types = src.parameter["type"]
102 for source_type in original_types:
103 src.parameter["type"] = source_type
104 match = match_types(
105 sinktype, src, iid, inputobj, linkMerge, valueFrom)
106 if match:
107 src.parameter["type"] = original_types
108 return True
109 src.parameter["type"] = original_types
110 return False
111 elif linkMerge:
112 if iid not in inputobj:
113 inputobj[iid] = []
114 if linkMerge == "merge_nested":
115 inputobj[iid].append(src.value)
116 elif linkMerge == "merge_flattened":
117 if isinstance(src.value, MutableSequence):
118 inputobj[iid].extend(src.value)
119 else:
120 inputobj[iid].append(src.value)
121 else:
122 raise WorkflowException(u"Unrecognized linkMerge enum '%s'" % linkMerge)
123 return True
124 elif valueFrom is not None \
125 or can_assign_src_to_sink(src.parameter["type"], sinktype) \
126 or sinktype == "Any":
127 # simply assign the value from state to input
128 inputobj[iid] = copy.deepcopy(src.value)
129 return True
130 return False
131
132
133 def object_from_state(state, # type: Dict[Text, Optional[WorkflowStateItem]]
134 parms, # type: List[Dict[Text, Any]]
135 frag_only, # type: bool
136 supportsMultipleInput, # type: bool
137 sourceField, # type: Text
138 incomplete=False # type: bool
139 ): # type: (...) -> Optional[Dict[Text, Any]]
140 inputobj = {} # type: Dict[Text, Any]
141 for inp in parms:
142 iid = inp["id"]
143 if frag_only:
144 iid = shortname(iid)
145 if sourceField in inp:
146 connections = aslist(inp[sourceField])
147 if (len(connections) > 1 and
148 not supportsMultipleInput):
149 raise WorkflowException(
150 "Workflow contains multiple inbound links to a single "
151 "parameter but MultipleInputFeatureRequirement is not "
152 "declared.")
153 for src in connections:
154 a_state = state.get(src, None)
155 if a_state is not None and (a_state.success == "success" or incomplete):
156 if not match_types(
157 inp["type"], a_state, iid, inputobj,
158 inp.get("linkMerge", ("merge_nested"
159 if len(connections) > 1 else None)),
160 valueFrom=inp.get("valueFrom")):
161 raise WorkflowException(
162 u"Type mismatch between source '%s' (%s) and "
163 "sink '%s' (%s)" % (src,
164 a_state.parameter["type"], inp["id"],
165 inp["type"]))
166 elif src not in state:
167 raise WorkflowException(
168 u"Connect source '%s' on parameter '%s' does not "
169 "exist" % (src, inp["id"]))
170 elif not incomplete:
171 return None
172
173 if inputobj.get(iid) is None and "default" in inp:
174 inputobj[iid] = inp["default"]
175
176 if iid not in inputobj and ("valueFrom" in inp or incomplete):
177 inputobj[iid] = None
178
179 if iid not in inputobj:
180 raise WorkflowException(u"Value for %s not specified" % (inp["id"]))
181 return inputobj
182
183
184 class WorkflowJobStep(object):
185 def __init__(self, step):
186 # type: (WorkflowStep) -> None
187 """Initialize this WorkflowJobStep."""
188 self.step = step
189 self.tool = step.tool
190 self.id = step.id
191 self.submitted = False
192 self.completed = False
193 self.iterable = None # type: Optional[Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]]
194 self.name = uniquename(u"step %s" % shortname(self.id))
195 self.prov_obj = step.prov_obj
196 self.parent_wf = step.parent_wf
197
198 def job(self,
199 joborder, # type: Mapping[Text, Text]
200 output_callback, # type: functools.partial[None]
201 runtimeContext # type: RuntimeContext
202 ):
203 # type: (...) -> Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob], None, None]
204 runtimeContext = runtimeContext.copy()
205 runtimeContext.part_of = self.name
206 runtimeContext.name = shortname(self.id)
207
208 _logger.info(u"[%s] start", self.name)
209
210 for j in self.step.job(joborder, output_callback, runtimeContext):
211 yield j
212
213 class WorkflowJob(object):
214 def __init__(self, workflow, runtimeContext):
215 # type: (Workflow, RuntimeContext) -> None
216 """Initialize this WorkflowJob."""
217 self.workflow = workflow
218 self.prov_obj = None # type: Optional[ProvenanceProfile]
219 self.parent_wf = None # type: Optional[ProvenanceProfile]
220 self.tool = workflow.tool
221 if runtimeContext.research_obj is not None:
222 self.prov_obj = workflow.provenance_object
223 self.parent_wf = workflow.parent_wf
224 self.steps = [WorkflowJobStep(s) for s in workflow.steps]
225 self.state = {} # type: Dict[Text, Optional[WorkflowStateItem]]
226 self.processStatus = u""
227 self.did_callback = False
228 self.made_progress = None # type: Optional[bool]
229
230 if runtimeContext.outdir is not None:
231 self.outdir = runtimeContext.outdir
232 else:
233 self.outdir = tempfile.mkdtemp(
234 prefix=getdefault(runtimeContext.tmp_outdir_prefix, DEFAULT_TMP_PREFIX))
235
236 self.name = uniquename(u"workflow {}".format(
237 getdefault(runtimeContext.name,
238 shortname(self.workflow.tool.get("id", "embedded")))))
239
240 _logger.debug(
241 u"[%s] initialized from %s", self.name,
242 self.tool.get("id", "workflow embedded in %s" % runtimeContext.part_of))
243
244 def do_output_callback(self, final_output_callback):
245 # type: (Callable[[Any, Any], Any]) -> None
246
247 supportsMultipleInput = bool(self.workflow.get_requirement("MultipleInputFeatureRequirement")[0])
248
249 wo = None # type: Optional[Dict[Text, Text]]
250 try:
251 wo = object_from_state(
252 self.state, self.tool["outputs"], True, supportsMultipleInput,
253 "outputSource", incomplete=True)
254 except WorkflowException as err:
255 _logger.error(
256 u"[%s] Cannot collect workflow output: %s", self.name, Text(err))
257 self.processStatus = "permanentFail"
258 if self.prov_obj and self.parent_wf \
259 and self.prov_obj.workflow_run_uri != self.parent_wf.workflow_run_uri:
260 process_run_id = None
261 self.prov_obj.generate_output_prov(wo or {}, process_run_id, self.name)
262 self.prov_obj.document.wasEndedBy(
263 self.prov_obj.workflow_run_uri, None, self.prov_obj.engine_uuid,
264 datetime.datetime.now())
265 prov_ids = self.prov_obj.finalize_prov_profile(self.name)
266 # Tell parent to associate our provenance files with our wf run
267 self.parent_wf.activity_has_provenance(self.prov_obj.workflow_run_uri, prov_ids)
268
269 _logger.info(u"[%s] completed %s", self.name, self.processStatus)
270 if _logger.isEnabledFor(logging.DEBUG):
271 _logger.debug(u"[%s] %s", self.name, json_dumps(wo, indent=4))
272
273 self.did_callback = True
274
275 final_output_callback(wo, self.processStatus)
276
277 def receive_output(self, step, outputparms, final_output_callback, jobout, processStatus):
278 # type: (WorkflowJobStep, List[Dict[Text,Text]], Callable[[Any, Any], Any], Dict[Text,Text], Text) -> None
279
280 for i in outputparms:
281 if "id" in i:
282 if i["id"] in jobout:
283 self.state[i["id"]] = WorkflowStateItem(i, jobout[i["id"]], processStatus)
284 else:
285 _logger.error(u"[%s] Output is missing expected field %s", step.name, i["id"])
286 processStatus = "permanentFail"
287 if _logger.isEnabledFor(logging.DEBUG):
288 _logger.debug(u"[%s] produced output %s", step.name,
289 json_dumps(jobout, indent=4))
290
291 if processStatus != "success":
292 if self.processStatus != "permanentFail":
293 self.processStatus = processStatus
294
295 _logger.warning(u"[%s] completed %s", step.name, processStatus)
296 else:
297 _logger.info(u"[%s] completed %s", step.name, processStatus)
298
299 step.completed = True
300 # Release the iterable related to this step to
301 # reclaim memory.
302 step.iterable = None
303 self.made_progress = True
304
305 completed = sum(1 for s in self.steps if s.completed)
306 if completed == len(self.steps):
307 self.do_output_callback(final_output_callback)
308
309 def try_make_job(self,
310 step, # type: WorkflowJobStep
311 final_output_callback, # type: Callable[[Any, Any], Any]
312 runtimeContext # type: RuntimeContext
313 ): # type: (...) -> Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]
314
315 inputparms = step.tool["inputs"]
316 outputparms = step.tool["outputs"]
317
318 supportsMultipleInput = bool(self.workflow.get_requirement(
319 "MultipleInputFeatureRequirement")[0])
320
321 try:
322 inputobj = object_from_state(
323 self.state, inputparms, False, supportsMultipleInput, "source")
324 if inputobj is None:
325 _logger.debug(u"[%s] job step %s not ready", self.name, step.id)
326 return
327
328 if step.submitted:
329 return
330 _logger.info(u"[%s] starting %s", self.name, step.name)
331
332 callback = functools.partial(self.receive_output, step, outputparms, final_output_callback)
333
334
335 valueFrom = {
336 i["id"]: i["valueFrom"] for i in step.tool["inputs"]
337 if "valueFrom" in i}
338
339 loadContents = set(i["id"] for i in step.tool["inputs"]
340 if i.get("loadContents"))
341
342 if len(valueFrom) > 0 and not bool(self.workflow.get_requirement("StepInputExpressionRequirement")[0]):
343 raise WorkflowException(
344 "Workflow step contains valueFrom but StepInputExpressionRequirement not in requirements")
345
346 vfinputs = {shortname(k): v for k, v in iteritems(inputobj)}
347
348 def postScatterEval(io):
349 # type: (MutableMapping[Text, Any]) -> Dict[Text, Any]
350 shortio = {shortname(k): v for k, v in iteritems(io)}
351
352 fs_access = getdefault(runtimeContext.make_fs_access, StdFsAccess)("")
353 for k, v in io.items():
354 if k in loadContents and v.get("contents") is None:
355 with fs_access.open(v["location"], "rb") as f:
356 v["contents"] = content_limit_respected_read(f)
357
358 def valueFromFunc(k, v): # type: (Any, Any) -> Any
359 if k in valueFrom:
360 adjustDirObjs(v, functools.partial(get_listing,
361 fs_access, recursive=True))
362 return expression.do_eval(
363 valueFrom[k], shortio, self.workflow.requirements,
364 None, None, {}, context=v,
365 debug=runtimeContext.debug,
366 js_console=runtimeContext.js_console,
367 timeout=runtimeContext.eval_timeout)
368 return v
369
370 return {k: valueFromFunc(k, v) for k, v in io.items()}
371
372 if "scatter" in step.tool:
373 scatter = aslist(step.tool["scatter"])
374 method = step.tool.get("scatterMethod")
375 if method is None and len(scatter) != 1:
376 raise WorkflowException("Must specify scatterMethod when scattering over multiple inputs")
377 runtimeContext = runtimeContext.copy()
378 runtimeContext.postScatterEval = postScatterEval
379
380 emptyscatter = [shortname(s) for s in scatter if len(inputobj[s]) == 0]
381 if emptyscatter:
382 _logger.warning(
383 "[job %s] Notice: scattering over empty input in "
384 "'%s'. All outputs will be empty.", step.name,
385 "', '".join(emptyscatter))
386
387 if method == "dotproduct" or method is None:
388 jobs = dotproduct_scatter(
389 step, inputobj, scatter, callback, runtimeContext)
390 elif method == "nested_crossproduct":
391 jobs = nested_crossproduct_scatter(
392 step, inputobj, scatter, callback, runtimeContext)
393 elif method == "flat_crossproduct":
394 jobs = flat_crossproduct_scatter(
395 step, inputobj, scatter, callback, runtimeContext)
396 else:
397 if _logger.isEnabledFor(logging.DEBUG):
398 _logger.debug(u"[job %s] job input %s", step.name,
399 json_dumps(inputobj, indent=4))
400
401 inputobj = postScatterEval(inputobj)
402
403 if _logger.isEnabledFor(logging.DEBUG):
404 _logger.debug(u"[job %s] evaluated job input to %s",
405 step.name, json_dumps(inputobj, indent=4))
406 jobs = step.job(inputobj, callback, runtimeContext)
407
408 step.submitted = True
409
410 for j in jobs:
411 yield j
412 except WorkflowException:
413 raise
414 except Exception:
415 _logger.exception("Unhandled exception")
416 self.processStatus = "permanentFail"
417 step.completed = True
418
419
420 def run(self,
421 runtimeContext, # type: RuntimeContext
422 tmpdir_lock=None # type: Optional[threading.Lock]
423 ): # type: (...) -> None
424 """Log the start of each workflow."""
425 _logger.info(u"[%s] start", self.name)
426
427 def job(self,
428 joborder, # type: Mapping[Text, Any]
429 output_callback, # type: Callable[[Any, Any], Any]
430 runtimeContext # type: RuntimeContext
431 ): # type: (...) -> Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]
432 self.state = {}
433 self.processStatus = "success"
434
435 if _logger.isEnabledFor(logging.DEBUG):
436 _logger.debug(u"[%s] %s", self.name, json_dumps(joborder, indent=4))
437
438 runtimeContext = runtimeContext.copy()
439 runtimeContext.outdir = None
440
441 for index, inp in enumerate(self.tool["inputs"]):
442 with SourceLine(self.tool["inputs"], index, WorkflowException,
443 _logger.isEnabledFor(logging.DEBUG)):
444 inp_id = shortname(inp["id"])
445 if inp_id in joborder:
446 self.state[inp["id"]] = WorkflowStateItem(
447 inp, joborder[inp_id], "success")
448 elif "default" in inp:
449 self.state[inp["id"]] = WorkflowStateItem(
450 inp, inp["default"], "success")
451 else:
452 raise WorkflowException(
453 u"Input '%s' not in input object and does not have a "
454 " default value." % (inp["id"]))
455
456 for step in self.steps:
457 for out in step.tool["outputs"]:
458 self.state[out["id"]] = None
459
460 completed = 0
461 while completed < len(self.steps):
462 self.made_progress = False
463
464 for step in self.steps:
465 if getdefault(runtimeContext.on_error, "stop") == "stop" and self.processStatus != "success":
466 break
467
468 if not step.submitted:
469 try:
470 step.iterable = self.try_make_job(
471 step, output_callback, runtimeContext)
472 except WorkflowException as exc:
473 _logger.error(u"[%s] Cannot make job: %s", step.name, Text(exc))
474 _logger.debug("", exc_info=True)
475 self.processStatus = "permanentFail"
476
477 if step.iterable is not None:
478 try:
479 for newjob in step.iterable:
480 if getdefault(runtimeContext.on_error, "stop") == "stop" \
481 and self.processStatus != "success":
482 break
483 if newjob is not None:
484 self.made_progress = True
485 yield newjob
486 else:
487 break
488 except WorkflowException as exc:
489 _logger.error(u"[%s] Cannot make job: %s", step.name, Text(exc))
490 _logger.debug("", exc_info=True)
491 self.processStatus = "permanentFail"
492
493 completed = sum(1 for s in self.steps if s.completed)
494
495 if not self.made_progress and completed < len(self.steps):
496 if self.processStatus != "success":
497 break
498 else:
499 yield None
500
501 if not self.did_callback:
502 self.do_output_callback(output_callback) # could have called earlier on line 336;
503 #depends which one comes first. All steps are completed
504 #or all outputs have been produced.
505
506 class Workflow(Process):
507 def __init__(self,
508 toolpath_object, # type: MutableMapping[Text, Any]
509 loadingContext # type: LoadingContext
510 ): # type: (...) -> None
511 """Initializet this Workflow."""
512 super(Workflow, self).__init__(
513 toolpath_object, loadingContext)
514 self.provenance_object = None # type: Optional[ProvenanceProfile]
515 if loadingContext.research_obj is not None:
516 run_uuid = None # type: Optional[UUID]
517 is_master = not loadingContext.prov_obj # Not yet set
518 if is_master:
519 run_uuid = loadingContext.research_obj.ro_uuid
520
521 self.provenance_object = ProvenanceProfile(
522 loadingContext.research_obj,
523 full_name=loadingContext.cwl_full_name,
524 host_provenance=loadingContext.host_provenance,
525 user_provenance=loadingContext.user_provenance,
526 orcid=loadingContext.orcid,
527 run_uuid=run_uuid,
528 fsaccess=loadingContext.research_obj.fsaccess) # inherit RO UUID for master wf run
529 # TODO: Is Workflow(..) only called when we are the master workflow?
530 self.parent_wf = self.provenance_object
531
532 # FIXME: Won't this overwrite prov_obj for nested workflows?
533 loadingContext.prov_obj = self.provenance_object
534 loadingContext = loadingContext.copy()
535 loadingContext.requirements = self.requirements
536 loadingContext.hints = self.hints
537
538 self.steps = [] # type: List[WorkflowStep]
539 validation_errors = []
540 for index, step in enumerate(self.tool.get("steps", [])):
541 try:
542 self.steps.append(self.make_workflow_step(step, index, loadingContext,
543 loadingContext.prov_obj))
544 except validate.ValidationException as vexc:
545 if _logger.isEnabledFor(logging.DEBUG):
546 _logger.exception("Validation failed at")
547 validation_errors.append(vexc)
548
549 if validation_errors:
550 raise validate.ValidationException("\n".join(str(v) for v in validation_errors))
551
552 random.shuffle(self.steps)
553
554 # statically validate data links instead of doing it at runtime.
555 workflow_inputs = self.tool["inputs"]
556 workflow_outputs = self.tool["outputs"]
557
558 step_inputs = [] # type: List[Any]
559 step_outputs = [] # type: List[Any]
560 param_to_step = {} # type: Dict[Text, Dict[Text, Any]]
561 for step in self.steps:
562 step_inputs.extend(step.tool["inputs"])
563 step_outputs.extend(step.tool["outputs"])
564 for s in step.tool["inputs"]:
565 param_to_step[s["id"]] = step.tool
566
567 if getdefault(loadingContext.do_validate, True):
568 static_checker(workflow_inputs, workflow_outputs, step_inputs, step_outputs, param_to_step)
569
570 def make_workflow_step(self,
571 toolpath_object, # type: Dict[Text, Any]
572 pos, # type: int
573 loadingContext, # type: LoadingContext
574 parentworkflowProv=None # type: Optional[ProvenanceProfile]
575 ): # type: (...) -> WorkflowStep
576 return WorkflowStep(toolpath_object, pos, loadingContext, parentworkflowProv)
577
578 def job(self,
579 job_order, # type: Mapping[Text, Any]
580 output_callbacks, # type: Callable[[Any, Any], Any]
581 runtimeContext # type: RuntimeContext
582 ): # type: (...) -> Generator[Union[WorkflowJob, ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]
583 builder = self._init_job(job_order, runtimeContext)
584
585 if runtimeContext.research_obj is not None:
586 if runtimeContext.toplevel:
587 # Record primary-job.json
588 runtimeContext.research_obj.fsaccess = runtimeContext.make_fs_access('')
589 runtimeContext.research_obj.create_job(builder.job, self.job)
590
591 job = WorkflowJob(self, runtimeContext)
592 yield job
593
594 runtimeContext = runtimeContext.copy()
595 runtimeContext.part_of = u"workflow %s" % job.name
596 runtimeContext.toplevel = False
597
598 for wjob in job.job(builder.job, output_callbacks, runtimeContext):
599 yield wjob
600
601 def visit(self, op): # type: (Callable[[MutableMapping[Text, Any]], Any]) -> None
602 op(self.tool)
603 for step in self.steps:
604 step.visit(op)
605
606
607
608 class WorkflowStep(Process):
609 def __init__(self,
610 toolpath_object, # type: Dict[Text, Any]
611 pos, # type: int
612 loadingContext, # type: LoadingContext
613 parentworkflowProv=None # type: Optional[ProvenanceProfile]
614 ): # type: (...) -> None
615 """Initialize this WorkflowStep."""
616 if "id" in toolpath_object:
617 self.id = toolpath_object["id"]
618 else:
619 self.id = "#step" + Text(pos)
620
621 loadingContext = loadingContext.copy()
622
623 loadingContext.requirements = copy.deepcopy(getdefault(loadingContext.requirements, []))
624 assert loadingContext.requirements is not None # nosec
625 loadingContext.requirements.extend(toolpath_object.get("requirements", []))
626 loadingContext.requirements.extend(get_overrides(getdefault(loadingContext.overrides_list, []),
627 self.id).get("requirements", []))
628
629 hints = copy.deepcopy(getdefault(loadingContext.hints, []))
630 hints.extend(toolpath_object.get("hints", []))
631 loadingContext.hints = hints
632
633
634 try:
635 if isinstance(toolpath_object["run"], MutableMapping):
636 self.embedded_tool = loadingContext.construct_tool_object(
637 toolpath_object["run"], loadingContext) # type: Process
638 else:
639 loadingContext.metadata = {}
640 self.embedded_tool = load_tool(
641 toolpath_object["run"], loadingContext)
642 except validate.ValidationException as vexc:
643 if loadingContext.debug:
644 _logger.exception("Validation exception")
645 raise_from(WorkflowException(
646 u"Tool definition %s failed validation:\n%s" %
647 (toolpath_object["run"], indent(str(vexc)))), vexc)
648
649 validation_errors = []
650 self.tool = toolpath_object = copy.deepcopy(toolpath_object)
651 bound = set()
652 for stepfield, toolfield in (("in", "inputs"), ("out", "outputs")):
653 toolpath_object[toolfield] = []
654 for index, step_entry in enumerate(toolpath_object[stepfield]):
655 if isinstance(step_entry, string_types):
656 param = CommentedMap() # type: CommentedMap
657 inputid = step_entry
658 else:
659 param = CommentedMap(iteritems(step_entry))
660 inputid = step_entry["id"]
661
662 shortinputid = shortname(inputid)
663 found = False
664 for tool_entry in self.embedded_tool.tool[toolfield]:
665 frag = shortname(tool_entry["id"])
666 if frag == shortinputid:
667 #if the case that the step has a default for a parameter,
668 #we do not want the default of the tool to override it
669 step_default = None
670 if "default" in param and "default" in tool_entry:
671 step_default = param["default"]
672 param.update(tool_entry)
673 param["_tool_entry"] = tool_entry
674 if step_default is not None:
675 param["default"] = step_default
676 found = True
677 bound.add(frag)
678 break
679 if not found:
680 if stepfield == "in":
681 param["type"] = "Any"
682 param["not_connected"] = True
683 else:
684 if isinstance(step_entry, Mapping):
685 step_entry_name = step_entry['id']
686 else:
687 step_entry_name = step_entry
688 validation_errors.append(
689 SourceLine(self.tool["out"], index).makeError(
690 "Workflow step output '%s' does not correspond to"
691 % shortname(step_entry_name))
692 + "\n" + SourceLine(self.embedded_tool.tool, "outputs").makeError(
693 " tool output (expected '%s')" % (
694 "', '".join(
695 [shortname(tool_entry["id"]) for tool_entry in
696 self.embedded_tool.tool['outputs']]))))
697 param["id"] = inputid
698 param.lc.line = toolpath_object[stepfield].lc.data[index][0]
699 param.lc.col = toolpath_object[stepfield].lc.data[index][1]
700 param.lc.filename = toolpath_object[stepfield].lc.filename
701 toolpath_object[toolfield].append(param)
702
703 missing_values = []
704 for _, tool_entry in enumerate(self.embedded_tool.tool["inputs"]):
705 if shortname(tool_entry["id"]) not in bound:
706 if "null" not in tool_entry["type"] and "default" not in tool_entry:
707 missing_values.append(shortname(tool_entry["id"]))
708
709 if missing_values:
710 validation_errors.append(SourceLine(self.tool, "in").makeError(
711 "Step is missing required parameter%s '%s'" %
712 ("s" if len(missing_values) > 1 else "", "', '".join(missing_values))))
713
714 if validation_errors:
715 raise validate.ValidationException("\n".join(validation_errors))
716
717 super(WorkflowStep, self).__init__(toolpath_object, loadingContext)
718
719 if self.embedded_tool.tool["class"] == "Workflow":
720 (feature, _) = self.get_requirement("SubworkflowFeatureRequirement")
721 if not feature:
722 raise WorkflowException(
723 "Workflow contains embedded workflow but "
724 "SubworkflowFeatureRequirement not in requirements")
725
726 if "scatter" in self.tool:
727 (feature, _) = self.get_requirement("ScatterFeatureRequirement")
728 if not feature:
729 raise WorkflowException(
730 "Workflow contains scatter but ScatterFeatureRequirement "
731 "not in requirements")
732
733 inputparms = copy.deepcopy(self.tool["inputs"])
734 outputparms = copy.deepcopy(self.tool["outputs"])
735 scatter = aslist(self.tool["scatter"])
736
737 method = self.tool.get("scatterMethod")
738 if method is None and len(scatter) != 1:
739 raise validate.ValidationException(
740 "Must specify scatterMethod when scattering over multiple inputs")
741
742 inp_map = {i["id"]: i for i in inputparms}
743 for inp in scatter:
744 if inp not in inp_map:
745 raise validate.ValidationException(
746 SourceLine(self.tool, "scatter").makeError(
747 "Scatter parameter '%s' does not correspond to "
748 "an input parameter of this step, expecting '%s'"
749 % (shortname(inp), "', '".join(
750 shortname(k) for k in inp_map.keys()))))
751
752 inp_map[inp]["type"] = {"type": "array", "items": inp_map[inp]["type"]}
753
754 if self.tool.get("scatterMethod") == "nested_crossproduct":
755 nesting = len(scatter)
756 else:
757 nesting = 1
758
759 for _ in range(0, nesting):
760 for oparam in outputparms:
761 oparam["type"] = {"type": "array", "items": oparam["type"]}
762 self.tool["inputs"] = inputparms
763 self.tool["outputs"] = outputparms
764 self.prov_obj = None # type: Optional[ProvenanceProfile]
765 if loadingContext.research_obj is not None:
766 self.prov_obj = parentworkflowProv
767 if self.embedded_tool.tool["class"] == "Workflow":
768 self.parent_wf = self.embedded_tool.parent_wf
769 else:
770 self.parent_wf = self.prov_obj
771
772 def receive_output(self, output_callback, jobout, processStatus):
773 # type: (Callable[...,Any], Dict[Text, Text], Text) -> None
774 output = {}
775 for i in self.tool["outputs"]:
776 field = shortname(i["id"])
777 if field in jobout:
778 output[i["id"]] = jobout[field]
779 else:
780 processStatus = "permanentFail"
781 output_callback(output, processStatus)
782
783 def job(self,
784 job_order, # type: Mapping[Text, Text]
785 output_callbacks, # type: Callable[[Any, Any], Any]
786 runtimeContext, # type: RuntimeContext
787 ): # type: (...) -> Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob], None, None]
788 #initialize sub-workflow as a step in the parent profile
789
790 if self.embedded_tool.tool["class"] == "Workflow" \
791 and runtimeContext.research_obj and self.prov_obj \
792 and self.embedded_tool.provenance_object:
793 self.embedded_tool.parent_wf = self.prov_obj
794 process_name = self.tool["id"].split("#")[1]
795 self.prov_obj.start_process(
796 process_name, datetime.datetime.now(),
797 self.embedded_tool.provenance_object.workflow_run_uri)
798
799 step_input = {}
800 for inp in self.tool["inputs"]:
801 field = shortname(inp["id"])
802 if not inp.get("not_connected"):
803 step_input[field] = job_order[inp["id"]]
804
805 try:
806 for tool in self.embedded_tool.job(
807 step_input,
808 functools.partial(self.receive_output, output_callbacks),
809 runtimeContext):
810 yield tool
811 except WorkflowException:
812 _logger.error(u"Exception on step '%s'", runtimeContext.name)
813 raise
814 except Exception as exc:
815 _logger.exception("Unexpected exception")
816 raise_from(WorkflowException(Text(exc)), exc)
817
818 def visit(self, op): # type: (Callable[[MutableMapping[Text, Any]], Any]) -> None
819 self.embedded_tool.visit(op)
820
821
822 class ReceiveScatterOutput(object):
823 def __init__(self,
824 output_callback, # type: Callable[..., Any]
825 dest, # type: Dict[Text, List[Optional[Text]]]
826 total # type: int
827 ): # type: (...) -> None
828 """Initialize."""
829 self.dest = dest
830 self.completed = 0
831 self.processStatus = u"success"
832 self.total = total
833 self.output_callback = output_callback
834 self.steps = [] # type: List[Optional[Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]]]
835
836 def receive_scatter_output(self, index, jobout, processStatus):
837 # type: (int, Dict[Text, Text], Text) -> None
838 for key, val in jobout.items():
839 self.dest[key][index] = val
840
841 # Release the iterable related to this step to
842 # reclaim memory.
843 if self.steps:
844 self.steps[index] = None
845
846 if processStatus != "success":
847 if self.processStatus != "permanentFail":
848 self.processStatus = processStatus
849
850 self.completed += 1
851
852 if self.completed == self.total:
853 self.output_callback(self.dest, self.processStatus)
854
855 def setTotal(self, total, steps): # type: (int, List[Optional[Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]]]) -> None
856 self.total = total
857 self.steps = steps
858 if self.completed == self.total:
859 self.output_callback(self.dest, self.processStatus)
860
861
862 def parallel_steps(steps, rc, runtimeContext):
863 # type: (List[Optional[Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]]], ReceiveScatterOutput, RuntimeContext) -> Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]
864 while rc.completed < rc.total:
865 made_progress = False
866 for index, step in enumerate(steps):
867 if getdefault(runtimeContext.on_error, "stop") == "stop" and rc.processStatus != "success":
868 break
869 if step is None:
870 continue
871 try:
872 for j in step:
873 if getdefault(runtimeContext.on_error, "stop") == "stop" and rc.processStatus != "success":
874 break
875 if j is not None:
876 made_progress = True
877 yield j
878 else:
879 break
880 if made_progress:
881 break
882 except WorkflowException as exc:
883 _logger.error(u"Cannot make scatter job: %s", Text(exc))
884 _logger.debug("", exc_info=True)
885 rc.receive_scatter_output(index, {}, "permanentFail")
886 if not made_progress and rc.completed < rc.total:
887 yield None
888
889
890 def dotproduct_scatter(process, # type: WorkflowJobStep
891 joborder, # type: MutableMapping[Text, Any]
892 scatter_keys, # type: MutableSequence[Text]
893 output_callback, # type: Callable[..., Any]
894 runtimeContext # type: RuntimeContext
895 ): # type: (...) -> Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]
896 jobl = None # type: Optional[int]
897 for key in scatter_keys:
898 if jobl is None:
899 jobl = len(joborder[key])
900 elif jobl != len(joborder[key]):
901 raise WorkflowException(
902 "Length of input arrays must be equal when performing "
903 "dotproduct scatter.")
904 if jobl is None:
905 raise Exception("Impossible codepath")
906
907 output = {} # type: Dict[Text,List[Optional[Text]]]
908 for i in process.tool["outputs"]:
909 output[i["id"]] = [None] * jobl
910
911 rc = ReceiveScatterOutput(output_callback, output, jobl)
912
913 steps = [] # type: List[Optional[Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]]]
914 for index in range(0, jobl):
915 sjobo = copy.copy(joborder)
916 for key in scatter_keys:
917 sjobo[key] = joborder[key][index]
918
919 if runtimeContext.postScatterEval is not None:
920 sjobo = runtimeContext.postScatterEval(sjobo)
921
922 steps.append(process.job(
923 sjobo, functools.partial(rc.receive_scatter_output, index),
924 runtimeContext))
925
926 rc.setTotal(jobl, steps)
927 return parallel_steps(steps, rc, runtimeContext)
928
929
930 def nested_crossproduct_scatter(process, # type: WorkflowJobStep
931 joborder, # type: MutableMapping[Text, Any]
932 scatter_keys, # type: MutableSequence[Text]
933 output_callback, # type: Callable[..., Any]
934 runtimeContext # type: RuntimeContext
935 ): # type: (...) -> Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]
936 scatter_key = scatter_keys[0]
937 jobl = len(joborder[scatter_key])
938 output = {} # type: Dict[Text, List[Optional[Text]]]
939 for i in process.tool["outputs"]:
940 output[i["id"]] = [None] * jobl
941
942 rc = ReceiveScatterOutput(output_callback, output, jobl)
943
944 steps = [] # type: List[Optional[Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]]]
945 for index in range(0, jobl):
946 sjob = copy.copy(joborder)
947 sjob[scatter_key] = joborder[scatter_key][index]
948
949 if len(scatter_keys) == 1:
950 if runtimeContext.postScatterEval is not None:
951 sjob = runtimeContext.postScatterEval(sjob)
952 steps.append(process.job(
953 sjob, functools.partial(rc.receive_scatter_output, index),
954 runtimeContext))
955 else:
956 steps.append(nested_crossproduct_scatter(
957 process, sjob, scatter_keys[1:],
958 functools.partial(rc.receive_scatter_output, index),
959 runtimeContext))
960
961 rc.setTotal(jobl, steps)
962 return parallel_steps(steps, rc, runtimeContext)
963
964
965 def crossproduct_size(joborder, scatter_keys):
966 # type: (MutableMapping[Text, Any], MutableSequence[Text]) -> int
967 scatter_key = scatter_keys[0]
968 if len(scatter_keys) == 1:
969 ssum = len(joborder[scatter_key])
970 else:
971 ssum = 0
972 for _ in range(0, len(joborder[scatter_key])):
973 ssum += crossproduct_size(joborder, scatter_keys[1:])
974 return ssum
975
976 def flat_crossproduct_scatter(process, # type: WorkflowJobStep
977 joborder, # type: MutableMapping[Text, Any]
978 scatter_keys, # type: MutableSequence[Text]
979 output_callback, # type: Callable[..., Any]
980 runtimeContext # type: RuntimeContext
981 ): # type: (...) -> Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]
982 output = {} # type: Dict[Text, List[Optional[Text]]]
983 for i in process.tool["outputs"]:
984 output[i["id"]] = [None] * crossproduct_size(joborder, scatter_keys)
985 callback = ReceiveScatterOutput(output_callback, output, 0)
986 (steps, total) = _flat_crossproduct_scatter(
987 process, joborder, scatter_keys, callback, 0, runtimeContext)
988 callback.setTotal(total, steps)
989 return parallel_steps(steps, callback, runtimeContext)
990
991 def _flat_crossproduct_scatter(process, # type: WorkflowJobStep
992 joborder, # type: MutableMapping[Text, Any]
993 scatter_keys, # type: MutableSequence[Text]
994 callback, # type: ReceiveScatterOutput
995 startindex, # type: int
996 runtimeContext # type: RuntimeContext
997 ): # type: (...) -> Tuple[List[Optional[Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]]], int]
998 """Inner loop."""
999 scatter_key = scatter_keys[0]
1000 jobl = len(joborder[scatter_key])
1001 steps = [] # type: List[Optional[Generator[Union[ExpressionTool.ExpressionJob, JobBase, CallbackJob, None], None, None]]]
1002 put = startindex
1003 for index in range(0, jobl):
1004 sjob = copy.copy(joborder)
1005 sjob[scatter_key] = joborder[scatter_key][index]
1006
1007 if len(scatter_keys) == 1:
1008 if runtimeContext.postScatterEval is not None:
1009 sjob = runtimeContext.postScatterEval(sjob)
1010 steps.append(process.job(
1011 sjob, functools.partial(callback.receive_scatter_output, put),
1012 runtimeContext))
1013 put += 1
1014 else:
1015 (add, _) = _flat_crossproduct_scatter(
1016 process, sjob, scatter_keys[1:], callback, put, runtimeContext)
1017 put += len(add)
1018 steps.extend(add)
1019
1020 return (steps, put)