From 322a894ed7e7b7552ee3d55dd43c39c76686d4ba Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Fri, 11 Oct 2024 12:21:28 +0100 Subject: [PATCH 1/2] Tidy & add type annotations --- cylc/flow/taskdef.py | 86 +++++++++++++++++++++++++++++--------------- 1 file changed, 58 insertions(+), 28 deletions(-) diff --git a/cylc/flow/taskdef.py b/cylc/flow/taskdef.py index 448844c8cc3..e0adcaac686 100644 --- a/cylc/flow/taskdef.py +++ b/cylc/flow/taskdef.py @@ -17,10 +17,17 @@ """Task definition.""" from collections import deque -from typing import TYPE_CHECKING +from typing import ( + TYPE_CHECKING, + Dict, + List, + NamedTuple, + Set, + Tuple, +) -import cylc.flow.flags from cylc.flow.exceptions import TaskDefError +import cylc.flow.flags from cylc.flow.task_id import TaskID from cylc.flow.task_outputs import ( TASK_OUTPUT_SUBMITTED, @@ -30,16 +37,26 @@ ) if TYPE_CHECKING: - from cylc.flow.cycling import PointBase + from cylc.flow.cycling import ( + PointBase, + SequenceBase, + ) + from cylc.flow.task_trigger import TaskTrigger -def generate_graph_children(tdef, point): +class TaskTuple(NamedTuple): + name: str + point: 'PointBase' + is_abs: bool + + +def generate_graph_children( + tdef: 'TaskDef', point: 'PointBase' +) -> Dict[str, List[TaskTuple]]: """Determine graph children of this task at point.""" - graph_children = {} + graph_children: Dict[str, List[TaskTuple]] = {} for seq, dout in tdef.graph_children.items(): for output, downs in dout.items(): - if output not in graph_children: - graph_children[output] = [] for name, trigger in downs: child_point = trigger.get_child_point(point, seq) is_abs = ( @@ -53,7 +70,9 @@ def generate_graph_children(tdef, point): # E.g.: foo should trigger only on T06: # PT6H = "waz" # T06 = "waz[-PT6H] => foo" - graph_children[output].append((name, child_point, is_abs)) + graph_children.setdefault(output, []).append( + TaskTuple(name, child_point, is_abs) + ) if tdef.sequential: # Add next-instance child. @@ -64,20 +83,21 @@ def generate_graph_children(tdef, point): # Within sequence bounds. nexts.append(nxt) if nexts: - if TASK_OUTPUT_SUCCEEDED not in graph_children: - graph_children[TASK_OUTPUT_SUCCEEDED] = [] - graph_children[TASK_OUTPUT_SUCCEEDED].append( - (tdef.name, min(nexts), False)) + graph_children.setdefault(TASK_OUTPUT_SUCCEEDED, []).append( + TaskTuple(tdef.name, min(nexts), False) + ) return graph_children -def generate_graph_parents(tdef, point, taskdefs): +def generate_graph_parents( + tdef: 'TaskDef', point: 'PointBase', taskdefs: Dict[str, 'TaskDef'] +) -> Dict['SequenceBase', List[TaskTuple]]: """Determine concrete graph parents of task tdef at point. Infer parents be reversing upstream triggers that lead to point/task. """ - graph_parents = {} + graph_parents: Dict['SequenceBase', List[TaskTuple]] = {} for seq, triggers in tdef.graph_parents.items(): if not seq.is_valid(point): # Don't infer parents if the trigger belongs to a sequence that @@ -102,7 +122,9 @@ def generate_graph_parents(tdef, point, taskdefs): # TODO ideally validation would flag this as an error. continue is_abs = trigger.offset_is_absolute or trigger.offset_is_from_icp - graph_parents[seq].append((parent_name, parent_point, is_abs)) + graph_parents[seq].append( + TaskTuple(parent_name, parent_point, is_abs) + ) if tdef.sequential: # Add implicit previous-instance parent. @@ -113,9 +135,9 @@ def generate_graph_parents(tdef, point, taskdefs): # Within sequence bounds. prevs.append(prev) if prevs: - if seq not in graph_parents: - graph_parents[seq] = [] - graph_parents[seq].append((tdef.name, min(prevs), False)) + graph_parents.setdefault(seq, []).append( + TaskTuple(tdef.name, min(prevs), False) + ) return graph_parents @@ -157,8 +179,12 @@ def __init__(self, name, rtcfg, run_mode, start_point, initial_point): self.namespace_hierarchy = [] self.dependencies = {} self.outputs = {} # {output: (message, is_required)} - self.graph_children = {} - self.graph_parents = {} + self.graph_children: Dict[ + SequenceBase, Dict[str, List[Tuple[str, TaskTrigger]]] + ] = {} + self.graph_parents: Dict[ + SequenceBase, Set[Tuple[str, TaskTrigger]] + ] = {} self.param_var = {} self.external_triggers = [] self.xtrig_labels = {} # {sequence: [labels]} @@ -209,7 +235,9 @@ def tweak_outputs(self): ]: self.set_required_output(output, True) - def add_graph_child(self, trigger, taskname, sequence): + def add_graph_child( + self, trigger: 'TaskTrigger', taskname: str, sequence: 'SequenceBase' + ) -> None: """Record child task instances that depend on my outputs. {sequence: { @@ -218,18 +246,20 @@ def add_graph_child(self, trigger, taskname, sequence): } """ self.graph_children.setdefault( - sequence, {}).setdefault( - trigger.output, []).append((taskname, trigger)) - - def add_graph_parent(self, trigger, parent, sequence): + sequence, {} + ).setdefault( + trigger.output, [] + ).append((taskname, trigger)) + + def add_graph_parent( + self, trigger: 'TaskTrigger', parent: str, sequence: 'SequenceBase' + ) -> None: """Record task instances that I depend on. { sequence: set([(a,t1), (b,t2), ...]) # (task-name, trigger) } """ - if sequence not in self.graph_parents: - self.graph_parents[sequence] = set() - self.graph_parents[sequence].add((parent, trigger)) + self.graph_parents.setdefault(sequence, set()).add((parent, trigger)) def add_dependency(self, dependency, sequence): """Add a dependency to a named sequence. From 2790b5f77e75ef579ec747d13e42c3aff02ecd76 Mon Sep 17 00:00:00 2001 From: Ronnie Dutta <61982285+MetRonnie@users.noreply.github.com> Date: Fri, 11 Oct 2024 13:26:09 +0100 Subject: [PATCH 2/2] Fix possible type error --- cylc/flow/data_store_mgr.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cylc/flow/data_store_mgr.py b/cylc/flow/data_store_mgr.py index bfadadfb55a..3be29bc81a9 100644 --- a/cylc/flow/data_store_mgr.py +++ b/cylc/flow/data_store_mgr.py @@ -947,7 +947,7 @@ def increment_graph_window( ) for items in graph_children.values(): for child_name, child_point, _ in items: - if child_point > final_point: + if final_point and child_point > final_point: continue child_tokens = self.id_.duplicate( cycle=str(child_point), @@ -977,7 +977,7 @@ def increment_graph_window( taskdefs ).values(): for parent_name, parent_point, _ in items: - if parent_point > final_point: + if final_point and parent_point > final_point: continue parent_tokens = self.id_.duplicate( cycle=str(parent_point),