Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

framework states rework #359

Merged
merged 5 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dff/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@
from dff.pipeline import Pipeline
from dff.script import Context, Script

Script.model_rebuild()
import dff.__rebuild_pydantic_models__
9 changes: 9 additions & 0 deletions dff/__rebuild_pydantic_models__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# flake8: noqa: F401

from dff.pipeline import Pipeline
from dff.pipeline.types import ExtraHandlerRuntimeInfo
from dff.script import Context, Script

Script.model_rebuild()
Context.model_rebuild()
ExtraHandlerRuntimeInfo.model_rebuild()
3 changes: 0 additions & 3 deletions dff/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
ComponentExecutionState,
GlobalExtraHandlerType,
ExtraHandlerType,
PIPELINE_STATE_KEY,
StartConditionCheckerFunction,
StartConditionCheckerAggregationFunction,
ExtraHandlerConditionFunction,
Expand All @@ -32,5 +31,3 @@
from .service.extra import BeforeHandler, AfterHandler
from .service.group import ServiceGroup
from .service.service import Service, to_service

ExtraHandlerRuntimeInfo.model_rebuild()
3 changes: 1 addition & 2 deletions dff/pipeline/conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
from dff.script import Context

from .types import (
PIPELINE_STATE_KEY,
StartConditionCheckerFunction,
ComponentExecutionState,
StartConditionCheckerAggregationFunction,
Expand Down Expand Up @@ -41,7 +40,7 @@ def service_successful_condition(path: Optional[str] = None) -> StartConditionCh
"""

def check_service_state(ctx: Context, _: Pipeline):
state = ctx.framework_states[PIPELINE_STATE_KEY].get(path, ComponentExecutionState.NOT_RUN)
state = ctx.framework_data.service_states.get(path, ComponentExecutionState.NOT_RUN)
return ComponentExecutionState[state] == ComponentExecutionState.FINISHED

return check_service_state
Expand Down
90 changes: 42 additions & 48 deletions dff/pipeline/pipeline/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,6 @@ def __init__(
self._clean_turn_cache = True

async def __call__(self, pipeline: Pipeline, ctx: Context):
# context init
self._context_init(ctx)
await self._run_handlers(ctx, pipeline, ActorStage.CONTEXT_INIT)

# get previous node
Expand All @@ -121,7 +119,7 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):
self._get_next_node(ctx)
await self._run_handlers(ctx, pipeline, ActorStage.GET_NEXT_NODE)

ctx.add_label(ctx.framework_states["actor"]["next_label"][:2])
ctx.add_label(ctx.framework_data.actor_data["next_label"][:2])

# rewrite next node
self._rewrite_next_node(ctx)
Expand All @@ -132,89 +130,85 @@ async def __call__(self, pipeline: Pipeline, ctx: Context):
await self._run_handlers(ctx, pipeline, ActorStage.RUN_PRE_RESPONSE_PROCESSING)

# create response
ctx.framework_states["actor"]["response"] = await self.run_response(
ctx.framework_states["actor"]["pre_response_processed_node"].response, ctx, pipeline
ctx.framework_data.actor_data["response"] = await self.run_response(
ctx.framework_data.actor_data["pre_response_processed_node"].response, ctx, pipeline
)
await self._run_handlers(ctx, pipeline, ActorStage.CREATE_RESPONSE)
ctx.add_response(ctx.framework_states["actor"]["response"])
ctx.add_response(ctx.framework_data.actor_data["response"])

await self._run_handlers(ctx, pipeline, ActorStage.FINISH_TURN)
if self._clean_turn_cache:
cache_clear()

del ctx.framework_states["actor"]

@staticmethod
def _context_init(ctx: Optional[Union[Context, dict, str]] = None):
ctx.framework_states["actor"] = {}
ctx.framework_data.actor_data.clear()

def _get_previous_node(self, ctx: Context):
ctx.framework_states["actor"]["previous_label"] = (
ctx.framework_data.actor_data["previous_label"] = (
normalize_label(ctx.last_label) if ctx.last_label else self.start_label
)
ctx.framework_states["actor"]["previous_node"] = self.script.get(
ctx.framework_states["actor"]["previous_label"][0], {}
).get(ctx.framework_states["actor"]["previous_label"][1], Node())
ctx.framework_data.actor_data["previous_node"] = self.script.get(
ctx.framework_data.actor_data["previous_label"][0], {}
).get(ctx.framework_data.actor_data["previous_label"][1], Node())

async def _get_true_labels(self, ctx: Context, pipeline: Pipeline):
# GLOBAL
ctx.framework_states["actor"]["global_transitions"] = (
ctx.framework_data.actor_data["global_transitions"] = (
self.script.get(GLOBAL, {}).get(GLOBAL, Node()).transitions
)
ctx.framework_states["actor"]["global_true_label"] = await self._get_true_label(
ctx.framework_states["actor"]["global_transitions"], ctx, pipeline, GLOBAL, "global"
ctx.framework_data.actor_data["global_true_label"] = await self._get_true_label(
ctx.framework_data.actor_data["global_transitions"], ctx, pipeline, GLOBAL, "global"
)

# LOCAL
ctx.framework_states["actor"]["local_transitions"] = (
self.script.get(ctx.framework_states["actor"]["previous_label"][0], {}).get(LOCAL, Node()).transitions
ctx.framework_data.actor_data["local_transitions"] = (
self.script.get(ctx.framework_data.actor_data["previous_label"][0], {}).get(LOCAL, Node()).transitions
)
ctx.framework_states["actor"]["local_true_label"] = await self._get_true_label(
ctx.framework_states["actor"]["local_transitions"],
ctx.framework_data.actor_data["local_true_label"] = await self._get_true_label(
ctx.framework_data.actor_data["local_transitions"],
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
ctx.framework_data.actor_data["previous_label"][0],
"local",
)

# NODE
ctx.framework_states["actor"]["node_transitions"] = ctx.framework_states["actor"][
ctx.framework_data.actor_data["node_transitions"] = ctx.framework_data.actor_data[
"pre_transitions_processed_node"
].transitions
ctx.framework_states["actor"]["node_true_label"] = await self._get_true_label(
ctx.framework_states["actor"]["node_transitions"],
ctx.framework_data.actor_data["node_true_label"] = await self._get_true_label(
ctx.framework_data.actor_data["node_transitions"],
ctx,
pipeline,
ctx.framework_states["actor"]["previous_label"][0],
ctx.framework_data.actor_data["previous_label"][0],
"node",
)

def _get_next_node(self, ctx: Context):
# choose next label
ctx.framework_states["actor"]["next_label"] = self._choose_label(
ctx.framework_states["actor"]["node_true_label"], ctx.framework_states["actor"]["local_true_label"]
ctx.framework_data.actor_data["next_label"] = self._choose_label(
ctx.framework_data.actor_data["node_true_label"], ctx.framework_data.actor_data["local_true_label"]
)
ctx.framework_states["actor"]["next_label"] = self._choose_label(
ctx.framework_states["actor"]["next_label"], ctx.framework_states["actor"]["global_true_label"]
ctx.framework_data.actor_data["next_label"] = self._choose_label(
ctx.framework_data.actor_data["next_label"], ctx.framework_data.actor_data["global_true_label"]
)
# get next node
ctx.framework_states["actor"]["next_node"] = self.script.get(
ctx.framework_states["actor"]["next_label"][0], {}
).get(ctx.framework_states["actor"]["next_label"][1])
ctx.framework_data.actor_data["next_node"] = self.script.get(
ctx.framework_data.actor_data["next_label"][0], {}
).get(ctx.framework_data.actor_data["next_label"][1])

def _rewrite_previous_node(self, ctx: Context):
node = ctx.framework_states["actor"]["previous_node"]
flow_label = ctx.framework_states["actor"]["previous_label"][0]
ctx.framework_states["actor"]["previous_node"] = self._overwrite_node(
node = ctx.framework_data.actor_data["previous_node"]
flow_label = ctx.framework_data.actor_data["previous_label"][0]
ctx.framework_data.actor_data["previous_node"] = self._overwrite_node(
node,
flow_label,
only_current_node_transitions=True,
)

def _rewrite_next_node(self, ctx: Context):
node = ctx.framework_states["actor"]["next_node"]
flow_label = ctx.framework_states["actor"]["next_label"][0]
ctx.framework_states["actor"]["next_node"] = self._overwrite_node(node, flow_label)
node = ctx.framework_data.actor_data["next_node"]
flow_label = ctx.framework_data.actor_data["next_label"][0]
ctx.framework_data.actor_data["next_node"] = self._overwrite_node(node, flow_label)

def _overwrite_node(
self,
Expand Down Expand Up @@ -290,18 +284,18 @@ async def _run_pre_transitions_processing(self, ctx: Context, pipeline: Pipeline
The execution order depends on the value of the :py:class:`.Pipeline`'s
`parallelize_processing` flag.
"""
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["previous_node"])
pre_transitions_processing = ctx.framework_states["actor"]["previous_node"].pre_transitions_processing
ctx.framework_data.actor_data["processed_node"] = copy.deepcopy(ctx.framework_data.actor_data["previous_node"])
pre_transitions_processing = ctx.framework_data.actor_data["previous_node"].pre_transitions_processing

