Skip to content

Commit

Permalink
Merge pull request #3586 from Textualize/wokers-inside-workers
Browse files Browse the repository at this point in the history
Workers inside workers
  • Loading branch information
rodrigogiraoserrao authored Oct 30, 2023
2 parents 001881c + 841f726 commit 3f33cd1
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 4 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/).
- DataTable now has a max-height of 100vh rather than 100%, which doesn't work with auto
- Breaking change: empty rules now result in an error https://github.com/Textualize/textual/pull/3566
- Improved startup time by caching CSS parsing https://github.com/Textualize/textual/pull/3575
- Workers are now created/run in a thread-safe way https://github.com/Textualize/textual/pull/3586

### Added

Expand Down
13 changes: 10 additions & 3 deletions src/textual/dom.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
"""
A DOMNode is a base class for any object within the Textual Document Object Model,
which includes all Widgets, Screens, and Apps.
"""
Expand All @@ -8,7 +7,8 @@
from __future__ import annotations

import re
from functools import lru_cache
import threading
from functools import lru_cache, partial
from inspect import getfile
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -267,7 +267,14 @@ def run_worker(
Returns:
New Worker instance.
"""
worker: Worker[ResultType] = self.workers._new_worker(

# If we're running a worker from inside a secondary thread,
# do so in a thread-safe way.
if self.app._thread_id != threading.get_ident():
creator = partial(self.app.call_from_thread, self.workers._new_worker)
else:
creator = self.workers._new_worker
worker: Worker[ResultType] = creator(
work,
self,
name=name,
Expand Down
98 changes: 97 additions & 1 deletion tests/workers/test_work_decorator.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
from time import sleep
from typing import Callable
from typing import Callable, List, Tuple

import pytest

Expand Down Expand Up @@ -88,3 +88,99 @@ class _(App[None]):
@work(thread=False)
def foo(self) -> None:
pass


class NestedWorkersApp(App[None]):
def __init__(self, call_stack: List[str]):
self.call_stack = call_stack
super().__init__()

def call_from_stack(self):
if self.call_stack:
call_now = self.call_stack.pop()
getattr(self, call_now)()

@work(thread=False)
async def async_no_thread(self):
self.call_from_stack()

@work(thread=True)
async def async_thread(self):
self.call_from_stack()

@work(thread=True)
def thread(self):
self.call_from_stack()


@pytest.mark.parametrize(
"call_stack",
[ # from itertools import product; list(product("async_no_thread async_thread thread".split(), repeat=3))
("async_no_thread", "async_no_thread", "async_no_thread"),
("async_no_thread", "async_no_thread", "async_thread"),
("async_no_thread", "async_no_thread", "thread"),
("async_no_thread", "async_thread", "async_no_thread"),
("async_no_thread", "async_thread", "async_thread"),
("async_no_thread", "async_thread", "thread"),
("async_no_thread", "thread", "async_no_thread"),
("async_no_thread", "thread", "async_thread"),
("async_no_thread", "thread", "thread"),
("async_thread", "async_no_thread", "async_no_thread"),
("async_thread", "async_no_thread", "async_thread"),
("async_thread", "async_no_thread", "thread"),
("async_thread", "async_thread", "async_no_thread"),
("async_thread", "async_thread", "async_thread"),
("async_thread", "async_thread", "thread"),
("async_thread", "thread", "async_no_thread"),
("async_thread", "thread", "async_thread"),
("async_thread", "thread", "thread"),
("thread", "async_no_thread", "async_no_thread"),
("thread", "async_no_thread", "async_thread"),
("thread", "async_no_thread", "thread"),
("thread", "async_thread", "async_no_thread"),
("thread", "async_thread", "async_thread"),
("thread", "async_thread", "thread"),
("thread", "thread", "async_no_thread"),
("thread", "thread", "async_thread"),
("thread", "thread", "thread"),
( # Plus a longer chain to stress test this mechanism.
"async_no_thread",
"async_no_thread",
"thread",
"thread",
"async_thread",
"async_thread",
"async_no_thread",
"async_thread",
"async_no_thread",
"async_thread",
"thread",
"async_thread",
"async_thread",
"async_no_thread",
"async_no_thread",
"thread",
"thread",
"async_no_thread",
"async_no_thread",
"thread",
"async_no_thread",
"thread",
"thread",
),
],
)
async def test_calling_workers_from_within_workers(call_stack: Tuple[str]):
"""Regression test for https://github.com/Textualize/textual/issues/3472.
This makes sure we can nest worker calls without a problem.
"""
app = NestedWorkersApp(list(call_stack))
async with app.run_test():
app.call_from_stack()
# We need multiple awaits because we're creating a chain of workers that may
# have multiple async workers, each of which may need the await to have enough
# time to call the next one in the chain.
for _ in range(len(call_stack)):
await app.workers.wait_for_complete()
assert app.call_stack == []

0 comments on commit 3f33cd1

Please sign in to comment.