Skip to content

Commit

Permalink
refact: merge Context and Session to simplify the workflows api (#15709)
Browse files Browse the repository at this point in the history
* refact: merge Context and Session to simplify the workflows api

* add deprecation warnings

* remove session usage

* remove session usage from docs

* remove usage of self.data
  • Loading branch information
masci authored Aug 29, 2024
1 parent ff4fd29 commit 5102092
Show file tree
Hide file tree
Showing 14 changed files with 225 additions and 301 deletions.
4 changes: 2 additions & 2 deletions docs/docs/examples/workflow/workflows_cookbook.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,7 @@
" @step\n",
" async def setup(self, ctx: Context, ev: StartEvent) -> StopEvent:\n",
" if hasattr(ev, \"data\"):\n",
" self.data = ev.data\n",
" await ctx.set(\"data\", ev.data)\n",
"\n",
" return StopEvent(result=None)\n",
"\n",
Expand All @@ -330,7 +330,7 @@
" if hasattr(ev, \"query\"):\n",
" # do we have any data?\n",
" if hasattr(self, \"data\"):\n",
" data = self.data\n",
" data = await ctx.get(\"data\")\n",
" return StopEvent(result=f\"Got the data {data}\")\n",
" else:\n",
" # there's non data yet\n",
Expand Down
6 changes: 3 additions & 3 deletions docs/docs/module_guides/workflow/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ Using `ctx.collect_events()` we can buffer and wait for ALL expected events to a

## Manually Triggering Events

Normally, events are triggered by returning another event during a step. However, events can also be manually dispatched using the `ctx.session.send_event(event)` method within a workflow.
Normally, events are triggered by returning another event during a step. However, events can also be manually dispatched using the `ctx.send_event(event)` method within a workflow.

Here is a short toy example showing how this would be used:

Expand All @@ -259,8 +259,8 @@ class MyWorkflow(Workflow):
async def dispatch_step(
self, ctx: Context, ev: StartEvent
) -> MyEvent | GatherEvent:
ctx.session.send_event(MyEvent())
ctx.session.send_event(MyEvent())
ctx.send_event(MyEvent())
ctx.send_event(MyEvent())

return GatherEvent()

Expand Down
18 changes: 9 additions & 9 deletions docs/docs/understanding/workflows/concurrent_execution.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ In our examples so far, we've only emitted one event from each step. But there a
class ParallelFlow(Workflow):
@step
async def start(self, ctx: Context, ev: StartEvent) -> StepTwoEvent:
ctx.session.send_event(StepTwoEvent(query="Query 1"))
ctx.session.send_event(StepTwoEvent(query="Query 2"))
ctx.session.send_event(StepTwoEvent(query="Query 3"))
ctx.send_event(StepTwoEvent(query="Query 1"))
ctx.send_event(StepTwoEvent(query="Query 2"))
ctx.send_event(StepTwoEvent(query="Query 3"))

@step(num_workers=4)
async def step_two(self, ctx: Context, ev: StepTwoEvent) -> StopEvent:
Expand All @@ -32,9 +32,9 @@ If you execute the previous example, you'll note that the workflow stops after w
class ConcurrentFlow(Workflow):
@step
async def start(self, ctx: Context, ev: StartEvent) -> StepTwoEvent:
ctx.session.send_event(StepTwoEvent(query="Query 1"))
ctx.session.send_event(StepTwoEvent(query="Query 2"))
ctx.session.send_event(StepTwoEvent(query="Query 3"))
ctx.send_event(StepTwoEvent(query="Query 1"))
ctx.send_event(StepTwoEvent(query="Query 2"))
ctx.send_event(StepTwoEvent(query="Query 3"))

@step(num_workers=4)
async def step_two(self, ctx: Context, ev: StepTwoEvent) -> StepThreeEvent:
Expand Down Expand Up @@ -70,9 +70,9 @@ class ConcurrentFlow(Workflow):
async def start(
self, ctx: Context, ev: StartEvent
) -> StepAEvent | StepBEvent | StepCEvent:
ctx.session.send_event(StepAEvent(query="Query 1"))
ctx.session.send_event(StepBEvent(query="Query 2"))
ctx.session.send_event(StepCEvent(query="Query 3"))
ctx.send_event(StepAEvent(query="Query 1"))
ctx.send_event(StepBEvent(query="Query 2"))
ctx.send_event(StepCEvent(query="Query 3"))

@step
async def step_a(self, ctx: Context, ev: StepAEvent) -> StepACompleteEvent:
Expand Down
6 changes: 3 additions & 3 deletions docs/docs/understanding/workflows/observability.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class ConcurrentFlow(Workflow):
async def start(
self, ctx: Context, ev: StartEvent
) -> StepAEvent | StepBEvent | StepCEvent:
ctx.session.send_event(StepAEvent(query="Query 1"))
ctx.session.send_event(StepBEvent(query="Query 2"))
ctx.session.send_event(StepCEvent(query="Query 3"))
ctx.send_event(StepAEvent(query="Query 1"))
ctx.send_event(StepBEvent(query="Query 2"))
ctx.send_event(StepCEvent(query="Query 3"))

@step
async def step_a(self, ctx: Context, ev: StepAEvent) -> StepACompleteEvent:
Expand Down
8 changes: 4 additions & 4 deletions docs/docs/understanding/workflows/stream.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Streaming events

Workflows can be complex -- they are designed to handle complex, branching, concurrent logic -- which means they can take time to fully execute. To provide your user with a good experience, you may want to provide an indication of progress by streaming events as they occur. Workflows have built-in support for this on the `Context.session` object.
Workflows can be complex -- they are designed to handle complex, branching, concurrent logic -- which means they can take time to fully execute. To provide your user with a good experience, you may want to provide an indication of progress by streaming events as they occur. Workflows have built-in support for this on the `Context` object.

To get this done, let's bring in all the deps we need:

Expand Down Expand Up @@ -36,7 +36,7 @@ And define a workflow class that sends events:
class MyWorkflow(Workflow):
@step
async def step_one(self, ctx: Context, ev: StartEvent) -> FirstEvent:
ctx.session.write_event_to_stream(Event(msg="Step one is happening"))
ctx.write_event_to_stream(Event(msg="Step one is happening"))
return FirstEvent(first_output="First step complete.")

@step
Expand All @@ -47,15 +47,15 @@ class MyWorkflow(Workflow):
)
async for response in generator:
# Allow the workflow to stream this piece of response
ctx.session.write_event_to_stream(Event(msg=response.delta))
ctx.write_event_to_stream(Event(msg=response.delta))
return SecondEvent(
second_output="Second step complete, full response attached",
response=str(response),
)

