From a69d11123afe86af82bf985c11168b88f2cb5346 Mon Sep 17 00:00:00 2001 From: David Brochart Date: Mon, 20 Mar 2023 21:11:51 +0100 Subject: [PATCH] Use watchfiles stop_event --- plugins/contents/fps_contents/fileid.py | 6 +++++- plugins/yjs/fps_yjs/main.py | 3 ++- tests/test_server.py | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/plugins/contents/fps_contents/fileid.py b/plugins/contents/fps_contents/fileid.py index fbfbe089..00c9597c 100644 --- a/plugins/contents/fps_contents/fileid.py +++ b/plugins/contents/fps_contents/fileid.py @@ -40,6 +40,8 @@ def __init__(self, db_path: str = "fileid.db"): self.initialized = asyncio.Event() self.watchers = {} self.watch_files_task = asyncio.create_task(self.watch_files()) + self.stop_watching_files = asyncio.Event() + self.stopped_watching_files = asyncio.Event() self.lock = asyncio.Lock() async def get_id(self, path: str) -> Optional[str]: @@ -96,7 +98,7 @@ async def watch_files(self): await db.commit() self.initialized.set() - async for changes in awatch("."): + async for changes in awatch(".", stop_event=self.stop_watching_files): async with self.lock: async with aiosqlite.connect(self.db_path) as db: deleted_paths = [] @@ -156,6 +158,8 @@ async def watch_files(self): for watcher in self.watchers.get(changed_path, []): watcher.notify(change) + self.stopped_watching_files.set() + def watch(self, path: str) -> Watcher: watcher = Watcher(path) self.watchers.setdefault(path, []).append(watcher) diff --git a/plugins/yjs/fps_yjs/main.py b/plugins/yjs/fps_yjs/main.py index 57d567aa..348275d6 100644 --- a/plugins/yjs/fps_yjs/main.py +++ b/plugins/yjs/fps_yjs/main.py @@ -25,4 +25,5 @@ async def start( yield - contents.file_id_manager.watch_files_task.cancel() + contents.file_id_manager.stop_watching_files.set() + await contents.file_id_manager.stopped_watching_files.wait() diff --git a/tests/test_server.py b/tests/test_server.py index fbd09bf1..fede89d2 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -84,7 +84,7 @@ async def test_rest_api(start_jupyverse): ydoc = Y.YDoc() WebsocketProvider(ydoc, websocket) # wait for file to be loaded and Y model to be created in server and client - await asyncio.sleep(0.5) + await asyncio.sleep(1) # execute notebook for cell_idx in range(3): response = requests.post( @@ -98,7 +98,7 @@ async def test_rest_api(start_jupyverse): ) print(f"{url}/api/kernels/{kernel_id}/execute", response.json()) # wait for Y model to be updated - await asyncio.sleep(0.5) + await asyncio.sleep(1) # retrieve cells cells = json.loads(ydoc.get_array("cells").to_json()) assert cells[0]["outputs"] == [