if pipeline.parallelize_processing:
await self._run_processing_parallel(pre_transitions_processing, ctx, pipeline)
else:
await self._run_processing_sequential(pre_transitions_processing, ctx, pipeline)

ctx.framework_states["actor"]["pre_transitions_processed_node"] = ctx.framework_states["actor"][
ctx.framework_data.actor_data["pre_transitions_processed_node"] = ctx.framework_data.actor_data[
"processed_node"
]
del ctx.framework_states["actor"]["processed_node"]
del ctx.framework_data.actor_data["processed_node"]

async def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline) -> None:
"""
Expand All @@ -312,16 +306,16 @@ async def _run_pre_response_processing(self, ctx: Context, pipeline: Pipeline) -
The execution order depends on the value of the :py:class:`.Pipeline`'s
`parallelize_processing` flag.
"""
ctx.framework_states["actor"]["processed_node"] = copy.deepcopy(ctx.framework_states["actor"]["next_node"])
pre_response_processing = ctx.framework_states["actor"]["next_node"].pre_response_processing
ctx.framework_data.actor_data["processed_node"] = copy.deepcopy(ctx.framework_data.actor_data["next_node"])
pre_response_processing = ctx.framework_data.actor_data["next_node"].pre_response_processing

if pipeline.parallelize_processing:
await self._run_processing_parallel(pre_response_processing, ctx, pipeline)
else:
await self._run_processing_sequential(pre_response_processing, ctx, pipeline)