@step
async def step_three(self, ctx: Context, ev: SecondEvent) -> StopEvent:
ctx.session.write_event_to_stream(Event(msg="Step three is happening"))
ctx.write_event_to_stream(Event(msg="Step three is happening"))
return StopEvent(result="Workflow complete.")
```

Expand Down
110 changes: 72 additions & 38 deletions llama-index-core/llama_index/core/workflow/context.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from collections import defaultdict
import asyncio
from typing import Dict, Any, Optional, List, Type, TYPE_CHECKING
import warnings
from collections import defaultdict
from typing import Dict, Any, Optional, List, Type, TYPE_CHECKING, Set, Tuple

from .decorators import StepConfig
from .events import Event
from .errors import WorkflowRuntimeError

if TYPE_CHECKING: # pragma: no cover
from .session import WorkflowSession
from .workflow import Workflow


class Context:
Expand All @@ -19,27 +22,21 @@ class Context:
Both `set` and `get` operations on global data are governed by a lock, and considered coroutine-safe.
"""

def __init__(
self,
session: Optional["WorkflowSession"] = None,
parent: Optional["Context"] = None,
) -> None:
def __init__(self, workflow: "Workflow") -> None:
self._workflow = workflow
# Broker machinery
self._queues: Dict[str, asyncio.Queue] = {}
self._tasks: Set[asyncio.Task] = set()
self._broker_log: List[Event] = []
self._step_flags: Dict[str, asyncio.Event] = {}
self._accepted_events: List[Tuple[str, str]] = []
self._retval: Any = None
# Streaming machinery
self._streaming_queue: asyncio.Queue = asyncio.Queue()
# Global data storage
if parent is not None:
self._globals = parent._globals
else:
self._globals: Dict[str, Any] = {}
self._lock = asyncio.Lock()
if session is None:
msg = "A workflow session is needed to create a root context"
raise ValueError(msg)
self._session = session

# Local data storage
self._locals: Dict[str, Any] = {}

self._lock = asyncio.Lock()
self._globals: Dict[str, Any] = {}
# Step-specific instance
self._parent: Optional[Context] = parent
self._events_buffer: Dict[Type[Event], List[Event]] = defaultdict(list)

async def set(self, key: str, value: Any, make_private: bool = False) -> None:
Expand All @@ -48,17 +45,14 @@ async def set(self, key: str, value: Any, make_private: bool = False) -> None:
Args:
key: A unique string to identify the value stored.
value: The data to be stored.
make_private: Make the value only accessible from the step that stored it.
Raises:
ValueError: When make_private is True but a key already exists in the global storage.
"""
if make_private:
if key in self._globals:
msg = f"A key named '{key}' already exists in the Context storage."
raise ValueError(msg)
self._locals[key] = value
return
warnings.warn(
"`make_private` is deprecated and will be ignored", DeprecationWarning
)

async with self.lock:
self._globals[key] = value
Expand All @@ -73,13 +67,11 @@ async def get(self, key: str, default: Optional[Any] = None) -> Any:
Raises:
ValueError: When there's not value accessible corresponding to `key`.
"""
if key in self._locals:
return self._locals[key]
elif key in self._globals:
async with self.lock:
async with self.lock:
if key in self._globals:
return self._globals[key]
elif default is not None:
return default
elif default is not None:
return default

msg = f"Key '{key}' not found in Context"
raise ValueError(msg)
Expand All @@ -90,17 +82,21 @@ def data(self):
Use `get` and `set` instead.
"""
msg = "`data` is deprecated, please use the `get` and `set` method to store data into the Context."
warnings.warn(msg, DeprecationWarning)
return self._globals

@property
def lock(self) -> asyncio.Lock:
"""Returns a mutex to lock the Context."""
return self._parent._lock if self._parent else self._lock
return self._lock

@property
def session(self) -> "WorkflowSession":
"""Returns a mutex to lock the Context."""
return self._parent._session if self._parent else self._session
def session(self) -> "Context":
"""This property is provided for backward compatibility."""
msg = "`session` is deprecated, please use the Context instance directly."
warnings.warn(msg, DeprecationWarning)
return self

def collect_events(
self, ev: Event, expected: List[Type[Event]]
Expand All @@ -121,3 +117,41 @@ def collect_events(
self._events_buffer[type(ev)].append(ev)

return None

def send_event(self, message: Event, step: Optional[str] = None) -> None:
"""Sends an event to a specific step in the workflow.
If step is None, the event is sent to all the receivers and we let
them discard events they don't want.
"""
if step is None:
for queue in self._queues.values():
queue.put_nowait(message)
else:
if step not in self._workflow._get_steps():
raise WorkflowRuntimeError(f"Step {step} does not exist")

step_func = self._workflow._get_steps()[step]
step_config: Optional[StepConfig] = getattr(
step_func, "__step_config", None
)

if step_config and type(message) in step_config.accepted_events:
self._queues[step].put_nowait(message)
else:
raise WorkflowRuntimeError(
f"Step {step} does not accept event of type {type(message)}"
)

self._broker_log.append(message)

def write_event_to_stream(self, ev: Optional[Event]) -> None:
self._streaming_queue.put_nowait(ev)

def get_result(self) -> Any:
"""Returns the result of the workflow."""
return self._retval

@property
def streaming_queue(self) -> asyncio.Queue:
return self._streaming_queue
77 changes: 0 additions & 77 deletions llama-index-core/llama_index/core/workflow/session.py

This file was deleted.

3 changes: 1 addition & 2 deletions llama-index-core/llama_index/core/workflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from llama_index.core.bridge.pydantic import BaseModel, ConfigDict

from .context import Context
from .events import Event, EventType
from .errors import WorkflowValidationError

Expand Down Expand Up @@ -57,7 +56,7 @@ def inspect_signature(fn: Callable) -> StepSignatureSpec:
continue

# Get name and type of the Context param
if t.annotation == Context:
if hasattr(t.annotation, "__name__") and t.annotation.__name__ == "Context":
context_parameter = name
continue

Expand Down
Loading

0 comments on commit 5102092

Please sign in to comment.