Skip to content

Commit

Permalink
Tidy & add type annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
MetRonnie committed Oct 11, 2024
1 parent b6b3cc0 commit 8beab10
Showing 1 changed file with 61 additions and 29 deletions.
90 changes: 61 additions & 29 deletions cylc/flow/taskdef.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,17 @@
"""Task definition."""

from collections import deque
from typing import TYPE_CHECKING, Dict, List
from typing import (
TYPE_CHECKING,
Dict,
List,
NamedTuple,
Set,
Tuple,
)

import cylc.flow.flags
from cylc.flow.exceptions import TaskDefError
import cylc.flow.flags
from cylc.flow.task_id import TaskID
from cylc.flow.task_outputs import (
SORT_ORDERS,
Expand All @@ -31,17 +38,29 @@


if TYPE_CHECKING:
from cylc.flow.cycling import PointBase, SequenceBase
from cylc.flow.task_trigger import Dependency
from cylc.flow.cycling import (
PointBase,
SequenceBase,
)
from cylc.flow.task_trigger import (
Dependency,
TaskTrigger,
)


def generate_graph_children(tdef, point):
class TaskTuple(NamedTuple):
name: str
point: 'PointBase'
is_abs: bool


def generate_graph_children(
tdef: 'TaskDef', point: 'PointBase'
) -> Dict[str, List[TaskTuple]]:
"""Determine graph children of this task at point."""
graph_children = {}
graph_children: Dict[str, List[TaskTuple]] = {}
for seq, dout in tdef.graph_children.items():
for output, downs in dout.items():
if output not in graph_children:
graph_children[output] = []
for name, trigger in downs:
child_point = trigger.get_child_point(point, seq)
is_abs = (
Expand All @@ -55,7 +74,9 @@ def generate_graph_children(tdef, point):
# E.g.: foo should trigger only on T06:
# PT6H = "waz"
# T06 = "waz[-PT6H] => foo"
graph_children[output].append((name, child_point, is_abs))
graph_children.setdefault(output, []).append(
TaskTuple(name, child_point, is_abs)
)

if tdef.sequential:
# Add next-instance child.
Expand All @@ -66,20 +87,21 @@ def generate_graph_children(tdef, point):
# Within sequence bounds.
nexts.append(nxt)
if nexts:
if TASK_OUTPUT_SUCCEEDED not in graph_children:
graph_children[TASK_OUTPUT_SUCCEEDED] = []
graph_children[TASK_OUTPUT_SUCCEEDED].append(
(tdef.name, min(nexts), False))
graph_children.setdefault(TASK_OUTPUT_SUCCEEDED, []).append(
TaskTuple(tdef.name, min(nexts), False)
)

return graph_children


def generate_graph_parents(tdef, point, taskdefs):
def generate_graph_parents(
tdef: 'TaskDef', point: 'PointBase', taskdefs: Dict[str, 'TaskDef']
) -> Dict['SequenceBase', List[TaskTuple]]:
"""Determine concrete graph parents of task tdef at point.
Infer parents be reversing upstream triggers that lead to point/task.
"""
graph_parents = {}
graph_parents: Dict['SequenceBase', List[TaskTuple]] = {}
for seq, triggers in tdef.graph_parents.items():
if not seq.is_valid(point):
# Don't infer parents if the trigger belongs to a sequence that
Expand All @@ -104,7 +126,9 @@ def generate_graph_parents(tdef, point, taskdefs):
# TODO ideally validation would flag this as an error.
continue
is_abs = trigger.offset_is_absolute or trigger.offset_is_from_icp
graph_parents[seq].append((parent_name, parent_point, is_abs))
graph_parents[seq].append(
TaskTuple(parent_name, parent_point, is_abs)
)

if tdef.sequential:
# Add implicit previous-instance parent.
Expand All @@ -115,9 +139,9 @@ def generate_graph_parents(tdef, point, taskdefs):
# Within sequence bounds.
prevs.append(prev)
if prevs:
if seq not in graph_parents:
graph_parents[seq] = []
graph_parents[seq].append((tdef.name, min(prevs), False))
graph_parents.setdefault(seq, []).append(
TaskTuple(tdef.name, min(prevs), False)
)

return graph_parents

Expand Down Expand Up @@ -159,8 +183,12 @@ def __init__(self, name, rtcfg, run_mode, start_point, initial_point):
self.namespace_hierarchy = []
self.dependencies: Dict[SequenceBase, List[Dependency]] = {}
self.outputs = {} # {output: (message, is_required)}
self.graph_children = {}
self.graph_parents = {}
self.graph_children: Dict[
SequenceBase, Dict[str, List[Tuple[str, TaskTrigger]]]
] = {}
self.graph_parents: Dict[
SequenceBase, Set[Tuple[str, TaskTrigger]]
] = {}
self.param_var = {}
self.external_triggers = []
self.xtrig_labels = {} # {sequence: [labels]}
Expand Down Expand Up @@ -211,7 +239,9 @@ def tweak_outputs(self):
]:
self.set_required_output(output, True)

def add_graph_child(self, trigger, taskname, sequence):
def add_graph_child(
self, trigger: 'TaskTrigger', taskname: str, sequence: 'SequenceBase'
) -> None:
"""Record child task instances that depend on my outputs.
{sequence:
{
Expand All @@ -220,18 +250,20 @@ def add_graph_child(self, trigger, taskname, sequence):
}
"""
self.graph_children.setdefault(
sequence, {}).setdefault(
trigger.output, []).append((taskname, trigger))

def add_graph_parent(self, trigger, parent, sequence):
sequence, {}
).setdefault(
trigger.output, []
).append((taskname, trigger))

def add_graph_parent(
self, trigger: 'TaskTrigger', parent: str, sequence: 'SequenceBase'
) -> None:
"""Record task instances that I depend on.
{
sequence: set([(a,t1), (b,t2), ...]) # (task-name, trigger)
}
"""
if sequence not in self.graph_parents:
self.graph_parents[sequence] = set()
self.graph_parents[sequence].add((parent, trigger))
self.graph_parents.setdefault(sequence, set()).add((parent, trigger))

def add_dependency(self, dependency, sequence):
"""Add a dependency to a named sequence.
Expand Down

0 comments on commit 8beab10

Please sign in to comment.