Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] Make _self_call_placeholder_ thread-safe #1034

Merged
merged 1 commit into from
Sep 24, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions guidance/_guidance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import inspect
import threading

from ._grammar import DeferredReference, RawFunction, Terminal, string
from ._utils import strip_multiline_string_indents
Expand Down Expand Up @@ -40,6 +41,10 @@ def _decorator(f, *, stateless, cache, dedent, model):
if cache:
f = functools.cache(f)

# Use thread local to store the reference to the grammar node for recursive calls
# Otherwise, shared state between threads may otherwise trick us into thinking we are in a recursive call
thread_local = threading.local()

@functools.wraps(f)
def wrapped(*args, **kwargs):

Expand All @@ -49,7 +54,7 @@ def wrapped(*args, **kwargs):
):

# if we have a (deferred) reference set, then we must be in a recursive definition and so we return the reference
reference = getattr(f, "_self_call_reference_", None)
reference = getattr(thread_local, "_self_call_reference_", None)
if reference is not None:
return reference

Expand All @@ -59,7 +64,7 @@ def wrapped(*args, **kwargs):
# set a DeferredReference for recursive calls (only if we don't have arguments that might make caching a bad idea)
no_args = len(args) + len(kwargs) == 0
if no_args:
f._self_call_reference_ = DeferredReference()
thread_local._self_call_reference_ = DeferredReference()

try:
# call the function to get the grammar node
Expand All @@ -71,10 +76,10 @@ def wrapped(*args, **kwargs):
node.name = f.__name__
# set the reference value with our generated node
if no_args:
f._self_call_reference_.value = node
thread_local._self_call_reference_.value = node
finally:
if no_args:
del f._self_call_reference_
del thread_local._self_call_reference_

return node

Expand Down