diff --git a/tests/conftest.py b/tests/conftest.py index 37aba0b..a54456a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -68,27 +68,27 @@ def get_config() -> RabbitMQConnectionConfig: yield context -@pytest.fixture(name="managed_thread", scope="function") +@pytest.fixture(name="managed_thread", scope="function", autouse=True) def get_managed_thread_fixture( context: ApplicationContext, ) -> Generator[ManagedThread, Any, None]: - action = context.get(IManagedThreadAction) - managed_thread: ManagedThread = ManagedThread(action, "RabbitMQ Sync Thread") - - yield managed_thread - - managed_thread.stop() + thread: ManagedThread = ManagedThread( + context.get(IManagedThreadAction), + "RabbitMQ Sync Thread", + ) + thread.start() + yield thread + thread.stop() -@pytest.fixture(name="async_managed_thread", scope="function") +@pytest.fixture(name="async_managed_thread", scope="function", autouse=True) def get_async_managed_thread_fixture( context: ApplicationContext, ) -> Generator[AsyncManagedThread, Any, None]: - action = context.get(IAsyncManagedThreadAction) - async_managed_thread: AsyncManagedThread = AsyncManagedThread( - action, "RabbitMQ Async Thread" + thread: AsyncManagedThread = AsyncManagedThread( + context.get(IAsyncManagedThreadAction), + "RabbitMQ Async Thread", ) - - yield async_managed_thread - - async_managed_thread.stop() + thread.start() + yield thread + thread.stop() diff --git a/tests/test_event.py b/tests/test_event.py index de15af9..5f22ef3 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -1,32 +1,32 @@ +from time import sleep +from asyncio import sleep as asleep + import pytest from spakky.application.application_context import ApplicationContext from spakky.domain.ports.event.event_publisher import ( IAsyncEventPublisher, IEventPublisher, ) -from spakky.threading.managed_thread import AsyncManagedThread, ManagedThread from tests.apps.dummy import DummyEventHandler, SampleEvent -def test_synchronous_event_publish_and_consume( - context: ApplicationContext, managed_thread: ManagedThread -) -> None: - managed_thread.start() +def test_synchronous_event_publish_and_consume(context: ApplicationContext) -> None: publisher = context.get(IEventPublisher) publisher.publish(SampleEvent(message="Hello, World!")) publisher.publish(SampleEvent(message="Goodbye, World!")) + sleep(0.1) handler = context.get(DummyEventHandler) assert handler.count == 2 @pytest.mark.asyncio async def test_asynchronous_event_publish_and_consume( - context: ApplicationContext, async_managed_thread: AsyncManagedThread + context: ApplicationContext, ) -> None: - async_managed_thread.start() publisher = context.get(IAsyncEventPublisher) await publisher.publish(SampleEvent(message="Hello, World!")) await publisher.publish(SampleEvent(message="Goodbye, World!")) + await asleep(0.1) handler = context.get(DummyEventHandler) assert handler.count == 2