diff --git a/src/stepfunctions/steps/states.py b/src/stepfunctions/steps/states.py index 7a1f093..9c86457 100644 --- a/src/stepfunctions/steps/states.py +++ b/src/stepfunctions/steps/states.py @@ -218,9 +218,18 @@ def next(self, next_step): Returns: State or Chain: Next state or chain that will be transitioned to. """ - if self.type in ('Choice', 'Succeed', 'Fail'): + if self.type in ('Succeed', 'Fail'): raise ValueError('Unexpected State instance `{step}`, State type `{state_type}` does not support method `next`.'.format(step=next_step, state_type=self.type)) + # By design, Choice states do not have the Next field. When used in a chain, the subsequent step becomes the + # default choice that executes if none of the specified rules match. + # See language spec for more info: https://states-language.net/spec.html#choice-state + if self.type is 'Choice': + if self.default is not None: + logger.warning(f'Chaining Choice state: Overwriting {self.state_id}\'s current default_choice ({self.default.state_id}) with {next_step.state_id}') + self.default_choice(next_step) + return self.default + self.next_step = next_step return self.next_step @@ -402,7 +411,7 @@ def allowed_fields(self): class Choice(State): """ - Choice state adds branching logic to a state machine. The state holds a list of *rule* and *next_step* pairs. The interpreter attempts pattern-matches against the rules in list order and transitions to the state or chain specified in the *next_step* field on the first *rule* where there is an exact match between the input value and a member of the comparison-operator array. + Choice state adds branching logic to a state machine. The state holds a list of *rule* and *next_step* pairs. The interpreter attempts pattern-matches against the rules in list order and transitions to the state or chain specified in the *next_step* field on the first *rule* where there is an exact match between the input value and a member of the comparison-operator array. When used in a chain, the subsequent step becomes the default choice that executes if none of the specified rules match. """ def __init__(self, state_id, **kwargs): diff --git a/tests/unit/test_steps.py b/tests/unit/test_steps.py index dfc6262..5c86279 100644 --- a/tests/unit/test_steps.py +++ b/tests/unit/test_steps.py @@ -12,6 +12,7 @@ # permissions and limitations under the License. from __future__ import absolute_import +import logging import pytest from stepfunctions.exceptions import DuplicateStatesInChain @@ -346,12 +347,6 @@ def test_append_states_after_terminal_state_will_fail(): chain.append(Succeed('Succeed')) chain.append(Pass('Pass2')) - with pytest.raises(ValueError): - chain = Chain() - chain.append(Pass('Pass')) - chain.append(Choice('Choice')) - chain.append(Pass('Pass2')) - def test_chaining_steps(): s1 = Pass('Step - One') @@ -391,6 +386,36 @@ def test_chaining_steps(): assert s2.next_step == s3 +def test_chaining_choice_sets_default_field(): + s1_pass = Pass('Step - One') + s2_choice = Choice('Step - Two') + s3_pass = Pass('Step - Three') + + chain1 = Chain([s1_pass, s2_choice, s3_pass]) + assert chain1.steps == [s1_pass, s2_choice, s3_pass] + assert s1_pass.next_step == s2_choice + assert s2_choice.default == s3_pass + assert s2_choice.next_step is None # Choice steps do not have next_step + assert s3_pass.next_step is None + + +def test_chaining_choice_with_existing_default_overrides_value(caplog): + s1_pass = Pass('Step - One') + s2_choice = Choice('Step - Two') + s3_pass = Pass('Step - Three') + + s2_choice.default_choice(s3_pass) + + # Chain s2_choice when default_choice is already set will trigger Warning message + with caplog.at_level(logging.WARNING): + Chain([s2_choice, s1_pass]) + expected_warning = f'Chaining Choice state: Overwriting {s2_choice.state_id}\'s current default_choice ({s3_pass.state_id}) with {s1_pass.state_id}' + assert expected_warning in caplog.text + assert 'WARNING' in caplog.text + assert s2_choice.default == s1_pass + assert s2_choice.next_step is None # Choice steps do not have next_step + + def test_catch_fail_for_unsupported_state(): s1 = Pass('Step - One')