Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StateManager: Pass intermediates using normalized values, rename coerce to normalize #298

Merged
merged 13 commits into from
Oct 21, 2022
181 changes: 56 additions & 125 deletions poetry.lock

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,17 @@ rds-graphile-worker-client = {version = "^0.1.1", optional = true}
semver = "3.0.0.dev3"
simplejson = "^3.17.5"
jsonschema = "^4.1.2"
artifax = {version = "0.4", optional = true}
# artifax = {git = "https://github.com/curvewise-forks/artifax.git", rev = "2e496525a04185525cfedbb3a9375101ec7faec1", optional = true}
# artifax = {version = "0.4", optional = true}
artifax = {git = "https://github.com/curvewise-forks/artifax.git", rev = "3cd288a1698a798c0e488eba31e8e9ce3f075e1c", optional = true}
# Temporarily declare artifax dependencies until we publish an artifax fork.
# pathos = {version = "*", optional = true}
# exos = {version = "*", optional = true}
pathos = {version = "*", optional = true}
exos = {version = "*", optional = true}

[tool.poetry.extras]
aws_lambda_build = ["executor"]
client = ["boto3"]
compute_graph = ["artifax"]
# compute_graph = ["artifax", "pathos", "exos"]
# compute_graph = ["artifax"]
compute_graph = ["artifax", "pathos", "exos"]
lambda_common = ["harrison"]
cli = ["click"]
rds_graphile_worker = ["rds-graphile-worker-client"]
Expand Down
4 changes: 2 additions & 2 deletions werkit/compute/graph/_custom_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def deserialize(cls, json_data: JSONType) -> CanonicalType:

@classmethod
@abstractmethod
def coerce(cls, value: t.Any) -> CanonicalType:
def normalize(cls, value: t.Any) -> CanonicalType:
"""
Coerce the given value to the canonical native type. Raise an exception
if it can't be coerced.
if it can't be normalized.
"""

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions werkit/compute/graph/_dependency_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def deserialize(self, value: JSONType) -> t.Any:
value_type.validate(value)
return value_type.deserialize(value)

def coerce(self, name: str, value: t.Any) -> t.Any:
def normalize(self, name: str, value: t.Any) -> t.Any:
if self.value_type_is_built_in:
return coerce_value_to_builtin_type(
name=name,
Expand All @@ -60,7 +60,7 @@ def coerce(self, name: str, value: t.Any) -> t.Any:
)
else:
# TODO: Perhaps catch and re-throw to improve the error message.
return t.cast(t.Type[CustomType], self.value_type).coerce(value)
return t.cast(t.Type[CustomType], self.value_type).normalize(value)

def serialize_value(self, value: t.Any) -> JSONType:
if self.value_type_is_built_in:
Expand Down
28 changes: 20 additions & 8 deletions werkit/compute/graph/_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,23 @@ def deserialize(self, **kwargs: t.Dict) -> None:
}
self.store.update(deserialized)

def coerce(self, **kwargs: t.Dict) -> t.Dict:
def normalize(self, **kwargs: t.Dict) -> t.Dict:
return {
name: self.dependency_graph.all_nodes[name].coerce(name=name, value=value)
name: self.dependency_graph.all_nodes[name].normalize(
name=name, value=value
)
for name, value in kwargs.items()
}

def set(self, **kwargs: t.Dict) -> None:
self._assert_known_keys(kwargs.keys())
coerced = self.coerce(**kwargs)
self.store.update(coerced)
normalized = self.normalize(**kwargs)
self.store.update(normalized)

def evaluate(
self, targets: t.List[str] = None, handle_exceptions: bool = False
) -> None:
import functools
from artifax import Artifax

if targets is not None:
Expand All @@ -47,18 +50,27 @@ def evaluate(
else:
self._assert_known_keys(targets)

def wrap_node(name, node):
wrapped = node.bind(self.instance)

def wrapper(*args):
value = wrapped(*args)
return node.normalize(name, value)

functools.update_wrapper(wrapper, wrapped)
return wrapper

afx = Artifax(
{
name: self.dependency_graph.compute_nodes[name].bind(self.instance)
for name in self.dependency_graph.compute_nodes.keys()
name: wrap_node(name, node)
for name, node in self.dependency_graph.compute_nodes.items()
}
)
if self.store:
afx.set(**self.store)
afx.build(targets=targets)
# TODO: `afx.build()` should always return an object.
paulmelnikow marked this conversation as resolved.
Show resolved Hide resolved
coerced = self.coerce(**afx._result)
self.store.update(coerced)
self.store.update(**afx._result)

def serialize(self, targets: t.List[str] = None) -> t.Dict:
if targets is not None:
Expand Down
9 changes: 9 additions & 0 deletions werkit/compute/graph/test_state_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,9 +149,18 @@ def test_state_manager_with_custom_type() -> None:
assert thing.description == "Example description"
assert thing.count == 25

# Due to rounding in normalize(), other_thing should be rounded.
assert state_manager.store["other_thing"] == (1.52, 2.52, 3.52)


def test_state_manager_propagates_normalized_value() -> None:
state_manager = MyComputeProcessWithCustomType().state_manager
state_manager.set(a=1, b=2)
state_manager.evaluate()

assert state_manager.store["further_derived_thing"] == "(1.52, 2.52, 3.52)"


def test_state_manager_deserializes_custom_type() -> None:
state_manager = MyComputeProcessWithCustomType().state_manager

Expand Down
16 changes: 12 additions & 4 deletions werkit/compute/graph/testing_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,11 @@ def name(cls) -> str:
return "MyModel"

@classmethod
def coerce(cls, value: t.Any) -> MyModel:
def normalize(cls, value: t.Any) -> MyModel:
if not isinstance(value, MyModel):
raise ValueError(f"Can't coerce {type(value).__name__} to {cls.__name__}")
raise ValueError(
f"Can't normalize {type(value).__name__} to {cls.__name__}"
)
return value

@classmethod
Expand All @@ -93,9 +95,11 @@ class Vector3(CustomType[tuple]):
DECIMALS = 2

@classmethod
def coerce(cls, value: t.Any) -> tuple:
def normalize(cls, value: t.Any) -> tuple:
if not isinstance(value, tuple):
raise ValueError(f"Can't coerce {type(value).__name__} to {cls.__name__}")
raise ValueError(
f"Can't normalize {type(value).__name__} to {cls.__name__}"
)
elif not len(value) == 3:
raise ValueError("Excepted tuple to have length 3")
return tuple(round(coord, cls.DECIMALS) for coord in value)
Expand All @@ -120,3 +124,7 @@ def thing(self) -> MyModel:
@output(value_type=Vector3)
def other_thing(self) -> tuple:
return (1.5151, 2.5151, 3.5151)

@output(value_type=str)
def further_derived_thing(self, other_thing) -> str:
return str(other_thing)