Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Choice states chainable #132

Merged
merged 9 commits into from
May 25, 2021
18 changes: 16 additions & 2 deletions src/stepfunctions/steps/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,9 +218,23 @@ def next(self, next_step):
Returns:
State or Chain: Next state or chain that will be transitioned to.
"""
if self.type in ('Choice', 'Succeed', 'Fail'):
if self.type in ('Succeed', 'Fail'):
raise ValueError('Unexpected State instance `{step}`, State type `{state_type}` does not support method `next`.'.format(step=next_step, state_type=self.type))

# By design, Choice states do not have the Next field. When used in a chain, the subsequent step becomes the
# default choice that executes if none of the specified rules match.
# See language spec for more info: https://states-language.net/spec.html#choice-state
if self.type is 'Choice':
if self.default is not None:
logger.warning(
"Chaining Choice state: Overwriting %s's current default_choice (%s) with %s",
Copy link
Contributor

@yuan-bwn yuan-bwn May 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: not a blocker but can try using f-string next time: https://www.python.org/dev/peps/pep-0498/
Got this recommendation from Shiv and think this could also help here. Same for the unit test string

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! Using fstring will make it more readable and faster! Making changes in the next commit!

self.state_id,
self.default.state_id,
next_step.state_id
)
self.default_choice(next_step)
return self.default

self.next_step = next_step
return self.next_step

Expand Down Expand Up @@ -402,7 +416,7 @@ def allowed_fields(self):
class Choice(State):

"""
Choice state adds branching logic to a state machine. The state holds a list of *rule* and *next_step* pairs. The interpreter attempts pattern-matches against the rules in list order and transitions to the state or chain specified in the *next_step* field on the first *rule* where there is an exact match between the input value and a member of the comparison-operator array.
Choice state adds branching logic to a state machine. The state holds a list of *rule* and *next_step* pairs. The interpreter attempts pattern-matches against the rules in list order and transitions to the state or chain specified in the *next_step* field on the first *rule* where there is an exact match between the input value and a member of the comparison-operator array. When used in a chain, the subsequent step becomes the default choice that executes if none of the specified rules match.
"""

def __init__(self, state_id, **kwargs):
Expand Down
41 changes: 35 additions & 6 deletions tests/unit/test_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# permissions and limitations under the License.
from __future__ import absolute_import

import logging
import pytest

from stepfunctions.exceptions import DuplicateStatesInChain
Expand Down Expand Up @@ -328,12 +329,6 @@ def test_append_states_after_terminal_state_will_fail():
chain.append(Succeed('Succeed'))
chain.append(Pass('Pass2'))

with pytest.raises(ValueError):
chain = Chain()
chain.append(Pass('Pass'))
chain.append(Choice('Choice'))
chain.append(Pass('Pass2'))


def test_chaining_steps():
s1 = Pass('Step - One')
Expand Down Expand Up @@ -372,6 +367,40 @@ def test_chaining_steps():
assert s1.next_step == s2
assert s2.next_step == s3


def test_chaining_choice_sets_default_field():
s1_pass = Pass('Step - One')
s2_choice = Choice('Step - Two')
s3_pass = Pass('Step - Three')

chain1 = Chain([s1_pass, s2_choice, s3_pass])
assert chain1.steps == [s1_pass, s2_choice, s3_pass]
assert s1_pass.next_step == s2_choice
assert s2_choice.default == s3_pass
assert s2_choice.next_step is None # Choice steps do not have next_step
assert s3_pass.next_step is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When writing unit tests for these modules we should consider making assertions on the generated JSON too. One of the main responsibilities is generating ASL so we should make sure it's correct.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed! Will include assertions in the generated JSON in the next commit.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As discussed offline, there is already a unit test (test_choice_state_creation) that ensures the Choice state creation generates the expected ASL



def test_chaining_choice_with_existing_default_overrides_value(caplog):
s1_pass = Pass('Step - One')
s2_choice = Choice('Step - Two')
s3_pass = Pass('Step - Three')

s2_choice.default_choice(s3_pass)

# Chain s2_choice when default_choice is already set will trigger Warning
with caplog.at_level(logging.WARNING):
wong-a marked this conversation as resolved.
Show resolved Hide resolved
Chain([s2_choice, s1_pass])
expected_warning = (
"Chaining Choice state: Overwriting %s's current default_choice (%s) with %s" %
(s2_choice.state_id, s3_pass.state_id, s1_pass.state_id)
)
assert expected_warning in caplog.text
assert 'WARNING' in caplog.text
assert s2_choice.default == s1_pass
assert s2_choice.next_step is None # Choice steps do not have next_step


def test_catch_fail_for_unsupported_state():
s1 = Pass('Step - One')

Expand Down