From 3ab6090fd40033a46a2dba30a09129357406ec2e Mon Sep 17 00:00:00 2001 From: RobinGeens Date: Sun, 27 Oct 2024 14:18:55 +0100 Subject: [PATCH] make layer stack user input more tolerant --- .../generation/layer_stacks_generation.py | 29 +++++++++++++++---- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/stream/stages/generation/layer_stacks_generation.py b/stream/stages/generation/layer_stacks_generation.py index 906f38f..9a1e47a 100644 --- a/stream/stages/generation/layer_stacks_generation.py +++ b/stream/stages/generation/layer_stacks_generation.py @@ -12,6 +12,8 @@ class LayerStacksGenerationStage(Stage): + layer_stacks: list[tuple[int, ...]] | None + def __init__( self, list_of_callables: list[StageCallable], @@ -50,6 +52,9 @@ def run(self): self.layer_stacks = self.get_layer_stacks_fused_multiple_fixed() else: self.layer_stacks = self.get_layer_stacks_fused_single() + else: + self.layer_stacks = self.fill_layer_stacks_to_completion() + elif self.mode == "lbl": self.layer_stacks = self.get_layer_stacks_lbl() else: @@ -69,19 +74,33 @@ def run(self): def only_keep_computation_node_ids(self): """! Update the layer stacks to only keep ids of ComputationNodes""" - updated_layer_stacks = [] + assert self.layer_stacks is not None + updated_layer_stacks: list[tuple[int, ...]] = [] for stack in self.layer_stacks: - update_stack = [] + update_stack: list[tuple[int, ...]] = [] for layer_id in stack: - n = next(n for n in self.workload.node_list if n.id == layer_id) - if isinstance(n, ComputationNode): - update_stack.append(layer_id) + try: + # Ignore node ids that do not exist + n = next(n for n in self.workload.node_list if n.id == layer_id) + if isinstance(n, ComputationNode): + update_stack.append(layer_id) + except StopIteration: + pass updated_layer_stacks.append(tuple(update_stack)) self.layer_stacks = updated_layer_stacks def get_layer_stacks_lbl(self): return [(id,) for id in sorted([n.id for n in self.workload.node_list if isinstance(n, ComputationNode)])] + def fill_layer_stacks_to_completion(self): + assert self.layer_stacks is not None + stacks: list[tuple[int, ...]] = self.layer_stacks + + for node in self.workload.node_list: + if not any(node.id in stack for stack in stacks): + stacks += [(node.id,)] + return stacks + def get_layer_stacks_fused(self): cumsum = 0 stacks = []