ctx.framework_states["actor"]["pre_response_processed_node"] = ctx.framework_states["actor"]["processed_node"]
del ctx.framework_states["actor"]["processed_node"]
ctx.framework_data.actor_data["pre_response_processed_node"] = ctx.framework_data.actor_data["processed_node"]
del ctx.framework_data.actor_data["processed_node"]

async def _get_true_label(
self,
Expand Down
16 changes: 5 additions & 11 deletions dff/pipeline/pipeline/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,13 @@
import logging
import abc
import asyncio
import copy
from typing import Optional, Awaitable, TYPE_CHECKING

from dff.script import Context

from ..service.extra import BeforeHandler, AfterHandler
from ..conditions import always_start_condition
from ..types import (
PIPELINE_STATE_KEY,
StartConditionCheckerFunction,
ComponentExecutionState,
ServiceRuntimeInfo,
Expand Down Expand Up @@ -109,28 +107,24 @@ def __init__(

def _set_state(self, ctx: Context, value: ComponentExecutionState):
"""
Method for component runtime state setting, state is preserved in `ctx.framework_states` dict,
in subdict, dedicated to this library.
Method for component runtime state setting, state is preserved in `ctx.framework_data`.

:param ctx: :py:class:`~.Context` to keep state in.
:param value: State to set.
:return: `None`
"""
if PIPELINE_STATE_KEY not in ctx.framework_states:
ctx.framework_states[PIPELINE_STATE_KEY] = {}
ctx.framework_states[PIPELINE_STATE_KEY][self.path] = value
ctx.framework_data.service_states[self.path] = value

def get_state(self, ctx: Context, default: Optional[ComponentExecutionState] = None) -> ComponentExecutionState:
"""
Method for component runtime state getting, state is preserved in `ctx.framework_states` dict,
in subdict, dedicated to this library.
Method for component runtime state getting, state is preserved in `ctx.framework_data`.

:param ctx: :py:class:`~.Context` to get state from.
:param default: Default to return if no record found
(usually it's :py:attr:`~.pipeline.types.ComponentExecutionState.NOT_RUN`).
:return: :py:class:`~pipeline.types.ComponentExecutionState` of this service or default if not found.
"""
return ctx.framework_states[PIPELINE_STATE_KEY].get(self.path, default if default is not None else None)
return ctx.framework_data.service_states.get(self.path, default if default is not None else None)

@property
def asynchronous(self) -> bool:
Expand Down Expand Up @@ -218,7 +212,7 @@ def _get_runtime_info(self, ctx: Context) -> ServiceRuntimeInfo:
path=self.path if self.path is not None else "[None]",
timeout=self.timeout,
asynchronous=self.asynchronous,
execution_state=copy.deepcopy(ctx.framework_states[PIPELINE_STATE_KEY]),
execution_state=ctx.framework_data.service_states.copy(),
)

@property
Expand Down
4 changes: 1 addition & 3 deletions dff/pipeline/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
ExtraHandlerFunction,
ExtraHandlerBuilder,
)
from ..types import PIPELINE_STATE_KEY
from .utils import finalize_service_group, pretty_format_component_info_dict
from dff.pipeline.pipeline.actor import Actor

Expand Down Expand Up @@ -320,14 +319,13 @@ async def _run_pipeline(
if update_ctx_misc is not None:
ctx.misc.update(update_ctx_misc)

ctx.framework_states[PIPELINE_STATE_KEY] = {}
ctx.add_request(request)
result = await self._services_pipeline(ctx, self)

if asyncio.iscoroutine(result):
await result

del ctx.framework_states[PIPELINE_STATE_KEY]
ctx.framework_data.service_states.clear()

if isinstance(self.context_storage, DBContextStorage):
await self.context_storage.set_item_async(ctx_id, ctx)
Expand Down
Loading
Loading