diff --git a/changes.d/6370.feat.md b/changes.d/6370.feat.md new file mode 100644 index 00000000000..aa8d80d6140 --- /dev/null +++ b/changes.d/6370.feat.md @@ -0,0 +1,4 @@ +`cylc remove` improvements: +- It can now remove tasks that are no longer active, making it look like they never ran. +- Added the `--flow` option. +- Removed tasks are now demoted to `flow=none`. diff --git a/cylc/flow/command_validation.py b/cylc/flow/command_validation.py index eb5e45c4734..7ba2dfb4f42 100644 --- a/cylc/flow/command_validation.py +++ b/cylc/flow/command_validation.py @@ -24,19 +24,34 @@ ) from cylc.flow.exceptions import InputError -from cylc.flow.id import IDTokens, Tokens +from cylc.flow.flow_mgr import ( + FLOW_ALL, + FLOW_NEW, + FLOW_NONE, +) +from cylc.flow.id import ( + IDTokens, + Tokens, +) from cylc.flow.task_outputs import TASK_OUTPUT_SUCCEEDED -from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE -ERR_OPT_FLOW_VAL = "Flow values must be an integer, or 'all', 'new', or 'none'" +ERR_OPT_FLOW_VAL = ( + f"Flow values must be an integer, or '{FLOW_ALL}', '{FLOW_NEW}', " + f"or '{FLOW_NONE}'" +) +ERR_OPT_FLOW_VAL_2 = f"Flow values must be an integer, or '{FLOW_ALL}'" ERR_OPT_FLOW_COMBINE = "Cannot combine --flow={0} with other flow values" ERR_OPT_FLOW_WAIT = ( f"--wait is not compatible with --flow={FLOW_NEW} or --flow={FLOW_NONE}" ) -def flow_opts(flows: List[str], flow_wait: bool) -> None: +def flow_opts( + flows: List[str], + flow_wait: bool, + allow_new_or_none: bool = True +) -> None: """Check validity of flow-related CLI options. Note the schema defaults flows to []. @@ -63,6 +78,10 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None: cylc.flow.exceptions.InputError: --wait is not compatible with --flow=new or --flow=none + >>> flow_opts(["new"], False, allow_new_or_none=False) + Traceback (most recent call last): + cylc.flow.exceptions.InputError: ... must be an integer, or 'all' + """ if not flows: return @@ -70,9 +89,12 @@ def flow_opts(flows: List[str], flow_wait: bool) -> None: flows = [val.strip() for val in flows] for val in flows: + val = val.strip() if val in {FLOW_NONE, FLOW_NEW, FLOW_ALL}: if len(flows) != 1: raise InputError(ERR_OPT_FLOW_COMBINE.format(val)) + if not allow_new_or_none and val in {FLOW_NEW, FLOW_NONE}: + raise InputError(ERR_OPT_FLOW_VAL_2) else: try: int(val) diff --git a/cylc/flow/commands.py b/cylc/flow/commands.py index b8b777e0957..0b3dbc7b8c8 100644 --- a/cylc/flow/commands.py +++ b/cylc/flow/commands.py @@ -53,18 +53,24 @@ """ from contextlib import suppress -from time import sleep, time +from time import ( + sleep, + time, +) from typing import ( + TYPE_CHECKING, AsyncGenerator, Callable, Dict, Iterable, List, Optional, - TYPE_CHECKING, + TypeVar, Union, ) +from metomi.isodatetime.parsers import TimePointParser + from cylc.flow import LOG import cylc.flow.command_validation as validate from cylc.flow.exceptions import ( @@ -73,23 +79,28 @@ CylcConfigError, ) import cylc.flow.flags +from cylc.flow.flow_mgr import get_flow_nums_set from cylc.flow.log_level import log_level_to_verbosity from cylc.flow.network.schema import WorkflowStopMode from cylc.flow.parsec.exceptions import ParsecError from cylc.flow.task_id import TaskID -from cylc.flow.task_state import TASK_STATUSES_ACTIVE, TASK_STATUS_FAILED -from cylc.flow.workflow_status import RunMode, StopMode +from cylc.flow.task_state import ( + TASK_STATUS_FAILED, + TASK_STATUSES_ACTIVE, +) +from cylc.flow.workflow_status import ( + RunMode, + StopMode, +) -from metomi.isodatetime.parsers import TimePointParser if TYPE_CHECKING: from cylc.flow.scheduler import Scheduler # define a type for command implementations - Command = Callable[ - ..., - AsyncGenerator, - ] + Command = Callable[..., AsyncGenerator] + # define a generic type needed for the @_command decorator + _TCommand = TypeVar('_TCommand', bound=Command) # a directory of registered commands (populated on module import) COMMANDS: 'Dict[str, Command]' = {} @@ -97,15 +108,15 @@ def _command(name: str): """Decorator to register a command.""" - def _command(fcn: 'Command'): + def _command(fcn: '_TCommand') -> '_TCommand': nonlocal name COMMANDS[name] = fcn - fcn.command_name = name # type: ignore + fcn.command_name = name # type: ignore[attr-defined] return fcn return _command -async def run_cmd(fcn, *args, **kwargs): +async def run_cmd(bound_fcn: AsyncGenerator): """Run a command outside of the scheduler's main loop. Normally commands are run via the Scheduler's command_queue (which is @@ -120,10 +131,9 @@ async def run_cmd(fcn, *args, **kwargs): For these purposes use "run_cmd", otherwise, queue commands via the scheduler as normal. """ - cmd = fcn(*args, **kwargs) - await cmd.__anext__() # validate + await bound_fcn.__anext__() # validate with suppress(StopAsyncIteration): - return await cmd.__anext__() # run + return await bound_fcn.__anext__() # run @_command('set') @@ -314,11 +324,15 @@ async def set_verbosity(schd: 'Scheduler', level: Union[int, str]): @_command('remove_tasks') -async def remove_tasks(schd: 'Scheduler', tasks: Iterable[str]): +async def remove_tasks( + schd: 'Scheduler', tasks: Iterable[str], flow: List[str] +): """Remove tasks.""" validate.is_tasks(tasks) + validate.flow_opts(flow, flow_wait=False, allow_new_or_none=False) yield - yield schd.pool.remove_tasks(tasks) + flow_nums = get_flow_nums_set(flow) + schd.pool.remove_tasks(tasks, flow_nums) @_command('reload_workflow') diff --git a/cylc/flow/data_store_mgr.py b/cylc/flow/data_store_mgr.py index a4b28d44fdc..86f050f0299 100644 --- a/cylc/flow/data_store_mgr.py +++ b/cylc/flow/data_store_mgr.py @@ -2369,22 +2369,42 @@ def delta_task_queued(self, itask: TaskProxy) -> None: self.updates_pending = True def delta_task_flow_nums(self, itask: TaskProxy) -> None: - """Create delta for change in task proxy flow_nums. + """Create delta for change in task proxy flow numbers. Args: - itask (cylc.flow.task_proxy.TaskProxy): - Update task-node from corresponding task proxy - objects from the workflow task pool. - + itask: TaskProxy with updated flow numbers. """ tproxy: Optional[PbTaskProxy] tp_id, tproxy = self.store_node_fetcher(itask.tokens) if not tproxy: return - tp_delta = self.updated[TASK_PROXIES].setdefault( - tp_id, PbTaskProxy(id=tp_id)) + self._delta_task_flow_nums(tp_id, itask.flow_nums) + + def delta_remove_task_flow_nums( + self, task: str, removed: 'FlowNums' + ) -> None: + """Create delta for removal of flow numbers from a task proxy. + + Args: + task: Relative ID of task. + removed: Flow numbers to remove from the task proxy in the + data store. + """ + tproxy: Optional[PbTaskProxy] + tp_id, tproxy = self.store_node_fetcher( + Tokens(task, relative=True).duplicate(**self.id_) + ) + if not tproxy: + return + new_flow_nums = deserialise_set(tproxy.flow_nums).difference(removed) + self._delta_task_flow_nums(tp_id, new_flow_nums) + + def _delta_task_flow_nums(self, tp_id: str, flow_nums: 'FlowNums') -> None: + tp_delta: PbTaskProxy = self.updated[TASK_PROXIES].setdefault( + tp_id, PbTaskProxy(id=tp_id) + ) tp_delta.stamp = f'{tp_id}@{time()}' - tp_delta.flow_nums = serialise_set(itask.flow_nums) + tp_delta.flow_nums = serialise_set(flow_nums) self.updates_pending = True def delta_task_runahead(self, itask: TaskProxy) -> None: diff --git a/cylc/flow/dbstatecheck.py b/cylc/flow/dbstatecheck.py index fc2d9cf0da3..3fbad5c6723 100644 --- a/cylc/flow/dbstatecheck.py +++ b/cylc/flow/dbstatecheck.py @@ -28,7 +28,7 @@ IntegerPoint, IntegerInterval ) -from cylc.flow.flow_mgr import stringify_flow_nums +from cylc.flow.flow_mgr import repr_flow_nums from cylc.flow.pathutil import expand_path from cylc.flow.rundb import CylcWorkflowDAO from cylc.flow.task_outputs import ( @@ -318,7 +318,7 @@ def workflow_state_query( if flow_num is not None and flow_num not in flow_nums: # skip result, wrong flow continue - fstr = stringify_flow_nums(flow_nums) + fstr = repr_flow_nums(flow_nums) if fstr: res.append(fstr) db_res.append(res) diff --git a/cylc/flow/flow_mgr.py b/cylc/flow/flow_mgr.py index 1cd1c1e8c70..67f816982ec 100644 --- a/cylc/flow/flow_mgr.py +++ b/cylc/flow/flow_mgr.py @@ -16,8 +16,15 @@ """Manage flow counter and flow metadata.""" -from typing import Dict, Set, Optional, TYPE_CHECKING import datetime +from typing import ( + TYPE_CHECKING, + Dict, + Iterable, + List, + Optional, + Set, +) from cylc.flow import LOG @@ -55,36 +62,62 @@ def add_flow_opts(parser): ) -def stringify_flow_nums(flow_nums: Set[int], full: bool = False) -> str: - """Return a string representation of a set of flow numbers +def get_flow_nums_set(flow: List[str]) -> FlowNums: + """Return set of integer flow numbers from list of strings. - Return: - - "none" for no flow - - "" for the original flow (flows only matter if there are several) - - otherwise e.g. "(flow=1,2,3)" + Returns an empty set if the input is empty or contains only "all". + + >>> get_flow_nums_set(["1", "2", "3"]) + {1, 2, 3} + >>> get_flow_nums_set([]) + set() + >>> get_flow_nums_set(["all"]) + set() + """ + if flow == [FLOW_ALL]: + return set() + return {int(val.strip()) for val in flow} + + +def stringify_flow_nums(flow_nums: Iterable[int]) -> str: + """Return the canonical string for a set of flow numbers. Examples: + >>> stringify_flow_nums({1}) + '1' + + >>> stringify_flow_nums({3, 1, 2}) + '1,2,3' + >>> stringify_flow_nums({}) + '' + + """ + return ','.join(str(i) for i in sorted(flow_nums)) + + +def repr_flow_nums(flow_nums: FlowNums, full: bool = False) -> str: + """Return a representation of a set of flow numbers + + If `full` is False, return an empty string for flows=1. + + Examples: + >>> repr_flow_nums({}) '(flows=none)' - >>> stringify_flow_nums({1}) + >>> repr_flow_nums({1}) '' - >>> stringify_flow_nums({1}, True) + >>> repr_flow_nums({1}, full=True) '(flows=1)' - >>> stringify_flow_nums({1,2,3}) + >>> repr_flow_nums({1,2,3}) '(flows=1,2,3)' """ if not full and flow_nums == {1}: return "" - else: - return ( - "(flows=" - f"{','.join(str(i) for i in flow_nums) or 'none'}" - ")" - ) + return f"(flows={stringify_flow_nums(flow_nums) or 'none'})" class FlowMgr: diff --git a/cylc/flow/id.py b/cylc/flow/id.py index f2c8b05b4a1..68c62e3a118 100644 --- a/cylc/flow/id.py +++ b/cylc/flow/id.py @@ -22,6 +22,7 @@ from enum import Enum import re from typing import ( + TYPE_CHECKING, Iterable, List, Optional, @@ -33,6 +34,10 @@ from cylc.flow import LOG +if TYPE_CHECKING: + from cylc.flow.cycling import PointBase + + class IDTokens(Enum): """Cylc object identifier tokens.""" @@ -524,14 +529,14 @@ def duplicate( ) -def quick_relative_detokenise(cycle, task): +def quick_relative_id(cycle: Union[str, int, 'PointBase'], task: str) -> str: """Generate a relative ID for a task. This is a more efficient solution to `Tokens` for cases where you only want the ID string and don't have any use for a Tokens object. Example: - >>> q = quick_relative_detokenise + >>> q = quick_relative_id >>> q('1', 'a') == Tokens(cycle='1', task='a').relative_id True diff --git a/cylc/flow/network/schema.py b/cylc/flow/network/schema.py index ab34def7f75..a1b9fc1c50c 100644 --- a/cylc/flow/network/schema.py +++ b/cylc/flow/network/schema.py @@ -22,8 +22,8 @@ from operator import attrgetter from typing import ( TYPE_CHECKING, - AsyncGenerator, Any, + AsyncGenerator, Dict, List, Optional, @@ -34,44 +34,68 @@ import graphene from graphene import ( - Boolean, Field, Float, ID, InputObjectType, Int, - Mutation, ObjectType, Schema, String, Argument, Interface + ID, + Argument, + Boolean, + Field, + Float, + InputObjectType, + Int, + Interface, + Mutation, + ObjectType, + Schema, + String, ) from graphene.types.generic import GenericScalar from graphene.utils.str_converters import to_snake_case from graphql.type.definition import get_named_type from cylc.flow import LOG_LEVELS -from cylc.flow.broadcast_mgr import ALL_CYCLE_POINTS_STRS, addict +from cylc.flow.broadcast_mgr import ( + ALL_CYCLE_POINTS_STRS, + addict, +) from cylc.flow.data_store_mgr import ( - FAMILIES, FAMILY_PROXIES, JOBS, TASKS, TASK_PROXIES, - DELTA_ADDED, DELTA_UPDATED + DELTA_ADDED, + DELTA_UPDATED, + FAMILIES, + FAMILY_PROXIES, + JOBS, + TASK_PROXIES, + TASKS, +) +from cylc.flow.flow_mgr import ( + FLOW_ALL, + FLOW_NEW, + FLOW_NONE, ) -from cylc.flow.flow_mgr import FLOW_ALL, FLOW_NEW, FLOW_NONE from cylc.flow.id import Tokens from cylc.flow.task_outputs import SORT_ORDERS from cylc.flow.task_state import ( - TASK_STATUSES_ORDERED, TASK_STATUS_DESC, - TASK_STATUS_WAITING, TASK_STATUS_EXPIRED, + TASK_STATUS_FAILED, TASK_STATUS_PREPARING, + TASK_STATUS_RUNNING, TASK_STATUS_SUBMIT_FAILED, TASK_STATUS_SUBMITTED, - TASK_STATUS_RUNNING, - TASK_STATUS_FAILED, - TASK_STATUS_SUCCEEDED + TASK_STATUS_SUCCEEDED, + TASK_STATUS_WAITING, + TASK_STATUSES_ORDERED, ) from cylc.flow.util import sstrip from cylc.flow.workflow_status import StopMode + if TYPE_CHECKING: from graphql import ResolveInfo from graphql.type.definition import ( - GraphQLNamedType, GraphQLList, + GraphQLNamedType, GraphQLNonNull, ) + from cylc.flow.network.resolvers import BaseResolvers @@ -2108,6 +2132,19 @@ class Meta: ''') resolver = partial(mutator, command='remove_tasks') + class Arguments(TaskMutation.Arguments): + flow = graphene.List( + graphene.NonNull(Flow), + default_value=[FLOW_ALL], + description=sstrip(f''' + "Remove the task(s) from the specified flows. " + + This should be a list of flow numbers, or '{FLOW_ALL}' + to remove the task(s) from all flows they belong to + (which is the default). + ''') + ) + class SetPrereqsAndOutputs(Mutation, TaskMutation): class Meta: diff --git a/cylc/flow/prerequisite.py b/cylc/flow/prerequisite.py index 486c7e84ab3..b4025a53367 100644 --- a/cylc/flow/prerequisite.py +++ b/cylc/flow/prerequisite.py @@ -39,7 +39,7 @@ from cylc.flow.cycling.loader import get_point from cylc.flow.data_messages_pb2 import PbCondition, PbPrerequisite from cylc.flow.exceptions import TriggerExpressionError -from cylc.flow.id import quick_relative_detokenise +from cylc.flow.id import quick_relative_id if TYPE_CHECKING: @@ -58,7 +58,7 @@ class PrereqMessage(NamedTuple): def get_id(self) -> str: """Get the relative ID of the task in this prereq message.""" - return quick_relative_detokenise(self.point, self.task) + return quick_relative_id(self.point, self.task) @staticmethod def coerce(tuple_: AnyPrereqMessage) -> 'PrereqMessage': @@ -92,7 +92,7 @@ class Prerequisite: # Memory optimization - constrain possible attributes to this list. __slots__ = ( "_satisfied", - "_all_satisfied", + "_cached_satisfied", "conditional_expression", "point", ) @@ -118,7 +118,7 @@ def __init__(self, point: 'PointBase'): # * `None` (no cached state) # * `True` (prerequisite satisfied) # * `False` (prerequisite unsatisfied). - self._all_satisfied: Optional[bool] = None + self._cached_satisfied: Optional[bool] = None def instantaneous_hash(self) -> int: """Generate a hash of this prerequisite in its current state. @@ -159,9 +159,9 @@ def __setitem__( if value is True: value = 'satisfied naturally' self._satisfied[key] = value - if not (self._all_satisfied and value): + if not (self._cached_satisfied and value): # Force later recalculation of cached satisfaction state: - self._all_satisfied = None + self._cached_satisfied = None def __iter__(self) -> Iterator[PrereqMessage]: return iter(self._satisfied) @@ -185,7 +185,7 @@ def get_raw_conditional_expression(self): def set_condition(self, expr): """Set the conditional expression for this prerequisite. - Resets the cached state (self._all_satisfied). + Resets the cached state (self._cached_satisfied). Examples: # GH #3644 construct conditional expression when one task name @@ -201,7 +201,7 @@ def set_condition(self, expr): 'bool(self._satisfied[("1", "xfoo", "succeeded")])'] """ - self._all_satisfied = None + self._cached_satisfied = None if '|' in expr: # Make a Python expression so we can eval() the logic. for message in self._satisfied: @@ -221,14 +221,14 @@ def is_satisfied(self): Return cached state if present, else evaluate the prerequisite. """ - if self._all_satisfied is not None: + if self._cached_satisfied is not None: # Cached value. - return self._all_satisfied + return self._cached_satisfied if self._satisfied == {}: # No prerequisites left after pre-initial simplification. return True - self._all_satisfied = self._eval_satisfied() - return self._all_satisfied + self._cached_satisfied = self._eval_satisfied() + return self._cached_satisfied def _eval_satisfied(self) -> bool: """Evaluate the prerequisite's condition expression. @@ -253,12 +253,17 @@ def _eval_satisfied(self) -> bool: ) from None return res - def satisfy_me(self, outputs: Iterable['Tokens']) -> 'Set[Tokens]': - """Attempt to satisfy me with given outputs. + def satisfy_me( + self, outputs: Iterable['Tokens'], forced: bool = False + ) -> 'Set[Tokens]': + """Set the given outputs as satisfied. - Updates cache with the result. Return outputs that match. + Args: + outputs: List of outputs to satisfy. + forced: If True, records that this should not be undone by + `cylc remove`. """ valid = set() for output in outputs: @@ -268,7 +273,10 @@ def satisfy_me(self, outputs: Iterable['Tokens']) -> 'Set[Tokens]': if prereq not in self._satisfied: continue valid.add(output) - self[prereq] = 'satisfied naturally' + if self._satisfied[prereq] != 'satisfied naturally': + self[prereq] = ( + 'force satisfied' if forced else 'satisfied naturally' + ) return valid def api_dump(self) -> Optional[PbPrerequisite]: @@ -318,9 +326,9 @@ def set_satisfied(self) -> None: if not self._satisfied[message]: self._satisfied[message] = 'force satisfied' if self.conditional_expression: - self._all_satisfied = self._eval_satisfied() + self._cached_satisfied = self._eval_satisfied() else: - self._all_satisfied = True + self._cached_satisfied = True def iter_target_point_strings(self): yield from { @@ -334,7 +342,7 @@ def get_target_points(self): get_point(p) for p in self.iter_target_point_strings() ] - def get_resolved_dependencies(self) -> List[str]: + def get_satisfied_dependencies(self) -> List[str]: """Return a list of satisfied dependencies. E.G: ['1/foo', '2/bar'] @@ -345,3 +353,13 @@ def get_resolved_dependencies(self) -> List[str]: for msg, satisfied in self._satisfied.items() if satisfied ] + + def unset_naturally_satisfied_dependency(self, id_: str) -> bool: + """Set the matching dependency to unsatisfied and return True only if + it was naturally satisfied.""" + changed = False + for msg, sat in self._satisfied.items(): + if msg.get_id() == id_ and sat and sat != 'force satisfied': + self[msg] = False + changed = True + return changed diff --git a/cylc/flow/rundb.py b/cylc/flow/rundb.py index c715079491b..315098c390d 100644 --- a/cylc/flow/rundb.py +++ b/cylc/flow/rundb.py @@ -15,6 +15,7 @@ # along with this program. If not, see . """Provide data access object for the workflow runtime database.""" +from collections import defaultdict from contextlib import suppress from dataclasses import dataclass from os.path import expandvars @@ -23,6 +24,8 @@ import traceback from typing import ( TYPE_CHECKING, + Any, + DefaultDict, Dict, Iterable, List, @@ -30,11 +33,13 @@ Set, Tuple, Union, + cast, ) from cylc.flow import LOG from cylc.flow.exceptions import PlatformLookupError import cylc.flow.flags +from cylc.flow.flow_mgr import stringify_flow_nums from cylc.flow.util import ( deserialise_set, serialise_set, @@ -47,6 +52,13 @@ from cylc.flow.flow_mgr import FlowNums +DbArgDict = Dict[str, Any] +DbUpdateTuple = Union[ + Tuple[DbArgDict, DbArgDict], + Tuple[str, list] +] + + @dataclass class CylcWorkflowDAOTableColumn: """Represent a column in a table.""" @@ -69,7 +81,7 @@ class CylcWorkflowDAOTable: def __init__(self, name, column_items): self.name = name - self.columns = [] + self.columns: List[CylcWorkflowDAOTableColumn] = [] for column_item in column_items: name = column_item[0] attrs = {} @@ -81,7 +93,7 @@ def __init__(self, name, column_items): attrs.get("is_primary_key", False))) self.delete_queues = {} self.insert_queue = [] - self.update_queues = {} + self.update_queues: DefaultDict[str, list] = defaultdict(list) def get_create_stmt(self): """Return an SQL statement to create this table.""" @@ -150,14 +162,23 @@ def add_insert_item(self, args): args.get(column.name, None) for column in self.columns] self.insert_queue.append(stmt_args) - def add_update_item(self, set_args, where_args): + def add_update_item(self, item: DbUpdateTuple) -> None: """Queue an UPDATE item. + If stmt is not a string, it should be a tuple (set_args, where_args) - set_args should be a dict, with column keys and values to be set. where_args should be a dict, update will only apply to rows matching all these items. """ + if isinstance(item[0], str): + stmt = item[0] + params = cast('list', item[1]) + self.update_queues[stmt].extend(params) + return + + set_args = item[0] + where_args = cast('DbArgDict', item[1]) set_strs = [] stmt_args = [] for column in self.columns: @@ -177,9 +198,8 @@ def add_update_item(self, set_args, where_args): stmt = self.FMT_UPDATE % { "name": self.name, "set_str": set_str, - "where_str": where_str} - if stmt not in self.update_queues: - self.update_queues[stmt] = [] + "where_str": where_str + } self.update_queues[stmt].append(stmt_args) @@ -407,15 +427,18 @@ def add_insert_item(self, table_name, args): """ self.tables[table_name].add_insert_item(args) - def add_update_item(self, table_name, set_args, where_args=None): + def add_update_item( + self, table_name: str, item: DbUpdateTuple + ) -> None: """Queue an UPDATE item for a given table. + If stmt is not a string, it should be a tuple (set_args, where_args) - set_args should be a dict, with column keys and values to be set. where_args should be a dict, update will only apply to rows matching all these items. """ - self.tables[table_name].add_update_item(set_args, where_args) + self.tables[table_name].add_update_item(item) def close(self) -> None: """Explicitly close the connection.""" @@ -580,10 +603,10 @@ def select_workflow_params(self) -> Iterable[Tuple[str, Optional[str]]]: key, value FROM {self.TABLE_WORKFLOW_PARAMS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return self.connect().execute(stmt) - def select_workflow_flows(self, flow_nums): + def select_workflow_flows(self, flow_nums: Iterable[int]): """Return flow data for selected flows.""" stmt = rf''' SELECT @@ -591,8 +614,8 @@ def select_workflow_flows(self, flow_nums): FROM {self.TABLE_WORKFLOW_FLOWS} WHERE - flow_num in ({','.join(str(f) for f in flow_nums)}) - ''' # nosec (table name is code constant, flow_nums just integers) + flow_num in ({stringify_flow_nums(flow_nums)}) + ''' # nosec B608 (table name is code constant, flow_nums just ints) flows = {} for flow_num, start_time, descr in self.connect().execute(stmt): flows[flow_num] = { @@ -608,7 +631,7 @@ def select_workflow_flows_max_flow_num(self): MAX(flow_num) FROM {self.TABLE_WORKFLOW_FLOWS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return self.connect().execute(stmt).fetchone()[0] def select_workflow_params_restart_count(self): @@ -620,7 +643,7 @@ def select_workflow_params_restart_count(self): {self.TABLE_WORKFLOW_PARAMS} WHERE key == 'n_restart' - """ # nosec (table name is code constant) + """ # nosec B608 (table name is code constant) result = self.connect().execute(stmt).fetchone() return int(result[0]) if result else 0 @@ -636,7 +659,7 @@ def select_workflow_template_vars(self, callback): key, value FROM {self.TABLE_WORKFLOW_TEMPLATE_VARS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) )): callback(row_idx, list(row)) @@ -653,7 +676,7 @@ def select_task_action_timers(self, callback): {",".join(attrs)} FROM {self.TABLE_TASK_ACTION_TIMERS} - ''' # nosec + ''' # nosec B608 # * table name is code constant # * attrs are code constants for row_idx, row in enumerate(self.connect().execute(stmt)): @@ -679,7 +702,7 @@ def select_task_job(self, cycle, name, submit_num=None): AND name==? ORDER BY submit_num DESC LIMIT 1 - ''' # nosec + ''' # nosec B608 # * table name is code constant # * keys are code constants stmt_args = [cycle, name] @@ -693,7 +716,7 @@ def select_task_job(self, cycle, name, submit_num=None): cycle==? AND name==? AND submit_num==? - ''' # nosec + ''' # nosec B608 # * table name is code constant # * keys are code constants stmt_args = [cycle, name, submit_num] @@ -776,7 +799,7 @@ def select_task_job_platforms(self): platform_name FROM {self.TABLE_TASK_JOBS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return {i[0] for i in self.connect().execute(stmt)} def select_prev_instances( @@ -789,7 +812,7 @@ def select_prev_instances( # Ignore bandit false positive: B608: hardcoded_sql_expressions # Not an injection, simply putting the table name in the SQL query # expression as a string constant local to this module. - stmt = ( # nosec + stmt = ( # nosec B608 r"SELECT flow_nums,submit_num,flow_wait,status FROM %(name)s" r" WHERE name==? AND cycle==?" ) % {"name": self.TABLE_TASK_STATES} @@ -837,7 +860,7 @@ def select_task_outputs( {self.TABLE_TASK_OUTPUTS} WHERE name==? AND cycle==? - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return { outputs: deserialise_set(flow_nums) for flow_nums, outputs in self.connect().execute( @@ -851,7 +874,7 @@ def select_xtriggers_for_restart(self, callback): signature, results FROM {self.TABLE_XTRIGGERS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) for row_idx, row in enumerate(self.connect().execute(stmt, [])): callback(row_idx, list(row)) @@ -861,7 +884,7 @@ def select_abs_outputs_for_restart(self, callback): cycle, name, output FROM {self.TABLE_ABS_OUTPUTS} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) for row_idx, row in enumerate(self.connect().execute(stmt, [])): callback(row_idx, list(row)) @@ -974,7 +997,7 @@ def select_task_prerequisites( cycle == ? AND name == ? AND flow_nums == ? - """ # nosec (table name is code constant) + """ # nosec B608 (table name is code constant) stmt_args = [cycle, name, flow_nums] return list(self.connect().execute(stmt, stmt_args)) @@ -985,7 +1008,7 @@ def select_tasks_to_hold(self) -> List[Tuple[str, str]]: name, cycle FROM {self.TABLE_TASKS_TO_HOLD} - ''' # nosec (table name is code constant) + ''' # nosec B608 (table name is code constant) return list(self.connect().execute(stmt)) def select_task_times(self): @@ -1007,7 +1030,7 @@ def select_task_times(self): {self.TABLE_TASK_JOBS} WHERE run_status = 0 - """ # nosec (table name is code constant) + """ # nosec B608 (table name is code constant) columns = ( 'name', 'cycle', 'host', 'job_runner', 'submit_time', 'start_time', 'succeed_time' diff --git a/cylc/flow/scheduler.py b/cylc/flow/scheduler.py index d3f7fd18a36..da871bd2cc5 100644 --- a/cylc/flow/scheduler.py +++ b/cylc/flow/scheduler.py @@ -32,7 +32,6 @@ from typing import ( Any, AsyncGenerator, - Callable, Dict, Iterable, List, @@ -544,7 +543,7 @@ async def configure(self, params): elif self.config.cfg['scheduling']['hold after cycle point']: holdcp = self.config.cfg['scheduling']['hold after cycle point'] if holdcp is not None: - await commands.run_cmd(commands.set_hold_point, self, holdcp) + await commands.run_cmd(commands.set_hold_point(self, holdcp)) if self.options.paused_start: self.pause_workflow('Paused on start up') @@ -634,7 +633,7 @@ async def run_scheduler(self) -> None: if self.pool.get_tasks(): # (If we're not restarting a finished workflow) self.restart_remote_init() - await commands.run_cmd(commands.poll_tasks, self, ['*/*']) + await commands.run_cmd(commands.poll_tasks(self, ['*/*'])) self.run_event_handlers(self.EVENT_STARTUP, 'workflow starting') await asyncio.gather( @@ -949,10 +948,6 @@ def process_queued_task_messages(self) -> None: warn += f'\n {msg.job_id}: {msg.severity} - "{msg.message}"' LOG.warning(warn) - def get_command_method(self, command_name: str) -> Callable: - """Return a command processing method or raise AttributeError.""" - return getattr(self, f'command_{command_name}') - async def process_command_queue(self) -> None: """Process queued commands.""" qsize = self.command_queue.qsize() @@ -1391,8 +1386,8 @@ async def workflow_shutdown(self): self.time_next_kill is not None and time() > self.time_next_kill ): - await commands.run_cmd(commands.poll_tasks, self, ['*/*']) - await commands.run_cmd(commands.kill_tasks, self, ['*/*']) + await commands.run_cmd(commands.poll_tasks(self, ['*/*'])) + await commands.run_cmd(commands.kill_tasks(self, ['*/*'])) self.time_next_kill = time() + self.INTERVAL_STOP_KILL # Is the workflow set to auto stop [+restart] now ... @@ -1534,7 +1529,7 @@ async def _main_loop(self) -> None: self.broadcast_mgr.check_ext_triggers( itask, self.ext_trigger_queue) - if all(itask.is_ready_to_run()): + if itask.is_ready_to_run(): self.pool.queue_task(itask) if self.xtrigger_mgr.sequential_spawn_next: diff --git a/cylc/flow/scripts/remove.py b/cylc/flow/scripts/remove.py index ef4c74d02c8..edd65e68df5 100755 --- a/cylc/flow/scripts/remove.py +++ b/cylc/flow/scripts/remove.py @@ -18,7 +18,26 @@ """cylc remove [OPTIONS] ARGS -Remove one or more task instances from a running workflow. +Remove task instances from a running workflow and the workflow's history. + +This removes the task(s) from any specified flows. The task will still exist, +just not in the specified flows, so will not influence the evolution of +the workflow in those flows. + +If a task is removed from all flows, it and its outputs will be left in the +`None` flow. This preserves a record that the task ran, but it will not +influence any flows in any way. + +Examples: + # remove a task which has already run + # (any tasks downstream of this task which have already run or are currently + # running will be left alone The task and its outputs will be left in the + # None flow) + $ cylc remove + + # remove a task from a specified flow + # (the task may remain in other flows) + $ cylc remove --flow=1 """ from functools import partial @@ -33,6 +52,7 @@ ) from cylc.flow.terminal import cli_function + if TYPE_CHECKING: from optparse import Values @@ -41,10 +61,12 @@ mutation ( $wFlows: [WorkflowID]!, $tasks: [NamespaceIDGlob]!, + $flow: [Flow!], ) { remove ( workflows: $wFlows, tasks: $tasks, + flow: $flow ) { result } @@ -61,6 +83,19 @@ def get_option_parser() -> COP: argdoc=[FULL_ID_MULTI_ARG_DOC], ) + parser.add_option( + '--flow', + action='append', + dest='flow', + metavar='FLOW', + help=( + "Remove the task(s) from the specified flow number. " + "Reuse the option to remove the task(s) from multiple flows. " + "If the option is not used at all, the task(s) will be removed " + "from all flows." + ), + ) + return parser @@ -75,6 +110,7 @@ async def run(options: 'Values', workflow_id: str, *tokens_list): tokens.relative_id_with_selectors for tokens in tokens_list ], + 'flow': options.flow, } } diff --git a/cylc/flow/scripts/show.py b/cylc/flow/scripts/show.py index 7d7bab1dfdc..c9bf6670fcf 100755 --- a/cylc/flow/scripts/show.py +++ b/cylc/flow/scripts/show.py @@ -60,6 +60,7 @@ from cylc.flow.option_parsers import ( CylcOptionParser as COP, ID_MULTI_ARG_DOC, + Options, ) from cylc.flow.terminal import cli_function from cylc.flow.util import BOOL_SYMBOLS @@ -246,6 +247,9 @@ def get_option_parser(): return parser +ShowOptions = Options(get_option_parser()) + + async def workflow_meta_query(workflow_id, pclient, options, json_filter): query = WORKFLOW_META_QUERY query_kwargs = { diff --git a/cylc/flow/task_job_mgr.py b/cylc/flow/task_job_mgr.py index 500aa830b2c..a38db016855 100644 --- a/cylc/flow/task_job_mgr.py +++ b/cylc/flow/task_job_mgr.py @@ -26,19 +26,26 @@ from contextlib import suppress import json -import os from logging import ( CRITICAL, DEBUG, INFO, - WARNING + WARNING, ) +import os from shutil import rmtree from time import time -from typing import TYPE_CHECKING, Any, Union, Optional +from typing import ( + TYPE_CHECKING, + Any, + Dict, + List, + Optional, + Union, +) from cylc.flow import LOG -from cylc.flow.job_runner_mgr import JobPollContext +from cylc.flow.cfgspec.globalcfg import SYSPATH from cylc.flow.exceptions import ( NoHostsError, NoPlatformsError, @@ -48,9 +55,10 @@ ) from cylc.flow.hostuserutil import ( get_host, - is_remote_platform + is_remote_platform, ) from cylc.flow.job_file import JobFileWriter +from cylc.flow.job_runner_mgr import JobPollContext from cylc.flow.pathutil import get_remote_workflow_run_job_dir from cylc.flow.platforms import ( get_host_from_platform, @@ -64,50 +72,51 @@ from cylc.flow.subprocpool import SubProcPool from cylc.flow.task_action_timer import ( TaskActionTimer, - TimerFlags + TimerFlags, ) from cylc.flow.task_events_mgr import ( TaskEventsManager, - log_task_job_activity + log_task_job_activity, ) from cylc.flow.task_job_logs import ( JOB_LOG_JOB, NN, get_task_job_activity_log, get_task_job_job_log, - get_task_job_log + get_task_job_log, ) from cylc.flow.task_message import FAIL_MESSAGE_PREFIX from cylc.flow.task_outputs import ( TASK_OUTPUT_FAILED, TASK_OUTPUT_STARTED, TASK_OUTPUT_SUBMITTED, - TASK_OUTPUT_SUCCEEDED + TASK_OUTPUT_SUCCEEDED, ) from cylc.flow.task_remote_mgr import ( + REMOTE_FILE_INSTALL_255, REMOTE_FILE_INSTALL_DONE, REMOTE_FILE_INSTALL_FAILED, REMOTE_FILE_INSTALL_IN_PROGRESS, - REMOTE_INIT_IN_PROGRESS, REMOTE_INIT_255, - REMOTE_FILE_INSTALL_255, - REMOTE_INIT_DONE, REMOTE_INIT_FAILED, - TaskRemoteMgr + REMOTE_INIT_DONE, + REMOTE_INIT_FAILED, + REMOTE_INIT_IN_PROGRESS, + TaskRemoteMgr, ) from cylc.flow.task_state import ( TASK_STATUS_PREPARING, - TASK_STATUS_SUBMITTED, TASK_STATUS_RUNNING, + TASK_STATUS_SUBMITTED, TASK_STATUS_WAITING, - TASK_STATUSES_ACTIVE + TASK_STATUSES_ACTIVE, ) +from cylc.flow.util import serialise_set from cylc.flow.wallclock import ( get_current_time_string, get_time_string_from_unix_time, - get_utc_mode + get_utc_mode, ) -from cylc.flow.cfgspec.globalcfg import SYSPATH -from cylc.flow.util import serialise_set + if TYPE_CHECKING: from cylc.flow.task_proxy import TaskProxy @@ -271,12 +280,10 @@ def submit_task_jobs(self, workflow, itasks, curve_auth, if not prepared_tasks: return bad_tasks - auth_itasks = {} # {platform: [itask, ...], ...} - + # Mapping of platforms to task proxies: + auth_itasks: Dict[str, List[TaskProxy]] = {} for itask in prepared_tasks: - platform_name = itask.platform['name'] - auth_itasks.setdefault(platform_name, []) - auth_itasks[platform_name].append(itask) + auth_itasks.setdefault(itask.platform['name'], []).append(itask) # Submit task jobs for each platform # Non-prepared tasks can be considered done for now: done_tasks = bad_tasks diff --git a/cylc/flow/task_pool.py b/cylc/flow/task_pool.py index f645403df15..a1468b1c9f3 100644 --- a/cylc/flow/task_pool.py +++ b/cylc/flow/task_pool.py @@ -16,34 +16,54 @@ """Wrangle task proxies to manage the workflow.""" -from contextlib import suppress from collections import Counter +from contextlib import suppress +import itertools import json +import logging from textwrap import indent from typing import ( + TYPE_CHECKING, Dict, Iterable, List, NamedTuple, Optional, Set, - TYPE_CHECKING, Tuple, Type, Union, ) -import logging -import cylc.flow.flags from cylc.flow import LOG -from cylc.flow.cycling.loader import get_point, standardise_point_string +from cylc.flow.cycling.loader import ( + get_point, + standardise_point_string, +) from cylc.flow.exceptions import ( - WorkflowConfigError, PointParsingError, PlatformLookupError) -from cylc.flow.id import Tokens, detokenise + PlatformLookupError, + PointParsingError, + WorkflowConfigError, +) +import cylc.flow.flags +from cylc.flow.flow_mgr import ( + FLOW_ALL, + FLOW_NEW, + FLOW_NONE, + repr_flow_nums, +) +from cylc.flow.id import ( + Tokens, + detokenise, + quick_relative_id, +) from cylc.flow.id_cli import contains_fnmatch from cylc.flow.id_match import filter_ids -from cylc.flow.workflow_status import StopMode -from cylc.flow.task_action_timer import TaskActionTimer, TimerFlags +from cylc.flow.platforms import get_platform +from cylc.flow.task_action_timer import ( + TaskActionTimer, + TimerFlags, +) from cylc.flow.task_events_mgr import ( CustomTaskEventHandlerContext, EventKey, @@ -51,45 +71,42 @@ TaskJobLogsRetrieveContext, ) from cylc.flow.task_id import TaskID +from cylc.flow.task_outputs import ( + TASK_OUTPUT_EXPIRED, + TASK_OUTPUT_FAILED, + TASK_OUTPUT_SUBMIT_FAILED, + TASK_OUTPUT_SUCCEEDED, +) from cylc.flow.task_proxy import TaskProxy +from cylc.flow.task_queues.independent import IndepQueueManager from cylc.flow.task_state import ( - TASK_STATUSES_ACTIVE, - TASK_STATUSES_FINAL, - TASK_STATUS_WAITING, TASK_STATUS_EXPIRED, + TASK_STATUS_FAILED, TASK_STATUS_PREPARING, - TASK_STATUS_SUBMITTED, TASK_STATUS_RUNNING, + TASK_STATUS_SUBMITTED, TASK_STATUS_SUCCEEDED, - TASK_STATUS_FAILED, + TASK_STATUS_WAITING, + TASK_STATUSES_ACTIVE, + TASK_STATUSES_FINAL, ) from cylc.flow.task_trigger import TaskTrigger -from cylc.flow.util import ( - serialise_set, - deserialise_set -) -from cylc.flow.wallclock import get_current_time_string -from cylc.flow.platforms import get_platform -from cylc.flow.task_outputs import ( - TASK_OUTPUT_SUCCEEDED, - TASK_OUTPUT_EXPIRED, - TASK_OUTPUT_FAILED, - TASK_OUTPUT_SUBMIT_FAILED, -) -from cylc.flow.task_queues.independent import IndepQueueManager +from cylc.flow.taskdef import generate_graph_children +from cylc.flow.util import deserialise_set +from cylc.flow.workflow_status import StopMode -from cylc.flow.flow_mgr import ( - stringify_flow_nums, - FLOW_ALL, - FLOW_NONE, - FLOW_NEW -) if TYPE_CHECKING: from cylc.flow.config import WorkflowConfig - from cylc.flow.cycling import IntervalBase, PointBase + from cylc.flow.cycling import ( + IntervalBase, + PointBase, + ) from cylc.flow.data_store_mgr import DataStoreMgr - from cylc.flow.flow_mgr import FlowMgr, FlowNums + from cylc.flow.flow_mgr import ( + FlowMgr, + FlowNums, + ) from cylc.flow.prerequisite import SatisfiedState from cylc.flow.task_events_mgr import TaskEventsManager from cylc.flow.taskdef import TaskDef @@ -203,18 +220,7 @@ def db_add_new_flow_rows(self, itask: TaskProxy) -> None: Call when a new task is spawned or a flow merge occurs. """ # Add row to task_states table. - now = get_current_time_string() - self.workflow_db_mgr.put_insert_task_states( - itask, - { - "time_created": now, - "time_updated": now, - "status": itask.state.status, - "flow_nums": serialise_set(itask.flow_nums), - "flow_wait": itask.flow_wait, - "is_manual_submit": itask.is_manual_submit - } - ) + self.workflow_db_mgr.put_insert_task_states(itask) # Add row to task_outputs table: self.workflow_db_mgr.put_insert_task_outputs(itask) @@ -441,13 +447,24 @@ def check_task_output( output_msg: str, flow_nums: 'FlowNums', ) -> 'SatisfiedState': - """Returns truthy if the specified output is satisfied in the DB.""" + """Returns truthy if the specified output is satisfied in the DB. + + Args: + cycle: Cycle point of the task whose output is being checked. + task: Name of the task whose output is being checked. + output_msg: The output message to check for. + flow_nums: Flow numbers of the task whose output is being + checked. If this is empty it means 'none'; will return False. + """ + if not flow_nums: + return False + for task_outputs, task_flow_nums in ( self.workflow_db_mgr.pri_dao.select_task_outputs(task, cycle) ).items(): # loop through matching tasks + # (if task_flow_nums is empty, it means the 'none' flow) if flow_nums.intersection(task_flow_nums): - # this task is in the right flow # BACK COMPAT: In Cylc >8.0.0,<8.3.0, only the task # messages were stored in the DB as a list. # from: 8.0.0 @@ -709,7 +726,7 @@ def rh_release_and_queue(self, itask) -> None: """ if itask.state_reset(is_runahead=False): self.data_store_mgr.delta_task_runahead(itask) - if all(itask.is_ready_to_run()): + if itask.is_ready_to_run(): # (otherwise waiting on xtriggers etc.) self.queue_task(itask) @@ -736,9 +753,7 @@ def get_or_spawn_task( It does not add a spawned task proxy to the pool. """ - ntask = self._get_task_by_id( - Tokens(cycle=str(point), task=tdef.name).relative_id - ) + ntask = self.get_task(point, tdef.name) is_in_pool = False is_xtrig_sequential = False if ntask is None: @@ -817,7 +832,7 @@ def spawn_if_parentless(self, tdef, point, flow_nums): if ntask is not None and not is_in_pool: self.add_to_pool(ntask) - def remove(self, itask, reason=None): + def remove(self, itask: 'TaskProxy', reason: Optional[str] = None) -> None: """Remove a task from the pool.""" if itask.state.is_runahead and itask.flow_nums: @@ -829,11 +844,7 @@ def remove(self, itask, reason=None): itask.flow_nums ) - msg = "removed from active task pool" - if reason is None: - msg += ": completed" - else: - msg += f": {reason}" + msg = f"removed from active task pool: {reason or 'completed'}" if itask.is_xtrigger_sequential: self.xtrigger_mgr.sequential_spawn_next.discard(itask.identity) @@ -887,40 +898,53 @@ def get_tasks(self) -> List[TaskProxy]: # Cached list only for use internally in this method. if self.active_tasks_changed: self.active_tasks_changed = False - self._active_tasks_list = [] - for itask_id_map in self.active_tasks.values(): - for itask in itask_id_map.values(): - self._active_tasks_list.append(itask) + self._active_tasks_list = [ + itask + for itask_id_map in self.active_tasks.values() + for itask in itask_id_map.values() + ] return self._active_tasks_list def get_tasks_by_point(self) -> 'Dict[PointBase, List[TaskProxy]]': """Return a map of task proxies by cycle point.""" - point_itasks = {} - for point, itask_id_map in self.active_tasks.items(): - point_itasks[point] = list(itask_id_map.values()) - return point_itasks + return { + point: list(itask_id_map.values()) + for point, itask_id_map in self.active_tasks.items() + } def get_task(self, point: 'PointBase', name: str) -> Optional[TaskProxy]: """Retrieve a task from the pool.""" rel_id = f'{point}/{name}' tasks = self.active_tasks.get(point) - if tasks and rel_id in tasks: - return tasks[rel_id] + if tasks: + return tasks.get(rel_id) return None def _get_task_by_id(self, id_: str) -> Optional[TaskProxy]: """Return pool task by ID if it exists, or None.""" for itask_ids in self.active_tasks.values(): - with suppress(KeyError): + if id_ in itask_ids: return itask_ids[id_] return None def queue_task(self, itask: TaskProxy) -> None: - """Queue a task that is ready to run.""" + """Queue a task that is ready to run. + + If it is already queued, do nothing. + """ if itask.state_reset(is_queued=True): self.data_store_mgr.delta_task_queued(itask) self.task_queue_mgr.push_task(itask) + def unqueue_task(self, itask: TaskProxy) -> None: + """Un-queue a task that is no longer ready to run. + + If it is not queued, do nothing. + """ + if itask.state_reset(is_queued=False): + self.data_store_mgr.delta_task_queued(itask) + self.task_queue_mgr.remove_task(itask) + def release_queued_tasks(self): """Return list of queue-released tasks awaiting job prep. @@ -1100,8 +1124,7 @@ def _reload_taskdefs(self) -> None: if itask.state.is_queued: # Already queued continue - ready_check_items = itask.is_ready_to_run() - if all(ready_check_items) and not itask.state.is_runahead: + if itask.is_ready_to_run() and not itask.state.is_runahead: self.queue_task(itask) def set_stop_point(self, stop_point: 'PointBase') -> bool: @@ -1243,7 +1266,7 @@ def log_unsatisfied_prereqs(self) -> bool: LOG.warning( "Partially satisfied prerequisites:\n" + "\n".join( - f" * {id_} is waiting on {others}" + f" * {id_} is waiting on {sorted(others)}" for id_, others in unsat.items() ) ) @@ -1266,7 +1289,7 @@ def is_stalled(self) -> bool: itask.state(TASK_STATUS_WAITING) and not itask.state.is_runahead # (avoid waiting pre-spawned absolute-triggered tasks:) - and not itask.is_task_prereqs_not_done() + and itask.prereqs_are_satisfied() ) for itask in self.get_tasks() ): return False @@ -1284,7 +1307,7 @@ def hold_active_task(self, itask: TaskProxy) -> None: def release_held_active_task(self, itask: TaskProxy) -> None: if itask.state_reset(is_held=False): self.data_store_mgr.delta_task_held(itask) - if (not itask.state.is_runahead) and all(itask.is_ready_to_run()): + if (not itask.state.is_runahead) and itask.is_ready_to_run(): self.queue_task(itask) self.tasks_to_hold.discard((itask.tdef.name, itask.point)) self.workflow_db_mgr.put_tasks_to_hold(self.tasks_to_hold) @@ -1302,8 +1325,8 @@ def hold_tasks(self, items: Iterable[str]) -> int: # Hold active tasks: itasks, inactive_tasks, unmatched = self.filter_task_proxies( items, - warn=False, - future=True, + warn_no_active=False, + inactive=True, ) for itask in itasks: self.hold_active_task(itask) @@ -1320,8 +1343,8 @@ def release_held_tasks(self, items: Iterable[str]) -> int: # Release active tasks: itasks, inactive_tasks, unmatched = self.filter_task_proxies( items, - warn=False, - future=True, + warn_no_active=False, + inactive=True, ) for itask in itasks: self.release_held_active_task(itask) @@ -1400,12 +1423,7 @@ def spawn_on_output(self, itask: TaskProxy, output: str) -> None: str(itask.point), itask.tdef.name, output) self.workflow_db_mgr.process_queued_ops() - c_taskid = Tokens( - cycle=str(c_point), - task=c_name, - ).relative_id - - c_task = self._get_task_by_id(c_taskid) + c_task = self._get_task_by_id(quick_relative_id(c_point, c_name)) in_pool = c_task is not None if c_task is not None and c_task != itask: @@ -1422,7 +1440,7 @@ def spawn_on_output(self, itask: TaskProxy, output: str) -> None: if is_abs: tasks, *_ = self.filter_task_proxies( [f'*/{c_name}'], - warn=False, + warn_no_active=False, ) if c_task not in tasks: tasks.append(c_task) @@ -1650,6 +1668,7 @@ def _load_historical_outputs(self, itask: 'TaskProxy') -> None: else: flow_seen = False for outputs_str, fnums in info.items(): + # (if fnums is empty, it means the 'none' flow) if itask.flow_nums.intersection(fnums): # DB row has overlap with itask's flows flow_seen = True @@ -1717,12 +1736,10 @@ def spawn_task( and not itask.state.outputs.get_completed_outputs() ): # If itask has any history in this flow but no completed outputs - # we can infer it was deliberately removed, so don't respawn it. + # we can infer it has just been deliberately removed (N.B. not + # by `cylc remove`), so don't immediately respawn it. # TODO (follow-up work): # - this logic fails if task removed after some outputs completed - # - this is does not conform to future "cylc remove" flow-erasure - # behaviour which would result in respawning of the removed task - # See github.com/cylc/cylc-flow/pull/6186/#discussion_r1669727292 LOG.debug(f"Not respawning {point}/{name} - task was removed") return None @@ -1737,7 +1754,7 @@ def spawn_task( msg += " incomplete" LOG.info( - f"{msg} {stringify_flow_nums(flow_nums, full=True)})" + f"{msg} {repr_flow_nums(flow_nums, full=True)})" ) if prev_flow_wait: self._spawn_after_flow_wait(itask) @@ -1822,9 +1839,6 @@ def _get_task_proxy_db_outputs( self.xtrigger_mgr.xtriggers.sequential_xtrigger_labels ), ) - if itask is None: - return None - # Update it with outputs that were already completed. self._load_historical_outputs(itask) return itask @@ -1929,8 +1943,8 @@ def set_prereqs_and_outputs( # Get matching pool tasks and inactive task definitions. itasks, inactive_tasks, unmatched = self.filter_task_proxies( items, - future=True, - warn=False, + inactive=True, + warn_no_active=False, ) flow_nums = self._get_flow_nums(flow, flow_descr) @@ -1940,7 +1954,7 @@ def set_prereqs_and_outputs( if flow == ['none'] and itask.flow_nums != set(): LOG.error( f"[{itask}] ignoring 'flow=none' set: task already has" - f" {stringify_flow_nums(itask.flow_nums, full=True)}" + f" {repr_flow_nums(itask.flow_nums, full=True)}" ) continue self.merge_flows(itask, flow_nums) @@ -2023,7 +2037,7 @@ def _set_prereqs_itask( # Attempt to set the given presrequisites. # Log any that aren't valid for the task. presus = self._standardise_prereqs(prereqs) - unmatched = itask.satisfy_me(presus.keys()) + unmatched = itask.satisfy_me(presus.keys(), forced=True) for task_msg in unmatched: LOG.warning( f"{itask.identity} does not depend on" @@ -2068,16 +2082,128 @@ def _get_active_flow_nums(self) -> 'FlowNums': or {1} ) - def remove_tasks(self, items): - """Remove tasks from the pool (forced by command).""" - itasks, _, bad_items = self.filter_task_proxies(items) - for itask in itasks: - # Spawn next occurrence of xtrigger sequential task. - self.check_spawn_psx_task(itask) - self.remove(itask, 'request') - if self.compute_runahead(): + def remove_tasks( + self, items: Iterable[str], flow_nums: Optional['FlowNums'] = None + ) -> None: + """Remove tasks from the pool (forced by command). + + Args: + items: Relative IDs or globs. + flow_nums: Flows to remove the tasks from. If empty or None, it + means 'all'. + """ + active, inactive, _unmatched = self.filter_task_proxies( + items, warn_no_active=False, inactive=True + ) + if not (active or inactive): + return + + if flow_nums is None: + flow_nums = set() + # Mapping of task IDs to removed flow numbers: + removed: Dict[str, FlowNums] = {} + not_removed: Set[str] = set() + + for itask in active: + fnums_to_remove = itask.match_flows(flow_nums) + if not fnums_to_remove: + not_removed.add(itask.identity) + continue + removed[itask.identity] = fnums_to_remove + if fnums_to_remove == itask.flow_nums: + # Need to remove the task from the pool. + # Spawn next occurrence of xtrigger sequential task (otherwise + # this would not happen after removing this occurrence): + self.check_spawn_psx_task(itask) + self.remove(itask, 'request') + else: + itask.flow_nums.difference_update(fnums_to_remove) + + matched_task_ids = { + *removed.keys(), + *(quick_relative_id(cycle, task) for task, cycle in inactive), + } + + for id_ in matched_task_ids: + point_str, name = id_.split('/', 1) + tdef = self.config.taskdefs[name] + # Go through downstream tasks to see if any need to stand down + # as a result of this task being removed: + for child in set(itertools.chain.from_iterable( + generate_graph_children(tdef, get_point(point_str)).values() + )): + child_itask = self.get_task(child.point, child.name) + if not child_itask: + continue + fnums_to_remove = child_itask.match_flows(flow_nums) + if not fnums_to_remove: + continue + prereqs_changed = False + for prereq in ( + *child_itask.state.prerequisites, + *child_itask.state.suicide_prerequisites, + ): + # Unset any prereqs naturally satisfied by these tasks + # (do not unset those satisfied by `cylc set --pre`): + if prereq.unset_naturally_satisfied_dependency(id_): + prereqs_changed = True + removed.setdefault(id_, set()).update(fnums_to_remove) + if not prereqs_changed: + continue + self.data_store_mgr.delta_task_prerequisite(child_itask) + # Check if downstream task is still ready to run: + if ( + child_itask.state.is_gte(TASK_STATUS_PREPARING) + # Still ready if the task exists in other flows: + or child_itask.flow_nums != fnums_to_remove + or child_itask.state.prerequisites_all_satisfied() + ): + continue + # No longer ready to run + self.unqueue_task(child_itask) + # Check if downstream task should remain spawned: + if ( + # Ignoring tasks we are already dealing with: + child_itask.identity in matched_task_ids + or child_itask.state.any_satisfied_prerequisite_tasks() + ): + continue + # No longer has reason to be in pool: + self.remove(child_itask, 'upstream task(s) removed') + # Remove from DB tables to ensure it is not skipped if it + # respawns in future: + self.workflow_db_mgr.remove_task_from_flows( + str(child.point), child.name, fnums_to_remove + ) + + # Remove from DB tables: + db_removed_fnums = self.workflow_db_mgr.remove_task_from_flows( + point_str, name, flow_nums + ) + if db_removed_fnums: + removed.setdefault(id_, set()).update(db_removed_fnums) + + if removed: + tasks_str_list = [] + for task, fnums in removed.items(): + self.data_store_mgr.delta_remove_task_flow_nums(task, fnums) + tasks_str_list.append( + f"{task} {repr_flow_nums(fnums, full=True)}" + ) + LOG.info(f"Removed task(s): {', '.join(sorted(tasks_str_list))}") + + not_removed.update(matched_task_ids.difference(removed)) + if not_removed: + fnums_str = ( + repr_flow_nums(flow_nums, full=True) if flow_nums else '' + ) + LOG.warning( + "Task(s) not removable: " + f"{', '.join(sorted(not_removed))} {fnums_str}" + ) + + if removed and self.compute_runahead(): self.release_runahead_tasks() - return len(bad_items) def _get_flow_nums( self, @@ -2153,8 +2279,8 @@ def force_trigger_tasks( """ # Get matching tasks proxies, and matching inactive task IDs. - existing_tasks, future_ids, unmatched = self.filter_task_proxies( - items, future=True, warn=False, + existing_tasks, inactive_ids, unmatched = self.filter_task_proxies( + items, inactive=True, warn_no_active=False, ) flow_nums = self._get_flow_nums(flow, flow_descr) @@ -2164,7 +2290,7 @@ def force_trigger_tasks( if flow == ['none'] and itask.flow_nums != set(): LOG.error( f"[{itask}] ignoring 'flow=none' trigger: task already has" - f" {stringify_flow_nums(itask.flow_nums, full=True)}" + f" {repr_flow_nums(itask.flow_nums, full=True)}" ) continue if itask.state(TASK_STATUS_PREPARING, *TASK_STATUSES_ACTIVE): @@ -2177,7 +2303,7 @@ def force_trigger_tasks( if not flow: # default: assign to all active flows flow_nums = self._get_active_flow_nums() - for name, point in future_ids: + for name, point in inactive_ids: if not self.can_be_spawned(name, point): continue submit_num, _, prev_fwait = ( @@ -2303,41 +2429,41 @@ def log_task_pool(self, log_lvl=logging.DEBUG): def filter_task_proxies( self, ids: Iterable[str], - warn: bool = True, - future: bool = False, + warn_no_active: bool = True, + inactive: bool = False, ) -> 'Tuple[List[TaskProxy], Set[Tuple[str, PointBase]], List[str]]': """Return task proxies that match names, points, states in items. Args: ids: ID strings. - warn: - Whether to log a warning if no matching tasks are found. - future: + warn_no_active: + Whether to log a warning if no matching active tasks are found. + inactive: If True, unmatched IDs will be checked against taskdefs - and cycle, task pairs will be provided in the future_matched - argument providing the ID + and cycle, and any matches will be returned in the second + return value, provided that the ID: * Specifies a cycle point. * Is not a pattern. (e.g. `*/foo`). * Does not contain a state selector (e.g. `:failed`). Returns: - (matched, future_matched, unmatched) + (matched, inactive_matched, unmatched) """ matched, unmatched = filter_ids( self.active_tasks, ids, - warn=warn, + warn=warn_no_active, ) - future_matched: 'Set[Tuple[str, PointBase]]' = set() - if future and unmatched: - future_matched, unmatched = self.match_inactive_tasks( + inactive_matched: 'Set[Tuple[str, PointBase]]' = set() + if inactive and unmatched: + inactive_matched, unmatched = self.match_inactive_tasks( unmatched ) - return matched, future_matched, unmatched + return matched, inactive_matched, unmatched def match_inactive_tasks( self, diff --git a/cylc/flow/task_proxy.py b/cylc/flow/task_proxy.py index 73d819c2fce..a6e7d8a1c7c 100644 --- a/cylc/flow/task_proxy.py +++ b/cylc/flow/task_proxy.py @@ -30,32 +30,35 @@ List, Optional, Set, - Tuple, ) from metomi.isodatetime.timezone import get_local_time_zone from cylc.flow import LOG -from cylc.flow.flow_mgr import stringify_flow_nums +from cylc.flow.cycling.iso8601 import ( + interval_parse, + point_parse, +) +from cylc.flow.flow_mgr import repr_flow_nums from cylc.flow.platforms import get_platform from cylc.flow.task_action_timer import TimerFlags from cylc.flow.task_state import ( - TaskState, - TASK_STATUS_WAITING, TASK_STATUS_EXPIRED, + TASK_STATUS_WAITING, + TaskState, ) from cylc.flow.taskdef import generate_graph_children from cylc.flow.wallclock import get_unix_time_from_time_string as str2time -from cylc.flow.cycling.iso8601 import ( - point_parse, - interval_parse, -) + if TYPE_CHECKING: from cylc.flow.cycling import PointBase from cylc.flow.flow_mgr import FlowNums from cylc.flow.id import Tokens - from cylc.flow.prerequisite import PrereqMessage, SatisfiedState + from cylc.flow.prerequisite import ( + PrereqMessage, + SatisfiedState, + ) from cylc.flow.simulation import ModeSettings from cylc.flow.task_action_timer import TaskActionTimer from cylc.flow.taskdef import TaskDef @@ -148,7 +151,7 @@ class TaskProxy: .graph_children (dict) graph children: {msg: [(name, point), ...]} .flow_nums: - flows I belong to + flows I belong to (if empty, belongs to 'none' flow) flow_wait: wait for flow merge before spawning children .waiting_on_job_prep: @@ -211,7 +214,7 @@ def __init__( scheduler_tokens: 'Tokens', tdef: 'TaskDef', start_point: 'PointBase', - flow_nums: Optional[Set[int]] = None, + flow_nums: Optional['FlowNums'] = None, status: str = TASK_STATUS_WAITING, is_held: bool = False, submit_num: int = 0, @@ -306,7 +309,7 @@ def __init__( ) def __repr__(self) -> str: - return f"<{self.__class__.__name__} '{self.tokens}'>" + return f"<{self.__class__.__name__} {self.identity}>" def __str__(self) -> str: """Stringify with tokens, state, submit_num, and flow_nums. @@ -317,11 +320,11 @@ def __str__(self) -> str: """ id_ = self.identity if self.transient: - return f"{id_}{stringify_flow_nums(self.flow_nums)}" + return f"{id_}{repr_flow_nums(self.flow_nums)}" if not self.state(TASK_STATUS_WAITING, TASK_STATUS_EXPIRED): id_ += f"/{self.submit_num:02d}" return ( - f"{id_}{stringify_flow_nums(self.flow_nums)}:{self.state}" + f"{id_}{repr_flow_nums(self.flow_nums)}:{self.state}" ) def copy_to_reload_successor( @@ -454,7 +457,7 @@ def next_point(self): """Return the next cycle point.""" return self.tdef.next_point(self.point) - def is_ready_to_run(self) -> Tuple[bool, ...]: + def is_ready_to_run(self) -> bool: """Is this task ready to run? Takes account of all dependence: on other tasks, xtriggers, and @@ -463,16 +466,18 @@ def is_ready_to_run(self) -> Tuple[bool, ...]: """ if self.is_manual_submit: # Manually triggered, ignore unsatisfied prerequisites. - return (True,) + return True if self.state.is_held: # A held task is not ready to run. - return (False,) + return False if self.state.status in self.try_timers: # A try timer is still active. - return (self.try_timers[self.state.status].is_delay_done(),) + return self.try_timers[self.state.status].is_delay_done() return ( - self.state(TASK_STATUS_WAITING), - self.is_waiting_prereqs_done() + self.state(TASK_STATUS_WAITING) + and self.prereqs_are_satisfied() + and self.state.external_triggers_all_satisfied() + and self.state.xtriggers_all_satisfied() ) def set_summary_time(self, event_key, time_str=None): @@ -486,18 +491,9 @@ def set_summary_time(self, event_key, time_str=None): self.summary[event_key + '_time'] = float(str2time(time_str)) self.summary[event_key + '_time_string'] = time_str - def is_task_prereqs_not_done(self): - """Are some task prerequisites not satisfied?""" - return (not all(pre.is_satisfied() - for pre in self.state.prerequisites)) - - def is_waiting_prereqs_done(self): - """Are ALL prerequisites satisfied?""" - return ( - all(pre.is_satisfied() for pre in self.state.prerequisites) - and self.state.external_triggers_all_satisfied() - and self.state.xtriggers_all_satisfied() - ) + def prereqs_are_satisfied(self) -> bool: + """Are all task prerequisites satisfied?""" + return all(pre.is_satisfied() for pre in self.state.prerequisites) def reset_try_timers(self): # unset any retry delay timers @@ -522,6 +518,17 @@ def name_match( match_func(ns, value) for ns in self.tdef.namespace_hierarchy ) + def match_flows(self, flow_nums: 'FlowNums') -> 'FlowNums': + """Return which of the given flow numbers the task belongs to. + + NOTE: If `flow_nums` is empty, it means 'all', whereas + if `self.flow_nums` is empty, it means this task is in the 'none' flow + and will not match. + """ + if not flow_nums or not self.flow_nums: + return self.flow_nums + return self.flow_nums.intersection(flow_nums) + def merge_flows(self, flow_nums: Set) -> None: """Merge another set of flow_nums with mine.""" self.flow_nums.update(flow_nums) @@ -553,7 +560,7 @@ def state_reset( return False def satisfy_me( - self, task_messages: 'Iterable[Tokens]' + self, task_messages: 'Iterable[Tokens]', forced: bool = False ) -> 'Set[Tokens]': """Try to satisfy my prerequisites with given output messages. @@ -563,7 +570,7 @@ def satisfy_me( Return a set of unmatched task messages. """ - used = self.state.satisfy_me(task_messages) + used = self.state.satisfy_me(task_messages, forced) return set(task_messages) - used def clock_expire(self) -> bool: diff --git a/cylc/flow/task_state.py b/cylc/flow/task_state.py index 9ecd9414d17..7c1c1e69753 100644 --- a/cylc/flow/task_state.py +++ b/cylc/flow/task_state.py @@ -324,7 +324,8 @@ def __call__( def satisfy_me( self, - outputs: Iterable['Tokens'] + outputs: Iterable['Tokens'], + forced: bool = False, ) -> Set['Tokens']: """Try to satisfy my prerequisites with given outputs. @@ -333,7 +334,7 @@ def satisfy_me( valid: Set[Tokens] = set() for prereq in (*self.prerequisites, *self.suicide_prerequisites): valid.update( - prereq.satisfy_me(outputs) + prereq.satisfy_me(outputs, forced) ) return valid @@ -393,7 +394,7 @@ def get_resolved_dependencies(self): return sorted( dep for prereq in self.prerequisites - for dep in prereq.get_resolved_dependencies() + for dep in prereq.get_satisfied_dependencies() ) def reset( @@ -527,3 +528,12 @@ def get_unsatisfied_prerequisites(self) -> List['PrereqMessage']: for prereq in self.prerequisites if not prereq.is_satisfied() for key, satisfied in prereq.items() if not satisfied ] + + def any_satisfied_prerequisite_tasks(self) -> bool: + """Return True if any of this task's prerequisite tasks are + satisfied.""" + return any( + satisfied + for prereq in self.prerequisites + for satisfied in prereq._satisfied.values() + ) diff --git a/cylc/flow/util.py b/cylc/flow/util.py index b7e1e0e0c73..b649265dbd9 100644 --- a/cylc/flow/util.py +++ b/cylc/flow/util.py @@ -17,7 +17,10 @@ import ast from contextlib import suppress -from functools import partial +from functools import ( + lru_cache, + partial, +) import json import re from textwrap import dedent @@ -31,6 +34,7 @@ Tuple, ) + BOOL_SYMBOLS: Dict[bool, str] = { # U+2A2F (vector cross product) False: '⨯', @@ -163,15 +167,23 @@ def serialise_set(flow_nums: Optional[set] = None) -> str: '[]' """ - return json.dumps(sorted(flow_nums or ())) + return _serialise_set(tuple(sorted(flow_nums or ()))) + + +@lru_cache(maxsize=100) +def _serialise_set(flow_nums: tuple) -> str: + return json.dumps(flow_nums) +@lru_cache(maxsize=100) def deserialise_set(flow_num_str: str) -> set: """Convert json string to set. Example: - >>> sorted(deserialise_set('[2, 3]')) - [2, 3] + >>> deserialise_set('[2, 3]') == {2, 3} + True + >>> deserialise_set('[]') + set() """ return set(json.loads(flow_num_str)) diff --git a/cylc/flow/workflow_db_mgr.py b/cylc/flow/workflow_db_mgr.py index 5f1f673440c..ebe176b85ac 100644 --- a/cylc/flow/workflow_db_mgr.py +++ b/cylc/flow/workflow_db_mgr.py @@ -23,39 +23,66 @@ * Manage existing run database files on restart. """ +from collections import defaultdict import json import os -from shutil import copy, rmtree +from shutil import ( + copy, + rmtree, +) from sqlite3 import OperationalError from tempfile import mkstemp from typing import ( - Any, AnyStr, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union + TYPE_CHECKING, + Any, + AnyStr, + DefaultDict, + Dict, + List, + Optional, + Set, + Tuple, + Union, ) from packaging.version import parse as parse_version -from cylc.flow import LOG +from cylc.flow import ( + LOG, + __version__ as CYLC_VERSION, +) from cylc.flow.broadcast_report import get_broadcast_change_iter +from cylc.flow.exceptions import ( + CylcError, + ServiceFileError, +) from cylc.flow.rundb import CylcWorkflowDAO -from cylc.flow import __version__ as CYLC_VERSION -from cylc.flow.wallclock import get_current_time_string, get_utc_mode -from cylc.flow.exceptions import CylcError, ServiceFileError -from cylc.flow.util import serialise_set, deserialise_set +from cylc.flow.util import ( + deserialise_set, + serialise_set, +) +from cylc.flow.wallclock import ( + get_current_time_string, + get_utc_mode, +) + if TYPE_CHECKING: from pathlib import Path + + from packaging.version import Version + from cylc.flow.cycling import PointBase + from cylc.flow.flow_mgr import FlowNums + from cylc.flow.rundb import ( + DbArgDict, + DbUpdateTuple, + ) from cylc.flow.scheduler import Scheduler - from cylc.flow.task_pool import TaskPool from cylc.flow.task_events_mgr import EventKey + from cylc.flow.task_pool import TaskPool from cylc.flow.task_proxy import TaskProxy -Version = Any -# TODO: narrow down Any (should be str | int) after implementing type -# annotations in cylc.flow.task_state.TaskState -DbArgDict = Dict[str, Any] -DbUpdateTuple = Tuple[DbArgDict, DbArgDict] - PERM_PRIVATE = 0o600 # -rw------- @@ -141,7 +168,9 @@ def __init__(self, pri_d=None, pub_d=None): self.TABLE_TASKS_TO_HOLD: [], self.TABLE_XTRIGGERS: [], self.TABLE_ABS_OUTPUTS: []} - self.db_updates_map: Dict[str, List[DbUpdateTuple]] = {} + self.db_updates_map: DefaultDict[ + str, List[DbUpdateTuple] + ] = defaultdict(list) def copy_pri_to_pub(self) -> None: """Copy content of primary database file to public database file.""" @@ -232,29 +261,23 @@ def process_queued_ops(self) -> None: # Record workflow parameters and tasks in pool # Record any broadcast settings to be dumped out if any(self.db_deletes_map.values()): - for table_name, db_deletes in sorted( - self.db_deletes_map.items()): + for table_name, db_deletes in sorted(self.db_deletes_map.items()): while db_deletes: where_args = db_deletes.pop(0) self.pri_dao.add_delete_item(table_name, where_args) self.pub_dao.add_delete_item(table_name, where_args) if any(self.db_inserts_map.values()): - for table_name, db_inserts in sorted( - self.db_inserts_map.items()): + for table_name, db_inserts in sorted(self.db_inserts_map.items()): while db_inserts: db_insert = db_inserts.pop(0) self.pri_dao.add_insert_item(table_name, db_insert) self.pub_dao.add_insert_item(table_name, db_insert) - if (hasattr(self, 'db_updates_map') and - any(self.db_updates_map.values())): - for table_name, db_updates in sorted( - self.db_updates_map.items()): + if any(self.db_updates_map.values()): + for table_name, db_updates in sorted(self.db_updates_map.items()): while db_updates: - set_args, where_args = db_updates.pop(0) - self.pri_dao.add_update_item( - table_name, set_args, where_args) - self.pub_dao.add_update_item( - table_name, set_args, where_args) + db_update = db_updates.pop(0) + self.pri_dao.add_update_item(table_name, db_update) + self.pub_dao.add_update_item(table_name, db_update) # Previously, we used a separate thread for database writes. This has # now been removed. For the private database, there is no real @@ -422,7 +445,7 @@ def put_xtriggers(self, sat_xtrig): "signature": sig, "results": json.dumps(res)}) - def put_update_task_state(self, itask): + def put_update_task_state(self, itask: 'TaskProxy') -> None: """Update task_states table for current state of itask. NOTE the task_states table is normally updated along with the task pool @@ -443,9 +466,9 @@ def put_update_task_state(self, itask): "name": itask.tdef.name, "flow_nums": serialise_set(itask.flow_nums), } - self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( - (set_args, where_args)) + (set_args, where_args) + ) def put_update_task_flow_wait(self, itask): """Update flow_wait status of a task, in the task_states table. @@ -463,7 +486,6 @@ def put_update_task_flow_wait(self, itask): "name": itask.tdef.name, "flow_nums": serialise_set(itask.flow_nums), } - self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( (set_args, where_args)) @@ -491,7 +513,6 @@ def put_task_pool(self, pool: 'TaskPool') -> None: prereq.items() ): self.put_insert_task_prerequisites(itask, { - "flow_nums": serialise_set(itask.flow_nums), "prereq_name": p_name, "prereq_cycle": p_cycle, "prereq_output": p_output, @@ -547,7 +568,6 @@ def put_task_pool(self, pool: 'TaskPool') -> None: "name": itask.tdef.name, "flow_nums": serialise_set(itask.flow_nums) } - self.db_updates_map.setdefault(self.TABLE_TASK_STATES, []) self.db_updates_map[self.TABLE_TASK_STATES].append( (set_args, where_args) ) @@ -581,12 +601,25 @@ def put_insert_task_jobs(self, itask, args): """Put INSERT statement for task_jobs table.""" self._put_insert_task_x(CylcWorkflowDAO.TABLE_TASK_JOBS, itask, args) - def put_insert_task_states(self, itask, args): + def put_insert_task_states(self, itask: 'TaskProxy') -> None: """Put INSERT statement for task_states table.""" - self._put_insert_task_x(CylcWorkflowDAO.TABLE_TASK_STATES, itask, args) + now = get_current_time_string() + self._put_insert_task_x( + CylcWorkflowDAO.TABLE_TASK_STATES, + itask, + { + "time_created": now, + "time_updated": now, + "status": itask.state.status, + "flow_nums": serialise_set(itask.flow_nums), + "flow_wait": itask.flow_wait, + "is_manual_submit": itask.is_manual_submit, + }, + ) def put_insert_task_prerequisites(self, itask, args): """Put INSERT statement for task_prerequisites table.""" + args.setdefault("flow_nums", serialise_set(itask.flow_nums)) self._put_insert_task_x(self.TABLE_TASK_PREREQUISITES, itask, args) def put_insert_task_outputs(self, itask): @@ -623,15 +656,16 @@ def put_insert_workflow_flows(self, flow_num, flow_metadata): } ) - def _put_insert_task_x(self, table_name, itask, args): + def _put_insert_task_x( + self, table_name: str, itask: 'TaskProxy', args: 'DbArgDict' + ) -> None: """Put INSERT statement for a task_* table.""" args.update({ "name": itask.tdef.name, - "cycle": str(itask.point)}) - if "submit_num" not in args: - args["submit_num"] = itask.submit_num - self.db_inserts_map.setdefault(table_name, []) - self.db_inserts_map[table_name].append(args) + "cycle": str(itask.point), + }) + args.setdefault("submit_num", itask.submit_num) + self.db_inserts_map.setdefault(table_name, []).append(args) def put_update_task_jobs(self, itask, set_args): """Put UPDATE statement for task_jobs table.""" @@ -650,22 +684,107 @@ def put_update_task_outputs(self, itask: 'TaskProxy') -> None: "name": itask.tdef.name, "flow_nums": serialise_set(itask.flow_nums), } - self.db_updates_map.setdefault(self.TABLE_TASK_OUTPUTS, []).append( + self.db_updates_map[self.TABLE_TASK_OUTPUTS].append( (set_args, where_args) ) - def _put_update_task_x(self, table_name, itask, set_args): + def _put_update_task_x( + self, table_name: str, itask: 'TaskProxy', set_args: 'DbArgDict' + ) -> None: """Put UPDATE statement for a task_* table.""" where_args = { "cycle": str(itask.point), - "name": itask.tdef.name} + "name": itask.tdef.name, + } if "submit_num" not in set_args: where_args["submit_num"] = itask.submit_num if "flow_nums" not in set_args: where_args["flow_nums"] = serialise_set(itask.flow_nums) - self.db_updates_map.setdefault(table_name, []) self.db_updates_map[table_name].append((set_args, where_args)) + def remove_task_from_flows( + self, point: str, name: str, flow_nums: 'FlowNums' + ) -> 'FlowNums': + """Remove flow numbers for a task in the task_states and task_outputs + tables. + + Args: + point: Cycle point of the task. + name: Name of the task. + flow_nums: Flow numbers to remove. If empty, remove all + flow numbers. + + Returns the flow numbers that were removed, if any. + + N.B. the task_prerequisites table is automatically updated separately + during the main loop. + """ + removed_flow_nums: FlowNums = set() + for table in ( + self.TABLE_TASK_STATES, + self.TABLE_TASK_OUTPUTS, + ): + fnums_select_stmt = rf''' + SELECT + flow_nums + FROM + {table} + WHERE + cycle = ? + AND name = ? + ''' # nosec B608 (table name is a code constant) + fnums_select_cursor = self.pri_dao.connect().execute( + fnums_select_stmt, (point, name) + ) + + if not flow_nums: + for db_fnums_str, *_ in fnums_select_cursor: + removed_flow_nums.update(deserialise_set(db_fnums_str)) + + stmt = rf''' + UPDATE OR REPLACE + {table} + SET + flow_nums = ? + WHERE + cycle = ? + AND name = ? + ''' # nosec B608 (table name is a code constant) + params: List[tuple] = [(serialise_set(), point, name)] + else: + # Mapping of existing flow nums to what should be left after + # removing the specified flow nums: + flow_nums_map: Dict[str, FlowNums] = {} + for db_fnums_str, *_ in fnums_select_cursor: + db_fnums: FlowNums = deserialise_set(db_fnums_str) + fnums_to_remove = db_fnums.intersection(flow_nums) + if fnums_to_remove: + flow_nums_map[db_fnums_str] = db_fnums.difference( + flow_nums + ) + removed_flow_nums.update(fnums_to_remove) + + stmt = rf''' + UPDATE OR REPLACE + {table} + SET + flow_nums = ? + WHERE + cycle = ? + AND name = ? + AND flow_nums = ? + ''' # nosec B608 (table name is a code constant) + params = [ + (serialise_set(new), point, name, old) + for old, new in flow_nums_map.items() + ] + + self.db_updates_map[table].append( + (stmt, params) + ) + + return removed_flow_nums + def recover_pub_from_pri(self): """Recover public database from private database.""" if self.pub_dao.n_tries >= self.pub_dao.MAX_TRIES: @@ -690,7 +809,7 @@ def restart_check(self) -> None: self.process_queued_ops() @classmethod - def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version: + def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> 'Version': """Return the version of Cylc this DB was last run with. Args: @@ -706,7 +825,7 @@ def _get_last_run_version(cls, pri_dao: CylcWorkflowDAO) -> Version: {cls.TABLE_WORKFLOW_PARAMS} WHERE key == ? - ''', # nosec (table name is a code constant) + ''', # nosec B608 (table name is a code constant) [cls.KEY_CYLC_VERSION] ).fetchone()[0] except (TypeError, OperationalError) as exc: @@ -781,7 +900,7 @@ def upgrade(cls, db_file: Union['Path', str]) -> None: cls.upgrade_pre_810(pri_dao) @classmethod - def check_db_compatibility(cls, db_file: Union['Path', str]) -> Version: + def check_db_compatibility(cls, db_file: Union['Path', str]) -> 'Version': """Check this DB is compatible with this Cylc version. Raises: diff --git a/tests/conftest.py b/tests/conftest.py index d07f788ba0b..8e6b988d5c2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -99,31 +99,30 @@ def _inner(cached=False): @pytest.fixture -def log_filter(): - """Filter caplog record_tuples. +def log_filter(caplog: pytest.LogCaptureFixture): + """Filter caplog record_tuples (also discarding the log name entry). Args: - log: The caplog instance. - name: Filter out records if they don't match this logger name. level: Filter out records if they aren't at this logging level. contains: Filter out records if this string is not in the message. regex: Filter out records if the message doesn't match this regex. exact_match: Filter out records if the message does not exactly match this string. + log: A caplog instance. """ def _log_filter( - log: pytest.LogCaptureFixture, - name: Optional[str] = None, level: Optional[int] = None, contains: Optional[str] = None, regex: Optional[str] = None, exact_match: Optional[str] = None, - ) -> List[Tuple[str, int, str]]: + log: Optional[pytest.LogCaptureFixture] = None + ) -> List[Tuple[int, str]]: + if log is None: + log = caplog return [ - (log_name, log_level, log_message) - for log_name, log_level, log_message in log.record_tuples - if (name is None or name == log_name) - and (level is None or level == log_level) + (log_level, log_message) + for _, log_level, log_message in log.record_tuples + if (level is None or level == log_level) and (contains is None or contains in log_message) and (regex is None or re.search(regex, log_message)) and (exact_match is None or exact_match == log_message) diff --git a/tests/functional/cylc-remove/00-simple/flow.cylc b/tests/functional/cylc-remove/00-simple/flow.cylc index 84c740ad421..0ee53fafed6 100644 --- a/tests/functional/cylc-remove/00-simple/flow.cylc +++ b/tests/functional/cylc-remove/00-simple/flow.cylc @@ -1,13 +1,15 @@ # Abort on stall timeout unless we remove unhandled failed and waiting task. [scheduler] [[events]] - stall timeout = PT20S + stall timeout = PT30S abort on stall timeout = True expected task failures = 1/b [scheduling] [[graph]] - R1 = """a => b => c - cleaner""" + R1 = """ + a => b => c + cleaner + """ [runtime] [[a,c]] script = true @@ -15,10 +17,10 @@ script = false [[cleaner]] script = """ -cylc__job__poll_grep_workflow_log -E '1/b/01:running.* \(received\)failed' -# Remove the unhandled failed task -cylc remove "$CYLC_WORKFLOW_ID//1/b" -# Remove waiting 1/c -# (not auto-removed because parent 1/b, an unhandled fail, is not finished.) -cylc remove "$CYLC_WORKFLOW_ID//1/c:waiting" -""" + cylc__job__poll_grep_workflow_log -E '1/b/01:running.* \(received\)failed' + # Remove the unhandled failed task + cylc remove "$CYLC_WORKFLOW_ID//1/b" + # Remove waiting 1/c + # (not auto-removed because parent 1/b, an unhandled fail, is not finished.) + cylc remove "$CYLC_WORKFLOW_ID//1/c:waiting" + """ diff --git a/tests/functional/cylc-remove/02-cycling/flow.cylc b/tests/functional/cylc-remove/02-cycling/flow.cylc index 3b6c1051493..249f0676cc4 100644 --- a/tests/functional/cylc-remove/02-cycling/flow.cylc +++ b/tests/functional/cylc-remove/02-cycling/flow.cylc @@ -28,18 +28,6 @@ [[foo, waz]] script = true [[bar]] - script = """ - if [[ $CYLC_TASK_CYCLE_POINT == 2020 ]]; then - false - else - true - fi - """ + script = [[ $CYLC_TASK_CYCLE_POINT != 2020 ]] [[baz]] - script = """ - if [[ $CYLC_TASK_CYCLE_POINT == 2021 ]]; then - false - else - true - fi - """ + script = [[ $CYLC_TASK_CYCLE_POINT != 2021 ]] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index bce6ea64e9f..fe4b19ab92b 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -18,21 +18,34 @@ import asyncio from functools import partial from pathlib import Path -import pytest +import re from shutil import rmtree from time import time -from typing import List, TYPE_CHECKING, Set, Tuple, Union +from typing import ( + TYPE_CHECKING, + List, + Set, + Tuple, + Union, +) + +import pytest from cylc.flow.config import WorkflowConfig from cylc.flow.id import Tokens +from cylc.flow.network.client import WorkflowRuntimeClient from cylc.flow.option_parsers import Options from cylc.flow.pathutil import get_cylc_run_dir from cylc.flow.rundb import CylcWorkflowDAO -from cylc.flow.scripts.validate import ValidateOptions from cylc.flow.scripts.install import ( + get_option_parser as install_gop, install as cylc_install, - get_option_parser as install_gop ) +from cylc.flow.scripts.show import ( + ShowOptions, + prereqs_and_outputs_query, +) +from cylc.flow.scripts.validate import ValidateOptions from cylc.flow.util import serialise_set from cylc.flow.wallclock import get_current_time_string from cylc.flow.workflow_files import infer_latest_run_from_id @@ -41,14 +54,14 @@ from .utils import _rm_if_empty from .utils.flow_tools import ( _make_flow, - _make_src_flow, _make_scheduler, + _make_src_flow, _run_flow, _start_flow, ) + if TYPE_CHECKING: - from cylc.flow.network.client import WorkflowRuntimeClient from cylc.flow.scheduler import Scheduler from cylc.flow.task_proxy import TaskProxy @@ -116,7 +129,11 @@ def ses_test_dir(request, run_dir): @pytest.fixture(scope='module') def mod_test_dir(request, ses_test_dir): """The root run dir for test flows in this test module.""" - path = Path(ses_test_dir, request.module.__name__) + path = Path( + ses_test_dir, + # Shorten path by dropping `integration.` prefix: + re.sub(r'^integration\.', '', request.module.__name__) + ) path.mkdir(exist_ok=True) yield path if _pytest_passed(request): @@ -510,6 +527,10 @@ def reflog(): Note, you'll need to call this on the scheduler *after* you have started it. + N.B. Trigger order is not stable; using a set ensures that tests check + trigger logic rather than binding to specific trigger order which could + change in the future, breaking the test. + Args: schd: The scheduler to capture triggering information for. @@ -588,6 +609,9 @@ async def _complete( async_timeout (handles shutdown logic more cleanly). """ + if schd.is_paused: + raise Exception("Cannot wait for completion of a paused scheduler") + start_time = time() tokens_list: List[Tokens] = [] @@ -622,11 +646,16 @@ def _set_stop(mode=None): # determine the completion condition def done(): if wait_tokens: - return not tokens_list + if not tokens_list: + return True + if not schd.contact_data: + raise AssertionError( + "Scheduler shut down before tasks completed: " + + ", ".join(map(str, tokens_list)) + ) + return False # otherwise wait for the scheduler to shut down - if not schd.contact_data: - return True - return stop_requested + return stop_requested or not schd.contact_data with pytest.MonkeyPatch.context() as mp: mp.setattr(schd.pool, 'remove_if_complete', _remove_if_complete) @@ -672,3 +701,23 @@ async def _reftest( return triggers return _reftest + + +@pytest.fixture +def cylc_show(): + """Fixture that runs `cylc show` on a scheduler, returning JSON object.""" + + async def _cylc_show(schd: 'Scheduler', *task_ids: str) -> dict: + pclient = WorkflowRuntimeClient(schd.workflow) + await schd.update_data_structure() + json_filter: dict = {} + await prereqs_and_outputs_query( + schd.id, + [Tokens(id_, relative=True) for id_ in task_ids], + pclient, + ShowOptions(json=True), + json_filter, + ) + return json_filter + + return _cylc_show diff --git a/tests/integration/events/test_task_events.py b/tests/integration/events/test_task_events.py index 3ae30c1fe73..81bbfd4316e 100644 --- a/tests/integration/events/test_task_events.py +++ b/tests/integration/events/test_task_events.py @@ -52,7 +52,7 @@ async def test_mail_footer_template( # start the workflow and get it to send an email ctx = SimpleNamespace(mail_to=None, mail_from=None) id_keys = [EventKey('none', 'failed', 'failed', Tokens('//1/a'))] - async with start(mod_one) as one_log: + async with start(mod_one): mod_one.task_events_mgr._process_event_email(mod_one, ctx, id_keys) # warnings should appear only when the template is invalid @@ -60,11 +60,9 @@ async def test_mail_footer_template( # check that template issues are handled correctly assert bool(log_filter( - one_log, contains='Ignoring bad mail footer template', )) == should_log assert bool(log_filter( - one_log, contains=template, )) == should_log diff --git a/tests/integration/events/test_workflow_events.py b/tests/integration/events/test_workflow_events.py index 6b742264636..4569cbaed8b 100644 --- a/tests/integration/events/test_workflow_events.py +++ b/tests/integration/events/test_workflow_events.py @@ -69,11 +69,9 @@ async def test_mail_footer_template( # check that template issues are handled correctly assert bool(log_filter( - one_log, contains='Ignoring bad mail footer template', )) == should_log assert bool(log_filter( - one_log, contains=template, )) == should_log @@ -114,10 +112,8 @@ async def test_custom_event_handler_template( # check that template issues are handled correctly assert bool(log_filter( - one_log, contains='bad template', )) == should_log assert bool(log_filter( - one_log, contains=template, )) == should_log diff --git a/tests/integration/main_loop/test_auto_restart.py b/tests/integration/main_loop/test_auto_restart.py index 20cc4ea81c6..e03c97485a5 100644 --- a/tests/integration/main_loop/test_auto_restart.py +++ b/tests/integration/main_loop/test_auto_restart.py @@ -45,6 +45,6 @@ async def test_no_detach( id_: str = flow(one_conf) schd: Scheduler = scheduler(id_, paused_start=True, no_detach=True) with pytest.raises(MainLoopPluginException) as exc: - async with run(schd) as log: + async with run(schd): await asyncio.sleep(2) - assert log_filter(log, contains=f"Workflow shutting down - {exc.value}") + assert log_filter(contains=f"Workflow shutting down - {exc.value}") diff --git a/tests/integration/scripts/test_set.py b/tests/integration/scripts/test_set.py index af368aab0d5..9c2a2da9d3a 100644 --- a/tests/integration/scripts/test_set.py +++ b/tests/integration/scripts/test_set.py @@ -149,9 +149,9 @@ async def test_incomplete_detection( ): """It should detect and log finished tasks left with incomplete outputs.""" schd = scheduler(flow(one_conf)) - async with start(schd) as log: + async with start(schd): schd.pool.set_prereqs_and_outputs(['1/one'], ['failed'], None, ['1']) - assert log_filter(log, contains='1/one did not complete') + assert log_filter(contains='1/one did not complete') async def test_pre_all(flow, scheduler, run): diff --git a/tests/integration/test_config.py b/tests/integration/test_config.py index c75797e9cbb..e22965afd6f 100644 --- a/tests/integration/test_config.py +++ b/tests/integration/test_config.py @@ -151,7 +151,6 @@ def test_validate_param_env_templ( one_conf, validate, env_val, - caplog, log_filter, ): """It should validate parameter environment templates.""" @@ -166,8 +165,8 @@ def test_validate_param_env_templ( } }) validate(id_) - assert log_filter(caplog, contains='bad parameter environment template') - assert log_filter(caplog, contains=env_val) + assert log_filter(contains='bad parameter environment template') + assert log_filter(contains=env_val) def test_no_graph(flow, validate): diff --git a/tests/integration/test_data_store_mgr.py b/tests/integration/test_data_store_mgr.py index 4f45e3c0b0e..93076ae7604 100644 --- a/tests/integration/test_data_store_mgr.py +++ b/tests/integration/test_data_store_mgr.py @@ -331,9 +331,8 @@ def test_delta_task_prerequisite(harness): for itask in schd.pool.get_tasks(): # set prereqs as not-satisfied for prereq in itask.state.prerequisites: - prereq._all_satisfied = False for key in prereq: - prereq._satisfied[key] = False + prereq[key] = False schd.data_store_mgr.delta_task_prerequisite(itask) assert not any(p.satisfied for p in get_pb_prereqs(schd)) diff --git a/tests/integration/test_examples.py b/tests/integration/test_examples.py index dc3495fe39f..a0d15ee7289 100644 --- a/tests/integration/test_examples.py +++ b/tests/integration/test_examples.py @@ -23,9 +23,11 @@ import asyncio import logging from pathlib import Path + import pytest from cylc.flow import __version__ +from cylc.flow.scheduler import Scheduler async def test_create_flow(flow, run_dir): @@ -62,9 +64,9 @@ async def test_logging(flow, scheduler, start, one_conf, log_filter): # Ensure that the cylc version is logged on startup. id_ = flow(one_conf) schd = scheduler(id_) - async with start(schd) as log: + async with start(schd): # this returns a list of log records containing __version__ - assert log_filter(log, contains=__version__) + assert log_filter(contains=__version__) async def test_scheduler_arguments(flow, scheduler, start, one_conf): @@ -159,16 +161,12 @@ def killer(): # make sure that this error causes the flow to shutdown with pytest.raises(MyException): - async with run(one) as log: + async with run(one): # The `run` fixture's shutdown logic waits for the main loop to run pass # make sure the exception was logged - assert len(log_filter( - log, - level=logging.CRITICAL, - contains='mess' - )) == 1 + assert len(log_filter(logging.CRITICAL, contains='mess')) == 1 # make sure the server socket has closed - a good indication of a # successful clean shutdown @@ -290,3 +288,11 @@ async def test_reftest(flow, scheduler, reftest): ('1/a', None), ('1/b', ('1/a',)), } + + +async def test_show(one: Scheduler, start, cylc_show): + """Demonstrate the `cylc_show` fixture""" + async with start(one): + out = await cylc_show(one, '1/one') + assert list(out.keys()) == ['1/one'] + assert out['1/one']['state'] == 'waiting' diff --git a/tests/integration/test_flow_assignment.py b/tests/integration/test_flow_assignment.py index 6c0c58a8758..ea729efeb7b 100644 --- a/tests/integration/test_flow_assignment.py +++ b/tests/integration/test_flow_assignment.py @@ -27,7 +27,7 @@ FLOW_ALL, FLOW_NEW, FLOW_NONE, - stringify_flow_nums + repr_flow_nums ) from cylc.flow.scheduler import Scheduler @@ -110,7 +110,7 @@ async def test_flow_assignment( } id_ = flow(conf) schd: Scheduler = scheduler(id_, run_mode='simulation', paused_start=True) - async with start(schd) as log: + async with start(schd): if command == 'set': do_command: Callable = functools.partial( schd.pool.set_prereqs_and_outputs, outputs=['x'], prereqs=[] @@ -137,10 +137,9 @@ async def test_flow_assignment( do_command([active_a.identity], flow=[FLOW_NONE]) assert active_a.flow_nums == {1, 2} assert log_filter( - log, contains=( f'[{active_a}] ignoring \'flow=none\' {command}: ' - f'task already has {stringify_flow_nums(active_a.flow_nums)}' + f'task already has {repr_flow_nums(active_a.flow_nums)}' ), level=logging.ERROR ) diff --git a/tests/integration/test_queues.py b/tests/integration/test_queues.py index fc94c4c4a3d..7da83e1a1aa 100644 --- a/tests/integration/test_queues.py +++ b/tests/integration/test_queues.py @@ -120,7 +120,7 @@ async def test_queue_held_tasks( # hold all tasks and resume the workflow # (nothing should have run yet because the workflow started paused) - await commands.run_cmd(commands.hold, schd, ['*/*']) + await commands.run_cmd(commands.hold(schd, ['*/*'])) schd.resume_workflow() # release queued tasks @@ -129,7 +129,7 @@ async def test_queue_held_tasks( assert len(submitted_tasks) == 0 # un-hold tasks - await commands.run_cmd(commands.release, schd, ['*/*']) + await commands.run_cmd(commands.release(schd, ['*/*'])) # release queued tasks # (tasks should now be released from the queues) diff --git a/tests/integration/test_reload.py b/tests/integration/test_reload.py index 65960ffcdb7..ad96b187722 100644 --- a/tests/integration/test_reload.py +++ b/tests/integration/test_reload.py @@ -89,7 +89,7 @@ def change_state(_=0): change_state() # reload the workflow - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) # the task should end in the submitted state assert foo.state(TASK_STATUS_SUBMITTED) @@ -127,18 +127,17 @@ async def test_reload_failure( """ id_ = flow(one_conf) schd = scheduler(id_) - async with start(schd) as log: + async with start(schd): # corrupt the config by removing the scheduling section two_conf = {**one_conf, 'scheduling': {}} flow(two_conf, id_=id_) # reload the workflow - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) # the reload should have failed but the workflow should still be # running assert log_filter( - log, contains=( 'Reload failed - WorkflowConfigError:' ' missing [scheduling][[graph]] section' diff --git a/tests/integration/test_remove.py b/tests/integration/test_remove.py new file mode 100644 index 00000000000..a2c8b044f8f --- /dev/null +++ b/tests/integration/test_remove.py @@ -0,0 +1,408 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +import logging +from typing import ( + Dict, + List, + NamedTuple, + Set, +) + +import pytest + +from cylc.flow.commands import ( + force_trigger_tasks, + remove_tasks, + run_cmd, +) +from cylc.flow.cycling.integer import IntegerPoint +from cylc.flow.flow_mgr import FLOW_ALL +from cylc.flow.scheduler import Scheduler +from cylc.flow.task_outputs import TASK_OUTPUT_SUCCEEDED +from cylc.flow.task_proxy import TaskProxy + + +def get_pool_tasks(schd: Scheduler) -> Set[str]: + return {itask.identity for itask in schd.pool.get_tasks()} + + +class CylcShowPrereqs(NamedTuple): + prereqs: List[bool] + conditions: List[Dict[str, bool]] + + +@pytest.fixture +async def cylc_show_prereqs(cylc_show): + """Fixture that returns the prereq info from `cylc show` in an + easy-to-use format.""" + async def inner(schd: Scheduler, task: str): + prerequisites = (await cylc_show(schd, task))[task]['prerequisites'] + return [ + ( + p['satisfied'], + {c['taskId']: c['satisfied'] for c in p['conditions']}, + ) + for p in prerequisites + ] + + return inner + + +@pytest.fixture +def example_workflow(flow): + return flow({ + 'scheduling': { + 'graph': { + # Note: test both `&` and separate arrows for combining + # dependencies + 'R1': ''' + a1 & a2 => b + a3 => b + ''', + }, + }, + }) + + +def get_data_store_flow_nums(schd: Scheduler, itask: TaskProxy): + _, ds_tproxy = schd.data_store_mgr.store_node_fetcher(itask.tokens) + if ds_tproxy: + return ds_tproxy.flow_nums + + +async def test_basic( + example_workflow, scheduler, start, db_select +): + """Test removing a task from all flows.""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + a1 = schd.pool._get_task_by_id('1/a1') + a3 = schd.pool._get_task_by_id('1/a3') + schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED) + schd.pool.spawn_on_output(a3, TASK_OUTPUT_SUCCEEDED) + await schd.update_data_structure() + + assert a1 in schd.pool.get_tasks() + for table in ('task_states', 'task_outputs'): + assert db_select(schd, True, table, 'flow_nums', name='a1') == [ + ('[1]',), + ] + assert db_select( + schd, True, 'task_prerequisites', 'satisfied', prereq_name='a1' + ) == [ + ('satisfied naturally',), + ] + assert get_data_store_flow_nums(schd, a1) == '[1]' + + await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL])) + await schd.update_data_structure() + + assert a1 not in schd.pool.get_tasks() # removed from pool + for table in ('task_states', 'task_outputs'): + assert db_select(schd, True, table, 'flow_nums', name='a1') == [ + ('[]',), # removed from all flows + ] + assert db_select( + schd, True, 'task_prerequisites', 'satisfied', prereq_name='a1' + ) == [ + ('0',), # prereq is now unsatisfied + ] + assert get_data_store_flow_nums(schd, a1) == '[]' + + +async def test_specific_flow( + example_workflow, scheduler, start, db_select +): + """Test removing a task from a specific flow.""" + schd: Scheduler = scheduler(example_workflow) + + def select_prereqs(): + return db_select( + schd, + True, + 'task_prerequisites', + 'flow_nums', + 'satisfied', + prereq_name='a1', + ) + + async with start(schd): + a1 = schd.pool._get_task_by_id('1/a1') + schd.pool.force_trigger_tasks(['1/a1'], ['1', '2']) + schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED) + await schd.update_data_structure() + + assert a1 in schd.pool.get_tasks() + assert a1.flow_nums == {1, 2} + for table in ('task_states', 'task_outputs'): + assert sorted( + db_select(schd, True, table, 'flow_nums', name='a1') + ) == [ + ('[1, 2]',), # triggered task + ('[1]',), # original spawned task + ] + assert select_prereqs() == [ + ('[1, 2]', 'satisfied naturally'), + ] + assert get_data_store_flow_nums(schd, a1) == '[1, 2]' + + await run_cmd(remove_tasks(schd, ['1/a1'], ['1'])) + await schd.update_data_structure() + + assert a1 in schd.pool.get_tasks() # still in pool + assert a1.flow_nums == {2} + for table in ('task_states', 'task_outputs'): + assert sorted( + db_select(schd, True, table, 'flow_nums', name='a1') + ) == [ + ('[2]',), + ('[]',), + ] + assert select_prereqs() == [ + ('[1, 2]', '0'), + ] + assert get_data_store_flow_nums(schd, a1) == '[2]' + + +async def test_unset_prereq(example_workflow, scheduler, start): + """Test removing a task unsets any prerequisites it satisfied.""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + for task in ('a1', 'a2', 'a3'): + schd.pool.spawn_on_output( + schd.pool.get_task(IntegerPoint('1'), task), + TASK_OUTPUT_SUCCEEDED, + ) + b = schd.pool.get_task(IntegerPoint('1'), 'b') + assert b.prereqs_are_satisfied() + + await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL])) + + assert not b.prereqs_are_satisfied() + + +async def test_not_unset_prereq( + example_workflow, scheduler, start, db_select +): + """Test removing a task does not unset a force-satisfied prerequisite + (one that was satisfied by `cylc set --pre`).""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + # This set prereq should not be unset by removing a1: + schd.pool.set_prereqs_and_outputs( + ['1/b'], outputs=[], prereqs=['1/a1'], flow=[FLOW_ALL] + ) + # Whereas the prereq satisfied by this set output *should* be unset + # by removing a2: + schd.pool.set_prereqs_and_outputs( + ['1/a2'], outputs=['succeeded'], prereqs=[], flow=[FLOW_ALL] + ) + await schd.update_data_structure() + + assert sorted( + db_select( + schd, True, 'task_prerequisites', 'prereq_name', 'satisfied' + ) + ) == [ + ('a1', 'force satisfied'), + ('a2', 'satisfied naturally'), + ('a3', '0'), + ] + + await run_cmd(remove_tasks(schd, ['1/a1', '1/a2'], [FLOW_ALL])) + await schd.update_data_structure() + + assert sorted( + db_select( + schd, True, 'task_prerequisites', 'prereq_name', 'satisfied' + ) + ) == [ + ('a1', 'force satisfied'), + ('a2', '0'), + ('a3', '0'), + ] + + +async def test_logging(flow, scheduler, start, log_filter): + """Test logging of a mixture of valid and invalid task removals.""" + schd: Scheduler = scheduler( + flow({ + 'scheduler': { + 'cycle point format': 'CCYY', + }, + 'scheduling': { + 'initial cycle point': '2000', + 'graph': { + 'R3//P1Y': 'b[-P1Y] => a & b', + }, + }, + }) + ) + tasks_to_remove = [ + # Active, removable tasks: + '2000/*', + # Future, non-removable tasks: + '2001/a', '2001/b', + # Glob that doesn't match any active tasks: + '2002/*', + # Invalid tasks: + '2005/a', '2000/doh', + ] + async with start(schd): + await run_cmd(remove_tasks(schd, tasks_to_remove, [FLOW_ALL])) + + assert log_filter( + logging.INFO, "Removed task(s): 2000/a (flows=1), 2000/b (flows=1)" + ) + + assert log_filter(logging.WARNING, "Task(s) not removable: 2001/a, 2001/b") + assert log_filter(logging.WARNING, "No active tasks matching: 2002/*") + assert log_filter(logging.WARNING, "Invalid cycle point for task: a, 2005") + assert log_filter(logging.WARNING, "No matching tasks found: doh") + + +async def test_logging_flow_nums( + example_workflow, scheduler, start, log_filter +): + """Test logging of task removals involving flow numbers.""" + schd: Scheduler = scheduler(example_workflow) + async with start(schd): + schd.pool.force_trigger_tasks(['1/a1'], ['1', '2']) + # Removing from flow that doesn't exist doesn't work: + await run_cmd(remove_tasks(schd, ['1/a1'], ['3'])) + assert log_filter( + logging.WARNING, "Task(s) not removable: 1/a1 (flows=3)" + ) + + # But if a valid flow is included, it will be removed from that flow: + await run_cmd(remove_tasks(schd, ['1/a1'], ['2', '3'])) + assert log_filter(logging.INFO, "Removed task(s): 1/a1 (flows=2)") + assert schd.pool._get_task_by_id('1/a1').flow_nums == {1} + + +async def test_retrigger(flow, scheduler, run, reflog, complete): + """Test prereqs & re-run behaviour when removing tasks.""" + schd: Scheduler = scheduler( + flow('a => b => c'), + paused_start=False, + ) + async with run(schd): + reflog_triggers: set = reflog(schd) + await complete(schd, '1/b') + + await run_cmd(remove_tasks(schd, ['1/a', '1/b'], [FLOW_ALL])) + schd.process_workflow_db_queue() + # Removing 1/b should un-queue 1/c: + assert len(schd.pool.task_queue_mgr.queues['default'].deque) == 0 + + assert reflog_triggers == { + ('1/a', None), + ('1/b', ('1/a',)), + } + reflog_triggers.clear() + + await run_cmd(force_trigger_tasks(schd, ['1/a'], [])) + await complete(schd) + + assert reflog_triggers == { + ('1/a', None), + # 1/b should have run again after 1/a on the re-trigger in flow 1: + ('1/b', ('1/a',)), + ('1/c', ('1/b',)), + } + + +async def test_prereqs( + flow, scheduler, run, complete, cylc_show_prereqs, log_filter +): + """Test prereqs & stall behaviour when removing tasks.""" + schd: Scheduler = scheduler( + flow('(a1 | a2) & b => x'), + paused_start=False, + ) + async with run(schd): + await complete(schd, '1/a1', '1/a2', '1/b') + + await run_cmd(remove_tasks(schd, ['1/a1'], [FLOW_ALL])) + assert not schd.pool.is_stalled() + assert len(schd.pool.task_queue_mgr.queues['default'].deque) + # `cylc show` should reflect the now-unsatisfied condition: + assert await cylc_show_prereqs(schd, '1/x') == [ + (True, {'1/a1': False, '1/a2': True, '1/b': True}) + ] + + await run_cmd(remove_tasks(schd, ['1/b'], [FLOW_ALL])) + # Should cause stall now because 1/c prereq is unsatisfied: + assert len(schd.pool.task_queue_mgr.queues['default'].deque) == 0 + assert schd.pool.is_stalled() + assert log_filter( + logging.WARNING, + "1/x is waiting on ['1/a1:succeeded', '1/b:succeeded']", + ) + assert await cylc_show_prereqs(schd, '1/x') == [ + (False, {'1/a1': False, '1/a2': True, '1/b': False}) + ] + + assert schd.pool._get_task_by_id('1/x') + await run_cmd(remove_tasks(schd, ['1/a2'], [FLOW_ALL])) + # Should cause 1/x to be removed from the pool as it no longer has + # any satisfied prerequisite tasks: + assert not schd.pool._get_task_by_id('1/x') + + +async def test_downstream_preparing(flow, scheduler, start): + """Downstream dependents should not be removed if they are already + preparing.""" + schd: Scheduler = scheduler( + flow(''' + a => x + a => y + '''), + ) + async with start(schd): + a = schd.pool._get_task_by_id('1/a') + schd.pool.spawn_on_output(a, TASK_OUTPUT_SUCCEEDED) + assert get_pool_tasks(schd) == {'1/a', '1/x', '1/y'} + + schd.pool._get_task_by_id('1/y').state_reset('preparing') + await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL])) + assert get_pool_tasks(schd) == {'1/y'} + + +async def test_suicide(flow, scheduler, run, reflog, complete): + """Test that suicide prereqs are unset by `cylc remove`.""" + schd: Scheduler = scheduler( + flow(''' + a => b => c => d => x + a & c => !x + '''), + paused_start=False, + ) + async with run(schd): + reflog_triggers: set = reflog(schd) + await complete(schd, '1/b') + await run_cmd(remove_tasks(schd, ['1/a'], [FLOW_ALL])) + await complete(schd) + + assert reflog_triggers == { + ('1/a', None), + ('1/b', ('1/a',)), + ('1/c', ('1/b',)), + ('1/d', ('1/c',)), + # 1/x not suicided as 1/a was removed: + ('1/x', ('1/d',)), + } diff --git a/tests/integration/test_resolvers.py b/tests/integration/test_resolvers.py index 981237a4d2a..4fa21dadbb9 100644 --- a/tests/integration/test_resolvers.py +++ b/tests/integration/test_resolvers.py @@ -234,7 +234,7 @@ async def test_command_logging(mock_flow, caplog, log_filter): {'mode': StopMode.REQUEST_CLEAN.value}, meta, ) - assert log_filter(caplog, contains='Command "stop" received') + assert log_filter(contains='Command "stop" received') # put_messages: only log for owner kwargs = { @@ -244,12 +244,11 @@ async def test_command_logging(mock_flow, caplog, log_filter): } meta["auth_user"] = mock_flow.owner await mock_flow.resolvers._mutation_mapper("put_messages", kwargs, meta) - assert not log_filter(caplog, contains='Command "put_messages" received:') + assert not log_filter(contains='Command "put_messages" received:') meta["auth_user"] = "Dr Spock" await mock_flow.resolvers._mutation_mapper("put_messages", kwargs, meta) - assert log_filter( - caplog, contains='Command "put_messages" received from Dr Spock') + assert log_filter(contains='Command "put_messages" received from Dr Spock') async def test_command_validation_failure( diff --git a/tests/integration/test_scheduler.py b/tests/integration/test_scheduler.py index 79b7327206a..6f1f581e899 100644 --- a/tests/integration/test_scheduler.py +++ b/tests/integration/test_scheduler.py @@ -174,7 +174,7 @@ async def test_holding_tasks_whilst_scheduler_paused( assert submitted_tasks == set() # hold all tasks & resume the workflow - await commands.run_cmd(commands.hold, one, ['*/*']) + await commands.run_cmd(commands.hold(one, ['*/*'])) one.resume_workflow() # release queued tasks @@ -183,7 +183,7 @@ async def test_holding_tasks_whilst_scheduler_paused( assert submitted_tasks == set() # release all tasks - await commands.run_cmd(commands.release, one, ['*/*']) + await commands.run_cmd(commands.release(one, ['*/*'])) # release queued tasks # (the task should be submitted) @@ -219,12 +219,12 @@ async def test_no_poll_waiting_tasks( polled_tasks = capture_polling(one) # Waiting tasks should not be polled. - await commands.run_cmd(commands.poll_tasks, one, ['*/*']) + await commands.run_cmd(commands.poll_tasks(one, ['*/*'])) assert polled_tasks == set() # Even if they have a submit number. task.submit_num = 1 - await commands.run_cmd(commands.poll_tasks, one, ['*/*']) + await commands.run_cmd(commands.poll_tasks(one, ['*/*'])) assert len(polled_tasks) == 0 # But these states should be: @@ -235,7 +235,7 @@ async def test_no_poll_waiting_tasks( TASK_STATUS_RUNNING ]: task.state.status = state - await commands.run_cmd(commands.poll_tasks, one, ['*/*']) + await commands.run_cmd(commands.poll_tasks(one, ['*/*'])) assert len(polled_tasks) == 1 polled_tasks.clear() @@ -267,7 +267,7 @@ def raise_ParsecError(*a, **k): pass assert log_filter( - log, level=logging.CRITICAL, + logging.CRITICAL, exact_match="Workflow shutting down - Mock error" ) assert TRACEBACK_MSG in log.text @@ -295,7 +295,7 @@ def mock_auto_restart(*a, **k): async with run(one) as log: pass - assert log_filter(log, level=logging.ERROR, contains=err_msg) + assert log_filter(logging.ERROR, err_msg) assert TRACEBACK_MSG in log.text @@ -361,20 +361,19 @@ async def test_restart_timeout( # restart the completed workflow schd = scheduler(id_) - async with run(schd) as log: + async with run(schd): # it should detect that the workflow has completed and alert the user assert log_filter( - log, contains='This workflow already ran to completion.' ) # it should activate a timeout - assert log_filter(log, contains='restart timer starts NOW') + assert log_filter(contains='restart timer starts NOW') # when we trigger tasks the timeout should be cleared schd.pool.force_trigger_tasks(['1/one'], {1}) await asyncio.sleep(0) # yield control to the main loop - assert log_filter(log, contains='restart timer stopped') + assert log_filter(contains='restart timer stopped') @pytest.mark.parametrize("signal", ((SIGHUP), (SIGINT), (SIGTERM))) @@ -387,14 +386,14 @@ async def test_signal_escallation(one, start, signal, log_filter): See https://github.com/cylc/cylc-flow/pull/6444 """ - async with start(one) as log: + async with start(one): # put the workflow in the stopping state one._set_stop(StopMode.REQUEST_CLEAN) assert one.stop_mode.name == 'REQUEST_CLEAN' # one signal should escalate this from CLEAN to NOW one._handle_signal(signal, None) - assert log_filter(log, contains=signal.name) + assert log_filter(contains=signal.name) assert one.stop_mode.name == 'REQUEST_NOW' # two signals should escalate this from NOW to NOW_NOW diff --git a/tests/integration/test_simulation.py b/tests/integration/test_simulation.py index 49bc76ce5e2..b50acbb084d 100644 --- a/tests/integration/test_simulation.py +++ b/tests/integration/test_simulation.py @@ -344,7 +344,7 @@ async def test_settings_reload( conf_file.read_text().replace('False', 'True')) # Reload Workflow: - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) # Submit second psuedo-job and "run" to success: itask = run_simjob(schd, one_1066.point, 'one') diff --git a/tests/integration/test_stop_after_cycle_point.py b/tests/integration/test_stop_after_cycle_point.py index f92e8d449f0..90bab288515 100644 --- a/tests/integration/test_stop_after_cycle_point.py +++ b/tests/integration/test_stop_after_cycle_point.py @@ -119,10 +119,11 @@ def get_db_value(schd) -> Optional[str]: # override this value whilst the workflow is running await commands.run_cmd( - commands.stop, - schd, - cycle_point=IntegerPoint('4'), - mode=StopMode.REQUEST_CLEAN, + commands.stop( + schd, + cycle_point=IntegerPoint('4'), + mode=StopMode.REQUEST_CLEAN, + ) ) assert schd.config.stop_point == IntegerPoint('4') diff --git a/tests/integration/test_subprocctx.py b/tests/integration/test_subprocctx.py index c4d0c4ca28a..871106dcd53 100644 --- a/tests/integration/test_subprocctx.py +++ b/tests/integration/test_subprocctx.py @@ -47,7 +47,7 @@ def myxtrigger(): return True, {} """)) schd = scheduler(id_) - async with start(schd, level=DEBUG) as log: + async with start(schd, level=DEBUG): # Set off check for x-trigger: task = schd.pool.get_tasks()[0] schd.xtrigger_mgr.call_xtriggers_async(task) @@ -59,4 +59,4 @@ def myxtrigger(): # Assert that both stderr and out from the print statement # in our xtrigger appear in the log. for expected in ['Hello World', 'Hello Hades']: - assert log_filter(log, contains=expected, level=DEBUG) + assert log_filter(DEBUG, expected) diff --git a/tests/integration/test_task_job_mgr.py b/tests/integration/test_task_job_mgr.py index 48a49eb30aa..b1cf1347071 100644 --- a/tests/integration/test_task_job_mgr.py +++ b/tests/integration/test_task_job_mgr.py @@ -23,7 +23,6 @@ from cylc.flow.task_state import TASK_STATUS_RUNNING - async def test_run_job_cmd_no_hosts_error( flow, scheduler, @@ -92,7 +91,6 @@ async def test_run_job_cmd_no_hosts_error( # ...but the failure should be logged assert log_filter( - log, contains='No available hosts for no-host-platform', ) log.clear() @@ -105,7 +103,6 @@ async def test_run_job_cmd_no_hosts_error( # ...but the failure should be logged assert log_filter( - log, contains='No available hosts for no-host-platform', ) @@ -217,7 +214,7 @@ async def test_broadcast_platform_change( schd = scheduler(id_, run_mode='live') - async with start(schd) as log: + async with start(schd): # Change the task platform with broadcast: schd.broadcast_mgr.put_broadcast( ['1'], ['mytask'], [{'platform': 'foo'}]) @@ -235,4 +232,4 @@ async def test_broadcast_platform_change( # Check that task platform hasn't become "localhost": assert schd.pool.get_tasks()[0].platform['name'] == 'foo' # ... and that remote init failed because all hosts bad: - assert log_filter(log, contains="(no hosts were reachable)") + assert log_filter(regex=r"platform: foo .*\(no hosts were reachable\)") diff --git a/tests/integration/test_task_pool.py b/tests/integration/test_task_pool.py index f2a0e650dcd..a524e0b5a63 100644 --- a/tests/integration/test_task_pool.py +++ b/tests/integration/test_task_pool.py @@ -63,9 +63,6 @@ # immediately too, because we spawn autospawn absolute-triggered tasks as # well as parentless tasks. 3/asd does not spawn at start, however. EXAMPLE_FLOW_CFG = { - 'scheduler': { - 'allow implicit tasks': True - }, 'scheduling': { 'cycling mode': 'integer', 'initial cycle point': 1, @@ -86,7 +83,6 @@ EXAMPLE_FLOW_2_CFG = { 'scheduler': { - 'allow implicit tasks': True, 'UTC mode': True }, 'scheduling': { @@ -142,7 +138,7 @@ def assert_expected_log( @pytest.fixture(scope='module') async def mod_example_flow( mod_flow: Callable, mod_scheduler: Callable, mod_run: Callable -) -> 'Scheduler': +) -> AsyncGenerator['Scheduler', None]: """Return a scheduler for interrogating its task pool. This is module-scoped so faster than example_flow, but should only be used @@ -178,7 +174,7 @@ async def example_flow( @pytest.fixture(scope='module') async def mod_example_flow_2( mod_flow: Callable, mod_scheduler: Callable, mod_run: Callable -) -> 'Scheduler': +) -> AsyncGenerator['Scheduler', None]: """Return a scheduler for interrogating its task pool. This is module-scoped so faster than example_flow, but should only be used @@ -570,7 +566,7 @@ async def test_reload_stopcp( schd: 'Scheduler' = scheduler(flow(cfg)) async with start(schd): assert str(schd.pool.stop_point) == '2020' - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) assert str(schd.pool.stop_point) == '2020' @@ -841,7 +837,7 @@ async def test_reload_prereqs( flow(conf, id_=id_) # Reload the workflow config - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) assert list_tasks(schd) == expected_3 # Check resulting dependencies of task z @@ -973,7 +969,7 @@ async def test_graph_change_prereq_satisfaction( flow(conf, id_=id_) # Reload the workflow config - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) await test.asend(schd) @@ -1183,9 +1179,6 @@ async def test_detect_incomplete_tasks( TASK_STATUS_SUBMIT_FAILED: TaskEventsManager.EVENT_SUBMIT_FAILED } id_ = flow({ - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'graph': { # a workflow with one task for each of the final task states @@ -1206,7 +1199,6 @@ async def test_detect_incomplete_tasks( # ensure that it is correctly identified as incomplete assert not itask.state.outputs.is_complete() assert log_filter( - log, contains=( f"[{itask}] did not complete the required outputs:" ), @@ -1228,9 +1220,6 @@ async def test_future_trigger_final_point( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'cycling mode': 'integer', 'initial cycle point': 1, @@ -1246,7 +1235,6 @@ async def test_future_trigger_final_point( for itask in schd.pool.get_tasks(): schd.pool.spawn_on_output(itask, "succeeded") assert log_filter( - log, regex=( ".*1/baz.*not spawned: a prerequisite is beyond" r" the workflow stop point \(1\)" @@ -1271,17 +1259,17 @@ async def test_set_failed_complete( schd.pool.task_events_mgr.process_message(one, 1, "failed") assert log_filter( - log, regex="1/one.* setting implied output: submitted") + regex="1/one.* setting implied output: submitted") assert log_filter( - log, regex="1/one.* setting implied output: started") + regex="1/one.* setting implied output: started") assert log_filter( - log, regex="failed.* did not complete the required outputs") + regex="failed.* did not complete the required outputs") # Set failed task complete via default "set" args. schd.pool.set_prereqs_and_outputs([one.identity], None, None, ['all']) assert log_filter( - log, contains=f'[{one}] removed from active task pool: completed') + contains=f'[{one}] removed from active task pool: completed') db_outputs = db_select( schd, True, 'task_outputs', 'outputs', @@ -1305,9 +1293,6 @@ async def test_set_prereqs( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'initial cycle point': '2040', 'graph': { @@ -1339,7 +1324,8 @@ async def test_set_prereqs( schd.pool.set_prereqs_and_outputs( ["20400101T0000Z/qux"], None, ["20400101T0000Z/foo:a"], ['all']) assert log_filter( - log, contains='20400101T0000Z/qux does not depend on "20400101T0000Z/foo:a"') + contains='20400101T0000Z/qux does not depend on "20400101T0000Z/foo:a"' + ) # it should not add 20400101T0000Z/qux to the pool assert ( @@ -1390,7 +1376,6 @@ async def test_set_bad_prereqs( """ id_ = flow({ 'scheduler': { - 'allow implicit tasks': 'True', 'cycle point format': '%Y'}, 'scheduling': { 'initial cycle point': '2040', @@ -1406,11 +1391,11 @@ def set_prereqs(prereqs): async with start(schd) as log: # Invalid: task name wildcard: set_prereqs(["2040/*"]) - assert log_filter(log, contains='Invalid prerequisite task name') + assert log_filter(contains='Invalid prerequisite task name') # Invalid: cycle point wildcard. set_prereqs(["*/foo"]) - assert log_filter(log, contains='Invalid prerequisite cycle point') + assert log_filter(contains='Invalid prerequisite cycle point') async def test_set_outputs_live( @@ -1424,9 +1409,6 @@ async def test_set_outputs_live( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'graph': { 'R1': """ @@ -1476,15 +1458,12 @@ async def test_set_outputs_live( ) # it should complete implied outputs (submitted, started) too - assert log_filter( - log, contains="setting implied output: submitted") - assert log_filter( - log, contains="setting implied output: started") + assert log_filter(contains="setting implied output: submitted") + assert log_filter(contains="setting implied output: started") # set foo (default: all required outputs) to complete y. schd.pool.set_prereqs_and_outputs(["1/foo"], None, None, ['all']) - assert log_filter( - log, contains="output 1/foo:succeeded completed") + assert log_filter(contains="output 1/foo:succeeded completed") assert ( pool_get_task_ids(schd.pool) == ["1/bar", "1/baz"] ) @@ -1501,7 +1480,6 @@ async def test_set_outputs_live2( """ id_ = flow( { - 'scheduler': {'allow implicit tasks': 'True'}, 'scheduling': {'graph': { 'R1': """ foo:a => apple @@ -1517,7 +1495,6 @@ async def test_set_outputs_live2( async with start(schd) as log: schd.pool.set_prereqs_and_outputs(["1/foo"], None, None, ['all']) assert not log_filter( - log, contains="did not complete required outputs: ['a', 'b']" ) @@ -1533,9 +1510,6 @@ async def test_set_outputs_future( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'graph': { 'R1': "a:x & a:y => b => c" @@ -1571,9 +1545,9 @@ async def test_set_outputs_future( prereqs=None, flow=['all'] ) - assert log_filter(log, contains="output 1/a:cheese not found") - assert log_filter(log, contains="completed output x") - assert log_filter(log, contains="completed output y") + assert log_filter(contains="output 1/a:cheese not found") + assert log_filter(contains="completed output x") + assert log_filter(contains="completed output y") async def test_prereq_satisfaction( @@ -1587,9 +1561,6 @@ async def test_prereq_satisfaction( """ id_ = flow( { - 'scheduler': { - 'allow implicit tasks': 'True', - }, 'scheduling': { 'graph': { 'R1': "a:x & a:y => b" @@ -1605,8 +1576,8 @@ async def test_prereq_satisfaction( } } ) - schd = scheduler(id_) - async with start(schd) as log: + schd: Scheduler = scheduler(id_) + async with start(schd): # it should start up with just 1/a assert pool_get_task_ids(schd.pool) == ["1/a"] # spawn b @@ -1617,21 +1588,19 @@ async def test_prereq_satisfaction( b = schd.pool.get_task(IntegerPoint("1"), "b") - assert not b.is_waiting_prereqs_done() + assert not b.prereqs_are_satisfied() # set valid and invalid prerequisites, by label and message. schd.pool.set_prereqs_and_outputs( prereqs=["1/a:xylophone", "1/a:y", "1/a:w", "1/a:z"], items=["1/b"], outputs=None, flow=['all'] ) - assert log_filter(log, contains="1/a:z not found") - assert log_filter(log, contains="1/a:w not found") - assert not log_filter(log, contains='1/b does not depend on "1/a:x"') - assert not log_filter( - log, contains='1/b does not depend on "1/a:xylophone"') - assert not log_filter(log, contains='1/b does not depend on "1/a:y"') + assert log_filter(contains="1/a:z not found") + assert log_filter(contains="1/a:w not found") + # FIXME: testing that something is *not* logged is extremely fragile: + assert not log_filter(regex='.*does not depend on.*') - assert b.is_waiting_prereqs_done() + assert b.prereqs_are_satisfied() @pytest.mark.parametrize('compat_mode', ['compat-mode', 'normal-mode']) @@ -1910,7 +1879,6 @@ async def test_fast_respawn( async def test_remove_active_task( example_flow: 'Scheduler', - caplog: pytest.LogCaptureFixture, log_filter: Callable, ) -> None: """Test warning on removing an active task.""" @@ -1925,7 +1893,6 @@ async def test_remove_active_task( assert foo not in task_pool.get_tasks() assert log_filter( - caplog, regex=( "1/foo.*removed from active task pool:" " request - active job orphaned" @@ -1947,7 +1914,6 @@ async def test_remove_by_suicide( * Removing a task manually (cylc remove) should work the same. """ id_ = flow({ - 'scheduler': {'allow implicit tasks': 'True'}, 'scheduling': { 'graph': { 'R1': ''' @@ -1966,7 +1932,6 @@ async def test_remove_by_suicide( # mark 1/a as failed and ensure 1/b is removed by suicide trigger schd.pool.spawn_on_output(a, TASK_OUTPUT_FAILED) assert log_filter( - log, regex="1/b.*removed from active task pool: suicide trigger" ) assert pool_get_task_ids(schd.pool) == ["1/a"] @@ -1975,14 +1940,14 @@ async def test_remove_by_suicide( log.clear() schd.pool.force_trigger_tasks(['1/b'], ['1']) assert log_filter( - log, regex='1/b.*added to active task pool', ) # remove 1/b by request (cylc remove) - await commands.run_cmd(commands.remove_tasks, schd, ['1/b']) + await commands.run_cmd( + commands.remove_tasks(schd, ['1/b'], [FLOW_ALL]) + ) assert log_filter( - log, regex='1/b.*removed from active task pool: request', ) @@ -1990,55 +1955,10 @@ async def test_remove_by_suicide( log.clear() schd.pool.force_trigger_tasks(['1/b'], ['1']) assert log_filter( - log, regex='1/b.*added to active task pool', ) -async def test_remove_no_respawn(flow, scheduler, start, log_filter): - """Ensure that removed tasks stay removed. - - If a task is removed by suicide trigger or "cylc remove", then it should - not be automatically spawned at a later time. - """ - id_ = flow({ - 'scheduling': { - 'graph': { - 'R1': 'a & b => z', - }, - }, - }) - schd: 'Scheduler' = scheduler(id_) - async with start(schd, level=logging.DEBUG) as log: - a1 = schd.pool.get_task(IntegerPoint("1"), "a") - b1 = schd.pool.get_task(IntegerPoint("1"), "b") - assert a1, '1/a should have been spawned on startup' - assert b1, '1/b should have been spawned on startup' - - # mark one of the upstream tasks as succeeded, 1/z should spawn - schd.pool.spawn_on_output(a1, TASK_OUTPUT_SUCCEEDED) - schd.workflow_db_mgr.process_queued_ops() - z1 = schd.pool.get_task(IntegerPoint("1"), "z") - assert z1, '1/z should have been spawned after 1/a succeeded' - - # manually remove 1/z, it should be removed from the pool - await commands.run_cmd(commands.remove_tasks, schd, ['1/z']) - schd.workflow_db_mgr.process_queued_ops() - z1 = schd.pool.get_task(IntegerPoint("1"), "z") - assert z1 is None, '1/z should have been removed (by request)' - - # mark the other upstream task as succeeded, 1/z should not be - # respawned as a result - schd.pool.spawn_on_output(b1, TASK_OUTPUT_SUCCEEDED) - assert log_filter( - log, contains='Not respawning 1/z - task was removed' - ) - z1 = schd.pool.get_task(IntegerPoint("1"), "z") - assert ( - z1 is None - ), '1/z should have stayed removed (but has been added back into the pool' - - async def test_set_future_flow(flow, scheduler, start, log_filter): """Manually-set outputs for new flow num must be recorded in the DB. @@ -2170,7 +2090,7 @@ async def list_data_store(): # reload flow(config, id_=id_) - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) # check xtrigs post-reload assert list_xtrig_mgr() == { @@ -2182,7 +2102,7 @@ async def list_data_store(): 'c': 'wall_clock(trigger_time=946688400)', } - + async def test_trigger_unqueued(flow, scheduler, start): """Test triggering an unqueued active task. @@ -2247,7 +2167,7 @@ async def test_expire_dequeue_with_retries(flow, scheduler, start, expire_type): if expire_type == 'clock-expire': conf['scheduling']['special tasks'] = {'clock-expire': 'foo(PT0S)'} - method = lambda schd: schd.pool.clock_expire_tasks() + method = lambda schd: schd.pool.clock_expire_tasks() else: method = lambda schd: schd.pool.set_prereqs_and_outputs( ['2000/foo'], prereqs=[], outputs=['expired'], flow=['1'] diff --git a/tests/integration/test_workflow_db_mgr.py b/tests/integration/test_workflow_db_mgr.py index 774d8c21fac..b994f88377f 100644 --- a/tests/integration/test_workflow_db_mgr.py +++ b/tests/integration/test_workflow_db_mgr.py @@ -34,12 +34,12 @@ async def test(expected_restart_num: int, do_reload: bool = False): """(Re)start the workflow and check the restart number is as expected. """ schd: 'Scheduler' = scheduler(id_, paused_start=True) - async with start(schd) as log: + async with start(schd): if do_reload: - await commands.run_cmd(commands.reload_workflow, schd) + await commands.run_cmd(commands.reload_workflow(schd)) assert schd.workflow_db_mgr.n_restart == expected_restart_num assert log_filter( - log, contains=f"(re)start number={expected_restart_num + 1}" + contains=f"(re)start number={expected_restart_num + 1}" # (In the log, it's 1 higher than backend value) ) assert ('n_restart', f'{expected_restart_num}') in db_select( diff --git a/tests/integration/test_workflow_files.py b/tests/integration/test_workflow_files.py index 5b036ef2c11..77937ca66dd 100644 --- a/tests/integration/test_workflow_files.py +++ b/tests/integration/test_workflow_files.py @@ -151,7 +151,7 @@ def test_detect_old_contact_file_old_run(workflow, caplog, log_filter): # as a side effect the contact file should have been removed assert not workflow.contact_file.exists() - assert log_filter(caplog, contains='Removed contact file') + assert log_filter(contains='Removed contact file') def test_detect_old_contact_file_none(workflow): @@ -260,11 +260,9 @@ def _unlink(*args): # check the appropriate messages were logged assert bool(log_filter( - caplog, contains='Removed contact file', )) is remove_succeeded assert bool(log_filter( - caplog, contains=( f'Failed to remove contact file for {workflow.id_}:' '\nmocked-os-error' diff --git a/tests/integration/tui/conftest.py b/tests/integration/tui/conftest.py index 55f83c143f3..86d2267da1e 100644 --- a/tests/integration/tui/conftest.py +++ b/tests/integration/tui/conftest.py @@ -4,7 +4,7 @@ from pathlib import Path import re from time import sleep -from uuid import uuid1 +from secrets import token_hex import pytest from urwid.display import html_fragment @@ -211,7 +211,7 @@ def wait_until_loaded(self, *ids, retries=20): ) if exc: msg += f'\n{exc}' - self.compare_screenshot(f'fail-{uuid1()}', msg, 1) + self.compare_screenshot(f'fail-{token_hex(4)}', msg, 1) @pytest.fixture diff --git a/tests/integration/tui/test_mutations.py b/tests/integration/tui/test_mutations.py index e9a41466d70..b87622bca5f 100644 --- a/tests/integration/tui/test_mutations.py +++ b/tests/integration/tui/test_mutations.py @@ -59,7 +59,7 @@ async def test_online_mutation( id_ = flow(one_conf, name='one') schd = scheduler(id_) with rakiura(size='80,15') as rk: - async with start(schd) as schd_log: + async with start(schd): await schd.update_data_structure() assert schd.command_queue.empty() @@ -91,7 +91,7 @@ async def test_online_mutation( # the mutation should be in the scheduler's command_queue await asyncio.sleep(0) - assert log_filter(schd_log, contains="hold(tasks=['1/one'])") + assert log_filter(contains="hold(tasks=['1/one'])") # close the dialogue and re-run the hold mutation rk.user_input('q', 'q', 'enter') @@ -127,7 +127,7 @@ def standardise_cli_cmds(monkeypatch): """This remove the variable bit of the workflow ID from CLI commands. The workflow ID changes from run to run. In order to make screenshots - stable, this + stable, this """ from cylc.flow.tui.data import extract_context def _extract_context(selection): diff --git a/tests/integration/utils/flow_tools.py b/tests/integration/utils/flow_tools.py index 34b80a25882..7419fd4fe14 100644 --- a/tests/integration/utils/flow_tools.py +++ b/tests/integration/utils/flow_tools.py @@ -28,7 +28,7 @@ import logging import pytest from typing import Any, Optional, Union -from uuid import uuid1 +from secrets import token_hex from cylc.flow import CYLC_LOG from cylc.flow.workflow_files import WorkflowFiles @@ -41,7 +41,7 @@ def _make_src_flow(src_path, conf, filename=WorkflowFiles.FLOW_FILE): """Construct a workflow on the filesystem""" - flow_src_dir = (src_path / str(uuid1())) + flow_src_dir = (src_path / token_hex(4)) flow_src_dir.mkdir(parents=True, exist_ok=True) if isinstance(conf, dict): conf = flow_config_str(conf) @@ -53,7 +53,7 @@ def _make_src_flow(src_path, conf, filename=WorkflowFiles.FLOW_FILE): def _make_flow( cylc_run_dir: Union[Path, str], test_dir: Path, - conf: dict, + conf: Union[dict, str], name: Optional[str] = None, id_: Optional[str] = None, defaults: Optional[bool] = True, @@ -62,6 +62,8 @@ def _make_flow( """Construct a workflow on the filesystem. Args: + conf: Either a workflow config dictionary, or a graph string to be + used as the R1 graph in the workflow config. defaults: Set up a common defaults. * [scheduling]allow implicit tasks = true @@ -71,10 +73,18 @@ def _make_flow( flow_run_dir = (cylc_run_dir / id_) else: if name is None: - name = str(uuid1()) + name = token_hex(4) flow_run_dir = (test_dir / name) flow_run_dir.mkdir(parents=True, exist_ok=True) id_ = str(flow_run_dir.relative_to(cylc_run_dir)) + if isinstance(conf, str): + conf = { + 'scheduling': { + 'graph': { + 'R1': conf + } + } + } if defaults: # set the default simulation runtime to zero (can be overridden) ( diff --git a/tests/unit/cycling/test_iso8601.py b/tests/unit/cycling/test_iso8601.py index ae0eb957f47..b7a134f7882 100644 --- a/tests/unit/cycling/test_iso8601.py +++ b/tests/unit/cycling/test_iso8601.py @@ -601,12 +601,12 @@ def test_exclusion_zero_duration_warning(set_cycling_type, caplog, log_filter): set_cycling_type(ISO8601_CYCLING_TYPE, "+05") with pytest.raises(Exception): ISO8601Sequence('3000', '2999') - assert log_filter(caplog, contains='zero-duration') + assert log_filter(contains='zero-duration') # parsing a point in an exclusion should not caplog.clear() ISO8601Sequence('P1Y ! 3000', '2999') - assert not log_filter(caplog, contains='zero-duration') + assert not log_filter(contains='zero-duration') def test_simple(set_cycling_type): diff --git a/tests/unit/post_install/test_log_vc_info.py b/tests/unit/post_install/test_log_vc_info.py index 59511db5461..67e204db747 100644 --- a/tests/unit/post_install/test_log_vc_info.py +++ b/tests/unit/post_install/test_log_vc_info.py @@ -279,13 +279,13 @@ def test_no_base_commit_git(tmp_path: Path): @require_svn def test_untracked_svn_subdir( - svn_source_repo: Tuple[str, str, str], caplog, log_filter + svn_source_repo: Tuple[str, str, str], log_filter ): repo_dir, *_ = svn_source_repo source_dir = Path(repo_dir, 'jar_jar_binks') source_dir.mkdir() assert get_vc_info(source_dir) is None - assert log_filter(caplog, level=logging.WARNING, contains="$ svn info") + assert log_filter(logging.WARNING, contains="$ svn info") def test_not_installed( @@ -306,7 +306,6 @@ def test_not_installed( caplog.set_level(logging.DEBUG) assert get_vc_info(tmp_path) is None assert log_filter( - caplog, - level=logging.DEBUG, + logging.DEBUG, contains=f"{fake_vcs} does not appear to be installed", ) diff --git a/tests/unit/test_clean.py b/tests/unit/test_clean.py index 285bfa6a23f..2308fe318f0 100644 --- a/tests/unit/test_clean.py +++ b/tests/unit/test_clean.py @@ -945,7 +945,7 @@ def mocked_remote_clean_cmd_side_effect(id_, platform, timeout, rm_dirs): id_, platform_names, timeout='irrelevant', rm_dirs=rm_dirs ) for msg in expected_err_msgs: - assert log_filter(caplog, level=logging.ERROR, contains=msg) + assert log_filter(logging.ERROR, msg) if expected_platforms: for p_name in expected_platforms: mocked_remote_clean_cmd.assert_any_call( diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 4a820fdb126..18c5f37bd8b 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -1400,7 +1400,7 @@ def test_implicit_tasks( @pytest.mark.parametrize('workflow_meta', [True, False]) @pytest.mark.parametrize('url_type', ['good', 'bad', 'ugly', 'broken']) -def test_process_urls(caplog, log_filter, workflow_meta, url_type): +def test_process_urls(log_filter, workflow_meta, url_type): if url_type == 'good': # valid cylc 8 syntax @@ -1436,7 +1436,6 @@ def test_process_urls(caplog, log_filter, workflow_meta, url_type): elif url_type == 'ugly': WorkflowConfig.process_metadata_urls(config) assert log_filter( - caplog, contains='Detected deprecated template variables', ) @@ -1464,7 +1463,6 @@ def test_zero_interval( should_warn: bool, opts: Values, tmp_flow_config: Callable, - caplog: pytest.LogCaptureFixture, log_filter: Callable, ): """Test that a zero-duration recurrence with >1 repetition gets an @@ -1482,7 +1480,6 @@ def test_zero_interval( """) WorkflowConfig(id_, flow_file, options=opts) logged = log_filter( - caplog, level=logging.WARNING, contains="Cannot have more than 1 repetition for zero-duration" ) diff --git a/tests/unit/test_id.py b/tests/unit/test_id.py index 2d50c9a2706..e3011ee1d57 100644 --- a/tests/unit/test_id.py +++ b/tests/unit/test_id.py @@ -17,12 +17,15 @@ import pytest +from cylc.flow.cycling.integer import IntegerPoint +from cylc.flow.cycling.iso8601 import ISO8601Point from cylc.flow.id import ( LEGACY_CYCLE_SLASH_TASK, LEGACY_TASK_DOT_CYCLE, RELATIVE_ID, - Tokens, UNIVERSAL_ID, + Tokens, + quick_relative_id, ) @@ -392,3 +395,14 @@ def test_task_property(): ) == ( Tokens('//c:cs/t:ts/j:js', relative=True) ) + + +@pytest.mark.parametrize('cycle, expected', [ + ('2000', '2000/foo'), + (2001, '2001/foo'), + (IntegerPoint('3'), '3/foo'), + # NOTE: ISO8601Points are not standardised by this function: + (ISO8601Point('2002'), '2002/foo'), +]) +def test_quick_relative_id(cycle, expected): + assert quick_relative_id(cycle, 'foo') == expected diff --git a/tests/unit/test_prerequisite.py b/tests/unit/test_prerequisite.py index 105e5e85401..828cd0a96bc 100644 --- a/tests/unit/test_prerequisite.py +++ b/tests/unit/test_prerequisite.py @@ -14,12 +14,17 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . +from functools import partial + import pytest from cylc.flow.cycling.integer import IntegerPoint from cylc.flow.cycling.loader import ISO8601_CYCLING_TYPE, get_point -from cylc.flow.prerequisite import Prerequisite -from cylc.flow.id import Tokens +from cylc.flow.id import Tokens, detokenise +from cylc.flow.prerequisite import Prerequisite, SatisfiedState + + +detok = partial(detokenise, selectors=True, relative=True) @pytest.fixture @@ -43,10 +48,10 @@ def test_satisfied(prereq: Prerequisite): ('2001', 'd', 'custom'): False, } # No cached satisfaction state yet: - assert prereq._all_satisfied is None + assert prereq._cached_satisfied is None # Calling self.is_satisfied() should cache the result: assert not prereq.is_satisfied() - assert prereq._all_satisfied is False + assert prereq._cached_satisfied is False # mark two prerequisites as satisfied prereq.satisfy_me([ @@ -63,7 +68,7 @@ def test_satisfied(prereq: Prerequisite): ('2001', 'd', 'custom'): False, } # Should have reset cached satisfaction state: - assert prereq._all_satisfied is None + assert prereq._cached_satisfied is None assert not prereq.is_satisfied() # mark all prereqs as satisfied @@ -78,7 +83,7 @@ def test_satisfied(prereq: Prerequisite): ('2001', 'd', 'custom'): 'force satisfied', } # Should have set cached satisfaction state as must be true now: - assert prereq._all_satisfied is True + assert prereq._cached_satisfied is True assert prereq.is_satisfied() @@ -138,14 +143,100 @@ def test_get_target_points(prereq): } -def test_get_resolved_dependencies(): +@pytest.fixture +def satisfied_states_prereq(): + """Fixture for testing the full range of possible satisfied states.""" prereq = Prerequisite(IntegerPoint('2')) prereq[('1', 'a', 'x')] = True prereq[('1', 'b', 'x')] = False prereq[('1', 'c', 'x')] = 'satisfied from database' prereq[('1', 'd', 'x')] = 'force satisfied' - assert prereq.get_resolved_dependencies() == [ + return prereq + + +def test_get_satisfied_dependencies(satisfied_states_prereq: Prerequisite): + assert satisfied_states_prereq.get_satisfied_dependencies() == [ '1/a', '1/c', '1/d', ] + + +def test_unset_naturally_satisfied_dependency( + satisfied_states_prereq: Prerequisite +): + satisfied_states_prereq[('1', 'a', 'y')] = True + satisfied_states_prereq[('1', 'a', 'z')] = 'force satisfied' + for id_, expected in [ + ('1/a', True), + ('1/b', False), + ('1/c', True), + ('1/d', False), + ]: + assert ( + satisfied_states_prereq.unset_naturally_satisfied_dependency(id_) + == expected + ) + assert satisfied_states_prereq._satisfied == { + ('1', 'a', 'x'): False, + ('1', 'a', 'y'): False, + ('1', 'a', 'z'): 'force satisfied', + ('1', 'b', 'x'): False, + ('1', 'c', 'x'): False, + ('1', 'd', 'x'): 'force satisfied', + } + + +def test_satisfy_me(): + prereq = Prerequisite(IntegerPoint('2')) + for task_name in ('a', 'b', 'c'): + prereq[('1', task_name, 'x')] = False + assert not prereq.is_satisfied() + assert prereq._cached_satisfied is False + + valid = prereq.satisfy_me( + [Tokens('//1/a:x'), Tokens('//1/d:x'), Tokens('//1/c:y')], + ) + assert {detok(tokens) for tokens in valid} == {'1/a:x'} + assert prereq._satisfied == { + ('1', 'a', 'x'): 'satisfied naturally', + ('1', 'b', 'x'): False, + ('1', 'c', 'x'): False, + } + # should have reset cached satisfaction state + assert prereq._cached_satisfied is None + + valid = prereq.satisfy_me( + [Tokens('//1/a:x'), Tokens('//1/b:x')], + forced=True, + ) + assert {detok(tokens) for tokens in valid} == {'1/a:x', '1/b:x'} + assert prereq._satisfied == { + # 1/a:x unaffected as already satisfied + ('1', 'a', 'x'): 'satisfied naturally', + ('1', 'b', 'x'): 'force satisfied', + ('1', 'c', 'x'): False, + } + + +@pytest.mark.parametrize('forced', [False, True]) +@pytest.mark.parametrize('existing, expected_when_forced', [ + (False, 'force satisfied'), + ('satisfied from database', 'force satisfied'), + ('force satisfied', 'force satisfied'), + ('satisfied naturally', 'satisfied naturally'), +]) +def test_satisfy_me__override( + forced: bool, + existing: SatisfiedState, + expected_when_forced: SatisfiedState, +): + """Test that satisfying a prereq with a different state works as expected + with and without the `forced` arg.""" + prereq = Prerequisite(IntegerPoint('2')) + prereq[('1', 'a', 'x')] = existing + + prereq.satisfy_me([Tokens('//1/a:x')], forced) + assert prereq[('1', 'a', 'x')] == ( + expected_when_forced if forced else 'satisfied naturally' + ) diff --git a/tests/unit/test_rundb.py b/tests/unit/test_rundb.py index 06aba70699f..44db75fb2e5 100644 --- a/tests/unit/test_rundb.py +++ b/tests/unit/test_rundb.py @@ -112,7 +112,9 @@ def test_operational_error(tmp_path, caplog): # stage some stuff dao.add_delete_item(CylcWorkflowDAO.TABLE_TASK_JOBS) dao.add_insert_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub']) - dao.add_update_item(CylcWorkflowDAO.TABLE_TASK_JOBS, ['pub']) + dao.add_update_item( + CylcWorkflowDAO.TABLE_TASK_JOBS, ({'pub': None}, {}) + ) # connect the to DB dao.connect() diff --git a/tests/unit/test_scheduler.py b/tests/unit/test_scheduler.py index 36beb121d3b..ccd5f5dfed5 100644 --- a/tests/unit/test_scheduler.py +++ b/tests/unit/test_scheduler.py @@ -133,7 +133,7 @@ def _select_workflow_host(cached=False): ) caplog.set_level(logging.ERROR, CYLC_LOG) assert not Scheduler.workflow_auto_restart(schd, max_retries=2) - assert log_filter(caplog, contains='elephant') + assert log_filter(contains='elephant') def test_auto_restart_popen_error(monkeypatch, caplog, log_filter): @@ -166,4 +166,4 @@ def _popen(*args, **kwargs): ) caplog.set_level(logging.ERROR, CYLC_LOG) assert not Scheduler.workflow_auto_restart(schd, max_retries=2) - assert log_filter(caplog, contains='mystderr') + assert log_filter(contains='mystderr') diff --git a/tests/unit/test_task_pool.py b/tests/unit/test_task_pool.py index b32781895bc..d1ab1642442 100644 --- a/tests/unit/test_task_pool.py +++ b/tests/unit/test_task_pool.py @@ -21,6 +21,7 @@ import pytest from cylc.flow.flow_mgr import FlowNums +from cylc.flow.prerequisite import SatisfiedState from cylc.flow.task_pool import TaskPool @@ -55,3 +56,29 @@ def test_get_active_flow_nums( ) assert TaskPool._get_active_flow_nums(mock_task_pool) == expected + + +@pytest.mark.parametrize('output_msg, flow_nums, db_flow_nums, expected', [ + ('foo', set(), {1}, False), + ('foo', set(), set(), False), + ('foo', {1, 3}, {1}, 'satisfied from database'), + ('goo', {1, 3}, {1, 2}, 'satisfied from database'), + ('foo', {1, 3}, set(), False), + ('foo', {2}, {1}, False), + ('foo', {2}, {1, 2}, 'satisfied from database'), + ('f', {1}, {1}, False), +]) +def test_check_output( + output_msg: str, + flow_nums: set, + db_flow_nums: set, + expected: SatisfiedState, +): + mock_task_pool = Mock() + mock_task_pool.workflow_db_mgr.pri_dao.select_task_outputs.return_value = { + '{"f": "foo", "g": "goo"}': db_flow_nums, + } + + assert TaskPool.check_task_output( + mock_task_pool, '2000', 'haddock', output_msg, flow_nums + ) == expected diff --git a/tests/unit/test_task_proxy.py b/tests/unit/test_task_proxy.py index 98695ecd13f..4c5513be5cd 100644 --- a/tests/unit/test_task_proxy.py +++ b/tests/unit/test_task_proxy.py @@ -14,13 +14,15 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -import pytest -from pytest import param from typing import Callable, Optional from unittest.mock import Mock +import pytest +from pytest import param + from cylc.flow.cycling import PointBase from cylc.flow.cycling.iso8601 import ISO8601Point +from cylc.flow.flow_mgr import FlowNums from cylc.flow.task_proxy import TaskProxy @@ -101,3 +103,18 @@ def test_status_match(status_str: Optional[str], expected: bool): mock_itask = Mock(state=Mock(status='waiting')) assert TaskProxy.status_match(mock_itask, status_str) is expected + + +@pytest.mark.parametrize('itask_flow_nums, flow_nums, expected', [ + param({1, 2}, {2}, {2}, id="subset"), + param({2}, {1, 2}, {2}, id="superset"), + param({1, 2}, {3, 4}, set(), id="disjoint"), + param({1, 2}, set(), {1, 2}, id="all-matches-num"), + param(set(), {1, 2}, set(), id="num-doesnt-match-none"), + param(set(), set(), set(), id="all-doesnt-match-none"), +]) +def test_match_flows( + itask_flow_nums: FlowNums, flow_nums: FlowNums, expected: FlowNums +): + mock_itask = Mock(flow_nums=itask_flow_nums) + assert TaskProxy.match_flows(mock_itask, flow_nums) == expected diff --git a/tests/unit/test_workflow_db_mgr.py b/tests/unit/test_workflow_db_mgr.py new file mode 100644 index 00000000000..f642eecda61 --- /dev/null +++ b/tests/unit/test_workflow_db_mgr.py @@ -0,0 +1,80 @@ +# THIS FILE IS PART OF THE CYLC WORKFLOW ENGINE. +# Copyright (C) NIWA & British Crown (Met Office) & Contributors. +# +# This program is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# This program is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with this program. If not, see . + +from pathlib import Path +from typing import ( + List, + Set, +) +from unittest.mock import Mock + +import pytest +from pytest import param + +from cylc.flow.cycling.integer import IntegerPoint +from cylc.flow.flow_mgr import FlowNums +from cylc.flow.id import Tokens +from cylc.flow.task_proxy import TaskProxy +from cylc.flow.taskdef import TaskDef +from cylc.flow.util import serialise_set +from cylc.flow.workflow_db_mgr import WorkflowDatabaseManager + + +@pytest.mark.parametrize('flow_nums, expected_removed', [ + param(set(), {1, 2, 5}, id='all'), + param({1}, {1}, id='subset'), + param({1, 2, 5}, {1, 2, 5}, id='complete-set'), + param({1, 3, 5}, {1, 5}, id='intersect'), + param({3, 4}, set(), id='disjoint'), +]) +def test_remove_task_from_flows( + tmp_path: Path, flow_nums: FlowNums, expected_removed: FlowNums +): + db_flows: List[FlowNums] = [ + {1, 2}, + {5}, + set(), # FLOW_NONE + ] + db_mgr = WorkflowDatabaseManager(tmp_path) + schd_tokens = Tokens('~asterix/gaul') + tdef = TaskDef('a', {}, None, None, None) + with db_mgr.get_pri_dao() as dao: + db_mgr.pri_dao = dao + db_mgr.pub_dao = Mock() + for flow in db_flows: + db_mgr.put_insert_task_states( + TaskProxy( + schd_tokens, + tdef, + IntegerPoint('1'), + flow_nums=flow, + ), + ) + db_mgr.process_queued_ops() + + removed_fnums = db_mgr.remove_task_from_flows('1', 'a', flow_nums) + assert removed_fnums == expected_removed + + db_mgr.process_queued_ops() + remaining_fnums: Set[str] = { + fnums_str + for fnums_str, *_ in dao.connect().execute( + 'SELECT flow_nums FROM task_states' + ) + } + assert remaining_fnums == { + serialise_set(flow - expected_removed) for flow in db_flows + } diff --git a/tests/unit/test_workflow_events.py b/tests/unit/test_workflow_events.py index 89449953f20..c9d08791781 100644 --- a/tests/unit/test_workflow_events.py +++ b/tests/unit/test_workflow_events.py @@ -83,13 +83,13 @@ def test_process_mail_footer(caplog, log_filter): assert process_mail_footer( '%(host)s|%(port)s|%(owner)s|%(suite)s|%(workflow)s', template_vars ) == 'myhost|42|me|my_workflow|my_workflow\n' - assert not log_filter(caplog, contains='Ignoring bad mail footer template') + assert not log_filter(contains='Ignoring bad mail footer template') # test invalid variable assert process_mail_footer('%(invalid)s', template_vars) == '' - assert log_filter(caplog, contains='Ignoring bad mail footer template') + assert log_filter(contains='Ignoring bad mail footer template') # test broken template caplog.clear() assert process_mail_footer('%(invalid)s', template_vars) == '' - assert log_filter(caplog, contains='Ignoring bad mail footer template') + assert log_filter(contains='Ignoring bad mail footer template') diff --git a/tests/unit/test_workflow_files.py b/tests/unit/test_workflow_files.py index b2b33e495aa..2e85d4180a0 100644 --- a/tests/unit/test_workflow_files.py +++ b/tests/unit/test_workflow_files.py @@ -196,7 +196,6 @@ def test_infer_latest_run( @pytest.mark.parametrize('warn_arg', [True, False]) def test_infer_latest_run_warns_for_runN( warn_arg: bool, - caplog: pytest.LogCaptureFixture, log_filter: Callable, tmp_run_dir: Callable, ): @@ -206,8 +205,7 @@ def test_infer_latest_run_warns_for_runN( runN_path.symlink_to('run1') infer_latest_run(runN_path, warn_runN=warn_arg) filtered_log = log_filter( - caplog, level=logging.WARNING, - contains="You do not need to include runN in the workflow ID" + logging.WARNING, "You do not need to include runN in the workflow ID" ) assert filtered_log if warn_arg else not filtered_log