diff --git a/changes.d/6370.feat.md b/changes.d/6370.feat.md
new file mode 100644
index 00000000000..aa8d80d6140
--- /dev/null
+++ b/changes.d/6370.feat.md
@@ -0,0 +1,4 @@
+`cylc remove` improvements:
+- It can now remove tasks that are no longer active, making it look like they never ran.
+- Added the `--flow` option.
+- Removed tasks are now demoted to `flow=none`.
diff --git a/cylc/flow/command_validation.py b/cylc/flow/command_validation.py
index eb5e45c4734..7ba2dfb4f42 100644
--- a/cylc/flow/command_validation.py
+++ b/cylc/flow/command_validation.py
@@ -24,19 +24,34 @@
)
from cylc.flow.exceptions import InputError
-from cylc.flow.id import IDTokens, Tokens
+from cylc.flow.flow_mgr import (
+ FLOW_ALL,
+ FLOW_NEW,
+ FLOW_NONE,
+)
+from cylc.flow.id import (
+ IDTokens,
+ Tokens,
+)
from cylc.flow.task_outputs import TASK_OUTPUT_SUCCEEDED
-from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE
-ERR_OPT_FLOW_VAL = "Flow values must be an integer, or 'all', 'new', or 'none'"
+ERR_OPT_FLOW_VAL = (
+ f"Flow values must be an integer, or '{FLOW_ALL}', '{FLOW_NEW}', "
+ f"or '{FLOW_NONE}'"
+)
+ERR_OPT_FLOW_VAL_2 = f"Flow values must be an integer, or '{FLOW_ALL}'"
ERR_OPT_FLOW_COMBINE = "Cannot combine --flow={0} with other flow values"
ERR_OPT_FLOW_WAIT = (
f"--wait is not compatible with --flow={FLOW_NEW} or --flow={FLOW_NONE}"
)
-def flow_opts(flows: List[str], flow_wait: bool) -> None:
+def flow_opts(
+ flows: List[str],
+ flow_wait: bool,
+ allow_new_or_none: bool = True
+) -> None:
"""Check validity of flow-related CLI options.
Note the schema defaults flows to [].
@@ -63,6 +78,10 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None:
cylc.flow.exceptions.InputError: --wait is not compatible with
--flow=new or --flow=none
+ >>> flow_opts(["new"], False, allow_new_or_none=False)
+ Traceback (most recent call last):
+ cylc.flow.exceptions.InputError: ... must be an integer, or 'all'
+
"""
if not flows:
return
@@ -70,9 +89,12 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None:
flows = [val.strip() for val in flows]
for val in flows:
+ val = val.strip()
if val in {FLOW_NONE, FLOW_NEW, FLOW_ALL}:
if len(flows) != 1:
raise InputError(ERR_OPT_FLOW_COMBINE.format(val))
+ if not allow_new_or_none and val in {FLOW_NEW, FLOW_NONE}:
+ raise InputError(ERR_OPT_FLOW_VAL_2)
else:
try:
int(val)
diff --git a/cylc/flow/commands.py b/cylc/flow/commands.py
index b8b777e0957..0b3dbc7b8c8 100644
--- a/cylc/flow/commands.py
+++ b/cylc/flow/commands.py
@@ -53,18 +53,24 @@
"""
from contextlib import suppress
-from time import sleep, time
+from time import (
+ sleep,
+ time,
+)
from typing import (
+ TYPE_CHECKING,
AsyncGenerator,
Callable,
Dict,
Iterable,
List,
Optional,
- TYPE_CHECKING,
+ TypeVar,
Union,
)
+from metomi.isodatetime.parsers import TimePointParser
+
from cylc.flow import LOG
import cylc.flow.command_validation as validate
from cylc.flow.exceptions import (
@@ -73,23 +79,28 @@
CylcConfigError,
)
import cylc.flow.flags
+from cylc.flow.flow_mgr import get_flow_nums_set
from cylc.flow.log_level import log_level_to_verbosity
from cylc.flow.network.schema import WorkflowStopMode
from cylc.flow.parsec.exceptions import ParsecError
from cylc.flow.task_id import TaskID
-from cylc.flow.task_state import TASK_STATUSES_ACTIVE, TASK_STATUS_FAILED
-from cylc.flow.workflow_status import RunMode, StopMode
+from cylc.flow.task_state import (
+ TASK_STATUS_FAILED,
+ TASK_STATUSES_ACTIVE,
+)
+from cylc.flow.workflow_status import (
+ RunMode,
+ StopMode,
+)
-from metomi.isodatetime.parsers import TimePointParser
if TYPE_CHECKING:
from cylc.flow.scheduler import Scheduler
# define a type for command implementations
- Command = Callable[
- ...,
- AsyncGenerator,
- ]
+ Command = Callable[..., AsyncGenerator]
+ # define a generic type needed for the @_command decorator
+ _TCommand = TypeVar('_TCommand', bound=Command)
# a directory of registered commands (populated on module import)
COMMANDS: 'Dict[str, Command]' = {}
@@ -97,15 +108,15 @@
def _command(name: str):
"""Decorator to register a command."""
- def _command(fcn: 'Command'):
+ def _command(fcn: '_TCommand') -> '_TCommand':
nonlocal name
COMMANDS[name] = fcn
- fcn.command_name = name # type: ignore
+ fcn.command_name = name # type: ignore[attr-defined]
return fcn
return _command
-async def run_cmd(fcn, *args, **kwargs):
+async def run_cmd(bound_fcn: AsyncGenerator):
"""Run a command outside of the scheduler's main loop.
Normally commands are run via the Scheduler's command_queue (which is
@@ -120,10 +131,9 @@ async def run_cmd(fcn, *args, **kwargs):
For these purposes use "run_cmd", otherwise, queue commands via the
scheduler as normal.
"""
- cmd = fcn(*args, **kwargs)
- await cmd.__anext__() # validate
+ await bound_fcn.__anext__() # validate
with suppress(StopAsyncIteration):
- return await cmd.__anext__() # run
+ return await bound_fcn.__anext__() # run
@_command('set')
@@ -314,11 +324,15 @@ async def set_verbosity(schd: 'Scheduler', level: Union[int, str]):
@_command('remove_tasks')
-async def remove_tasks(schd: 'Scheduler', tasks: Iterable[str]):
+async def remove_tasks(
+ schd: 'Scheduler', tasks: Iterable[str], flow: List[str]
+):
"""Remove tasks."""
validate.is_tasks(tasks)
+ validate.flow_opts(flow, flow_wait=False, allow_new_or_none=False)
yield
- yield schd.pool.remove_tasks(tasks)
+ flow_nums = get_flow_nums_set(flow)
+ schd.pool.remove_tasks(tasks, flow_nums)
@_command('reload_workflow')
diff --git a/cylc/flow/data_store_mgr.py b/cylc/flow/data_store_mgr.py
index a4b28d44fdc..86f050f0299 100644
--- a/cylc/flow/data_store_mgr.py
+++ b/cylc/flow/data_store_mgr.py
@@ -2369,22 +2369,42 @@ def delta_task_queued(self, itask: TaskProxy) -> None:
self.updates_pending = True
def delta_task_flow_nums(self, itask: TaskProxy) -> None:
- """Create delta for change in task proxy flow_nums.
+ """Create delta for change in task proxy flow numbers.
Args:
- itask (cylc.flow.task_proxy.TaskProxy):
- Update task-node from corresponding task proxy
- objects from the workflow task pool.
-
+ itask: TaskProxy with updated flow numbers.
"""
tproxy: Optional[PbTaskProxy]
tp_id, tproxy = self.store_node_fetcher(itask.tokens)
if not tproxy:
return
- tp_delta = self.updated[TASK_PROXIES].setdefault(
- tp_id, PbTaskProxy(id=tp_id))
+ self._delta_task_flow_nums(tp_id, itask.flow_nums)
+
+ def delta_remove_task_flow_nums(
+ self, task: str, removed: 'FlowNums'
+ ) -> None:
+ """Create delta for removal of flow numbers from a task proxy.
+
+ Args:
+ task: Relative ID of task.
+ removed: Flow numbers to remove from the task proxy in the
+ data store.
+ """
+ tproxy: Optional[PbTaskProxy]
+ tp_id, tproxy = self.store_node_fetcher(
+ Tokens(task, relative=True).duplicate(**self.id_)
+ )
+ if not tproxy:
+ return
+ new_flow_nums = deserialise_set(tproxy.flow_nums).difference(removed)
+ self._delta_task_flow_nums(tp_id, new_flow_nums)
+
+ def _delta_task_flow_nums(self, tp_id: str, flow_nums: 'FlowNums') -> None:
+ tp_delta: PbTaskProxy = self.updated[TASK_PROXIES].setdefault(
+ tp_id, PbTaskProxy(id=tp_id)
+ )
tp_delta.stamp = f'{tp_id}@{time()}'
- tp_delta.flow_nums = serialise_set(itask.flow_nums)
+ tp_delta.flow_nums = serialise_set(flow_nums)
self.updates_pending = True
def delta_task_runahead(self, itask: TaskProxy) -> None:
diff --git a/cylc/flow/dbstatecheck.py b/cylc/flow/dbstatecheck.py
index fc2d9cf0da3..3fbad5c6723 100644
--- a/cylc/flow/dbstatecheck.py
+++ b/cylc/flow/dbstatecheck.py
@@ -28,7 +28,7 @@
IntegerPoint,
IntegerInterval
)
-from cylc.flow.flow_mgr import stringify_flow_nums
+from cylc.flow.flow_mgr import repr_flow_nums
from cylc.flow.pathutil import expand_path
from cylc.flow.rundb import CylcWorkflowDAO
from cylc.flow.task_outputs import (
@@ -318,7 +318,7 @@ def workflow_state_query(
if flow_num is not None and flow_num not in flow_nums:
# skip result, wrong flow
continue
- fstr = stringify_flow_nums(flow_nums)
+ fstr = repr_flow_nums(flow_nums)
if fstr:
res.append(fstr)
db_res.append(res)
diff --git a/cylc/flow/flow_mgr.py b/cylc/flow/flow_mgr.py
index 1cd1c1e8c70..67f816982ec 100644
--- a/cylc/flow/flow_mgr.py
+++ b/cylc/flow/flow_mgr.py
@@ -16,8 +16,15 @@
"""Manage flow counter and flow metadata."""
-from typing import Dict, Set, Optional, TYPE_CHECKING
import datetime
+from typing import (
+ TYPE_CHECKING,
+ Dict,
+ Iterable,
+ List,
+ Optional,
+ Set,
+)
from cylc.flow import LOG
@@ -55,36 +62,62 @@ def add_flow_opts(parser):
)
-def stringify_flow_nums(flow_nums: Set[int], full: bool = False) -> str:
- """Return a string representation of a set of flow numbers
+def get_flow_nums_set(flow: List[str]) -> FlowNums:
+ """Return set of integer flow numbers from list of strings.
- Return:
- - "none" for no flow
- - "" for the original flow (flows only matter if there are several)
- - otherwise e.g. "(flow=1,2,3)"
+ Returns an empty set if the input is empty or contains only "all".
+
+ >>> get_flow_nums_set(["1", "2", "3"])
+ {1, 2, 3}
+ >>> get_flow_nums_set([])
+ set()
+ >>> get_flow_nums_set(["all"])
+ set()
+ """
+ if flow == [FLOW_ALL]:
+ return set()
+ return {int(val.strip()) for val in flow}
+
+
+def stringify_flow_nums(flow_nums: Iterable[int]) -> str:
+ """Return the canonical string for a set of flow numbers.
Examples:
+ >>> stringify_flow_nums({1})
+ '1'
+
+ >>> stringify_flow_nums({3, 1, 2})
+ '1,2,3'
+
>>> stringify_flow_nums({})
+ ''
+
+ """
+ return ','.join(str(i) for i in sorted(flow_nums))
+
+
+def repr_flow_nums(flow_nums: FlowNums, full: bool = False) -> str:
+ """Return a representation of a set of flow numbers
+
+ If `full` is False, return an empty string for flows=1.
+
+ Examples:
+ >>> repr_flow_nums({})
'(flows=none)'
- >>> stringify_flow_nums({1})
+ >>> repr_flow_nums({1})
''
- >>> stringify_flow_nums({1}, True)
+ >>> repr_flow_nums({1}, full=True)
'(flows=1)'
- >>> stringify_flow_nums({1,2,3})
+ >>> repr_flow_nums({1,2,3})
'(flows=1,2,3)'
"""
if not full and flow_nums == {1}:
return ""
- else:
- return (
- "(flows="
- f"{','.join(str(i) for i in flow_nums) or 'none'}"
- ")"
- )
+ return f"(flows={stringify_flow_nums(flow_nums) or 'none'})"
class FlowMgr:
diff --git a/cylc/flow/id.py b/cylc/flow/id.py
index f2c8b05b4a1..68c62e3a118 100644
--- a/cylc/flow/id.py
+++ b/cylc/flow/id.py
@@ -22,6 +22,7 @@
from enum import Enum
import re
from typing import (
+ TYPE_CHECKING,
Iterable,
List,
Optional,
@@ -33,6 +34,10 @@
from cylc.flow import LOG
+if TYPE_CHECKING:
+ from cylc.flow.cycling import PointBase
+
+
class IDTokens(Enum):
"""Cylc object identifier tokens."""
@@ -524,14 +529,14 @@ def duplicate(
)
-def quick_relative_detokenise(cycle, task):
+def quick_relative_id(cycle: Union[str, int, 'PointBase'], task: str) -> str:
"""Generate a relative ID for a task.
This is a more efficient solution to `Tokens` for cases where
you only want the ID string and don't have any use for a Tokens object.
Example:
- >>> q = quick_relative_detokenise
+ >>> q = quick_relative_id
>>> q('1', 'a') == Tokens(cycle='1', task='a').relative_id
True
diff --git a/cylc/flow/network/schema.py b/cylc/flow/network/schema.py
index ab34def7f75..a1b9fc1c50c 100644
--- a/cylc/flow/network/schema.py
+++ b/cylc/flow/network/schema.py
@@ -22,8 +22,8 @@
from operator import attrgetter
from typing import (
TYPE_CHECKING,
- AsyncGenerator,
Any,
+ AsyncGenerator,
Dict,
List,
Optional,
@@ -34,44 +34,68 @@
import graphene
from graphene import (
- Boolean, Field, Float, ID, InputObjectType, Int,
- Mutation, ObjectType, Schema, String, Argument, Interface
+ ID,
+ Argument,
+ Boolean,
+ Field,
+ Float,
+ InputObjectType,
+ Int,
+ Interface,
+ Mutation,
+ ObjectType,
+ Schema,
+ String,
)
from graphene.types.generic import GenericScalar
from graphene.utils.str_converters import to_snake_case
from graphql.type.definition import get_named_type
from cylc.flow import LOG_LEVELS
-from cylc.flow.broadcast_mgr import ALL_CYCLE_POINTS_STRS, addict
+from cylc.flow.broadcast_mgr import (
+ ALL_CYCLE_POINTS_STRS,
+ addict,
+)
from cylc.flow.data_store_mgr import (
- FAMILIES, FAMILY_PROXIES, JOBS, TASKS, TASK_PROXIES,
- DELTA_ADDED, DELTA_UPDATED
+ DELTA_ADDED,
+ DELTA_UPDATED,
+ FAMILIES,
+ FAMILY_PROXIES,
+ JOBS,
+ TASK_PROXIES,
+ TASKS,
+)
+from cylc.flow.flow_mgr import (
+ FLOW_ALL,
+ FLOW_NEW,
+ FLOW_NONE,
)
-from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE
from cylc.flow.id import Tokens
from cylc.flow.task_outputs import SORT_ORDERS
from cylc.flow.task_state import (
- TASK_STATUSES_ORDERED,
TASK_STATUS_DESC,
- TASK_STATUS_WAITING,
TASK_STATUS_EXPIRED,
+ TASK_STATUS_FAILED,
TASK_STATUS_PREPARING,
+ TASK_STATUS_RUNNING,
TASK_STATUS_SUBMIT_FAILED,
TASK_STATUS_SUBMITTED,
- TASK_STATUS_RUNNING,
- TASK_STATUS_FAILED,
- TASK_STATUS_SUCCEEDED
+ TASK_STATUS_SUCCEEDED,
+ TASK_STATUS_WAITING,
+ TASK_STATUSES_ORDERED,
)
from cylc.flow.util import sstrip
from cylc.flow.workflow_status import StopMode
+
if TYPE_CHECKING:
from graphql import ResolveInfo
from graphql.type.definition import (
- GraphQLNamedType,
GraphQLList,
+ GraphQLNamedType,
GraphQLNonNull,
)
+
from cylc.flow.network.resolvers import BaseResolvers
@@ -2108,6 +2132,19 @@ class Meta:
''')
resolver = partial(mutator, command='remove_tasks')
+ class Arguments(TaskMutation.Arguments):
+ flow = graphene.List(
+ graphene.NonNull(Flow),
+ default_value=[FLOW_ALL],
+ description=sstrip(f'''
+ "Remove the task(s) from the specified flows. "
+
+ This should be a list of flow numbers, or '{FLOW_ALL}'
+ to remove the task(s) from all flows they belong to
+ (which is the default).
+ ''')
+ )
+
class SetPrereqsAndOutputs(Mutation, TaskMutation):
class Meta:
diff --git a/cylc/flow/prerequisite.py b/cylc/flow/prerequisite.py
index 486c7e84ab3..b4025a53367 100644
--- a/cylc/flow/prerequisite.py
+++ b/cylc/flow/prerequisite.py
@@ -39,7 +39,7 @@
from cylc.flow.cycling.loader import get_point
from cylc.flow.data_messages_pb2 import PbCondition, PbPrerequisite
from cylc.flow.exceptions import TriggerExpressionError
-from cylc.flow.id import quick_relative_detokenise
+from cylc.flow.id import quick_relative_id
if TYPE_CHECKING:
@@ -58,7 +58,7 @@ class PrereqMessage(NamedTuple):
def get_id(self) -> str:
"""Get the relative ID of the task in this prereq message."""
- return quick_relative_detokenise(self.point, self.task)
+ return quick_relative_id(self.point, self.task)
@staticmethod
def coerce(tuple_: AnyPrereqMessage) -> 'PrereqMessage':
@@ -92,7 +92,7 @@ class Prerequisite:
# Memory optimization - constrain possible attributes to this list.
__slots__ = (
"_satisfied",
- "_all_satisfied",
+ "_cached_satisfied",
"conditional_expression",
"point",
)
@@ -118,7 +118,7 @@ def __init__(self, point: 'PointBase'):
# * `None` (no cached state)
# * `True` (prerequisite satisfied)
# * `False` (prerequisite unsatisfied).
- self._all_satisfied: Optional[bool] = None
+ self._cached_satisfied: Optional[bool] = None
def instantaneous_hash(self) -> int:
"""Generate a hash of this prerequisite in its current state.
@@ -159,9 +159,9 @@ def __setitem__(
if value is True:
value = 'satisfied naturally'
self._satisfied[key] = value
- if not (self._all_satisfied and value):
+ if not (self._cached_satisfied and value):
# Force later recalculation of cached satisfaction state:
- self._all_satisfied = None
+ self._cached_satisfied = None
def __iter__(self) -> Iterator[PrereqMessage]:
return iter(self._satisfied)
@@ -185,7 +185,7 @@ def get_raw_conditional_expression(self):
def set_condition(self, expr):
"""Set the conditional expression for this prerequisite.
- Resets the cached state (self._all_satisfied).
+ Resets the cached state (self._cached_satisfied).
Examples:
# GH #3644 construct conditional expression when one task name
@@ -201,7 +201,7 @@ def set_condition(self, expr):
'bool(self._satisfied[("1", "xfoo", "succeeded")])']
"""
- self._all_satisfied = None
+ self._cached_satisfied = None
if '|' in expr:
# Make a Python expression so we can eval() the logic.
for message in self._satisfied:
@@ -221,14 +221,14 @@ def is_satisfied(self):
Return cached state if present, else evaluate the prerequisite.
"""
- if self._all_satisfied is not None:
+ if self._cached_satisfied is not None:
# Cached value.
- return self._all_satisfied
+ return self._cached_satisfied
if self._satisfied == {}:
# No prerequisites left after pre-initial simplification.
return True
- self._all_satisfied = self._eval_satisfied()
- return self._all_satisfied
+ self._cached_satisfied = self._eval_satisfied()
+ return self._cached_satisfied
def _eval_satisfied(self) -> bool:
"""Evaluate the prerequisite's condition expression.
@@ -253,12 +253,17 @@ def _eval_satisfied(self) -> bool:
) from None
return res
- def satisfy_me(self, outputs: Iterable['Tokens']) -> 'Set[Tokens]':
- """Attempt to satisfy me with given outputs.
+ def satisfy_me(
+ self, outputs: Iterable['Tokens'], forced: bool = False
+ ) -> 'Set[Tokens]':
+ """Set the given outputs as satisfied.
- Updates cache with the result.
Return outputs that match.
+ Args:
+ outputs: List of outputs to satisfy.
+ forced: If True, records that this should not be undone by
+ `cylc remove`.
"""
valid = set()
for output in outputs:
@@ -268,7 +273,10 @@ def satisfy_me(self, outputs: Iterable['Tokens']) -> 'Set[Tokens]':
if prereq not in self._satisfied:
continue
valid.add(output)
- self[prereq] = 'satisfied naturally'
+ if self._satisfied[prereq] != 'satisfied naturally':
+ self[prereq] = (
+ 'force satisfied' if forced else 'satisfied naturally'
+ )
return valid
def api_dump(self) -> Optional[PbPrerequisite]:
@@ -318,9 +326,9 @@ def set_satisfied(self) -> None:
if not self._satisfied[message]:
self._satisfied[message] = 'force satisfied'
if self.conditional_expression:
- self._all_satisfied = self._eval_satisfied()
+ self._cached_satisfied = self._eval_satisfied()
else:
- self._all_satisfied = True
+ self._cached_satisfied = True
def iter_target_point_strings(self):
yield from {
@@ -334,7 +342,7 @@ def get_target_points(self):
get_point(p) for p in self.iter_target_point_strings()
]
- def get_resolved_dependencies(self) -> List[str]:
+ def get_satisfied_dependencies(self) -> List[str]:
"""Return a list of satisfied dependencies.
E.G: ['1/foo', '2/bar']
@@ -345,3 +353,13 @@ def get_resolved_dependencies(self) -> List[str]:
for msg, satisfied in self._satisfied.items()
if satisfied
]
+
+ def unset_naturally_satisfied_dependency(self, id_: str) -> bool:
+ """Set the matching dependency to unsatisfied and return True only if
+ it was naturally satisfied."""
+ changed = False
+ for msg, sat in self._satisfied.items():
+ if msg.get_id() == id_ and sat and sat != 'force satisfied':
+ self[msg] = False
+ changed = True
+ return changed
diff --git a/cylc/flow/rundb.py b/cylc/flow/rundb.py
index c715079491b..315098c390d 100644
--- a/cylc/flow/rundb.py
+++ b/cylc/flow/rundb.py
@@ -15,6 +15,7 @@
# along with this program. If not, see .
"""Provide data access object for the workflow runtime database."""
+from collections import defaultdict
from contextlib import suppress
from dataclasses import dataclass
from os.path import expandvars
@@ -23,6 +24,8 @@
import traceback
from typing import (
TYPE_CHECKING,
+ Any,
+ DefaultDict,
Dict,
Iterable,
List,
@@ -30,11 +33,13 @@
Set,
Tuple,
Union,
+ cast,
)
from cylc.flow import LOG
from cylc.flow.exceptions import PlatformLookupError
import cylc.flow.flags
+from cylc.flow.flow_mgr import stringify_flow_nums
from cylc.flow.util import (
deserialise_set,
serialise_set,
@@ -47,6 +52,13 @@
from cylc.flow.flow_mgr import FlowNums
+DbArgDict = Dict[str, Any]
+DbUpdateTuple = Union[
+ Tuple[DbArgDict, DbArgDict],
+ Tuple[str, list]
+]
+
+
@dataclass
class CylcWorkflowDAOTableColumn:
"""Represent a column in a table."""
@@ -69,7 +81,7 @@ class CylcWorkflowDAOTable:
def __init__(self, name, column_items):
self.name = name
- self.columns = []
+ self.columns: List[CylcWorkflowDAOTableColumn] = []
for column_item in column_items:
name = column_item[0]
attrs = {}
@@ -81,7 +93,7 @@ def __init__(self, name, column_items):
attrs.get("is_primary_key", False)))
self.delete_queues = {}
self.insert_queue = []
- self.update_queues = {}
+ self.update_queues: DefaultDict[str, list] = defaultdict(list)
def get_create_stmt(self):
"""Return an SQL statement to create this table."""
@@ -150,14 +162,23 @@ def add_insert_item(self, args):
args.get(column.name, None) for column in self.columns]
self.insert_queue.append(stmt_args)
- def add_update_item(self, set_args, where_args):
+ def add_update_item(self, item: DbUpdateTuple) -> None:
"""Queue an UPDATE item.
+ If stmt is not a string, it should be a tuple (set_args, where_args) -
set_args should be a dict, with column keys and values to be set.
where_args should be a dict, update will only apply to rows matching
all these items.
"""
+ if isinstance(item[0], str):
+ stmt = item[0]
+ params = cast('list', item[1])
+ self.update_queues[stmt].extend(params)
+ return
+
+ set_args = item[0]
+ where_args = cast('DbArgDict', item[1])
set_strs = []
stmt_args = []
for column in self.columns:
@@ -177,9 +198,8 @@ def add_update_item(self, set_args, where_args):
stmt = self.FMT_UPDATE % {
"name": self.name,
"set_str": set_str,
- "where_str": where_str}
- if stmt not in self.update_queues:
- self.update_queues[stmt] = []
+ "where_str": where_str
+ }
self.update_queues[stmt].append(stmt_args)
@@ -407,15 +427,18 @@ def add_insert_item(self, table_name, args):
"""
self.tables[table_name].add_insert_item(args)
- def add_update_item(self, table_name, set_args, where_args=None):
+ def add_update_item(
+ self, table_name: str, item: DbUpdateTuple
+ ) -> None:
"""Queue an UPDATE item for a given table.
+ If stmt is not a string, it should be a tuple (set_args, where_args) -
set_args should be a dict, with column keys and values to be set.
where_args should be a dict, update will only apply to rows matching
all these items.
"""
- self.tables[table_name].add_update_item(set_args, where_args)
+ self.tables[table_name].add_update_item(item)
def close(self) -> None:
"""Explicitly close the connection."""
@@ -580,10 +603,10 @@ def select_workflow_params(self) -> Iterable[Tuple[str, Optional[str]]]:
key, value
FROM
{self.TABLE_WORKFLOW_PARAMS}
- ''' # nosec (table name is code constant)
+ ''' # nosec B608 (table name is code constant)
return self.connect().execute(stmt)
- def select_workflow_flows(self, flow_nums):
+ def select_workflow_flows(self, flow_nums: Iterable[int]):
"""Return flow data for selected flows."""
stmt = rf'''
SELECT
@@ -591,8 +614,8 @@ def select_workflow_flows(self, flow_nums):
FROM
{self.TABLE_WORKFLOW_FLOWS}
WHERE
- flow_num in ({','.join(str(f) for f in flow_nums)})
- ''' # nosec (table name is code constant, flow_nums just integers)
+ flow_num in ({stringify_flow_nums(flow_nums)})
+ ''' # nosec B608 (table name is code constant, flow_nums just ints)
flows = {}
for flow_num, start_time, descr in self.connect().execute(stmt):
flows[flow_num] = {
@@ -608,7 +631,7 @@ def select_workflow_flows_max_flow_num(self):
MAX(flow_num)
FROM
{self.TABLE_WORKFLOW_FLOWS}
- ''' # nosec (table name is code constant)
+ ''' # nosec B608 (table name is code constant)
return self.connect().execute(stmt).fetchone()[0]
def select_workflow_params_restart_count(self):
@@ -620,7 +643,7 @@ def select_workflow_params_restart_count(self):
{self.TABLE_WORKFLOW_PARAMS}
WHERE
key == 'n_restart'
- """ # nosec (table name is code constant)
+ """ # nosec B608 (table name is code constant)
result = self.connect().execute(stmt).fetchone()
return int(result[0]) if result else 0
@@ -636,7 +659,7 @@ def select_workflow_template_vars(self, callback):
key, value
FROM
{self.TABLE_WORKFLOW_TEMPLATE_VARS}
- ''' # nosec (table name is code constant)
+ ''' # nosec B608 (table name is code constant)
)):
callback(row_idx, list(row))
@@ -653,7 +676,7 @@ def select_task_action_timers(self, callback):
{",".join(attrs)}
FROM
{self.TABLE_TASK_ACTION_TIMERS}
- ''' # nosec
+ ''' # nosec B608
# * table name is code constant
# * attrs are code constants
for row_idx, row in enumerate(self.connect().execute(stmt)):
@@ -679,7 +702,7 @@ def select_task_job(self, cycle, name, submit_num=None):
AND name==?
ORDER BY
submit_num DESC LIMIT 1
- ''' # nosec
+ ''' # nosec B608
# * table name is code constant
# * keys are code constants
stmt_args = [cycle, name]
@@ -693,7 +716,7 @@ def select_task_job(self, cycle, name, submit_num=None):
cycle==?
AND name==?
AND submit_num==?
- ''' # nosec
+ ''' # nosec B608
# * table name is code constant
# * keys are code constants
stmt_args = [cycle, name, submit_num]
@@ -776,7 +799,7 @@ def select_task_job_platforms(self):
platform_name
FROM
{self.TABLE_TASK_JOBS}
- ''' # nosec (table name is code constant)
+ ''' # nosec B608 (table name is code constant)
return {i[0] for i in self.connect().execute(stmt)}
def select_prev_instances(
@@ -789,7 +812,7 @@ def select_prev_instances(
# Ignore bandit false positive: B608: hardcoded_sql_expressions
# Not an injection, simply putting the table name in the SQL query
# expression as a string constant local to this module.
- stmt = ( # nosec
+ stmt = ( # nosec B608
r"SELECT flow_nums,submit_num,flow_wait,status FROM %(name)s"
r" WHERE name==? AND cycle==?"
) % {"name": self.TABLE_TASK_STATES}
@@ -837,7 +860,7 @@ def select_task_outputs(
{self.TABLE_TASK_OUTPUTS}
WHERE
name==? AND cycle==?
- ''' # nosec (table name is code constant)
+ ''' # nosec B608 (table name is code constant)
return {
outputs: deserialise_set(flow_nums)
for flow_nums, outputs in self.connect().execute(
@@ -851,7 +874,7 @@ def select_xtriggers_for_restart(self, callback):
signature, results
FROM
{self.TABLE_XTRIGGERS}
- ''' # nosec (table name is code constant)
+ ''' # nosec B608 (table name is code constant)
for row_idx, row in enumerate(self.connect().execute(stmt, [])):
callback(row_idx, list(row))
@@ -861,7 +884,7 @@ def select_abs_outputs_for_restart(self, callback):
cycle, name, output
FROM
{self.TABLE_ABS_OUTPUTS}
- ''' # nosec (table name is code constant)
+ ''' # nosec B608 (table name is code constant)
for row_idx, row in enumerate(self.connect().execute(stmt, [])):
callback(row_idx, list(row))
@@ -974,7 +997,7 @@ def select_task_prerequisites(
cycle == ? AND
name == ? AND
flow_nums == ?
- """ # nosec (table name is code constant)
+ """ # nosec B608 (table name is code constant)
stmt_args = [cycle, name, flow_nums]
return list(self.connect().execute(stmt, stmt_args))
@@ -985,7 +1008,7 @@ def select_tasks_to_hold(self) -> List[Tuple[str, str]]:
name, cycle
FROM
{self.TABLE_TASKS_TO_HOLD}
- ''' # nosec (table name is code constant)
+ ''' # nosec B608 (table name is code constant)
return list(self.connect().execute(stmt))
def select_task_times(self):
@@ -1007,7 +1030,7 @@ def select_task_times(self):
{self.TABLE_TASK_JOBS}
WHERE
run_status = 0
- """ # nosec (table name is code constant)
+ """ # nosec B608 (table name is code constant)
columns = (
'name', 'cycle', 'host', 'job_runner',
'submit_time', 'start_time', 'succeed_time'
diff --git a/cylc/flow/scheduler.py b/cylc/flow/scheduler.py
index d3f7fd18a36..da871bd2cc5 100644
--- a/cylc/flow/scheduler.py
+++ b/cylc/flow/scheduler.py
@@ -32,7 +32,6 @@
from typing import (
Any,
AsyncGenerator,
- Callable,
Dict,
Iterable,
List,
@@ -544,7 +543,7 @@ async def configure(self, params):
elif self.config.cfg['scheduling']['hold after cycle point']:
holdcp = self.config.cfg['scheduling']['hold after cycle point']
if holdcp is not None:
- await commands.run_cmd(commands.set_hold_point, self, holdcp)
+ await commands.run_cmd(commands.set_hold_point(self, holdcp))
if self.options.paused_start:
self.pause_workflow('Paused on start up')
@@ -634,7 +633,7 @@ async def run_scheduler(self) -> None:
if self.pool.get_tasks():
# (If we're not restarting a finished workflow)
self.restart_remote_init()
- await commands.run_cmd(commands.poll_tasks, self, ['*/*'])
+ await commands.run_cmd(commands.poll_tasks(self, ['*/*']))
self.run_event_handlers(self.EVENT_STARTUP, 'workflow starting')
await asyncio.gather(
@@ -949,10 +948,6 @@ def process_queued_task_messages(self) -> None:
warn += f'\n {msg.job_id}: {msg.severity} - "{msg.message}"'
LOG.warning(warn)
- def get_command_method(self, command_name: str) -> Callable:
- """Return a command processing method or raise AttributeError."""
- return getattr(self, f'command_{command_name}')
-
async def process_command_queue(self) -> None:
"""Process queued commands."""
qsize = self.command_queue.qsize()
@@ -1391,8 +1386,8 @@ async def workflow_shutdown(self):
self.time_next_kill is not None
and time() > self.time_next_kill
):
- await commands.run_cmd(commands.poll_tasks, self, ['*/*'])
- await commands.run_cmd(commands.kill_tasks, self, ['*/*'])
+ await commands.run_cmd(commands.poll_tasks(self, ['*/*']))
+ await commands.run_cmd(commands.kill_tasks(self, ['*/*']))
self.time_next_kill = time() + self.INTERVAL_STOP_KILL
# Is the workflow set to auto stop [+restart] now ...
@@ -1534,7 +1529,7 @@ async def _main_loop(self) -> None:
self.broadcast_mgr.check_ext_triggers(
itask, self.ext_trigger_queue)
- if all(itask.is_ready_to_run()):
+ if itask.is_ready_to_run():
self.pool.queue_task(itask)
if self.xtrigger_mgr.sequential_spawn_next:
diff --git a/cylc/flow/scripts/remove.py b/cylc/flow/scripts/remove.py
index ef4c74d02c8..edd65e68df5 100755
--- a/cylc/flow/scripts/remove.py
+++ b/cylc/flow/scripts/remove.py
@@ -18,7 +18,26 @@
"""cylc remove [OPTIONS] ARGS
-Remove one or more task instances from a running workflow.
+Remove task instances from a running workflow and the workflow's history.
+
+This removes the task(s) from any specified flows. The task will still exist,
+just not in the specified flows, so will not influence the evolution of
+the workflow in those flows.
+
+If a task is removed from all flows, it and its outputs will be left in the
+`None` flow. This preserves a record that the task ran, but it will not
+influence any flows in any way.
+
+Examples:
+ # remove a task which has already run
+ # (any tasks downstream of this task which have already run or are currently
+ # running will be left alone The task and its outputs will be left in the
+ # None flow)
+ $ cylc remove
+
+ # remove a task from a specified flow
+ # (the task may remain in other flows)
+ $ cylc remove --flow=1
"""
from functools import partial
@@ -33,6 +52,7 @@
)
from cylc.flow.terminal import cli_function
+
if TYPE_CHECKING:
from optparse import Values
@@ -41,10 +61,12 @@
mutation (
$wFlows: [WorkflowID]!,
$tasks: [NamespaceIDGlob]!,
+ $flow: [Flow!],
) {
remove (
workflows: $wFlows,
tasks: $tasks,
+ flow: $flow
) {
result
}
@@ -61,6 +83,19 @@ def get_option_parser() -> COP:
argdoc=[FULL_ID_MULTI_ARG_DOC],
)
+ parser.add_option(
+ '--flow',
+ action='append',
+ dest='flow',
+ metavar='FLOW',
+ help=(
+ "Remove the task(s) from the specified flow number. "
+ "Reuse the option to remove the task(s) from multiple flows. "
+ "If the option is not used at all, the task(s) will be removed "
+ "from all flows."
+ ),
+ )
+
return parser
@@ -75,6 +110,7 @@ async def run(options: 'Values', workflow_id: str, *tokens_list):
tokens.relative_id_with_selectors
for tokens in tokens_list
],
+ 'flow': options.flow,
}
}
diff --git a/cylc/flow/scripts/show.py b/cylc/flow/scripts/show.py
index 7d7bab1dfdc..c9bf6670fcf 100755
--- a/cylc/flow/scripts/show.py
+++ b/cylc/flow/scripts/show.py
@@ -60,6 +60,7 @@
from cylc.flow.option_parsers import (
CylcOptionParser as COP,
ID_MULTI_ARG_DOC,
+ Options,
)
from cylc.flow.terminal import cli_function
from cylc.flow.util import BOOL_SYMBOLS
@@ -246,6 +247,9 @@ def get_option_parser():
return parser
+ShowOptions = Options(get_option_parser())
+
+
async def workflow_meta_query(workflow_id, pclient, options, json_filter):
query = WORKFLOW_META_QUERY
query_kwargs = {
diff --git a/cylc/flow/task_job_mgr.py b/cylc/flow/task_job_mgr.py
index 500aa830b2c..a38db016855 100644
--- a/cylc/flow/task_job_mgr.py
+++ b/cylc/flow/task_job_mgr.py
@@ -26,19 +26,26 @@
from contextlib import suppress
import json
-import os
from logging import (
CRITICAL,
DEBUG,
INFO,
- WARNING
+ WARNING,
)
+import os
from shutil import rmtree
from time import time
-from typing import TYPE_CHECKING, Any, Union, Optional
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Dict,
+ List,
+ Optional,
+ Union,
+)
from cylc.flow import LOG
-from cylc.flow.job_runner_mgr import JobPollContext
+from cylc.flow.cfgspec.globalcfg import SYSPATH
from cylc.flow.exceptions import (
NoHostsError,
NoPlatformsError,
@@ -48,9 +55,10 @@
)
from cylc.flow.hostuserutil import (
get_host,
- is_remote_platform
+ is_remote_platform,
)
from cylc.flow.job_file import JobFileWriter
+from cylc.flow.job_runner_mgr import JobPollContext
from cylc.flow.pathutil import get_remote_workflow_run_job_dir
from cylc.flow.platforms import (
get_host_from_platform,
@@ -64,50 +72,51 @@
from cylc.flow.subprocpool import SubProcPool
from cylc.flow.task_action_timer import (
TaskActionTimer,
- TimerFlags
+ TimerFlags,
)
from cylc.flow.task_events_mgr import (
TaskEventsManager,
- log_task_job_activity
+ log_task_job_activity,
)
from cylc.flow.task_job_logs import (
JOB_LOG_JOB,
NN,
get_task_job_activity_log,
get_task_job_job_log,
- get_task_job_log
+ get_task_job_log,
)
from cylc.flow.task_message import FAIL_MESSAGE_PREFIX
from cylc.flow.task_outputs import (
TASK_OUTPUT_FAILED,
TASK_OUTPUT_STARTED,
TASK_OUTPUT_SUBMITTED,
- TASK_OUTPUT_SUCCEEDED
+ TASK_OUTPUT_SUCCEEDED,
)
from cylc.flow.task_remote_mgr import (
+ REMOTE_FILE_INSTALL_255,
REMOTE_FILE_INSTALL_DONE,
REMOTE_FILE_INSTALL_FAILED,
REMOTE_FILE_INSTALL_IN_PROGRESS,
- REMOTE_INIT_IN_PROGRESS,
REMOTE_INIT_255,
- REMOTE_FILE_INSTALL_255,
- REMOTE_INIT_DONE, REMOTE_INIT_FAILED,
- TaskRemoteMgr
+ REMOTE_INIT_DONE,
+ REMOTE_INIT_FAILED,
+ REMOTE_INIT_IN_PROGRESS,
+ TaskRemoteMgr,
)
from cylc.flow.task_state import (
TASK_STATUS_PREPARING,
- TASK_STATUS_SUBMITTED,
TASK_STATUS_RUNNING,
+ TASK_STATUS_SUBMITTED,
TASK_STATUS_WAITING,
- TASK_STATUSES_ACTIVE
+ TASK_STATUSES_ACTIVE,
)
+from cylc.flow.util import serialise_set
from cylc.flow.wallclock import (
get_current_time_string,
get_time_string_from_unix_time,
- get_utc_mode
+ get_utc_mode,
)
-from cylc.flow.cfgspec.globalcfg import SYSPATH
-from cylc.flow.util import serialise_set
+
if TYPE_CHECKING:
from cylc.flow.task_proxy import TaskProxy
@@ -271,12 +280,10 @@ def submit_task_jobs(self, workflow, itasks, curve_auth,
if not prepared_tasks:
return bad_tasks
- auth_itasks = {} # {platform: [itask, ...], ...}
-
+ # Mapping of platforms to task proxies:
+ auth_itasks: Dict[str, List[TaskProxy]] = {}
for itask in prepared_tasks:
- platform_name = itask.platform['name']
- auth_itasks.setdefault(platform_name, [])
- auth_itasks[platform_name].append(itask)
+ auth_itasks.setdefault(itask.platform['name'], []).append(itask)
# Submit task jobs for each platform
# Non-prepared tasks can be considered done for now:
done_tasks = bad_tasks
diff --git a/cylc/flow/task_pool.py b/cylc/flow/task_pool.py
index f645403df15..a1468b1c9f3 100644
--- a/cylc/flow/task_pool.py
+++ b/cylc/flow/task_pool.py
@@ -16,34 +16,54 @@
"""Wrangle task proxies to manage the workflow."""
-from contextlib import suppress
from collections import Counter
+from contextlib import suppress
+import itertools
import json
+import logging
from textwrap import indent
from typing import (
+ TYPE_CHECKING,
Dict,
Iterable,
List,
NamedTuple,
Optional,
Set,
- TYPE_CHECKING,
Tuple,
Type,
Union,
)
-import logging
-import cylc.flow.flags
from cylc.flow import LOG
-from cylc.flow.cycling.loader import get_point, standardise_point_string
+from cylc.flow.cycling.loader import (
+ get_point,
+ standardise_point_string,
+)
from cylc.flow.exceptions import (
- WorkflowConfigError, PointParsingError, PlatformLookupError)
-from cylc.flow.id import Tokens, detokenise
+ PlatformLookupError,
+ PointParsingError,
+ WorkflowConfigError,
+)
+import cylc.flow.flags
+from cylc.flow.flow_mgr import (
+ FLOW_ALL,
+ FLOW_NEW,
+ FLOW_NONE,
+ repr_flow_nums,
+)
+from cylc.flow.id import (
+ Tokens,
+ detokenise,
+ quick_relative_id,
+)
from cylc.flow.id_cli import contains_fnmatch
from cylc.flow.id_match import filter_ids
-from cylc.flow.workflow_status import StopMode
-from cylc.flow.task_action_timer import TaskActionTimer, TimerFlags
+from cylc.flow.platforms import get_platform
+from cylc.flow.task_action_timer import (
+ TaskActionTimer,
+ TimerFlags,
+)
from cylc.flow.task_events_mgr import (
CustomTaskEventHandlerContext,
EventKey,
@@ -51,45 +71,42 @@
TaskJobLogsRetrieveContext,
)
from cylc.flow.task_id import TaskID
+from cylc.flow.task_outputs import (
+ TASK_OUTPUT_EXPIRED,
+ TASK_OUTPUT_FAILED,
+ TASK_OUTPUT_SUBMIT_FAILED,
+ TASK_OUTPUT_SUCCEEDED,
+)
from cylc.flow.task_proxy import TaskProxy
+from cylc.flow.task_queues.independent import IndepQueueManager
from cylc.flow.task_state import (
- TASK_STATUSES_ACTIVE,
- TASK_STATUSES_FINAL,
- TASK_STATUS_WAITING,
TASK_STATUS_EXPIRED,
+ TASK_STATUS_FAILED,
TASK_STATUS_PREPARING,
- TASK_STATUS_SUBMITTED,
TASK_STATUS_RUNNING,
+ TASK_STATUS_SUBMITTED,
TASK_STATUS_SUCCEEDED,
- TASK_STATUS_FAILED,
+ TASK_STATUS_WAITING,
+ TASK_STATUSES_ACTIVE,
+ TASK_STATUSES_FINAL,
)
from cylc.flow.task_trigger import TaskTrigger
-from cylc.flow.util import (
- serialise_set,
- deserialise_set
-)
-from cylc.flow.wallclock import get_current_time_string
-from cylc.flow.platforms import get_platform
-from cylc.flow.task_outputs import (
- TASK_OUTPUT_SUCCEEDED,
- TASK_OUTPUT_EXPIRED,
- TASK_OUTPUT_FAILED,
- TASK_OUTPUT_SUBMIT_FAILED,
-)
-from cylc.flow.task_queues.independent import IndepQueueManager
+from cylc.flow.taskdef import generate_graph_children
+from cylc.flow.util import deserialise_set
+from cylc.flow.workflow_status import StopMode
-from cylc.flow.flow_mgr import (
- stringify_flow_nums,
- FLOW_ALL,
- FLOW_NONE,
- FLOW_NEW
-)
if TYPE_CHECKING:
from cylc.flow.config import WorkflowConfig
- from cylc.flow.cycling import IntervalBase, PointBase
+ from cylc.flow.cycling import (
+ IntervalBase,
+ PointBase,
+ )
from cylc.flow.data_store_mgr import DataStoreMgr
- from cylc.flow.flow_mgr import FlowMgr, FlowNums
+ from cylc.flow.flow_mgr import (
+ FlowMgr,
+ FlowNums,
+ )
from cylc.flow.prerequisite import SatisfiedState
from cylc.flow.task_events_mgr import TaskEventsManager
from cylc.flow.taskdef import TaskDef
@@ -203,18 +220,7 @@ def db_add_new_flow_rows(self, itask: TaskProxy) -> None:
Call when a new task is spawned or a flow merge occurs.
"""
# Add row to task_states table.
- now = get_current_time_string()
- self.workflow_db_mgr.put_insert_task_states(
- itask,
- {
- "time_created": now,
- "time_updated": now,
- "status": itask.state.status,
- "flow_nums": serialise_set(itask.flow_nums),
- "flow_wait": itask.flow_wait,
- "is_manual_submit": itask.is_manual_submit
- }
- )
+ self.workflow_db_mgr.put_insert_task_states(itask)
# Add row to task_outputs table:
self.workflow_db_mgr.put_insert_task_outputs(itask)
@@ -441,13 +447,24 @@ def check_task_output(
output_msg: str,
flow_nums: 'FlowNums',
) -> 'SatisfiedState':
- """Returns truthy if the specified output is satisfied in the DB."""
+ """Returns truthy if the specified output is satisfied in the DB.
+
+ Args:
+ cycle: Cycle point of the task whose output is being checked.
+ task: Name of the task whose output is being checked.
+ output_msg: The output message to check for.
+ flow_nums: Flow numbers of the task whose output is being
+ checked. If this is empty it means 'none'; will return False.
+ """
+ if not flow_nums:
+ return False
+
for task_outputs, task_flow_nums in (
self.workflow_db_mgr.pri_dao.select_task_outputs(task, cycle)
).items():
# loop through matching tasks
+ # (if task_flow_nums is empty, it means the 'none' flow)
if flow_nums.intersection(task_flow_nums):
- # this task is in the right flow
# BACK COMPAT: In Cylc >8.0.0,<8.3.0, only the task
# messages were stored in the DB as a list.
# from: 8.0.0
@@ -709,7 +726,7 @@ def rh_release_and_queue(self, itask) -> None:
"""
if itask.state_reset(is_runahead=False):
self.data_store_mgr.delta_task_runahead(itask)
- if all(itask.is_ready_to_run()):
+ if itask.is_ready_to_run():
# (otherwise waiting on xtriggers etc.)
self.queue_task(itask)
@@ -736,9 +753,7 @@ def get_or_spawn_task(
It does not add a spawned task proxy to the pool.
"""
- ntask = self._get_task_by_id(
- Tokens(cycle=str(point), task=tdef.name).relative_id
- )
+ ntask = self.get_task(point, tdef.name)
is_in_pool = False
is_xtrig_sequential = False
if ntask is None:
@@ -817,7 +832,7 @@ def spawn_if_parentless(self, tdef, point, flow_nums):
if ntask is not None and not is_in_pool:
self.add_to_pool(ntask)
- def remove(self, itask, reason=None):
+ def remove(self, itask: 'TaskProxy', reason: Optional[str] = None) -> None:
"""Remove a task from the pool."""
if itask.state.is_runahead and itask.flow_nums:
@@ -829,11 +844,7 @@ def remove(self, itask, reason=None):
itask.flow_nums
)
- msg = "removed from active task pool"
- if reason is None:
- msg += ": completed"
- else:
- msg += f": {reason}"
+ msg = f"removed from active task pool: {reason or 'completed'}"
if itask.is_xtrigger_sequential:
self.xtrigger_mgr.sequential_spawn_next.discard(itask.identity)
@@ -887,40 +898,53 @@ def get_tasks(self) -> List[TaskProxy]:
# Cached list only for use internally in this method.
if self.active_tasks_changed:
self.active_tasks_changed = False
- self._active_tasks_list = []
- for itask_id_map in self.active_tasks.values():
- for itask in itask_id_map.values():
- self._active_tasks_list.append(itask)
+ self._active_tasks_list = [
+ itask
+ for itask_id_map in self.active_tasks.values()
+ for itask in itask_id_map.values()
+ ]
return self._active_tasks_list
def get_tasks_by_point(self) -> 'Dict[PointBase, List[TaskProxy]]':
"""Return a map of task proxies by cycle point."""
- point_itasks = {}
- for point, itask_id_map in self.active_tasks.items():
- point_itasks[point] = list(itask_id_map.values())
- return point_itasks
+ return {
+ point: list(itask_id_map.values())
+ for point, itask_id_map in self.active_tasks.items()
+ }
def get_task(self, point: 'PointBase', name: str) -> Optional[TaskProxy]:
"""Retrieve a task from the pool."""
rel_id = f'{point}/{name}'
tasks = self.active_tasks.get(point)
- if tasks and rel_id in tasks:
- return tasks[rel_id]
+ if tasks:
+ return tasks.get(rel_id)
return None
def _get_task_by_id(self, id_: str) -> Optional[TaskProxy]:
"""Return pool task by ID if it exists, or None."""
for itask_ids in self.active_tasks.values():
- with suppress(KeyError):
+ if id_ in itask_ids:
return itask_ids[id_]
return None
def queue_task(self, itask: TaskProxy) -> None:
- """Queue a task that is ready to run."""
+ """Queue a task that is ready to run.
+
+ If it is already queued, do nothing.
+ """
if itask.state_reset(is_queued=True):
self.data_store_mgr.delta_task_queued(itask)
self.task_queue_mgr.push_task(itask)
+ def unqueue_task(self, itask: TaskProxy) -> None:
+ """Un-queue a task that is no longer ready to run.
+
+ If it is not queued, do nothing.
+ """
+ if itask.state_reset(is_queued=False):
+ self.data_store_mgr.delta_task_queued(itask)
+ self.task_queue_mgr.remove_task(itask)
+
def release_queued_tasks(self):
"""Return list of queue-released tasks awaiting job prep.
@@ -1100,8 +1124,7 @@ def _reload_taskdefs(self) -> None:
if itask.state.is_queued:
# Already queued
continue
- ready_check_items = itask.is_ready_to_run()
- if all(ready_check_items) and not itask.state.is_runahead:
+ if itask.is_ready_to_run() and not itask.state.is_runahead:
self.queue_task(itask)
def set_stop_point(self, stop_point: 'PointBase') -> bool:
@@ -1243,7 +1266,7 @@ def log_unsatisfied_prereqs(self) -> bool:
LOG.warning(
"Partially satisfied prerequisites:\n"
+ "\n".join(
- f" * {id_} is waiting on {others}"
+ f" * {id_} is waiting on {sorted(others)}"
for id_, others in unsat.items()
)
)
@@ -1266,7 +1289,7 @@ def is_stalled(self) -> bool:
itask.state(TASK_STATUS_WAITING)
and not itask.state.is_runahead
# (avoid waiting pre-spawned absolute-triggered tasks:)
- and not itask.is_task_prereqs_not_done()
+ and itask.prereqs_are_satisfied()
) for itask in self.get_tasks()
):
return False
@@ -1284,7 +1307,7 @@ def hold_active_task(self, itask: TaskProxy) -> None:
def release_held_active_task(self, itask: TaskProxy) -> None:
if itask.state_reset(is_held=False):
self.data_store_mgr.delta_task_held(itask)
- if (not itask.state.is_runahead) and all(itask.is_ready_to_run()):
+ if (not itask.state.is_runahead) and itask.is_ready_to_run():
self.queue_task(itask)
self.tasks_to_hold.discard((itask.tdef.name, itask.point))
self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold)
@@ -1302,8 +1325,8 @@ def hold_tasks(self, items: Iterable[str]) -> int:
# Hold active tasks:
itasks, inactive_tasks, unmatched = self.filter_task_proxies(
items,
- warn=False,
- future=True,
+ warn_no_active=False,
+ inactive=True,
)
for itask in itasks:
self.hold_active_task(itask)
@@ -1320,8 +1343,8 @@ def release_held_tasks(self, items: Iterable[str]) -> int:
# Release active tasks:
itasks, inactive_tasks, unmatched = self.filter_task_proxies(
items,
- warn=False,
- future=True,
+ warn_no_active=False,
+ inactive=True,
)
for itask in itasks:
self.release_held_active_task(itask)
@@ -1400,12 +1423,7 @@ def spawn_on_output(self, itask: TaskProxy, output: str) -> None:
str(itask.point), itask.tdef.name, output)
self.workflow_db_mgr.process_queued_ops()
- c_taskid = Tokens(
- cycle=str(c_point),
- task=c_name,
- ).relative_id
-
- c_task = self._get_task_by_id(c_taskid)
+ c_task = self._get_task_by_id(quick_relative_id(c_point, c_name))
in_pool = c_task is not None
if c_task is not None and c_task != itask:
@@ -1422,7 +1440,7 @@ def spawn_on_output(self, itask: TaskProxy, output: str) -> None:
if is_abs:
tasks, *_ = self.filter_task_proxies(
[f'*/{c_name}'],
- warn=False,
+ warn_no_active=False,
)
if c_task not in tasks:
tasks.append(c_task)
@@ -1650,6 +1668,7 @@ def _load_historical_outputs(self, itask: 'TaskProxy') -> None:
else:
flow_seen = False
for outputs_str, fnums in info.items():
+ # (if fnums is empty, it means the 'none' flow)
if itask.flow_nums.intersection(fnums):
# DB row has overlap with itask's flows
flow_seen = True
@@ -1717,12 +1736,10 @@ def spawn_task(
and not itask.state.outputs.get_completed_outputs()
):
# If itask has any history in this flow but no completed outputs
- # we can infer it was deliberately removed, so don't respawn it.
+ # we can infer it has just been deliberately removed (N.B. not
+ # by `cylc remove`), so don't immediately respawn it.
# TODO (follow-up work):
# - this logic fails if task removed after some outputs completed
- # - this is does not conform to future "cylc remove" flow-erasure
- # behaviour which would result in respawning of the removed task
- # See github.com/cylc/cylc-flow/pull/6186/#discussion_r1669727292
LOG.debug(f"Not respawning {point}/{name} - task was removed")
return None
@@ -1737,7 +1754,7 @@ def spawn_task(
msg += " incomplete"
LOG.info(
- f"{msg} {stringify_flow_nums(flow_nums, full=True)})"
+ f"{msg} {repr_flow_nums(flow_nums, full=True)})"
)
if prev_flow_wait:
self._spawn_after_flow_wait(itask)
@@ -1822,9 +1839,6 @@ def _get_task_proxy_db_outputs(
self.xtrigger_mgr.xtriggers.sequential_xtrigger_labels
),
)
- if itask is None:
- return None
-
# Update it with outputs that were already completed.
self._load_historical_outputs(itask)
return itask
@@ -1929,8 +1943,8 @@ def set_prereqs_and_outputs(
# Get matching pool tasks and inactive task definitions.
itasks, inactive_tasks, unmatched = self.filter_task_proxies(
items,
- future=True,
- warn=False,
+ inactive=True,
+ warn_no_active=False,
)
flow_nums = self._get_flow_nums(flow, flow_descr)
@@ -1940,7 +1954,7 @@ def set_prereqs_and_outputs(
if flow == ['none'] and itask.flow_nums != set():
LOG.error(
f"[{itask}] ignoring 'flow=none' set: task already has"
- f" {stringify_flow_nums(itask.flow_nums, full=True)}"
+ f" {repr_flow_nums(itask.flow_nums, full=True)}"
)
continue
self.merge_flows(itask, flow_nums)
@@ -2023,7 +2037,7 @@ def _set_prereqs_itask(
# Attempt to set the given presrequisites.
# Log any that aren't valid for the task.
presus = self._standardise_prereqs(prereqs)
- unmatched = itask.satisfy_me(presus.keys())
+ unmatched = itask.satisfy_me(presus.keys(), forced=True)
for task_msg in unmatched:
LOG.warning(
f"{itask.identity} does not depend on"
@@ -2068,16 +2082,128 @@ def _get_active_flow_nums(self) -> 'FlowNums':
or {1}
)
- def remove_tasks(self, items):
- """Remove tasks from the pool (forced by command)."""
- itasks, _, bad_items = self.filter_task_proxies(items)
- for itask in itasks:
- # Spawn next occurrence of xtrigger sequential task.
- self.check_spawn_psx_task(itask)
- self.remove(itask, 'request')
- if self.compute_runahead():
+ def remove_tasks(
+ self, items: Iterable[str], flow_nums: Optional['FlowNums'] = None
+ ) -> None:
+ """Remove tasks from the pool (forced by command).
+
+ Args:
+ items: Relative IDs or globs.
+ flow_nums: Flows to remove the tasks from. If empty or None, it
+ means 'all'.
+ """
+ active, inactive, _unmatched = self.filter_task_proxies(
+ items, warn_no_active=False, inactive=True
+ )
+ if not (active or inactive):
+ return
+
+ if flow_nums is None:
+ flow_nums = set()
+ # Mapping of task IDs to removed flow numbers:
+ removed: Dict[str, FlowNums] = {}
+ not_removed: Set[str] = set()
+
+ for itask in active:
+ fnums_to_remove = itask.match_flows(flow_nums)
+ if not fnums_to_remove:
+ not_removed.add(itask.identity)
+ continue
+ removed[itask.identity] = fnums_to_remove
+ if fnums_to_remove == itask.flow_nums:
+ # Need to remove the task from the pool.
+ # Spawn next occurrence of xtrigger sequential task (otherwise
+ # this would not happen after removing this occurrence):
+ self.check_spawn_psx_task(itask)
+ self.remove(itask, 'request')
+ else:
+ itask.flow_nums.difference_update(fnums_to_remove)
+
+ matched_task_ids = {
+ *removed.keys(),
+ *(quick_relative_id(cycle, task) for task, cycle in inactive),
+ }
+
+ for id_ in matched_task_ids:
+ point_str, name = id_.split('/', 1)
+ tdef = self.config.taskdefs[name]
+ # Go through downstream tasks to see if any need to stand down
+ # as a result of this task being removed:
+ for child in set(itertools.chain.from_iterable(
+ generate_graph_children(tdef, get_point(point_str)).values()
+ )):
+ child_itask = self.get_task(child.point, child.name)
+ if not child_itask:
+ continue
+ fnums_to_remove = child_itask.match_flows(flow_nums)
+ if not fnums_to_remove:
+ continue
+ prereqs_changed = False
+ for prereq in (
+ *child_itask.state.prerequisites,
+ *child_itask.state.suicide_prerequisites,
+ ):
+ # Unset any prereqs naturally satisfied by these tasks
+ # (do not unset those satisfied by `cylc set --pre`):
+ if prereq.unset_naturally_satisfied_dependency(id_):
+ prereqs_changed = True
+ removed.setdefault(id_, set()).update(fnums_to_remove)
+ if not prereqs_changed:
+ continue
+ self.data_store_mgr.delta_task_prerequisite(child_itask)
+ # Check if downstream task is still ready to run:
+ if (
+ child_itask.state.is_gte(TASK_STATUS_PREPARING)
+ # Still ready if the task exists in other flows:
+ or child_itask.flow_nums != fnums_to_remove
+ or child_itask.state.prerequisites_all_satisfied()
+ ):
+ continue
+ # No longer ready to run
+ self.unqueue_task(child_itask)
+ # Check if downstream task should remain spawned:
+ if (
+ # Ignoring tasks we are already dealing with:
+ child_itask.identity in matched_task_ids
+ or child_itask.state.any_satisfied_prerequisite_tasks()
+ ):
+ continue
+ # No longer has reason to be in pool:
+ self.remove(child_itask, 'upstream task(s) removed')
+ # Remove from DB tables to ensure it is not skipped if it
+ # respawns in future:
+ self.workflow_db_mgr.remove_task_from_flows(
+ str(child.point), child.name, fnums_to_remove
+ )
+
+ # Remove from DB tables:
+ db_removed_fnums = self.workflow_db_mgr.remove_task_from_flows(
+ point_str, name, flow_nums
+ )
+ if db_removed_fnums:
+ removed.setdefault(id_, set()).update(db_removed_fnums)
+
+ if removed:
+ tasks_str_list = []
+ for task, fnums in removed.items():
+ self.data_store_mgr.delta_remove_task_flow_nums(task, fnums)
+ tasks_str_list.append(
+ f"{task} {repr_flow_nums(fnums, full=True)}"
+ )
+ LOG.info(f"Removed task(s): {', '.join(sorted(tasks_str_list))}")
+
+ not_removed.update(matched_task_ids.difference(removed))
+ if not_removed:
+ fnums_str = (
+ repr_flow_nums(flow_nums, full=True) if flow_nums else ''
+ )
+ LOG.warning(
+ "Task(s) not removable: "
+ f"{', '.join(sorted(not_removed))} {fnums_str}"
+ )
+
+ if removed and self.compute_runahead():
self.release_runahead_tasks()
- return len(bad_items)
def _get_flow_nums(
self,
@@ -2153,8 +2279,8 @@ def force_trigger_tasks(
"""
# Get matching tasks proxies, and matching inactive task IDs.
- existing_tasks, future_ids, unmatched = self.filter_task_proxies(
- items, future=True, warn=False,
+ existing_tasks, inactive_ids, unmatched = self.filter_task_proxies(
+ items, inactive=True, warn_no_active=False,
)
flow_nums = self._get_flow_nums(flow, flow_descr)
@@ -2164,7 +2290,7 @@ def force_trigger_tasks(
if flow == ['none'] and itask.flow_nums != set():
LOG.error(
f"[{itask}] ignoring 'flow=none' trigger: task already has"
- f" {stringify_flow_nums(itask.flow_nums, full=True)}"
+ f" {repr_flow_nums(itask.flow_nums, full=True)}"
)
continue
if itask.state(TASK_STATUS_PREPARING, *TASK_STATUSES_ACTIVE):
@@ -2177,7 +2303,7 @@ def force_trigger_tasks(
if not flow:
# default: assign to all active flows
flow_nums = self._get_active_flow_nums()
- for name, point in future_ids:
+ for name, point in inactive_ids:
if not self.can_be_spawned(name, point):
continue
submit_num, _, prev_fwait = (
@@ -2303,41 +2429,41 @@ def log_task_pool(self, log_lvl=logging.DEBUG):
def filter_task_proxies(
self,
ids: Iterable[str],
- warn: bool = True,
- future: bool = False,
+ warn_no_active: bool = True,
+ inactive: bool = False,
) -> 'Tuple[List[TaskProxy], Set[Tuple[str, PointBase]], List[str]]':
"""Return task proxies that match names, points, states in items.
Args:
ids:
ID strings.
- warn:
- Whether to log a warning if no matching tasks are found.
- future:
+ warn_no_active:
+ Whether to log a warning if no matching active tasks are found.
+ inactive:
If True, unmatched IDs will be checked against taskdefs
- and cycle, task pairs will be provided in the future_matched
- argument providing the ID
+ and cycle, and any matches will be returned in the second
+ return value, provided that the ID:
* Specifies a cycle point.
* Is not a pattern. (e.g. `*/foo`).
* Does not contain a state selector (e.g. `:failed`).
Returns:
- (matched, future_matched, unmatched)
+ (matched, inactive_matched, unmatched)
"""
matched, unmatched = filter_ids(
self.active_tasks,
ids,
- warn=warn,
+ warn=warn_no_active,
)
- future_matched: 'Set[Tuple[str, PointBase]]' = set()
- if future and unmatched:
- future_matched, unmatched = self.match_inactive_tasks(
+ inactive_matched: 'Set[Tuple[str, PointBase]]' = set()
+ if inactive and unmatched:
+ inactive_matched, unmatched = self.match_inactive_tasks(
unmatched
)
- return matched, future_matched, unmatched
+ return matched, inactive_matched, unmatched
def match_inactive_tasks(
self,
diff --git a/cylc/flow/task_proxy.py b/cylc/flow/task_proxy.py
index 73d819c2fce..a6e7d8a1c7c 100644
--- a/cylc/flow/task_proxy.py
+++ b/cylc/flow/task_proxy.py
@@ -30,32 +30,35 @@
List,
Optional,
Set,
- Tuple,
)
from metomi.isodatetime.timezone import get_local_time_zone
from cylc.flow import LOG
-from cylc.flow.flow_mgr import stringify_flow_nums
+from cylc.flow.cycling.iso8601 import (
+ interval_parse,
+ point_parse,
+)
+from cylc.flow.flow_mgr import repr_flow_nums
from cylc.flow.platforms import get_platform
from cylc.flow.task_action_timer import TimerFlags
from cylc.flow.task_state import (
- TaskState,
- TASK_STATUS_WAITING,
TASK_STATUS_EXPIRED,
+ TASK_STATUS_WAITING,
+ TaskState,
)
from cylc.flow.taskdef import generate_graph_children
from cylc.flow.wallclock import get_unix_time_from_time_string as str2time
-from cylc.flow.cycling.iso8601 import (
- point_parse,
- interval_parse,
-)
+
if TYPE_CHECKING:
from cylc.flow.cycling import PointBase
from cylc.flow.flow_mgr import FlowNums
from cylc.flow.id import Tokens
- from cylc.flow.prerequisite import PrereqMessage, SatisfiedState
+ from cylc.flow.prerequisite import (
+ PrereqMessage,
+ SatisfiedState,
+ )
from cylc.flow.simulation import ModeSettings
from cylc.flow.task_action_timer import TaskActionTimer
from cylc.flow.taskdef import TaskDef
@@ -148,7 +151,7 @@ class TaskProxy:
.graph_children (dict)
graph children: {msg: [(name, point), ...]}
.flow_nums:
- flows I belong to
+ flows I belong to (if empty, belongs to 'none' flow)
flow_wait:
wait for flow merge before spawning children
.waiting_on_job_prep:
@@ -211,7 +214,7 @@ def __init__(
scheduler_tokens: 'Tokens',
tdef: 'TaskDef',
start_point: 'PointBase',
- flow_nums: Optional[Set[int]] = None,
+ flow_nums: Optional['FlowNums'] = None,
status: str = TASK_STATUS_WAITING,
is_held: bool = False,
submit_num: int = 0,
@@ -306,7 +309,7 @@ def __init__(
)
def __repr__(self) -> str:
- return f"<{self.__class__.__name__} '{self.tokens}'>"
+ return f"<{self.__class__.__name__} {self.identity}>"
def __str__(self) -> str:
"""Stringify with tokens, state, submit_num, and flow_nums.
@@ -317,11 +320,11 @@ def __str__(self) -> str:
"""
id_ = self.identity
if self.transient:
- return f"{id_}{stringify_flow_nums(self.flow_nums)}"
+ return f"{id_}{repr_flow_nums(self.flow_nums)}"
if not self.state(TASK_STATUS_WAITING, TASK_STATUS_EXPIRED):
id_ += f"/{self.submit_num:02d}"
return (
- f"{id_}{stringify_flow_nums(self.flow_nums)}:{self.state}"
+ f"{id_}{repr_flow_nums(self.flow_nums)}:{self.state}"
)
def copy_to_reload_successor(
@@ -454,7 +457,7 @@ def next_point(self):
"""Return the next cycle point."""
return self.tdef.next_point(self.point)
- def is_ready_to_run(self) -> Tuple[bool, ...]:
+ def is_ready_to_run(self) -> bool:
"""Is this task ready to run?
Takes account of all dependence: on other tasks, xtriggers, and
@@ -463,16 +466,18 @@ def is_ready_to_run(self) -> Tuple[bool, ...]:
"""
if self.is_manual_submit:
# Manually triggered, ignore unsatisfied prerequisites.
- return (True,)
+ return True
if self.state.is_held:
# A held task is not ready to run.
- return (False,)
+ return False
if self.state.status in self.try_timers:
# A try timer is still active.
- return (self.try_timers[self.state.status].is_delay_done(),)
+ return self.try_timers[self.state.status].is_delay_done()
return (
- self.state(TASK_STATUS_WAITING),
- self.is_waiting_prereqs_done()
+ self.state(TASK_STATUS_WAITING)
+ and self.prereqs_are_satisfied()
+ and self.state.external_triggers_all_satisfied()
+ and self.state.xtriggers_all_satisfied()
)
def set_summary_time(self, event_key, time_str=None):
@@ -486,18 +491,9 @@ def set_summary_time(self, event_key, time_str=None):
self.summary[event_key + '_time'] = float(str2time(time_str))
self.summary[event_key + '_time_string'] = time_str
- def is_task_prereqs_not_done(self):
- """Are some task prerequisites not satisfied?"""
- return (not all(pre.is_satisfied()
- for pre in self.state.prerequisites))
-
- def is_waiting_prereqs_done(self):
- """Are ALL prerequisites satisfied?"""
- return (
- all(pre.is_satisfied() for pre in self.state.prerequisites)
- and self.state.external_triggers_all_satisfied()
- and self.state.xtriggers_all_satisfied()
- )
+ def prereqs_are_satisfied(self) -> bool:
+ """Are all task prerequisites satisfied?"""
+ return all(pre.is_satisfied() for pre in self.state.prerequisites)
def reset_try_timers(self):
# unset any retry delay timers
@@ -522,6 +518,17 @@ def name_match(
match_func(ns, value) for ns in self.tdef.namespace_hierarchy
)
+ def match_flows(self, flow_nums: 'FlowNums') -> 'FlowNums':
+ """Return which of the given flow numbers the task belongs to.
+
+ NOTE: If `flow_nums` is empty, it means 'all', whereas
+ if `self.flow_nums` is empty, it means this task is in the 'none' flow
+ and will not match.
+ """
+ if not flow_nums or not self.flow_nums:
+ return self.flow_nums
+ return self.flow_nums.intersection(flow_nums)
+
def merge_flows(self, flow_nums: Set) -> None:
"""Merge another set of flow_nums with mine."""
self.flow_nums.update(flow_nums)
@@ -553,7 +560,7 @@ def state_reset(
return False
def satisfy_me(
- self, task_messages: 'Iterable[Tokens]'
+ self, task_messages: 'Iterable[Tokens]', forced: bool = False
) -> 'Set[Tokens]':
"""Try to satisfy my prerequisites with given output messages.
@@ -563,7 +570,7 @@ def satisfy_me(
Return a set of unmatched task messages.
"""
- used = self.state.satisfy_me(task_messages)
+ used = self.state.satisfy_me(task_messages, forced)
return set(task_messages) - used
def clock_expire(self) -> bool:
diff --git a/cylc/flow/task_state.py b/cylc/flow/task_state.py
index 9ecd9414d17..7c1c1e69753 100644
--- a/cylc/flow/task_state.py
+++ b/cylc/flow/task_state.py
@@ -324,7 +324,8 @@ def __call__(
def satisfy_me(
self,
- outputs: Iterable['Tokens']
+ outputs: Iterable['Tokens'],
+ forced: bool = False,
) -> Set['Tokens']:
"""Try to satisfy my prerequisites with given outputs.
@@ -333,7 +334,7 @@ def satisfy_me(
valid: Set[Tokens] = set()
for prereq in (*self.prerequisites, *self.suicide_prerequisites):
valid.update(
- prereq.satisfy_me(outputs)
+ prereq.satisfy_me(outputs, forced)
)
return valid
@@ -393,7 +394,7 @@ def get_resolved_dependencies(self):
return sorted(
dep
for prereq in self.prerequisites
- for dep in prereq.get_resolved_dependencies()
+ for dep in prereq.get_satisfied_dependencies()
)
def reset(
@@ -527,3 +528,12 @@ def get_unsatisfied_prerequisites(self) -> List['PrereqMessage']:
for prereq in self.prerequisites if not prereq.is_satisfied()
for key, satisfied in prereq.items() if not satisfied
]
+
+ def any_satisfied_prerequisite_tasks(self) -> bool:
+ """Return True if any of this task's prerequisite tasks are
+ satisfied."""
+ return any(
+ satisfied
+ for prereq in self.prerequisites
+ for satisfied in prereq._satisfied.values()
+ )
diff --git a/cylc/flow/util.py b/cylc/flow/util.py
index b7e1e0e0c73..b649265dbd9 100644
--- a/cylc/flow/util.py
+++ b/cylc/flow/util.py
@@ -17,7 +17,10 @@
import ast
from contextlib import suppress
-from functools import partial
+from functools import (
+ lru_cache,
+ partial,
+)
import json
import re
from textwrap import dedent
@@ -31,6 +34,7 @@
Tuple,
)
+
BOOL_SYMBOLS: Dict[bool, str] = {
# U+2A2F (vector cross product)
False: '⨯',
@@ -163,15 +167,23 @@ def serialise_set(flow_nums: Optional[set] = None) -> str:
'[]'
"""
- return json.dumps(sorted(flow_nums or ()))
+ return _serialise_set(tuple(sorted(flow_nums or ())))
+
+
+@lru_cache(maxsize=100)
+def _serialise_set(flow_nums: tuple) -> str:
+ return json.dumps(flow_nums)
+@lru_cache(maxsize=100)
def deserialise_set(flow_num_str: str) -> set:
"""Convert json string to set.
Example:
- >>> sorted(deserialise_set('[2, 3]'))
- [2, 3]
+ >>> deserialise_set('[2, 3]') == {2, 3}
+ True
+ >>> deserialise_set('[]')
+ set()
"""
return set(json.loads(flow_num_str))
diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py
index 5f1f673440c..ebe176b85ac 100644
--- a/cylc/flow/workflow_db_mgr.py
+++ b/cylc/flow/workflow_db_mgr.py
@@ -23,39 +23,66 @@
* Manage existing run database files on restart.
"""
+from collections import defaultdict
import json
import os
-from shutil import copy, rmtree
+from shutil import (
+ copy,
+ rmtree,
+)
from sqlite3 import OperationalError
from tempfile import mkstemp
from typing import (
- Any, AnyStr, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union
+ TYPE_CHECKING,
+ Any,
+ AnyStr,
+ DefaultDict,
+ Dict,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Union,
)
from packaging.version import parse as parse_version
-from cylc.flow import LOG
+from cylc.flow import (
+ LOG,
+ __version__ as CYLC_VERSION,
+)
from cylc.flow.broadcast_report import get_broadcast_change_iter
+from cylc.flow.exceptions import (
+ CylcError,
+ ServiceFileError,
+)
from cylc.flow.rundb import CylcWorkflowDAO
-from cylc.flow import __version__ as CYLC_VERSION
-from cylc.flow.wallclock import get_current_time_string, get_utc_mode
-from cylc.flow.exceptions import CylcError, ServiceFileError
-from cylc.flow.util import serialise_set, deserialise_set
+from cylc.flow.util import (
+ deserialise_set,
+ serialise_set,
+)
+from cylc.flow.wallclock import (
+ get_current_time_string,
+ get_utc_mode,
+)
+
if TYPE_CHECKING:
from pathlib import Path
+
+ from packaging.version import Version
+
from cylc.flow.cycling import PointBase
+ from cylc.flow.flow_mgr import FlowNums
+ from cylc.flow.rundb import (
+ DbArgDict,
+ DbUpdateTuple,
+ )
from cylc.flow.scheduler import Scheduler
- from cylc.flow.task_pool import TaskPool
from cylc.flow.task_events_mgr import EventKey
+ from cylc.flow.task_pool import TaskPool
from cylc.flow.task_proxy import TaskProxy
-Version = Any
-# TODO: narrow down Any (should be str | int) after implementing type
-# annotations in cylc.flow.task_state.TaskState
-DbArgDict = Dict[str, Any]
-DbUpdateTuple = Tuple[DbArgDict, DbArgDict]
-
PERM_PRIVATE = 0o600 # -rw-------
@@ -141,7 +168,9 @@ def __init__(self, pri_d=None, pub_d=None):
self.TABLE_TASKS_TO_HOLD: [],
self.TABLE_XTRIGGERS: [],
self.TABLE_ABS_OUTPUTS: []}
- self.db_updates_map: Dict[str, List[DbUpdateTuple]] = {}
+ self.db_updates_map: DefaultDict[
+ str, List[DbUpdateTuple]
+ ] = defaultdict(list)
def copy_pri_to_pub(self) -> None:
"""Copy content of primary database file to public database file."""
@@ -232,29 +261,23 @@ def process_queued_ops(self) -> None:
# Record workflow parameters and tasks in pool
# Record any broadcast settings to be dumped out
if any(self.db_deletes_map.values()):
- for table_name, db_deletes in sorted(
- self.db_deletes_map.items()):
+ for table_name, db_deletes in sorted(self.db_deletes_map.items()):
while db_deletes:
where_args = db_deletes.pop(0)
self.pri_dao.add_delete_item(table_name, where_args)
self.pub_dao.add_delete_item(table_name, where_args)
if any(self.db_inserts_map.values()):
- for table_name, db_inserts in sorted(
- self.db_inserts_map.items()):
+ for table_name, db_inserts in sorted(self.db_inserts_map.items()):
while db_inserts:
db_insert = db_inserts.pop(0)
self.pri_dao.add_insert_item(table_name, db_insert)
self.pub_dao.add_insert_item(table_name, db_insert)
- if (hasattr(self, 'db_updates_map') and
- any(self.db_updates_map.values())):
- for table_name, db_updates in sorted(
- self.db_updates_map.items()):
+ if any(self.db_updates_map.values()):
+ for table_name, db_updates in sorted(self.db_updates_map.items()):
while db_updates:
- set_args, where_args = db_updates.pop(0)
- self.pri_dao.add_update_item(
- table_name, set_args, where_args)
- self.pub_dao.add_update_item(
- table_name, set_args, where_args)
+ db_update = db_updates.pop(0)
+ self.pri_dao.add_update_item(table_name, db_update)
+ self.pub_dao.add_update_item(table_name, db_update)
# Previously, we used a separate thread for database writes. This has
# now been removed. For the private database, there is no real
@@ -422,7 +445,7 @@ def put_xtriggers(self, sat_xtrig):
"signature": sig,
"results": json.dumps(res)})
- def put_update_task_state(self, itask):
+ def put_update_task_state(self, itask: 'TaskProxy') -> None:
"""Update task_states table for current state of itask.
NOTE the task_states table is normally updated along with the task pool
@@ -443,9 +466,9 @@ def put_update_task_state(self, itask):
"name": itask.tdef.name,
"flow_nums": serialise_set(itask.flow_nums),
}
- self.db_updates_map.setdefault(self.TABLE_TASK_STATES, [])
self.db_updates_map[self.TABLE_TASK_STATES].append(
- (set_args, where_args))
+ (set_args, where_args)
+ )
def put_update_task_flow_wait(self, itask):
"""Update flow_wait status of a task, in the task_states table.
@@ -463,7 +486,6 @@ def put_update_task_flow_wait(self, itask):
"name": itask.tdef.name,
"flow_nums": serialise_set(itask.flow_nums),
}
- self.db_updates_map.setdefault(self.TABLE_TASK_STATES, [])
self.db_updates_map[self.TABLE_TASK_STATES].append(
(set_args, where_args))
@@ -491,7 +513,6 @@ def put_task_pool(self, pool: 'TaskPool') -> None:
prereq.items()
):
self.put_insert_task_prerequisites(itask, {
- "flow_nums": serialise_set(itask.flow_nums),
"prereq_name": p_name,
"prereq_cycle": p_cycle,
"prereq_output": p_output,
@@ -547,7 +568,6 @@ def put_task_pool(self, pool: 'TaskPool') -> None:
"name": itask.tdef.name,
"flow_nums": serialise_set(itask.flow_nums)
}
- self.db_updates_map.setdefault(self.TABLE_TASK_STATES, [])
self.db_updates_map[self.TABLE_TASK_STATES].append(
(set_args, where_args)
)
@@ -581,12 +601,25 @@ def put_insert_task_jobs(self, itask, args):
"""Put INSERT statement for task_jobs table."""
self._put_insert_task_x(CylcWorkflowDAO.TABLE_TASK_JOBS, itask, args)
- def put_insert_task_states(self, itask, args):
+ def put_insert_task_states(self, itask: 'TaskProxy') -> None:
"""Put INSERT statement for task_states table."""
- self._put_insert_task_x(CylcWorkflowDAO.TABLE_TASK_STATES, itask, args)
+ now = get_current_time_string()
+ self._put_insert_task_x(
+ CylcWorkflowDAO.TABLE_TASK_STATES,
+ itask,
+ {
+ "time_created": now,
+ "time_updated": now,
+ "status": itask.state.status,
+ "flow_nums": serialise_set(itask.flow_nums),
+ "flow_wait": itask.flow_wait,
+ "is_manual_submit": itask.is_manual_submit,
+ },
+ )
def put_insert_task_prerequisites(self, itask, args):
"""Put INSERT statement for task_prerequisites table."""
+ args.setdefault("flow_nums", serialise_set(itask.flow_nums))
self._put_insert_task_x(self.TABLE_TASK_PREREQUISITES, itask, args)
def put_insert_task_outputs(self, itask):
@@ -623,15 +656,16 @@ def put_insert_workflow_flows(self, flow_num, flow_metadata):
}
)
- def _put_insert_task_x(self, table_name, itask, args):
+ def _put_insert_task_x(
+ self, table_name: str, itask: 'TaskProxy', args: 'DbArgDict'
+ ) -> None:
"""Put INSERT statement for a task_* table."""
args.update({
"name": itask.tdef.name,
- "cycle": str(itask.point)})
- if "submit_num" not in args:
- args["submit_num"] = itask.submit_num
- self.db_inserts_map.setdefault(table_name, [])
- self.db_inserts_map[table_name].append(args)
+ "cycle": str(itask.point),
+ })
+ args.setdefault("submit_num", itask.submit_num)
+ self.db_inserts_map.setdefault(table_name, []).append(args)
def put_update_task_jobs(self, itask, set_args):
"""Put UPDATE statement for task_jobs table."""
@@ -650,22 +684,107 @@ def put_update_task_outputs(self, itask: 'TaskProxy') -> None:
"name": itask.tdef.name,
"flow_nums": serialise_set(itask.flow_nums),
}
- self.db_updates_map.setdefault(self.TABLE_TASK_OUTPUTS, []).append(
+ self.db_updates_map[self.TABLE_TASK_OUTPUTS].append(
(set_args, where_args)
)
- def _put_update_task_x(self, table_name, itask, set_args):
+ def _put_update_task_x(
+ self, table_name: str, itask: 'TaskProxy', set_args: 'DbArgDict'
+ ) -> None:
"""Put UPDATE statement for a task_* table."""
where_args = {
"cycle": str(itask.point),
- "name": itask.tdef.name}
+ "name": itask.tdef.name,
+ }
if "submit_num" not in set_args:
where_args["submit_num"] = itask.submit_num
if "flow_nums" not in set_args:
where_args["flow_nums"] = serialise_set(itask.flow_nums)
- self.db_updates_map.setdefault(table_name, [])
self.db_updates_map[table_name].append((set_args, where_args))
+ def remove_task_from_flows(
+ self, point: str, name: str, flow_nums: 'FlowNums'
+ ) -> 'FlowNums':
+ """Remove flow numbers for a task in the task_states and task_outputs
+ tables.
+
+ Args:
+ point: Cycle point of the task.
+ name: Name of the task.
+ flow_nums: Flow numbers to remove. If empty, remove all
+ flow numbers.
+
+ Returns the flow numbers that were removed, if any.
+
+ N.B. the task_prerequisites table is automatically updated separately
+ during the main loop.
+ """
+ removed_flow_nums: FlowNums = set()
+ for table in (
+ self.TABLE_TASK_STATES,
+ self.TABLE_TASK_OUTPUTS,
+ ):
+ fnums_select_stmt = rf'''
+ SELECT
+ flow_nums
+ FROM
+ {table}
+ WHERE
+ cycle = ?
+ AND name = ?
+ ''' # nosec B608 (table name is a code constant)
+ fnums_select_cursor = self.pri_dao.connect().execute(
+ fnums_select_stmt, (point, name)
+ )
+
+ if not flow_nums:
+ for db_fnums_str, *_ in fnums_select_cursor:
+ removed_flow_nums.update(deserialise_set(db_fnums_str))
+
+ stmt = rf'''
+ UPDATE OR REPLACE
+ {table}
+ SET
+ flow_nums = ?
+ WHERE
+ cycle = ?
+ AND name = ?
+ ''' # nosec B608 (table name is a code constant)
+ params: List[tuple] = [(serialise_set(), point, name)]
+ else:
+ # Mapping of existing flow nums to what should be left after
+ # removing the specified flow nums:
+ flow_nums_map: Dict[str, FlowNums] = {}
+ for db_fnums_str, *_ in fnums_select_cursor:
+ db_fnums: FlowNums = deserialise_set(db_fnums_str)
+ fnums_to_remove = db_fnums.intersection(flow_nums)
+ if fnums_to_remove:
+ flow_nums_map[db_fnums_str] = db_fnums.difference(
+ flow_nums
+ )
+ removed_flow_nums.update(fnums_to_remove)
+
+ stmt = rf'''
+ UPDATE OR REPLACE
+ {table}
+ SET
+ flow_nums = ?
+ WHERE
+ cycle = ?
+ AND name = ?
+ AND flow_nums = ?
+ ''' # nosec B608 (table name is a code constant)
+ params = [
+ (serialise_set(new), point, name, old)
+ for old, new in flow_nums_map.items()
+ ]
+
+ self.db_updates_map[table].append(
+ (stmt, params)
+ )
+
+ return removed_flow_nums
+
def recover_pub_from_pri(self):
"""Recover public database from private database."""
if self.pub_dao.n_tries >= self.pub_dao.MAX_TRIES:
@@ -690,7 +809,7 @@ def restart_check(self) -> None:
self.process_queued_ops()
@classmethod
- def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version:
+ def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> 'Version':
"""Return the version of Cylc this DB was last run with.
Args:
@@ -706,7 +825,7 @@ def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version:
{cls.TABLE_WORKFLOW_PARAMS}
WHERE
key == ?
- ''', # nosec (table name is a code constant)
+ ''', # nosec B608 (table name is a code constant)
[cls.KEY_CYLC_VERSION]
).fetchone()[0]
except (TypeError, OperationalError) as exc:
@@ -781,7 +900,7 @@ def upgrade(cls, db_file: Union['Path', str]) -> None:
cls.upgrade_pre_810(pri_dao)
@classmethod
- def check_db_compatibility(cls, db_file: Union['Path', str]) -> Version:
+ def check_db_compatibility(cls, db_file: Union['Path', str]) -> 'Version':
"""Check this DB is compatible with this Cylc version.
Raises:
diff --git a/tests/conftest.py b/tests/conftest.py
index d07f788ba0b..8e6b988d5c2 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -99,31 +99,30 @@ def _inner(cached=False):
@pytest.fixture
-def log_filter():
- """Filter caplog record_tuples.
+def log_filter(caplog: pytest.LogCaptureFixture):
+ """Filter caplog record_tuples (also discarding the log name entry).
Args:
- log: The caplog instance.
- name: Filter out records if they don't match this logger name.
level: Filter out records if they aren't at this logging level.
contains: Filter out records if this string is not in the message.
regex: Filter out records if the message doesn't match this regex.
exact_match: Filter out records if the message does not exactly match
this string.
+ log: A caplog instance.
"""
def _log_filter(
- log: pytest.LogCaptureFixture,
- name: Optional[str] = None,
level: Optional[int] = None,
contains: Optional[str] = None,
regex: Optional[str] = None,
exact_match: Optional[str] = None,
- ) -> List[Tuple[str, int, str]]:
+ log: Optional[pytest.LogCaptureFixture] = None
+ ) -> List[Tuple[int, str]]:
+ if log is None:
+ log = caplog
return [
- (log_name, log_level, log_message)
- for log_name, log_level, log_message in log.record_tuples
- if (name is None or name == log_name)
- and (level is None or level == log_level)
+ (log_level, log_message)
+ for _, log_level, log_message in log.record_tuples
+ if (level is None or level == log_level)
and (contains is None or contains in log_message)
and (regex is None or re.search(regex, log_message))
and (exact_match is None or exact_match == log_message)
diff --git a/tests/functional/cylc-remove/00-simple/flow.cylc b/tests/functional/cylc-remove/00-simple/flow.cylc
index 84c740ad421..0ee53fafed6 100644
--- a/tests/functional/cylc-remove/00-simple/flow.cylc
+++ b/tests/functional/cylc-remove/00-simple/flow.cylc
@@ -1,13 +1,15 @@
# Abort on stall timeout unless we remove unhandled failed and waiting task.
[scheduler]
[[events]]
- stall timeout = PT20S
+ stall timeout = PT30S
abort on stall timeout = True
expected task failures = 1/b
[scheduling]
[[graph]]
- R1 = """a => b => c
- cleaner"""
+ R1 = """
+ a => b => c
+ cleaner
+ """
[runtime]
[[a,c]]
script = true
@@ -15,10 +17,10 @@
script = false
[[cleaner]]
script = """
-cylc__job__poll_grep_workflow_log -E '1/b/01:running.* \(received\)failed'
-# Remove the unhandled failed task
-cylc remove "$CYLC_WORKFLOW_ID//1/b"
-# Remove waiting 1/c
-# (not auto-removed because parent 1/b, an unhandled fail, is not finished.)
-cylc remove "$CYLC_WORKFLOW_ID//1/c:waiting"
-"""
+ cylc__job__poll_grep_workflow_log -E '1/b/01:running.* \(received\)failed'
+ # Remove the unhandled failed task
+ cylc remove "$CYLC_WORKFLOW_ID//1/b"
+ # Remove waiting 1/c
+ # (not auto-removed because parent 1/b, an unhandled fail, is not finished.)
+ cylc remove "$CYLC_WORKFLOW_ID//1/c:waiting"
+ """
diff --git a/tests/functional/cylc-remove/02-cycling/flow.cylc b/tests/functional/cylc-remove/02-cycling/flow.cylc
index 3b6c1051493..249f0676cc4 100644
--- a/tests/functional/cylc-remove/02-cycling/flow.cylc
+++ b/tests/functional/cylc-remove/02-cycling/flow.cylc
@@ -28,18 +28,6 @@
[[foo, waz]]
script = true
[[bar]]
- script = """
- if [[ $CYLC_TASK_CYCLE_POINT == 2020 ]]; then
- false
- else
- true
- fi
- """
+ script = [[ $CYLC_TASK_CYCLE_POINT != 2020 ]]
[[baz]]
- script = """
- if [[ $CYLC_TASK_CYCLE_POINT == 2021 ]]; then
- false
- else
- true
- fi
- """
+ script = [[ $CYLC_TASK_CYCLE_POINT != 2021 ]]
diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py
index bce6ea64e9f..fe4b19ab92b 100644
--- a/tests/integration/conftest.py
+++ b/tests/integration/conftest.py
@@ -18,21 +18,34 @@
import asyncio
from functools import partial
from pathlib import Path
-import pytest
+import re
from shutil import rmtree
from time import time
-from typing import List, TYPE_CHECKING, Set, Tuple, Union
+from typing import (
+ TYPE_CHECKING,
+ List,
+ Set,
+ Tuple,
+ Union,
+)
+
+import pytest
from cylc.flow.config import WorkflowConfig
from cylc.flow.id import Tokens
+from cylc.flow.network.client import WorkflowRuntimeClient
from cylc.flow.option_parsers import Options
from cylc.flow.pathutil import get_cylc_run_dir
from cylc.flow.rundb import CylcWorkflowDAO
-from cylc.flow.scripts.validate import ValidateOptions
from cylc.flow.scripts.install import (
+ get_option_parser as install_gop,
install as cylc_install,
- get_option_parser as install_gop
)
+from cylc.flow.scripts.show import (
+ ShowOptions,
+ prereqs_and_outputs_query,
+)
+from cylc.flow.scripts.validate import ValidateOptions
from cylc.flow.util import serialise_set
from cylc.flow.wallclock import get_current_time_string
from cylc.flow.workflow_files import infer_latest_run_from_id
@@ -41,14 +54,14 @@
from .utils import _rm_if_empty
from .utils.flow_tools import (
_make_flow,
- _make_src_flow,
_make_scheduler,
+ _make_src_flow,
_run_flow,
_start_flow,
)
+
if TYPE_CHECKING:
- from cylc.flow.network.client import WorkflowRuntimeClient
from cylc.flow.scheduler import Scheduler
from cylc.flow.task_proxy import TaskProxy
@@ -116,7 +129,11 @@ def ses_test_dir(request, run_dir):
@pytest.fixture(scope='module')
def mod_test_dir(request, ses_test_dir):
"""The root run dir for test flows in this test module."""
- path = Path(ses_test_dir, request.module.__name__)
+ path = Path(
+ ses_test_dir,
+ # Shorten path by dropping `integration.` prefix:
+ re.sub(r'^integration\.', '', request.module.__name__)
+ )
path.mkdir(exist_ok=True)
yield path
if _pytest_passed(request):
@@ -510,6 +527,10 @@ def reflog():
Note, you'll need to call this on the scheduler *after* you have started
it.
+ N.B. Trigger order is not stable; using a set ensures that tests check
+ trigger logic rather than binding to specific trigger order which could
+ change in the future, breaking the test.
+
Args:
schd:
The scheduler to capture triggering information for.
@@ -588,6 +609,9 @@ async def _complete(
async_timeout (handles shutdown logic more cleanly).
"""
+ if schd.is_paused:
+ raise Exception("Cannot wait for completion of a paused scheduler")
+
start_time = time()
tokens_list: List[Tokens] = []
@@ -622,11 +646,16 @@ def _set_stop(mode=None):
# determine the completion condition
def done():
if wait_tokens:
- return not tokens_list
+ if not tokens_list:
+ return True
+ if not schd.contact_data:
+ raise AssertionError(
+ "Scheduler shut down before tasks completed: " +
+ ", ".join(map(str, tokens_list))
+ )
+ return False
# otherwise wait for the scheduler to shut down
- if not schd.contact_data:
- return True
- return stop_requested
+ return stop_requested or not schd.contact_data
with pytest.MonkeyPatch.context() as mp:
mp.setattr(schd.pool, 'remove_if_complete', _remove_if_complete)
@@ -672,3 +701,23 @@ async def _reftest(
return triggers
return _reftest
+
+
+@pytest.fixture
+def cylc_show():
+ """Fixture that runs `cylc show` on a scheduler, returning JSON object."""
+
+ async def _cylc_show(schd: 'Scheduler', *task_ids: str) -> dict:
+ pclient = WorkflowRuntimeClient(schd.workflow)
+ await schd.update_data_structure()
+ json_filter: dict = {}
+ await prereqs_and_outputs_query(
+ schd.id,
+ [Tokens(id_, relative=True) for id_ in task_ids],
+ pclient,
+ ShowOptions(json=True),
+ json_filter,
+ )
+ return json_filter
+
+ return _cylc_show
diff --git a/tests/integration/events/test_task_events.py b/tests/integration/events/test_task_events.py
index 3ae30c1fe73..81bbfd4316e 100644
--- a/tests/integration/events/test_task_events.py
+++ b/tests/integration/events/test_task_events.py
@@ -52,7 +52,7 @@ async def test_mail_footer_template(
# start the workflow and get it to send an email
ctx = SimpleNamespace(mail_to=None, mail_from=None)
id_keys = [EventKey('none', 'failed', 'failed', Tokens('//1/a'))]
- async with start(mod_one) as one_log:
+ async with start(mod_one):
mod_one.task_events_mgr._process_event_email(mod_one, ctx, id_keys)
# warnings should appear only when the template is invalid
@@ -60,11 +60,9 @@ async def test_mail_footer_template(
# check that template issues are handled correctly
assert bool(log_filter(
- one_log,
contains='Ignoring bad mail footer template',
)) == should_log
assert bool(log_filter(
- one_log,
contains=template,
)) == should_log
diff --git a/tests/integration/events/test_workflow_events.py b/tests/integration/events/test_workflow_events.py
index 6b742264636..4569cbaed8b 100644
--- a/tests/integration/events/test_workflow_events.py
+++ b/tests/integration/events/test_workflow_events.py
@@ -69,11 +69,9 @@ async def test_mail_footer_template(
# check that template issues are handled correctly
assert bool(log_filter(
- one_log,
contains='Ignoring bad mail footer template',
)) == should_log
assert bool(log_filter(
- one_log,
contains=template,
)) == should_log
@@ -114,10 +112,8 @@ async def test_custom_event_handler_template(
# check that template issues are handled correctly
assert bool(log_filter(
- one_log,
contains='bad template',
)) == should_log
assert bool(log_filter(
- one_log,
contains=template,
)) == should_log
diff --git a/tests/integration/main_loop/test_auto_restart.py b/tests/integration/main_loop/test_auto_restart.py
index 20cc4ea81c6..e03c97485a5 100644
--- a/tests/integration/main_loop/test_auto_restart.py
+++ b/tests/integration/main_loop/test_auto_restart.py
@@ -45,6 +45,6 @@ async def test_no_detach(
id_: str = flow(one_conf)
schd: Scheduler = scheduler(id_, paused_start=True, no_detach=True)
with pytest.raises(MainLoopPluginException) as exc:
- async with run(schd) as log:
+ async with run(schd):
await asyncio.sleep(2)
- assert log_filter(log, contains=f"Workflow shutting down - {exc.value}")
+ assert log_filter(contains=f"Workflow shutting down - {exc.value}")
diff --git a/tests/integration/scripts/test_set.py b/tests/integration/scripts/test_set.py
index af368aab0d5..9c2a2da9d3a 100644
--- a/tests/integration/scripts/test_set.py
+++ b/tests/integration/scripts/test_set.py
@@ -149,9 +149,9 @@ async def test_incomplete_detection(
):
"""It should detect and log finished tasks left with incomplete outputs."""
schd = scheduler(flow(one_conf))
- async with start(schd) as log:
+ async with start(schd):
schd.pool.set_prereqs_and_outputs(['1/one'], ['failed'], None, ['1'])
- assert log_filter(log, contains='1/one did not complete')
+ assert log_filter(contains='1/one did not complete')
async def test_pre_all(flow, scheduler, run):
diff --git a/tests/integration/test_config.py b/tests/integration/test_config.py
index c75797e9cbb..e22965afd6f 100644
--- a/tests/integration/test_config.py
+++ b/tests/integration/test_config.py
@@ -151,7 +151,6 @@ def test_validate_param_env_templ(
one_conf,
validate,
env_val,
- caplog,
log_filter,
):
"""It should validate parameter environment templates."""
@@ -166,8 +165,8 @@ def test_validate_param_env_templ(
}
})
validate(id_)
- assert log_filter(caplog, contains='bad parameter environment template')
- assert log_filter(caplog, contains=env_val)
+ assert log_filter(contains='bad parameter environment template')
+ assert log_filter(contains=env_val)
def test_no_graph(flow, validate):
diff --git a/tests/integration/test_data_store_mgr.py b/tests/integration/test_data_store_mgr.py
index 4f45e3c0b0e..93076ae7604 100644
--- a/tests/integration/test_data_store_mgr.py
+++ b/tests/integration/test_data_store_mgr.py
@@ -331,9 +331,8 @@ def test_delta_task_prerequisite(harness):
for itask in schd.pool.get_tasks():
# set prereqs as not-satisfied
for prereq in itask.state.prerequisites:
- prereq._all_satisfied = False
for key in prereq:
- prereq._satisfied[key] = False
+ prereq[key] = False
schd.data_store_mgr.delta_task_prerequisite(itask)
assert not any(p.satisfied for p in get_pb_prereqs(schd))
diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples.py
index dc3495fe39f..a0d15ee7289 100644
--- a/tests/integration/test_examples.py
+++ b/tests/integration/test_examples.py
@@ -23,9 +23,11 @@
import asyncio
import logging
from pathlib import Path
+
import pytest
from cylc.flow import __version__
+from cylc.flow.scheduler import Scheduler
async def test_create_flow(flow, run_dir):
@@ -62,9 +64,9 @@ async def test_logging(flow, scheduler, start, one_conf, log_filter):
# Ensure that the cylc version is logged on startup.
id_ = flow(one_conf)
schd = scheduler(id_)
- async with start(schd) as log:
+ async with start(schd):
# this returns a list of log records containing __version__
- assert log_filter(log, contains=__version__)
+ assert log_filter(contains=__version__)
async def test_scheduler_arguments(flow, scheduler, start, one_conf):
@@ -159,16 +161,12 @@ def killer():
# make sure that this error causes the flow to shutdown
with pytest.raises(MyException):
- async with run(one) as log:
+ async with run(one):
# The `run` fixture's shutdown logic waits for the main loop to run
pass
# make sure the exception was logged
- assert len(log_filter(
- log,
- level=logging.CRITICAL,
- contains='mess'
- )) == 1
+ assert len(log_filter(logging.CRITICAL, contains='mess')) == 1
# make sure the server socket has closed - a good indication of a
# successful clean shutdown
@@ -290,3 +288,11 @@ async def test_reftest(flow, scheduler, reftest):
('1/a', None),
('1/b', ('1/a',)),
}
+
+
+async def test_show(one: Scheduler, start, cylc_show):
+ """Demonstrate the `cylc_show` fixture"""
+ async with start(one):
+ out = await cylc_show(one, '1/one')
+ assert list(out.keys()) == ['1/one']
+ assert out['1/one']['state'] == 'waiting'
diff --git a/tests/integration/test_flow_assignment.py b/tests/integration/test_flow_assignment.py
index 6c0c58a8758..ea729efeb7b 100644
--- a/tests/integration/test_flow_assignment.py
+++ b/tests/integration/test_flow_assignment.py
@@ -27,7 +27,7 @@
FLOW_ALL,
FLOW_NEW,
FLOW_NONE,
- stringify_flow_nums
+ repr_flow_nums
)
from cylc.flow.scheduler import Scheduler
@@ -110,7 +110,7 @@ async def test_flow_assignment(
}
id_ = flow(conf)
schd: Scheduler = scheduler(id_, run_mode='simulation', paused_start=True)
- async with start(schd) as log:
+ async with start(schd):
if command == 'set':
do_command: Callable = functools.partial(
schd.pool.set_prereqs_and_outputs, outputs=['x'], prereqs=[]
@@ -137,10 +137,9 @@ async def test_flow_assignment(
do_command([active_a.identity], flow=[FLOW_NONE])
assert active_a.flow_nums == {1, 2}
assert log_filter(
- log,
contains=(
f'[{active_a}] ignoring \'flow=none\' {command}: '
- f'task already has {stringify_flow_nums(active_a.flow_nums)}'
+ f'task already has {repr_flow_nums(active_a.flow_nums)}'
),
level=logging.ERROR
)
diff --git a/tests/integration/test_queues.py b/tests/integration/test_queues.py
index fc94c4c4a3d..7da83e1a1aa 100644
--- a/tests/integration/test_queues.py
+++ b/tests/integration/test_queues.py
@@ -120,7 +120,7 @@ async def test_queue_held_tasks(
# hold all tasks and resume the workflow
# (nothing should have run yet because the workflow started paused)
- await commands.run_cmd(commands.hold, schd, ['*/*'])
+ await commands.run_cmd(commands.hold(schd, ['*/*']))
schd.resume_workflow()
# release queued tasks
@@ -129,7 +129,7 @@ async def test_queue_held_tasks(
assert len(submitted_tasks) == 0
# un-hold tasks
- await commands.run_cmd(commands.release, schd, ['*/*'])
+ await commands.run_cmd(commands.release(schd, ['*/*']))
# release queued tasks
# (tasks should now be released from the queues)
diff --git a/tests/integration/test_reload.py b/tests/integration/test_reload.py
index 65960ffcdb7..ad96b187722 100644
--- a/tests/integration/test_reload.py
+++ b/tests/integration/test_reload.py
@@ -89,7 +89,7 @@ def change_state(_=0):
change_state()
# reload the workflow
- await commands.run_cmd(commands.reload_workflow, schd)
+ await commands.run_cmd(commands.reload_workflow(schd))
# the task should end in the submitted state
assert foo.state(TASK_STATUS_SUBMITTED)
@@ -127,18 +127,17 @@ async def test_reload_failure(
"""
id_ = flow(one_conf)
schd = scheduler(id_)
- async with start(schd) as log:
+ async with start(schd):
# corrupt the config by removing the scheduling section
two_conf = {**one_conf, 'scheduling': {}}
flow(two_conf, id_=id_)
# reload the workflow
- await commands.run_cmd(commands.reload_workflow, schd)
+ await commands.run_cmd(commands.reload_workflow(schd))
# the reload should have failed but the workflow should still be
# running
assert log_filter(
- log,
contains=(
'Reload failed - WorkflowConfigError:'
' missing [scheduling][[graph]] section'
diff --git a/tests/integration/test_remove.py b/tests/integration/test_remove.py
new file mode 100644
index 00000000000..a2c8b044f8f
--- /dev/null
+++ b/tests/integration/test_remove.py
@@ -0,0 +1,408 @@
+# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
+# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import logging
+from typing import (
+ Dict,
+ List,
+ NamedTuple,
+ Set,
+)
+
+import pytest
+
+from cylc.flow.commands import (
+ force_trigger_tasks,
+ remove_tasks,
+ run_cmd,
+)
+from cylc.flow.cycling.integer import IntegerPoint
+from cylc.flow.flow_mgr import FLOW_ALL
+from cylc.flow.scheduler import Scheduler
+from cylc.flow.task_outputs import TASK_OUTPUT_SUCCEEDED
+from cylc.flow.task_proxy import TaskProxy
+
+
+def get_pool_tasks(schd: Scheduler) -> Set[str]:
+ return {itask.identity for itask in schd.pool.get_tasks()}
+
+
+class CylcShowPrereqs(NamedTuple):
+ prereqs: List[bool]
+ conditions: List[Dict[str, bool]]
+
+
+@pytest.fixture
+async def cylc_show_prereqs(cylc_show):
+ """Fixture that returns the prereq info from `cylc show` in an
+ easy-to-use format."""
+ async def inner(schd: Scheduler, task: str):
+ prerequisites = (await cylc_show(schd, task))[task]['prerequisites']
+ return [
+ (
+ p['satisfied'],
+ {c['taskId']: c['satisfied'] for c in p['conditions']},
+ )
+ for p in prerequisites
+ ]
+
+ return inner
+
+
+@pytest.fixture
+def example_workflow(flow):
+ return flow({
+ 'scheduling': {
+ 'graph': {
+ # Note: test both `&` and separate arrows for combining
+ # dependencies
+ 'R1': '''
+ a1 & a2 => b
+ a3 => b
+ ''',
+ },
+ },
+ })
+
+
+def get_data_store_flow_nums(schd: Scheduler, itask: TaskProxy):
+ _, ds_tproxy = schd.data_store_mgr.store_node_fetcher(itask.tokens)
+ if ds_tproxy:
+ return ds_tproxy.flow_nums
+
+
+async def test_basic(
+ example_workflow, scheduler, start, db_select
+):
+ """Test removing a task from all flows."""
+ schd: Scheduler = scheduler(example_workflow)
+ async with start(schd):
+ a1 = schd.pool._get_task_by_id('1/a1')
+ a3 = schd.pool._get_task_by_id('1/a3')
+ schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED)
+ schd.pool.spawn_on_output(a3, TASK_OUTPUT_SUCCEEDED)
+ await schd.update_data_structure()
+
+ assert a1 in schd.pool.get_tasks()
+ for table in ('task_states', 'task_outputs'):
+ assert db_select(schd, True, table, 'flow_nums', name='a1') == [
+ ('[1]',),
+ ]
+ assert db_select(
+ schd, True, 'task_prerequisites', 'satisfied', prereq_name='a1'
+ ) == [
+ ('satisfied naturally',),
+ ]
+ assert get_data_store_flow_nums(schd, a1) == '[1]'
+
+ await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL]))
+ await schd.update_data_structure()
+
+ assert a1 not in schd.pool.get_tasks() # removed from pool
+ for table in ('task_states', 'task_outputs'):
+ assert db_select(schd, True, table, 'flow_nums', name='a1') == [
+ ('[]',), # removed from all flows
+ ]
+ assert db_select(
+ schd, True, 'task_prerequisites', 'satisfied', prereq_name='a1'
+ ) == [
+ ('0',), # prereq is now unsatisfied
+ ]
+ assert get_data_store_flow_nums(schd, a1) == '[]'
+
+
+async def test_specific_flow(
+ example_workflow, scheduler, start, db_select
+):
+ """Test removing a task from a specific flow."""
+ schd: Scheduler = scheduler(example_workflow)
+
+ def select_prereqs():
+ return db_select(
+ schd,
+ True,
+ 'task_prerequisites',
+ 'flow_nums',
+ 'satisfied',
+ prereq_name='a1',
+ )
+
+ async with start(schd):
+ a1 = schd.pool._get_task_by_id('1/a1')
+ schd.pool.force_trigger_tasks(['1/a1'], ['1', '2'])
+ schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED)
+ await schd.update_data_structure()
+
+ assert a1 in schd.pool.get_tasks()
+ assert a1.flow_nums == {1, 2}
+ for table in ('task_states', 'task_outputs'):
+ assert sorted(
+ db_select(schd, True, table, 'flow_nums', name='a1')
+ ) == [
+ ('[1, 2]',), # triggered task
+ ('[1]',), # original spawned task
+ ]
+ assert select_prereqs() == [
+ ('[1, 2]', 'satisfied naturally'),
+ ]
+ assert get_data_store_flow_nums(schd, a1) == '[1, 2]'
+
+ await run_cmd(remove_tasks(schd, ['1/a1'], ['1']))
+ await schd.update_data_structure()
+
+ assert a1 in schd.pool.get_tasks() # still in pool
+ assert a1.flow_nums == {2}
+ for table in ('task_states', 'task_outputs'):
+ assert sorted(
+ db_select(schd, True, table, 'flow_nums', name='a1')
+ ) == [
+ ('[2]',),
+ ('[]',),
+ ]
+ assert select_prereqs() == [
+ ('[1, 2]', '0'),
+ ]
+ assert get_data_store_flow_nums(schd, a1) == '[2]'
+
+
+async def test_unset_prereq(example_workflow, scheduler, start):
+ """Test removing a task unsets any prerequisites it satisfied."""
+ schd: Scheduler = scheduler(example_workflow)
+ async with start(schd):
+ for task in ('a1', 'a2', 'a3'):
+ schd.pool.spawn_on_output(
+ schd.pool.get_task(IntegerPoint('1'), task),
+ TASK_OUTPUT_SUCCEEDED,
+ )
+ b = schd.pool.get_task(IntegerPoint('1'), 'b')
+ assert b.prereqs_are_satisfied()
+
+ await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL]))
+
+ assert not b.prereqs_are_satisfied()
+
+
+async def test_not_unset_prereq(
+ example_workflow, scheduler, start, db_select
+):
+ """Test removing a task does not unset a force-satisfied prerequisite
+ (one that was satisfied by `cylc set --pre`)."""
+ schd: Scheduler = scheduler(example_workflow)
+ async with start(schd):
+ # This set prereq should not be unset by removing a1:
+ schd.pool.set_prereqs_and_outputs(
+ ['1/b'], outputs=[], prereqs=['1/a1'], flow=[FLOW_ALL]
+ )
+ # Whereas the prereq satisfied by this set output *should* be unset
+ # by removing a2:
+ schd.pool.set_prereqs_and_outputs(
+ ['1/a2'], outputs=['succeeded'], prereqs=[], flow=[FLOW_ALL]
+ )
+ await schd.update_data_structure()
+
+ assert sorted(
+ db_select(
+ schd, True, 'task_prerequisites', 'prereq_name', 'satisfied'
+ )
+ ) == [
+ ('a1', 'force satisfied'),
+ ('a2', 'satisfied naturally'),
+ ('a3', '0'),
+ ]
+
+ await run_cmd(remove_tasks(schd, ['1/a1', '1/a2'], [FLOW_ALL]))
+ await schd.update_data_structure()
+
+ assert sorted(
+ db_select(
+ schd, True, 'task_prerequisites', 'prereq_name', 'satisfied'
+ )
+ ) == [
+ ('a1', 'force satisfied'),
+ ('a2', '0'),
+ ('a3', '0'),
+ ]
+
+
+async def test_logging(flow, scheduler, start, log_filter):
+ """Test logging of a mixture of valid and invalid task removals."""
+ schd: Scheduler = scheduler(
+ flow({
+ 'scheduler': {
+ 'cycle point format': 'CCYY',
+ },
+ 'scheduling': {
+ 'initial cycle point': '2000',
+ 'graph': {
+ 'R3//P1Y': 'b[-P1Y] => a & b',
+ },
+ },
+ })
+ )
+ tasks_to_remove = [
+ # Active, removable tasks:
+ '2000/*',
+ # Future, non-removable tasks:
+ '2001/a', '2001/b',
+ # Glob that doesn't match any active tasks:
+ '2002/*',
+ # Invalid tasks:
+ '2005/a', '2000/doh',
+ ]
+ async with start(schd):
+ await run_cmd(remove_tasks(schd, tasks_to_remove, [FLOW_ALL]))
+
+ assert log_filter(
+ logging.INFO, "Removed task(s): 2000/a (flows=1), 2000/b (flows=1)"
+ )
+
+ assert log_filter(logging.WARNING, "Task(s) not removable: 2001/a, 2001/b")
+ assert log_filter(logging.WARNING, "No active tasks matching: 2002/*")
+ assert log_filter(logging.WARNING, "Invalid cycle point for task: a, 2005")
+ assert log_filter(logging.WARNING, "No matching tasks found: doh")
+
+
+async def test_logging_flow_nums(
+ example_workflow, scheduler, start, log_filter
+):
+ """Test logging of task removals involving flow numbers."""
+ schd: Scheduler = scheduler(example_workflow)
+ async with start(schd):
+ schd.pool.force_trigger_tasks(['1/a1'], ['1', '2'])
+ # Removing from flow that doesn't exist doesn't work:
+ await run_cmd(remove_tasks(schd, ['1/a1'], ['3']))
+ assert log_filter(
+ logging.WARNING, "Task(s) not removable: 1/a1 (flows=3)"
+ )
+
+ # But if a valid flow is included, it will be removed from that flow:
+ await run_cmd(remove_tasks(schd, ['1/a1'], ['2', '3']))
+ assert log_filter(logging.INFO, "Removed task(s): 1/a1 (flows=2)")
+ assert schd.pool._get_task_by_id('1/a1').flow_nums == {1}
+
+
+async def test_retrigger(flow, scheduler, run, reflog, complete):
+ """Test prereqs & re-run behaviour when removing tasks."""
+ schd: Scheduler = scheduler(
+ flow('a => b => c'),
+ paused_start=False,
+ )
+ async with run(schd):
+ reflog_triggers: set = reflog(schd)
+ await complete(schd, '1/b')
+
+ await run_cmd(remove_tasks(schd, ['1/a', '1/b'], [FLOW_ALL]))
+ schd.process_workflow_db_queue()
+ # Removing 1/b should un-queue 1/c:
+ assert len(schd.pool.task_queue_mgr.queues['default'].deque) == 0
+
+ assert reflog_triggers == {
+ ('1/a', None),
+ ('1/b', ('1/a',)),
+ }
+ reflog_triggers.clear()
+
+ await run_cmd(force_trigger_tasks(schd, ['1/a'], []))
+ await complete(schd)
+
+ assert reflog_triggers == {
+ ('1/a', None),
+ # 1/b should have run again after 1/a on the re-trigger in flow 1:
+ ('1/b', ('1/a',)),
+ ('1/c', ('1/b',)),
+ }
+
+
+async def test_prereqs(
+ flow, scheduler, run, complete, cylc_show_prereqs, log_filter
+):
+ """Test prereqs & stall behaviour when removing tasks."""
+ schd: Scheduler = scheduler(
+ flow('(a1 | a2) & b => x'),
+ paused_start=False,
+ )
+ async with run(schd):
+ await complete(schd, '1/a1', '1/a2', '1/b')
+
+ await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL]))
+ assert not schd.pool.is_stalled()
+ assert len(schd.pool.task_queue_mgr.queues['default'].deque)
+ # `cylc show` should reflect the now-unsatisfied condition:
+ assert await cylc_show_prereqs(schd, '1/x') == [
+ (True, {'1/a1': False, '1/a2': True, '1/b': True})
+ ]
+
+ await run_cmd(remove_tasks(schd, ['1/b'], [FLOW_ALL]))
+ # Should cause stall now because 1/c prereq is unsatisfied:
+ assert len(schd.pool.task_queue_mgr.queues['default'].deque) == 0
+ assert schd.pool.is_stalled()
+ assert log_filter(
+ logging.WARNING,
+ "1/x is waiting on ['1/a1:succeeded', '1/b:succeeded']",
+ )
+ assert await cylc_show_prereqs(schd, '1/x') == [
+ (False, {'1/a1': False, '1/a2': True, '1/b': False})
+ ]
+
+ assert schd.pool._get_task_by_id('1/x')
+ await run_cmd(remove_tasks(schd, ['1/a2'], [FLOW_ALL]))
+ # Should cause 1/x to be removed from the pool as it no longer has
+ # any satisfied prerequisite tasks:
+ assert not schd.pool._get_task_by_id('1/x')
+
+
+async def test_downstream_preparing(flow, scheduler, start):
+ """Downstream dependents should not be removed if they are already
+ preparing."""
+ schd: Scheduler = scheduler(
+ flow('''
+ a => x
+ a => y
+ '''),
+ )
+ async with start(schd):
+ a = schd.pool._get_task_by_id('1/a')
+ schd.pool.spawn_on_output(a, TASK_OUTPUT_SUCCEEDED)
+ assert get_pool_tasks(schd) == {'1/a', '1/x', '1/y'}
+
+ schd.pool._get_task_by_id('1/y').state_reset('preparing')
+ await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL]))
+ assert get_pool_tasks(schd) == {'1/y'}
+
+
+async def test_suicide(flow, scheduler, run, reflog, complete):
+ """Test that suicide prereqs are unset by `cylc remove`."""
+ schd: Scheduler = scheduler(
+ flow('''
+ a => b => c => d => x
+ a & c => !x
+ '''),
+ paused_start=False,
+ )
+ async with run(schd):
+ reflog_triggers: set = reflog(schd)
+ await complete(schd, '1/b')
+ await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL]))
+ await complete(schd)
+
+ assert reflog_triggers == {
+ ('1/a', None),
+ ('1/b', ('1/a',)),
+ ('1/c', ('1/b',)),
+ ('1/d', ('1/c',)),
+ # 1/x not suicided as 1/a was removed:
+ ('1/x', ('1/d',)),
+ }
diff --git a/tests/integration/test_resolvers.py b/tests/integration/test_resolvers.py
index 981237a4d2a..4fa21dadbb9 100644
--- a/tests/integration/test_resolvers.py
+++ b/tests/integration/test_resolvers.py
@@ -234,7 +234,7 @@ async def test_command_logging(mock_flow, caplog, log_filter):
{'mode': StopMode.REQUEST_CLEAN.value},
meta,
)
- assert log_filter(caplog, contains='Command "stop" received')
+ assert log_filter(contains='Command "stop" received')
# put_messages: only log for owner
kwargs = {
@@ -244,12 +244,11 @@ async def test_command_logging(mock_flow, caplog, log_filter):
}
meta["auth_user"] = mock_flow.owner
await mock_flow.resolvers._mutation_mapper("put_messages", kwargs, meta)
- assert not log_filter(caplog, contains='Command "put_messages" received:')
+ assert not log_filter(contains='Command "put_messages" received:')
meta["auth_user"] = "Dr Spock"
await mock_flow.resolvers._mutation_mapper("put_messages", kwargs, meta)
- assert log_filter(
- caplog, contains='Command "put_messages" received from Dr Spock')
+ assert log_filter(contains='Command "put_messages" received from Dr Spock')
async def test_command_validation_failure(
diff --git a/tests/integration/test_scheduler.py b/tests/integration/test_scheduler.py
index 79b7327206a..6f1f581e899 100644
--- a/tests/integration/test_scheduler.py
+++ b/tests/integration/test_scheduler.py
@@ -174,7 +174,7 @@ async def test_holding_tasks_whilst_scheduler_paused(
assert submitted_tasks == set()
# hold all tasks & resume the workflow
- await commands.run_cmd(commands.hold, one, ['*/*'])
+ await commands.run_cmd(commands.hold(one, ['*/*']))
one.resume_workflow()
# release queued tasks
@@ -183,7 +183,7 @@ async def test_holding_tasks_whilst_scheduler_paused(
assert submitted_tasks == set()
# release all tasks
- await commands.run_cmd(commands.release, one, ['*/*'])
+ await commands.run_cmd(commands.release(one, ['*/*']))
# release queued tasks
# (the task should be submitted)
@@ -219,12 +219,12 @@ async def test_no_poll_waiting_tasks(
polled_tasks = capture_polling(one)
# Waiting tasks should not be polled.
- await commands.run_cmd(commands.poll_tasks, one, ['*/*'])
+ await commands.run_cmd(commands.poll_tasks(one, ['*/*']))
assert polled_tasks == set()
# Even if they have a submit number.
task.submit_num = 1
- await commands.run_cmd(commands.poll_tasks, one, ['*/*'])
+ await commands.run_cmd(commands.poll_tasks(one, ['*/*']))
assert len(polled_tasks) == 0
# But these states should be:
@@ -235,7 +235,7 @@ async def test_no_poll_waiting_tasks(
TASK_STATUS_RUNNING
]:
task.state.status = state
- await commands.run_cmd(commands.poll_tasks, one, ['*/*'])
+ await commands.run_cmd(commands.poll_tasks(one, ['*/*']))
assert len(polled_tasks) == 1
polled_tasks.clear()
@@ -267,7 +267,7 @@ def raise_ParsecError(*a, **k):
pass
assert log_filter(
- log, level=logging.CRITICAL,
+ logging.CRITICAL,
exact_match="Workflow shutting down - Mock error"
)
assert TRACEBACK_MSG in log.text
@@ -295,7 +295,7 @@ def mock_auto_restart(*a, **k):
async with run(one) as log:
pass
- assert log_filter(log, level=logging.ERROR, contains=err_msg)
+ assert log_filter(logging.ERROR, err_msg)
assert TRACEBACK_MSG in log.text
@@ -361,20 +361,19 @@ async def test_restart_timeout(
# restart the completed workflow
schd = scheduler(id_)
- async with run(schd) as log:
+ async with run(schd):
# it should detect that the workflow has completed and alert the user
assert log_filter(
- log,
contains='This workflow already ran to completion.'
)
# it should activate a timeout
- assert log_filter(log, contains='restart timer starts NOW')
+ assert log_filter(contains='restart timer starts NOW')
# when we trigger tasks the timeout should be cleared
schd.pool.force_trigger_tasks(['1/one'], {1})
await asyncio.sleep(0) # yield control to the main loop
- assert log_filter(log, contains='restart timer stopped')
+ assert log_filter(contains='restart timer stopped')
@pytest.mark.parametrize("signal", ((SIGHUP), (SIGINT), (SIGTERM)))
@@ -387,14 +386,14 @@ async def test_signal_escallation(one, start, signal, log_filter):
See https://github.com/cylc/cylc-flow/pull/6444
"""
- async with start(one) as log:
+ async with start(one):
# put the workflow in the stopping state
one._set_stop(StopMode.REQUEST_CLEAN)
assert one.stop_mode.name == 'REQUEST_CLEAN'
# one signal should escalate this from CLEAN to NOW
one._handle_signal(signal, None)
- assert log_filter(log, contains=signal.name)
+ assert log_filter(contains=signal.name)
assert one.stop_mode.name == 'REQUEST_NOW'
# two signals should escalate this from NOW to NOW_NOW
diff --git a/tests/integration/test_simulation.py b/tests/integration/test_simulation.py
index 49bc76ce5e2..b50acbb084d 100644
--- a/tests/integration/test_simulation.py
+++ b/tests/integration/test_simulation.py
@@ -344,7 +344,7 @@ async def test_settings_reload(
conf_file.read_text().replace('False', 'True'))
# Reload Workflow:
- await commands.run_cmd(commands.reload_workflow, schd)
+ await commands.run_cmd(commands.reload_workflow(schd))
# Submit second psuedo-job and "run" to success:
itask = run_simjob(schd, one_1066.point, 'one')
diff --git a/tests/integration/test_stop_after_cycle_point.py b/tests/integration/test_stop_after_cycle_point.py
index f92e8d449f0..90bab288515 100644
--- a/tests/integration/test_stop_after_cycle_point.py
+++ b/tests/integration/test_stop_after_cycle_point.py
@@ -119,10 +119,11 @@ def get_db_value(schd) -> Optional[str]:
# override this value whilst the workflow is running
await commands.run_cmd(
- commands.stop,
- schd,
- cycle_point=IntegerPoint('4'),
- mode=StopMode.REQUEST_CLEAN,
+ commands.stop(
+ schd,
+ cycle_point=IntegerPoint('4'),
+ mode=StopMode.REQUEST_CLEAN,
+ )
)
assert schd.config.stop_point == IntegerPoint('4')
diff --git a/tests/integration/test_subprocctx.py b/tests/integration/test_subprocctx.py
index c4d0c4ca28a..871106dcd53 100644
--- a/tests/integration/test_subprocctx.py
+++ b/tests/integration/test_subprocctx.py
@@ -47,7 +47,7 @@ def myxtrigger():
return True, {}
"""))
schd = scheduler(id_)
- async with start(schd, level=DEBUG) as log:
+ async with start(schd, level=DEBUG):
# Set off check for x-trigger:
task = schd.pool.get_tasks()[0]
schd.xtrigger_mgr.call_xtriggers_async(task)
@@ -59,4 +59,4 @@ def myxtrigger():
# Assert that both stderr and out from the print statement
# in our xtrigger appear in the log.
for expected in ['Hello World', 'Hello Hades']:
- assert log_filter(log, contains=expected, level=DEBUG)
+ assert log_filter(DEBUG, expected)
diff --git a/tests/integration/test_task_job_mgr.py b/tests/integration/test_task_job_mgr.py
index 48a49eb30aa..b1cf1347071 100644
--- a/tests/integration/test_task_job_mgr.py
+++ b/tests/integration/test_task_job_mgr.py
@@ -23,7 +23,6 @@
from cylc.flow.task_state import TASK_STATUS_RUNNING
-
async def test_run_job_cmd_no_hosts_error(
flow,
scheduler,
@@ -92,7 +91,6 @@ async def test_run_job_cmd_no_hosts_error(
# ...but the failure should be logged
assert log_filter(
- log,
contains='No available hosts for no-host-platform',
)
log.clear()
@@ -105,7 +103,6 @@ async def test_run_job_cmd_no_hosts_error(
# ...but the failure should be logged
assert log_filter(
- log,
contains='No available hosts for no-host-platform',
)
@@ -217,7 +214,7 @@ async def test_broadcast_platform_change(
schd = scheduler(id_, run_mode='live')
- async with start(schd) as log:
+ async with start(schd):
# Change the task platform with broadcast:
schd.broadcast_mgr.put_broadcast(
['1'], ['mytask'], [{'platform': 'foo'}])
@@ -235,4 +232,4 @@ async def test_broadcast_platform_change(
# Check that task platform hasn't become "localhost":
assert schd.pool.get_tasks()[0].platform['name'] == 'foo'
# ... and that remote init failed because all hosts bad:
- assert log_filter(log, contains="(no hosts were reachable)")
+ assert log_filter(regex=r"platform: foo .*\(no hosts were reachable\)")
diff --git a/tests/integration/test_task_pool.py b/tests/integration/test_task_pool.py
index f2a0e650dcd..a524e0b5a63 100644
--- a/tests/integration/test_task_pool.py
+++ b/tests/integration/test_task_pool.py
@@ -63,9 +63,6 @@
# immediately too, because we spawn autospawn absolute-triggered tasks as
# well as parentless tasks. 3/asd does not spawn at start, however.
EXAMPLE_FLOW_CFG = {
- 'scheduler': {
- 'allow implicit tasks': True
- },
'scheduling': {
'cycling mode': 'integer',
'initial cycle point': 1,
@@ -86,7 +83,6 @@
EXAMPLE_FLOW_2_CFG = {
'scheduler': {
- 'allow implicit tasks': True,
'UTC mode': True
},
'scheduling': {
@@ -142,7 +138,7 @@ def assert_expected_log(
@pytest.fixture(scope='module')
async def mod_example_flow(
mod_flow: Callable, mod_scheduler: Callable, mod_run: Callable
-) -> 'Scheduler':
+) -> AsyncGenerator['Scheduler', None]:
"""Return a scheduler for interrogating its task pool.
This is module-scoped so faster than example_flow, but should only be used
@@ -178,7 +174,7 @@ async def example_flow(
@pytest.fixture(scope='module')
async def mod_example_flow_2(
mod_flow: Callable, mod_scheduler: Callable, mod_run: Callable
-) -> 'Scheduler':
+) -> AsyncGenerator['Scheduler', None]:
"""Return a scheduler for interrogating its task pool.
This is module-scoped so faster than example_flow, but should only be used
@@ -570,7 +566,7 @@ async def test_reload_stopcp(
schd: 'Scheduler' = scheduler(flow(cfg))
async with start(schd):
assert str(schd.pool.stop_point) == '2020'
- await commands.run_cmd(commands.reload_workflow, schd)
+ await commands.run_cmd(commands.reload_workflow(schd))
assert str(schd.pool.stop_point) == '2020'
@@ -841,7 +837,7 @@ async def test_reload_prereqs(
flow(conf, id_=id_)
# Reload the workflow config
- await commands.run_cmd(commands.reload_workflow, schd)
+ await commands.run_cmd(commands.reload_workflow(schd))
assert list_tasks(schd) == expected_3
# Check resulting dependencies of task z
@@ -973,7 +969,7 @@ async def test_graph_change_prereq_satisfaction(
flow(conf, id_=id_)
# Reload the workflow config
- await commands.run_cmd(commands.reload_workflow, schd)
+ await commands.run_cmd(commands.reload_workflow(schd))
await test.asend(schd)
@@ -1183,9 +1179,6 @@ async def test_detect_incomplete_tasks(
TASK_STATUS_SUBMIT_FAILED: TaskEventsManager.EVENT_SUBMIT_FAILED
}
id_ = flow({
- 'scheduler': {
- 'allow implicit tasks': 'True',
- },
'scheduling': {
'graph': {
# a workflow with one task for each of the final task states
@@ -1206,7 +1199,6 @@ async def test_detect_incomplete_tasks(
# ensure that it is correctly identified as incomplete
assert not itask.state.outputs.is_complete()
assert log_filter(
- log,
contains=(
f"[{itask}] did not complete the required outputs:"
),
@@ -1228,9 +1220,6 @@ async def test_future_trigger_final_point(
"""
id_ = flow(
{
- 'scheduler': {
- 'allow implicit tasks': 'True',
- },
'scheduling': {
'cycling mode': 'integer',
'initial cycle point': 1,
@@ -1246,7 +1235,6 @@ async def test_future_trigger_final_point(
for itask in schd.pool.get_tasks():
schd.pool.spawn_on_output(itask, "succeeded")
assert log_filter(
- log,
regex=(
".*1/baz.*not spawned: a prerequisite is beyond"
r" the workflow stop point \(1\)"
@@ -1271,17 +1259,17 @@ async def test_set_failed_complete(
schd.pool.task_events_mgr.process_message(one, 1, "failed")
assert log_filter(
- log, regex="1/one.* setting implied output: submitted")
+ regex="1/one.* setting implied output: submitted")
assert log_filter(
- log, regex="1/one.* setting implied output: started")
+ regex="1/one.* setting implied output: started")
assert log_filter(
- log, regex="failed.* did not complete the required outputs")
+ regex="failed.* did not complete the required outputs")
# Set failed task complete via default "set" args.
schd.pool.set_prereqs_and_outputs([one.identity], None, None, ['all'])
assert log_filter(
- log, contains=f'[{one}] removed from active task pool: completed')
+ contains=f'[{one}] removed from active task pool: completed')
db_outputs = db_select(
schd, True, 'task_outputs', 'outputs',
@@ -1305,9 +1293,6 @@ async def test_set_prereqs(
"""
id_ = flow(
{
- 'scheduler': {
- 'allow implicit tasks': 'True',
- },
'scheduling': {
'initial cycle point': '2040',
'graph': {
@@ -1339,7 +1324,8 @@ async def test_set_prereqs(
schd.pool.set_prereqs_and_outputs(
["20400101T0000Z/qux"], None, ["20400101T0000Z/foo:a"], ['all'])
assert log_filter(
- log, contains='20400101T0000Z/qux does not depend on "20400101T0000Z/foo:a"')
+ contains='20400101T0000Z/qux does not depend on "20400101T0000Z/foo:a"'
+ )
# it should not add 20400101T0000Z/qux to the pool
assert (
@@ -1390,7 +1376,6 @@ async def test_set_bad_prereqs(
"""
id_ = flow({
'scheduler': {
- 'allow implicit tasks': 'True',
'cycle point format': '%Y'},
'scheduling': {
'initial cycle point': '2040',
@@ -1406,11 +1391,11 @@ def set_prereqs(prereqs):
async with start(schd) as log:
# Invalid: task name wildcard:
set_prereqs(["2040/*"])
- assert log_filter(log, contains='Invalid prerequisite task name')
+ assert log_filter(contains='Invalid prerequisite task name')
# Invalid: cycle point wildcard.
set_prereqs(["*/foo"])
- assert log_filter(log, contains='Invalid prerequisite cycle point')
+ assert log_filter(contains='Invalid prerequisite cycle point')
async def test_set_outputs_live(
@@ -1424,9 +1409,6 @@ async def test_set_outputs_live(
"""
id_ = flow(
{
- 'scheduler': {
- 'allow implicit tasks': 'True',
- },
'scheduling': {
'graph': {
'R1': """
@@ -1476,15 +1458,12 @@ async def test_set_outputs_live(
)
# it should complete implied outputs (submitted, started) too
- assert log_filter(
- log, contains="setting implied output: submitted")
- assert log_filter(
- log, contains="setting implied output: started")
+ assert log_filter(contains="setting implied output: submitted")
+ assert log_filter(contains="setting implied output: started")
# set foo (default: all required outputs) to complete y.
schd.pool.set_prereqs_and_outputs(["1/foo"], None, None, ['all'])
- assert log_filter(
- log, contains="output 1/foo:succeeded completed")
+ assert log_filter(contains="output 1/foo:succeeded completed")
assert (
pool_get_task_ids(schd.pool) == ["1/bar", "1/baz"]
)
@@ -1501,7 +1480,6 @@ async def test_set_outputs_live2(
"""
id_ = flow(
{
- 'scheduler': {'allow implicit tasks': 'True'},
'scheduling': {'graph': {
'R1': """
foo:a => apple
@@ -1517,7 +1495,6 @@ async def test_set_outputs_live2(
async with start(schd) as log:
schd.pool.set_prereqs_and_outputs(["1/foo"], None, None, ['all'])
assert not log_filter(
- log,
contains="did not complete required outputs: ['a', 'b']"
)
@@ -1533,9 +1510,6 @@ async def test_set_outputs_future(
"""
id_ = flow(
{
- 'scheduler': {
- 'allow implicit tasks': 'True',
- },
'scheduling': {
'graph': {
'R1': "a:x & a:y => b => c"
@@ -1571,9 +1545,9 @@ async def test_set_outputs_future(
prereqs=None,
flow=['all']
)
- assert log_filter(log, contains="output 1/a:cheese not found")
- assert log_filter(log, contains="completed output x")
- assert log_filter(log, contains="completed output y")
+ assert log_filter(contains="output 1/a:cheese not found")
+ assert log_filter(contains="completed output x")
+ assert log_filter(contains="completed output y")
async def test_prereq_satisfaction(
@@ -1587,9 +1561,6 @@ async def test_prereq_satisfaction(
"""
id_ = flow(
{
- 'scheduler': {
- 'allow implicit tasks': 'True',
- },
'scheduling': {
'graph': {
'R1': "a:x & a:y => b"
@@ -1605,8 +1576,8 @@ async def test_prereq_satisfaction(
}
}
)
- schd = scheduler(id_)
- async with start(schd) as log:
+ schd: Scheduler = scheduler(id_)
+ async with start(schd):
# it should start up with just 1/a
assert pool_get_task_ids(schd.pool) == ["1/a"]
# spawn b
@@ -1617,21 +1588,19 @@ async def test_prereq_satisfaction(
b = schd.pool.get_task(IntegerPoint("1"), "b")
- assert not b.is_waiting_prereqs_done()
+ assert not b.prereqs_are_satisfied()
# set valid and invalid prerequisites, by label and message.
schd.pool.set_prereqs_and_outputs(
prereqs=["1/a:xylophone", "1/a:y", "1/a:w", "1/a:z"],
items=["1/b"], outputs=None, flow=['all']
)
- assert log_filter(log, contains="1/a:z not found")
- assert log_filter(log, contains="1/a:w not found")
- assert not log_filter(log, contains='1/b does not depend on "1/a:x"')
- assert not log_filter(
- log, contains='1/b does not depend on "1/a:xylophone"')
- assert not log_filter(log, contains='1/b does not depend on "1/a:y"')
+ assert log_filter(contains="1/a:z not found")
+ assert log_filter(contains="1/a:w not found")
+ # FIXME: testing that something is *not* logged is extremely fragile:
+ assert not log_filter(regex='.*does not depend on.*')
- assert b.is_waiting_prereqs_done()
+ assert b.prereqs_are_satisfied()
@pytest.mark.parametrize('compat_mode', ['compat-mode', 'normal-mode'])
@@ -1910,7 +1879,6 @@ async def test_fast_respawn(
async def test_remove_active_task(
example_flow: 'Scheduler',
- caplog: pytest.LogCaptureFixture,
log_filter: Callable,
) -> None:
"""Test warning on removing an active task."""
@@ -1925,7 +1893,6 @@ async def test_remove_active_task(
assert foo not in task_pool.get_tasks()
assert log_filter(
- caplog,
regex=(
"1/foo.*removed from active task pool:"
" request - active job orphaned"
@@ -1947,7 +1914,6 @@ async def test_remove_by_suicide(
* Removing a task manually (cylc remove) should work the same.
"""
id_ = flow({
- 'scheduler': {'allow implicit tasks': 'True'},
'scheduling': {
'graph': {
'R1': '''
@@ -1966,7 +1932,6 @@ async def test_remove_by_suicide(
# mark 1/a as failed and ensure 1/b is removed by suicide trigger
schd.pool.spawn_on_output(a, TASK_OUTPUT_FAILED)
assert log_filter(
- log,
regex="1/b.*removed from active task pool: suicide trigger"
)
assert pool_get_task_ids(schd.pool) == ["1/a"]
@@ -1975,14 +1940,14 @@ async def test_remove_by_suicide(
log.clear()
schd.pool.force_trigger_tasks(['1/b'], ['1'])
assert log_filter(
- log,
regex='1/b.*added to active task pool',
)
# remove 1/b by request (cylc remove)
- await commands.run_cmd(commands.remove_tasks, schd, ['1/b'])
+ await commands.run_cmd(
+ commands.remove_tasks(schd, ['1/b'], [FLOW_ALL])
+ )
assert log_filter(
- log,
regex='1/b.*removed from active task pool: request',
)
@@ -1990,55 +1955,10 @@ async def test_remove_by_suicide(
log.clear()
schd.pool.force_trigger_tasks(['1/b'], ['1'])
assert log_filter(
- log,
regex='1/b.*added to active task pool',
)
-async def test_remove_no_respawn(flow, scheduler, start, log_filter):
- """Ensure that removed tasks stay removed.
-
- If a task is removed by suicide trigger or "cylc remove", then it should
- not be automatically spawned at a later time.
- """
- id_ = flow({
- 'scheduling': {
- 'graph': {
- 'R1': 'a & b => z',
- },
- },
- })
- schd: 'Scheduler' = scheduler(id_)
- async with start(schd, level=logging.DEBUG) as log:
- a1 = schd.pool.get_task(IntegerPoint("1"), "a")
- b1 = schd.pool.get_task(IntegerPoint("1"), "b")
- assert a1, '1/a should have been spawned on startup'
- assert b1, '1/b should have been spawned on startup'
-
- # mark one of the upstream tasks as succeeded, 1/z should spawn
- schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED)
- schd.workflow_db_mgr.process_queued_ops()
- z1 = schd.pool.get_task(IntegerPoint("1"), "z")
- assert z1, '1/z should have been spawned after 1/a succeeded'
-
- # manually remove 1/z, it should be removed from the pool
- await commands.run_cmd(commands.remove_tasks, schd, ['1/z'])
- schd.workflow_db_mgr.process_queued_ops()
- z1 = schd.pool.get_task(IntegerPoint("1"), "z")
- assert z1 is None, '1/z should have been removed (by request)'
-
- # mark the other upstream task as succeeded, 1/z should not be
- # respawned as a result
- schd.pool.spawn_on_output(b1, TASK_OUTPUT_SUCCEEDED)
- assert log_filter(
- log, contains='Not respawning 1/z - task was removed'
- )
- z1 = schd.pool.get_task(IntegerPoint("1"), "z")
- assert (
- z1 is None
- ), '1/z should have stayed removed (but has been added back into the pool'
-
-
async def test_set_future_flow(flow, scheduler, start, log_filter):
"""Manually-set outputs for new flow num must be recorded in the DB.
@@ -2170,7 +2090,7 @@ async def list_data_store():
# reload
flow(config, id_=id_)
- await commands.run_cmd(commands.reload_workflow, schd)
+ await commands.run_cmd(commands.reload_workflow(schd))
# check xtrigs post-reload
assert list_xtrig_mgr() == {
@@ -2182,7 +2102,7 @@ async def list_data_store():
'c': 'wall_clock(trigger_time=946688400)',
}
-
+
async def test_trigger_unqueued(flow, scheduler, start):
"""Test triggering an unqueued active task.
@@ -2247,7 +2167,7 @@ async def test_expire_dequeue_with_retries(flow, scheduler, start, expire_type):
if expire_type == 'clock-expire':
conf['scheduling']['special tasks'] = {'clock-expire': 'foo(PT0S)'}
- method = lambda schd: schd.pool.clock_expire_tasks()
+ method = lambda schd: schd.pool.clock_expire_tasks()
else:
method = lambda schd: schd.pool.set_prereqs_and_outputs(
['2000/foo'], prereqs=[], outputs=['expired'], flow=['1']
diff --git a/tests/integration/test_workflow_db_mgr.py b/tests/integration/test_workflow_db_mgr.py
index 774d8c21fac..b994f88377f 100644
--- a/tests/integration/test_workflow_db_mgr.py
+++ b/tests/integration/test_workflow_db_mgr.py
@@ -34,12 +34,12 @@ async def test(expected_restart_num: int, do_reload: bool = False):
"""(Re)start the workflow and check the restart number is as expected.
"""
schd: 'Scheduler' = scheduler(id_, paused_start=True)
- async with start(schd) as log:
+ async with start(schd):
if do_reload:
- await commands.run_cmd(commands.reload_workflow, schd)
+ await commands.run_cmd(commands.reload_workflow(schd))
assert schd.workflow_db_mgr.n_restart == expected_restart_num
assert log_filter(
- log, contains=f"(re)start number={expected_restart_num + 1}"
+ contains=f"(re)start number={expected_restart_num + 1}"
# (In the log, it's 1 higher than backend value)
)
assert ('n_restart', f'{expected_restart_num}') in db_select(
diff --git a/tests/integration/test_workflow_files.py b/tests/integration/test_workflow_files.py
index 5b036ef2c11..77937ca66dd 100644
--- a/tests/integration/test_workflow_files.py
+++ b/tests/integration/test_workflow_files.py
@@ -151,7 +151,7 @@ def test_detect_old_contact_file_old_run(workflow, caplog, log_filter):
# as a side effect the contact file should have been removed
assert not workflow.contact_file.exists()
- assert log_filter(caplog, contains='Removed contact file')
+ assert log_filter(contains='Removed contact file')
def test_detect_old_contact_file_none(workflow):
@@ -260,11 +260,9 @@ def _unlink(*args):
# check the appropriate messages were logged
assert bool(log_filter(
- caplog,
contains='Removed contact file',
)) is remove_succeeded
assert bool(log_filter(
- caplog,
contains=(
f'Failed to remove contact file for {workflow.id_}:'
'\nmocked-os-error'
diff --git a/tests/integration/tui/conftest.py b/tests/integration/tui/conftest.py
index 55f83c143f3..86d2267da1e 100644
--- a/tests/integration/tui/conftest.py
+++ b/tests/integration/tui/conftest.py
@@ -4,7 +4,7 @@
from pathlib import Path
import re
from time import sleep
-from uuid import uuid1
+from secrets import token_hex
import pytest
from urwid.display import html_fragment
@@ -211,7 +211,7 @@ def wait_until_loaded(self, *ids, retries=20):
)
if exc:
msg += f'\n{exc}'
- self.compare_screenshot(f'fail-{uuid1()}', msg, 1)
+ self.compare_screenshot(f'fail-{token_hex(4)}', msg, 1)
@pytest.fixture
diff --git a/tests/integration/tui/test_mutations.py b/tests/integration/tui/test_mutations.py
index e9a41466d70..b87622bca5f 100644
--- a/tests/integration/tui/test_mutations.py
+++ b/tests/integration/tui/test_mutations.py
@@ -59,7 +59,7 @@ async def test_online_mutation(
id_ = flow(one_conf, name='one')
schd = scheduler(id_)
with rakiura(size='80,15') as rk:
- async with start(schd) as schd_log:
+ async with start(schd):
await schd.update_data_structure()
assert schd.command_queue.empty()
@@ -91,7 +91,7 @@ async def test_online_mutation(
# the mutation should be in the scheduler's command_queue
await asyncio.sleep(0)
- assert log_filter(schd_log, contains="hold(tasks=['1/one'])")
+ assert log_filter(contains="hold(tasks=['1/one'])")
# close the dialogue and re-run the hold mutation
rk.user_input('q', 'q', 'enter')
@@ -127,7 +127,7 @@ def standardise_cli_cmds(monkeypatch):
"""This remove the variable bit of the workflow ID from CLI commands.
The workflow ID changes from run to run. In order to make screenshots
- stable, this
+ stable, this
"""
from cylc.flow.tui.data import extract_context
def _extract_context(selection):
diff --git a/tests/integration/utils/flow_tools.py b/tests/integration/utils/flow_tools.py
index 34b80a25882..7419fd4fe14 100644
--- a/tests/integration/utils/flow_tools.py
+++ b/tests/integration/utils/flow_tools.py
@@ -28,7 +28,7 @@
import logging
import pytest
from typing import Any, Optional, Union
-from uuid import uuid1
+from secrets import token_hex
from cylc.flow import CYLC_LOG
from cylc.flow.workflow_files import WorkflowFiles
@@ -41,7 +41,7 @@
def _make_src_flow(src_path, conf, filename=WorkflowFiles.FLOW_FILE):
"""Construct a workflow on the filesystem"""
- flow_src_dir = (src_path / str(uuid1()))
+ flow_src_dir = (src_path / token_hex(4))
flow_src_dir.mkdir(parents=True, exist_ok=True)
if isinstance(conf, dict):
conf = flow_config_str(conf)
@@ -53,7 +53,7 @@ def _make_src_flow(src_path, conf, filename=WorkflowFiles.FLOW_FILE):
def _make_flow(
cylc_run_dir: Union[Path, str],
test_dir: Path,
- conf: dict,
+ conf: Union[dict, str],
name: Optional[str] = None,
id_: Optional[str] = None,
defaults: Optional[bool] = True,
@@ -62,6 +62,8 @@ def _make_flow(
"""Construct a workflow on the filesystem.
Args:
+ conf: Either a workflow config dictionary, or a graph string to be
+ used as the R1 graph in the workflow config.
defaults: Set up a common defaults.
* [scheduling]allow implicit tasks = true
@@ -71,10 +73,18 @@ def _make_flow(
flow_run_dir = (cylc_run_dir / id_)
else:
if name is None:
- name = str(uuid1())
+ name = token_hex(4)
flow_run_dir = (test_dir / name)
flow_run_dir.mkdir(parents=True, exist_ok=True)
id_ = str(flow_run_dir.relative_to(cylc_run_dir))
+ if isinstance(conf, str):
+ conf = {
+ 'scheduling': {
+ 'graph': {
+ 'R1': conf
+ }
+ }
+ }
if defaults:
# set the default simulation runtime to zero (can be overridden)
(
diff --git a/tests/unit/cycling/test_iso8601.py b/tests/unit/cycling/test_iso8601.py
index ae0eb957f47..b7a134f7882 100644
--- a/tests/unit/cycling/test_iso8601.py
+++ b/tests/unit/cycling/test_iso8601.py
@@ -601,12 +601,12 @@ def test_exclusion_zero_duration_warning(set_cycling_type, caplog, log_filter):
set_cycling_type(ISO8601_CYCLING_TYPE, "+05")
with pytest.raises(Exception):
ISO8601Sequence('3000', '2999')
- assert log_filter(caplog, contains='zero-duration')
+ assert log_filter(contains='zero-duration')
# parsing a point in an exclusion should not
caplog.clear()
ISO8601Sequence('P1Y ! 3000', '2999')
- assert not log_filter(caplog, contains='zero-duration')
+ assert not log_filter(contains='zero-duration')
def test_simple(set_cycling_type):
diff --git a/tests/unit/post_install/test_log_vc_info.py b/tests/unit/post_install/test_log_vc_info.py
index 59511db5461..67e204db747 100644
--- a/tests/unit/post_install/test_log_vc_info.py
+++ b/tests/unit/post_install/test_log_vc_info.py
@@ -279,13 +279,13 @@ def test_no_base_commit_git(tmp_path: Path):
@require_svn
def test_untracked_svn_subdir(
- svn_source_repo: Tuple[str, str, str], caplog, log_filter
+ svn_source_repo: Tuple[str, str, str], log_filter
):
repo_dir, *_ = svn_source_repo
source_dir = Path(repo_dir, 'jar_jar_binks')
source_dir.mkdir()
assert get_vc_info(source_dir) is None
- assert log_filter(caplog, level=logging.WARNING, contains="$ svn info")
+ assert log_filter(logging.WARNING, contains="$ svn info")
def test_not_installed(
@@ -306,7 +306,6 @@ def test_not_installed(
caplog.set_level(logging.DEBUG)
assert get_vc_info(tmp_path) is None
assert log_filter(
- caplog,
- level=logging.DEBUG,
+ logging.DEBUG,
contains=f"{fake_vcs} does not appear to be installed",
)
diff --git a/tests/unit/test_clean.py b/tests/unit/test_clean.py
index 285bfa6a23f..2308fe318f0 100644
--- a/tests/unit/test_clean.py
+++ b/tests/unit/test_clean.py
@@ -945,7 +945,7 @@ def mocked_remote_clean_cmd_side_effect(id_, platform, timeout, rm_dirs):
id_, platform_names, timeout='irrelevant', rm_dirs=rm_dirs
)
for msg in expected_err_msgs:
- assert log_filter(caplog, level=logging.ERROR, contains=msg)
+ assert log_filter(logging.ERROR, msg)
if expected_platforms:
for p_name in expected_platforms:
mocked_remote_clean_cmd.assert_any_call(
diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py
index 4a820fdb126..18c5f37bd8b 100644
--- a/tests/unit/test_config.py
+++ b/tests/unit/test_config.py
@@ -1400,7 +1400,7 @@ def test_implicit_tasks(
@pytest.mark.parametrize('workflow_meta', [True, False])
@pytest.mark.parametrize('url_type', ['good', 'bad', 'ugly', 'broken'])
-def test_process_urls(caplog, log_filter, workflow_meta, url_type):
+def test_process_urls(log_filter, workflow_meta, url_type):
if url_type == 'good':
# valid cylc 8 syntax
@@ -1436,7 +1436,6 @@ def test_process_urls(caplog, log_filter, workflow_meta, url_type):
elif url_type == 'ugly':
WorkflowConfig.process_metadata_urls(config)
assert log_filter(
- caplog,
contains='Detected deprecated template variables',
)
@@ -1464,7 +1463,6 @@ def test_zero_interval(
should_warn: bool,
opts: Values,
tmp_flow_config: Callable,
- caplog: pytest.LogCaptureFixture,
log_filter: Callable,
):
"""Test that a zero-duration recurrence with >1 repetition gets an
@@ -1482,7 +1480,6 @@ def test_zero_interval(
""")
WorkflowConfig(id_, flow_file, options=opts)
logged = log_filter(
- caplog,
level=logging.WARNING,
contains="Cannot have more than 1 repetition for zero-duration"
)
diff --git a/tests/unit/test_id.py b/tests/unit/test_id.py
index 2d50c9a2706..e3011ee1d57 100644
--- a/tests/unit/test_id.py
+++ b/tests/unit/test_id.py
@@ -17,12 +17,15 @@
import pytest
+from cylc.flow.cycling.integer import IntegerPoint
+from cylc.flow.cycling.iso8601 import ISO8601Point
from cylc.flow.id import (
LEGACY_CYCLE_SLASH_TASK,
LEGACY_TASK_DOT_CYCLE,
RELATIVE_ID,
- Tokens,
UNIVERSAL_ID,
+ Tokens,
+ quick_relative_id,
)
@@ -392,3 +395,14 @@ def test_task_property():
) == (
Tokens('//c:cs/t:ts/j:js', relative=True)
)
+
+
+@pytest.mark.parametrize('cycle, expected', [
+ ('2000', '2000/foo'),
+ (2001, '2001/foo'),
+ (IntegerPoint('3'), '3/foo'),
+ # NOTE: ISO8601Points are not standardised by this function:
+ (ISO8601Point('2002'), '2002/foo'),
+])
+def test_quick_relative_id(cycle, expected):
+ assert quick_relative_id(cycle, 'foo') == expected
diff --git a/tests/unit/test_prerequisite.py b/tests/unit/test_prerequisite.py
index 105e5e85401..828cd0a96bc 100644
--- a/tests/unit/test_prerequisite.py
+++ b/tests/unit/test_prerequisite.py
@@ -14,12 +14,17 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
+from functools import partial
+
import pytest
from cylc.flow.cycling.integer import IntegerPoint
from cylc.flow.cycling.loader import ISO8601_CYCLING_TYPE, get_point
-from cylc.flow.prerequisite import Prerequisite
-from cylc.flow.id import Tokens
+from cylc.flow.id import Tokens, detokenise
+from cylc.flow.prerequisite import Prerequisite, SatisfiedState
+
+
+detok = partial(detokenise, selectors=True, relative=True)
@pytest.fixture
@@ -43,10 +48,10 @@ def test_satisfied(prereq: Prerequisite):
('2001', 'd', 'custom'): False,
}
# No cached satisfaction state yet:
- assert prereq._all_satisfied is None
+ assert prereq._cached_satisfied is None
# Calling self.is_satisfied() should cache the result:
assert not prereq.is_satisfied()
- assert prereq._all_satisfied is False
+ assert prereq._cached_satisfied is False
# mark two prerequisites as satisfied
prereq.satisfy_me([
@@ -63,7 +68,7 @@ def test_satisfied(prereq: Prerequisite):
('2001', 'd', 'custom'): False,
}
# Should have reset cached satisfaction state:
- assert prereq._all_satisfied is None
+ assert prereq._cached_satisfied is None
assert not prereq.is_satisfied()
# mark all prereqs as satisfied
@@ -78,7 +83,7 @@ def test_satisfied(prereq: Prerequisite):
('2001', 'd', 'custom'): 'force satisfied',
}
# Should have set cached satisfaction state as must be true now:
- assert prereq._all_satisfied is True
+ assert prereq._cached_satisfied is True
assert prereq.is_satisfied()
@@ -138,14 +143,100 @@ def test_get_target_points(prereq):
}
-def test_get_resolved_dependencies():
+@pytest.fixture
+def satisfied_states_prereq():
+ """Fixture for testing the full range of possible satisfied states."""
prereq = Prerequisite(IntegerPoint('2'))
prereq[('1', 'a', 'x')] = True
prereq[('1', 'b', 'x')] = False
prereq[('1', 'c', 'x')] = 'satisfied from database'
prereq[('1', 'd', 'x')] = 'force satisfied'
- assert prereq.get_resolved_dependencies() == [
+ return prereq
+
+
+def test_get_satisfied_dependencies(satisfied_states_prereq: Prerequisite):
+ assert satisfied_states_prereq.get_satisfied_dependencies() == [
'1/a',
'1/c',
'1/d',
]
+
+
+def test_unset_naturally_satisfied_dependency(
+ satisfied_states_prereq: Prerequisite
+):
+ satisfied_states_prereq[('1', 'a', 'y')] = True
+ satisfied_states_prereq[('1', 'a', 'z')] = 'force satisfied'
+ for id_, expected in [
+ ('1/a', True),
+ ('1/b', False),
+ ('1/c', True),
+ ('1/d', False),
+ ]:
+ assert (
+ satisfied_states_prereq.unset_naturally_satisfied_dependency(id_)
+ == expected
+ )
+ assert satisfied_states_prereq._satisfied == {
+ ('1', 'a', 'x'): False,
+ ('1', 'a', 'y'): False,
+ ('1', 'a', 'z'): 'force satisfied',
+ ('1', 'b', 'x'): False,
+ ('1', 'c', 'x'): False,
+ ('1', 'd', 'x'): 'force satisfied',
+ }
+
+
+def test_satisfy_me():
+ prereq = Prerequisite(IntegerPoint('2'))
+ for task_name in ('a', 'b', 'c'):
+ prereq[('1', task_name, 'x')] = False
+ assert not prereq.is_satisfied()
+ assert prereq._cached_satisfied is False
+
+ valid = prereq.satisfy_me(
+ [Tokens('//1/a:x'), Tokens('//1/d:x'), Tokens('//1/c:y')],
+ )
+ assert {detok(tokens) for tokens in valid} == {'1/a:x'}
+ assert prereq._satisfied == {
+ ('1', 'a', 'x'): 'satisfied naturally',
+ ('1', 'b', 'x'): False,
+ ('1', 'c', 'x'): False,
+ }
+ # should have reset cached satisfaction state
+ assert prereq._cached_satisfied is None
+
+ valid = prereq.satisfy_me(
+ [Tokens('//1/a:x'), Tokens('//1/b:x')],
+ forced=True,
+ )
+ assert {detok(tokens) for tokens in valid} == {'1/a:x', '1/b:x'}
+ assert prereq._satisfied == {
+ # 1/a:x unaffected as already satisfied
+ ('1', 'a', 'x'): 'satisfied naturally',
+ ('1', 'b', 'x'): 'force satisfied',
+ ('1', 'c', 'x'): False,
+ }
+
+
+@pytest.mark.parametrize('forced', [False, True])
+@pytest.mark.parametrize('existing, expected_when_forced', [
+ (False, 'force satisfied'),
+ ('satisfied from database', 'force satisfied'),
+ ('force satisfied', 'force satisfied'),
+ ('satisfied naturally', 'satisfied naturally'),
+])
+def test_satisfy_me__override(
+ forced: bool,
+ existing: SatisfiedState,
+ expected_when_forced: SatisfiedState,
+):
+ """Test that satisfying a prereq with a different state works as expected
+ with and without the `forced` arg."""
+ prereq = Prerequisite(IntegerPoint('2'))
+ prereq[('1', 'a', 'x')] = existing
+
+ prereq.satisfy_me([Tokens('//1/a:x')], forced)
+ assert prereq[('1', 'a', 'x')] == (
+ expected_when_forced if forced else 'satisfied naturally'
+ )
diff --git a/tests/unit/test_rundb.py b/tests/unit/test_rundb.py
index 06aba70699f..44db75fb2e5 100644
--- a/tests/unit/test_rundb.py
+++ b/tests/unit/test_rundb.py
@@ -112,7 +112,9 @@ def test_operational_error(tmp_path, caplog):
# stage some stuff
dao.add_delete_item(CylcWorkflowDAO.TABLE_TASK_JOBS)
dao.add_insert_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub'])
- dao.add_update_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub'])
+ dao.add_update_item(
+ CylcWorkflowDAO.TABLE_TASK_JOBS, ({'pub': None}, {})
+ )
# connect the to DB
dao.connect()
diff --git a/tests/unit/test_scheduler.py b/tests/unit/test_scheduler.py
index 36beb121d3b..ccd5f5dfed5 100644
--- a/tests/unit/test_scheduler.py
+++ b/tests/unit/test_scheduler.py
@@ -133,7 +133,7 @@ def _select_workflow_host(cached=False):
)
caplog.set_level(logging.ERROR, CYLC_LOG)
assert not Scheduler.workflow_auto_restart(schd, max_retries=2)
- assert log_filter(caplog, contains='elephant')
+ assert log_filter(contains='elephant')
def test_auto_restart_popen_error(monkeypatch, caplog, log_filter):
@@ -166,4 +166,4 @@ def _popen(*args, **kwargs):
)
caplog.set_level(logging.ERROR, CYLC_LOG)
assert not Scheduler.workflow_auto_restart(schd, max_retries=2)
- assert log_filter(caplog, contains='mystderr')
+ assert log_filter(contains='mystderr')
diff --git a/tests/unit/test_task_pool.py b/tests/unit/test_task_pool.py
index b32781895bc..d1ab1642442 100644
--- a/tests/unit/test_task_pool.py
+++ b/tests/unit/test_task_pool.py
@@ -21,6 +21,7 @@
import pytest
from cylc.flow.flow_mgr import FlowNums
+from cylc.flow.prerequisite import SatisfiedState
from cylc.flow.task_pool import TaskPool
@@ -55,3 +56,29 @@ def test_get_active_flow_nums(
)
assert TaskPool._get_active_flow_nums(mock_task_pool) == expected
+
+
+@pytest.mark.parametrize('output_msg, flow_nums, db_flow_nums, expected', [
+ ('foo', set(), {1}, False),
+ ('foo', set(), set(), False),
+ ('foo', {1, 3}, {1}, 'satisfied from database'),
+ ('goo', {1, 3}, {1, 2}, 'satisfied from database'),
+ ('foo', {1, 3}, set(), False),
+ ('foo', {2}, {1}, False),
+ ('foo', {2}, {1, 2}, 'satisfied from database'),
+ ('f', {1}, {1}, False),
+])
+def test_check_output(
+ output_msg: str,
+ flow_nums: set,
+ db_flow_nums: set,
+ expected: SatisfiedState,
+):
+ mock_task_pool = Mock()
+ mock_task_pool.workflow_db_mgr.pri_dao.select_task_outputs.return_value = {
+ '{"f": "foo", "g": "goo"}': db_flow_nums,
+ }
+
+ assert TaskPool.check_task_output(
+ mock_task_pool, '2000', 'haddock', output_msg, flow_nums
+ ) == expected
diff --git a/tests/unit/test_task_proxy.py b/tests/unit/test_task_proxy.py
index 98695ecd13f..4c5513be5cd 100644
--- a/tests/unit/test_task_proxy.py
+++ b/tests/unit/test_task_proxy.py
@@ -14,13 +14,15 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see .
-import pytest
-from pytest import param
from typing import Callable, Optional
from unittest.mock import Mock
+import pytest
+from pytest import param
+
from cylc.flow.cycling import PointBase
from cylc.flow.cycling.iso8601 import ISO8601Point
+from cylc.flow.flow_mgr import FlowNums
from cylc.flow.task_proxy import TaskProxy
@@ -101,3 +103,18 @@ def test_status_match(status_str: Optional[str], expected: bool):
mock_itask = Mock(state=Mock(status='waiting'))
assert TaskProxy.status_match(mock_itask, status_str) is expected
+
+
+@pytest.mark.parametrize('itask_flow_nums, flow_nums, expected', [
+ param({1, 2}, {2}, {2}, id="subset"),
+ param({2}, {1, 2}, {2}, id="superset"),
+ param({1, 2}, {3, 4}, set(), id="disjoint"),
+ param({1, 2}, set(), {1, 2}, id="all-matches-num"),
+ param(set(), {1, 2}, set(), id="num-doesnt-match-none"),
+ param(set(), set(), set(), id="all-doesnt-match-none"),
+])
+def test_match_flows(
+ itask_flow_nums: FlowNums, flow_nums: FlowNums, expected: FlowNums
+):
+ mock_itask = Mock(flow_nums=itask_flow_nums)
+ assert TaskProxy.match_flows(mock_itask, flow_nums) == expected
diff --git a/tests/unit/test_workflow_db_mgr.py b/tests/unit/test_workflow_db_mgr.py
new file mode 100644
index 00000000000..f642eecda61
--- /dev/null
+++ b/tests/unit/test_workflow_db_mgr.py
@@ -0,0 +1,80 @@
+# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE.
+# Copyright (C) NIWA & British Crown (Met Office) & Contributors.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+from pathlib import Path
+from typing import (
+ List,
+ Set,
+)
+from unittest.mock import Mock
+
+import pytest
+from pytest import param
+
+from cylc.flow.cycling.integer import IntegerPoint
+from cylc.flow.flow_mgr import FlowNums
+from cylc.flow.id import Tokens
+from cylc.flow.task_proxy import TaskProxy
+from cylc.flow.taskdef import TaskDef
+from cylc.flow.util import serialise_set
+from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager
+
+
+@pytest.mark.parametrize('flow_nums, expected_removed', [
+ param(set(), {1, 2, 5}, id='all'),
+ param({1}, {1}, id='subset'),
+ param({1, 2, 5}, {1, 2, 5}, id='complete-set'),
+ param({1, 3, 5}, {1, 5}, id='intersect'),
+ param({3, 4}, set(), id='disjoint'),
+])
+def test_remove_task_from_flows(
+ tmp_path: Path, flow_nums: FlowNums, expected_removed: FlowNums
+):
+ db_flows: List[FlowNums] = [
+ {1, 2},
+ {5},
+ set(), # FLOW_NONE
+ ]
+ db_mgr = WorkflowDatabaseManager(tmp_path)
+ schd_tokens = Tokens('~asterix/gaul')
+ tdef = TaskDef('a', {}, None, None, None)
+ with db_mgr.get_pri_dao() as dao:
+ db_mgr.pri_dao = dao
+ db_mgr.pub_dao = Mock()
+ for flow in db_flows:
+ db_mgr.put_insert_task_states(
+ TaskProxy(
+ schd_tokens,
+ tdef,
+ IntegerPoint('1'),
+ flow_nums=flow,
+ ),
+ )
+ db_mgr.process_queued_ops()
+
+ removed_fnums = db_mgr.remove_task_from_flows('1', 'a', flow_nums)
+ assert removed_fnums == expected_removed
+
+ db_mgr.process_queued_ops()
+ remaining_fnums: Set[str] = {
+ fnums_str
+ for fnums_str, *_ in dao.connect().execute(
+ 'SELECT flow_nums FROM task_states'
+ )
+ }
+ assert remaining_fnums == {
+ serialise_set(flow - expected_removed) for flow in db_flows
+ }
diff --git a/tests/unit/test_workflow_events.py b/tests/unit/test_workflow_events.py
index 89449953f20..c9d08791781 100644
--- a/tests/unit/test_workflow_events.py
+++ b/tests/unit/test_workflow_events.py
@@ -83,13 +83,13 @@ def test_process_mail_footer(caplog, log_filter):
assert process_mail_footer(
'%(host)s|%(port)s|%(owner)s|%(suite)s|%(workflow)s', template_vars
) == 'myhost|42|me|my_workflow|my_workflow\n'
- assert not log_filter(caplog, contains='Ignoring bad mail footer template')
+ assert not log_filter(contains='Ignoring bad mail footer template')
# test invalid variable
assert process_mail_footer('%(invalid)s', template_vars) == ''
- assert log_filter(caplog, contains='Ignoring bad mail footer template')
+ assert log_filter(contains='Ignoring bad mail footer template')
# test broken template
caplog.clear()
assert process_mail_footer('%(invalid)s', template_vars) == ''
- assert log_filter(caplog, contains='Ignoring bad mail footer template')
+ assert log_filter(contains='Ignoring bad mail footer template')
diff --git a/tests/unit/test_workflow_files.py b/tests/unit/test_workflow_files.py
index b2b33e495aa..2e85d4180a0 100644
--- a/tests/unit/test_workflow_files.py
+++ b/tests/unit/test_workflow_files.py
@@ -196,7 +196,6 @@ def test_infer_latest_run(
@pytest.mark.parametrize('warn_arg', [True, False])
def test_infer_latest_run_warns_for_runN(
warn_arg: bool,
- caplog: pytest.LogCaptureFixture,
log_filter: Callable,
tmp_run_dir: Callable,
):
@@ -206,8 +205,7 @@ def test_infer_latest_run_warns_for_runN(
runN_path.symlink_to('run1')
infer_latest_run(runN_path, warn_runN=warn_arg)
filtered_log = log_filter(
- caplog, level=logging.WARNING,
- contains="You do not need to include runN in the workflow ID"
+ logging.WARNING, "You do not need to include runN in the workflow ID"
)
assert filtered_log if warn_arg else not filtered_log