From 33c624ae1e9f5553302a6b6a086f4ed1314298fc Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Fri, 6 Jan 2023 23:28:57 +0300 Subject: [PATCH] PyTorch adapter: add a way to disable cache updates (#5549) This will let users to run their PyTorch code without network access, provided that they have already cached the data. ### How has this been tested? Unit tests. --- CHANGELOG.md | 2 + cvat-sdk/cvat_sdk/pytorch/__init__.py | 1 + cvat-sdk/cvat_sdk/pytorch/caching.py | 222 +++++++++++++++++++ cvat-sdk/cvat_sdk/pytorch/common.py | 8 - cvat-sdk/cvat_sdk/pytorch/project_dataset.py | 27 +-- cvat-sdk/cvat_sdk/pytorch/task_dataset.py | 118 +++------- tests/python/sdk/test_pytorch.py | 50 +++++ 7 files changed, 317 insertions(+), 111 deletions(-) create mode 100644 cvat-sdk/cvat_sdk/pytorch/caching.py diff --git a/CHANGELOG.md b/CHANGELOG.md index cb76509d893..0fda7aa1a74 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 () - \[SDK\] Class to represent a project as a PyTorch dataset () +- \[SDK\] A PyTorch adapter setting to disable cache updates + () ### Changed - The Docker Compose files now use the Compose Specification version diff --git a/cvat-sdk/cvat_sdk/pytorch/__init__.py b/cvat-sdk/cvat_sdk/pytorch/__init__.py index cff084c1a98..ba6609b268a 100644 --- a/cvat-sdk/cvat_sdk/pytorch/__init__.py +++ b/cvat-sdk/cvat_sdk/pytorch/__init__.py @@ -2,6 +2,7 @@ # # SPDX-License-Identifier: MIT +from .caching import UpdatePolicy from .common import FrameAnnotations, Target, UnsupportedDatasetError from .project_dataset import ProjectVisionDataset from .task_dataset import TaskVisionDataset diff --git a/cvat-sdk/cvat_sdk/pytorch/caching.py b/cvat-sdk/cvat_sdk/pytorch/caching.py new file mode 100644 index 00000000000..47f46d759e9 --- /dev/null +++ b/cvat-sdk/cvat_sdk/pytorch/caching.py @@ -0,0 +1,222 @@ +# Copyright (C) 2023 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import base64 +import json +import shutil +from abc import ABCMeta, abstractmethod +from enum import Enum, auto +from pathlib import Path +from typing import Callable, Mapping, Type, TypeVar + +import cvat_sdk.models as models +from cvat_sdk.api_client.model_utils import OpenApiModel, to_json +from cvat_sdk.core.client import Client +from cvat_sdk.core.proxies.projects import Project +from cvat_sdk.core.proxies.tasks import Task +from cvat_sdk.core.utils import atomic_writer + + +class UpdatePolicy(Enum): + """ + Defines policies for when the local cache is updated from the CVAT server. + """ + + IF_MISSING_OR_STALE = auto() + """ + Update the cache whenever cached data is missing or the server has a newer version. + """ + + NEVER = auto() + """ + Never update the cache. If an operation requires data that is not cached, + it will fail. + + No network access will be performed if this policy is used. + """ + + +_ModelType = TypeVar("_ModelType", bound=OpenApiModel) + + +class CacheManager(metaclass=ABCMeta): + def __init__(self, client: Client) -> None: + self._client = client + self._logger = client.logger + + self._server_dir = client.config.cache_dir / f"servers/{self.server_dir_name}" + + @property + def server_dir_name(self) -> str: + # Base64-encode the name to avoid FS-unsafe characters (like slashes) + return base64.urlsafe_b64encode(self._client.api_map.host.encode()).rstrip(b"=").decode() + + def task_dir(self, task_id: int) -> Path: + return self._server_dir / f"tasks/{task_id}" + + def task_json_path(self, task_id: int) -> Path: + return self.task_dir(task_id) / "task.json" + + def chunk_dir(self, task_id: int) -> Path: + return self.task_dir(task_id) / "chunks" + + def project_dir(self, project_id: int) -> Path: + return self._server_dir / f"projects/{project_id}" + + def project_json_path(self, project_id: int) -> Path: + return self.project_dir(project_id) / "project.json" + + def load_model(self, path: Path, model_type: Type[_ModelType]) -> _ModelType: + with open(path, "rb") as f: + return model_type._new_from_openapi_data(**json.load(f)) + + def save_model(self, path: Path, model: OpenApiModel) -> None: + with atomic_writer(path, "w", encoding="UTF-8") as f: + json.dump(to_json(model), f, indent=4) + print(file=f) # add final newline + + @abstractmethod + def retrieve_task(self, task_id: int) -> Task: + ... + + @abstractmethod + def ensure_task_model( + self, + task_id: int, + filename: str, + model_type: Type[_ModelType], + downloader: Callable[[], _ModelType], + model_description: str, + ) -> _ModelType: + ... + + @abstractmethod + def ensure_chunk(self, task: Task, chunk_index: int) -> None: + ... + + @abstractmethod + def retrieve_project(self, project_id: int) -> Project: + ... + + +class _CacheManagerOnline(CacheManager): + def retrieve_task(self, task_id: int) -> Task: + self._logger.info(f"Fetching task {task_id}...") + task = self._client.tasks.retrieve(task_id) + + self._initialize_task_dir(task) + return task + + def _initialize_task_dir(self, task: Task) -> None: + task_dir = self.task_dir(task.id) + task_json_path = self.task_json_path(task.id) + + try: + saved_task = self.load_model(task_json_path, models.TaskRead) + except Exception: + self._logger.info(f"Task {task.id} is not yet cached or the cache is corrupted") + + # If the cache was corrupted, the directory might already be there; clear it. + if task_dir.exists(): + shutil.rmtree(task_dir) + else: + if saved_task.updated_date < task.updated_date: + self._logger.info( + f"Task {task.id} has been updated on the server since it was cached; purging the cache" + ) + shutil.rmtree(task_dir) + + task_dir.mkdir(exist_ok=True, parents=True) + self.save_model(task_json_path, task._model) + + def ensure_task_model( + self, + task_id: int, + filename: str, + model_type: Type[_ModelType], + downloader: Callable[[], _ModelType], + model_description: str, + ) -> _ModelType: + path = self.task_dir(task_id) / filename + + try: + model = self.load_model(path, model_type) + self._logger.info(f"Loaded {model_description} from cache") + return model + except FileNotFoundError: + pass + except Exception: + self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True) + + self._logger.info(f"Downloading {model_description}...") + model = downloader() + self._logger.info(f"Downloaded {model_description}") + + self.save_model(path, model) + + return model + + def ensure_chunk(self, task: Task, chunk_index: int) -> None: + chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip" + if chunk_path.exists(): + return # already downloaded previously + + self._logger.info(f"Downloading chunk #{chunk_index}...") + + with atomic_writer(chunk_path, "wb") as chunk_file: + task.download_chunk(chunk_index, chunk_file, quality="original") + + def retrieve_project(self, project_id: int) -> Project: + self._logger.info(f"Fetching project {project_id}...") + project = self._client.projects.retrieve(project_id) + + project_dir = self.project_dir(project_id) + project_dir.mkdir(parents=True, exist_ok=True) + project_json_path = self.project_json_path(project_id) + + # There are currently no files cached alongside project.json, + # so we don't need to check if we need to purge them. + + self.save_model(project_json_path, project._model) + + return project + + +class _CacheManagerOffline(CacheManager): + def retrieve_task(self, task_id: int) -> Task: + self._logger.info(f"Retrieving task {task_id} from cache...") + return Task(self._client, self.load_model(self.task_json_path(task_id), models.TaskRead)) + + def ensure_task_model( + self, + task_id: int, + filename: str, + model_type: Type[_ModelType], + downloader: Callable[[], _ModelType], + model_description: str, + ) -> _ModelType: + self._logger.info(f"Loading {model_description} from cache...") + return self.load_model(self.task_dir(task_id) / filename, model_type) + + def ensure_chunk(self, task: Task, chunk_index: int) -> None: + chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip" + + if not chunk_path.exists(): + raise FileNotFoundError(f"Chunk {chunk_index} of task {task.id} is not cached") + + def retrieve_project(self, project_id: int) -> Project: + self._logger.info(f"Retrieving project {project_id} from cache...") + return Project( + self._client, self.load_model(self.project_json_path(project_id), models.ProjectRead) + ) + + +_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, Type[CacheManager]] = { + UpdatePolicy.IF_MISSING_OR_STALE: _CacheManagerOnline, + UpdatePolicy.NEVER: _CacheManagerOffline, +} + + +def make_cache_manager(client: Client, update_policy: UpdatePolicy) -> CacheManager: + return _CACHE_MANAGER_CLASSES[update_policy](client) diff --git a/cvat-sdk/cvat_sdk/pytorch/common.py b/cvat-sdk/cvat_sdk/pytorch/common.py index 3002e0c428a..ac5d8fb7ad9 100644 --- a/cvat-sdk/cvat_sdk/pytorch/common.py +++ b/cvat-sdk/cvat_sdk/pytorch/common.py @@ -2,8 +2,6 @@ # # SPDX-License-Identifier: MIT -import base64 -from pathlib import Path from typing import List, Mapping import attrs @@ -42,9 +40,3 @@ class Target: A mapping from label_id values in `LabeledImage` and `LabeledShape` objects to an integer index. This mapping is consistent across all samples for a given task. """ - - -def get_server_cache_dir(client: cvat_sdk.core.Client) -> Path: - # Base64-encode the name to avoid FS-unsafe characters (like slashes) - server_dir_name = base64.urlsafe_b64encode(client.api_map.host.encode()).rstrip(b"=").decode() - return client.config.cache_dir / f"servers/{server_dir_name}" diff --git a/cvat-sdk/cvat_sdk/pytorch/project_dataset.py b/cvat-sdk/cvat_sdk/pytorch/project_dataset.py index 3b4f9f4dc3d..421a17ff642 100644 --- a/cvat-sdk/cvat_sdk/pytorch/project_dataset.py +++ b/cvat-sdk/cvat_sdk/pytorch/project_dataset.py @@ -12,7 +12,7 @@ import cvat_sdk.core import cvat_sdk.core.exceptions import cvat_sdk.models as models -from cvat_sdk.pytorch.common import get_server_cache_dir +from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager from cvat_sdk.pytorch.task_dataset import TaskVisionDataset @@ -42,6 +42,7 @@ def __init__( label_name_to_index: Mapping[str, int] = None, task_filter: Optional[Callable[[models.ITaskRead], bool]] = None, include_subsets: Optional[Container[str]] = None, + update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE, ) -> None: """ Creates a dataset corresponding to the project with ID `project_id` on the @@ -61,29 +62,24 @@ def __init__( * If `include_subsets` is set to a container, then tasks whose subset is not a member of this container will be excluded. + + `update_policy` determines when and if the local cache will be updated. """ self._logger = client.logger - self._logger.info(f"Fetching project {project_id}...") - project = client.projects.retrieve(project_id) - - # We don't actually need to save anything to this directory (yet), - # but VisionDataset.__init__ requires a root, so make one. - # It could be useful in the future to store the project data for - # offline-only mode. - project_dir = get_server_cache_dir(client) / f"projects/{project_id}" - project_dir.mkdir(parents=True, exist_ok=True) + cache_manager = make_cache_manager(client, update_policy) + project = cache_manager.retrieve_project(project_id) super().__init__( - os.fspath(project_dir), + os.fspath(cache_manager.project_dir(project_id)), transforms=transforms, transform=transform, target_transform=target_transform, ) self._logger.info("Fetching project tasks...") - tasks = project.get_tasks() + tasks = [cache_manager.retrieve_task(task_id) for task_id in project.tasks] if task_filter is not None: tasks = list(filter(task_filter, tasks)) @@ -95,7 +91,12 @@ def __init__( self._underlying = torch.utils.data.ConcatDataset( [ - TaskVisionDataset(client, task.id, label_name_to_index=label_name_to_index) + TaskVisionDataset( + client, + task.id, + label_name_to_index=label_name_to_index, + update_policy=update_policy, + ) for task in tasks ] ) diff --git a/cvat-sdk/cvat_sdk/pytorch/task_dataset.py b/cvat-sdk/cvat_sdk/pytorch/task_dataset.py index bcda76f6db9..aecd6b74bea 100644 --- a/cvat-sdk/cvat_sdk/pytorch/task_dataset.py +++ b/cvat-sdk/cvat_sdk/pytorch/task_dataset.py @@ -3,13 +3,11 @@ # SPDX-License-Identifier: MIT import collections -import json import os -import shutil import types import zipfile from concurrent.futures import ThreadPoolExecutor -from typing import Callable, Dict, Mapping, Optional, Type, TypeVar +from typing import Callable, Dict, Mapping, Optional import PIL.Image import torchvision.datasets @@ -17,16 +15,8 @@ import cvat_sdk.core import cvat_sdk.core.exceptions import cvat_sdk.models as models -from cvat_sdk.api_client.model_utils import to_json -from cvat_sdk.core.utils import atomic_writer -from cvat_sdk.pytorch.common import ( - FrameAnnotations, - Target, - UnsupportedDatasetError, - get_server_cache_dir, -) - -_ModelType = TypeVar("_ModelType") +from cvat_sdk.pytorch.caching import UpdatePolicy, make_cache_manager +from cvat_sdk.pytorch.common import FrameAnnotations, Target, UnsupportedDatasetError _NUM_DOWNLOAD_THREADS = 4 @@ -44,7 +34,7 @@ class TaskVisionDataset(torchvision.datasets.VisionDataset): * target is a `Target` object containing annotations for the frame. This class caches all data and annotations for the task on the local file system - during construction. If the task is updated on the server, the cache is updated. + during construction. Limitations: @@ -61,6 +51,7 @@ def __init__( transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, label_name_to_index: Mapping[str, int] = None, + update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE, ) -> None: """ Creates a dataset corresponding to the task with ID `task_id` on the @@ -80,12 +71,14 @@ def __init__( will map each label ID to a distinct integer in the range [0, `num_labels`), where `num_labels` is the number of labels defined in the task. This mapping will be generally unpredictable, but consistent for a given task. + + `update_policy` determines when and if the local cache will be updated. """ self._logger = client.logger - self._logger.info(f"Fetching task {task_id}...") - self._task = client.tasks.retrieve(task_id) + cache_manager = make_cache_manager(client, update_policy) + self._task = cache_manager.retrieve_task(task_id) if not self._task.size or not self._task.data_chunk_size: raise UnsupportedDatasetError("The task has no data") @@ -96,18 +89,19 @@ def __init__( f" current chunk type is {self._task.data_original_chunk_type!r}" ) - self._task_dir = get_server_cache_dir(client) / f"tasks/{self._task.id}" - self._initialize_task_dir() - super().__init__( - os.fspath(self._task_dir), + os.fspath(cache_manager.task_dir(self._task.id)), transforms=transforms, transform=transform, target_transform=target_transform, ) - data_meta = self._ensure_model( - "data_meta.json", models.DataMetaRead, self._task.get_meta, "data metadata" + data_meta = cache_manager.ensure_task_model( + self._task.id, + "data_meta.json", + models.DataMetaRead, + self._task.get_meta, + "data metadata", ) self._active_frame_indexes = sorted( set(range(self._task.size)) - set(data_meta.deleted_frames) @@ -115,7 +109,7 @@ def __init__( self._logger.info("Downloading chunks...") - self._chunk_dir = self._task_dir / "chunks" + self._chunk_dir = cache_manager.chunk_dir(task_id) self._chunk_dir.mkdir(exist_ok=True, parents=True) needed_chunks = { @@ -123,7 +117,11 @@ def __init__( } with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool: - for _ in pool.map(self._ensure_chunk, sorted(needed_chunks)): + + def ensure_chunk(chunk_index): + cache_manager.ensure_chunk(self._task, chunk_index) + + for _ in pool.map(ensure_chunk, sorted(needed_chunks)): # just need to loop through all results so that any exceptions are propagated pass @@ -143,8 +141,12 @@ def __init__( {label.id: label_name_to_index[label.name] for label in self._task.labels} ) - annotations = self._ensure_model( - "annotations.json", models.LabeledData, self._task.get_annotations, "annotations" + annotations = cache_manager.ensure_task_model( + self._task.id, + "annotations.json", + models.LabeledData, + self._task.get_annotations, + "annotations", ) self._frame_annotations: Dict[int, FrameAnnotations] = collections.defaultdict( @@ -159,70 +161,6 @@ def __init__( # TODO: tracks? - def _initialize_task_dir(self) -> None: - task_json_path = self._task_dir / "task.json" - - try: - with open(task_json_path, "rb") as task_json_file: - saved_task = models.TaskRead._new_from_openapi_data(**json.load(task_json_file)) - except Exception: - self._logger.info("Task is not yet cached or the cache is corrupted") - - # If the cache was corrupted, the directory might already be there; clear it. - if self._task_dir.exists(): - shutil.rmtree(self._task_dir) - else: - if saved_task.updated_date < self._task.updated_date: - self._logger.info( - "Task has been updated on the server since it was cached; purging the cache" - ) - shutil.rmtree(self._task_dir) - - self._task_dir.mkdir(exist_ok=True, parents=True) - - with atomic_writer(task_json_path, "w", encoding="UTF-8") as task_json_file: - json.dump(to_json(self._task._model), task_json_file, indent=4) - print(file=task_json_file) # add final newline - - def _ensure_chunk(self, chunk_index: int) -> None: - chunk_path = self._chunk_dir / f"{chunk_index}.zip" - if chunk_path.exists(): - return # already downloaded previously - - self._logger.info(f"Downloading chunk #{chunk_index}...") - - with atomic_writer(chunk_path, "wb") as chunk_file: - self._task.download_chunk(chunk_index, chunk_file, quality="original") - - def _ensure_model( - self, - filename: str, - model_type: Type[_ModelType], - download: Callable[[], _ModelType], - model_description: str, - ) -> _ModelType: - path = self._task_dir / filename - - try: - with open(path, "rb") as f: - model = model_type._new_from_openapi_data(**json.load(f)) - self._logger.info(f"Loaded {model_description} from cache") - return model - except FileNotFoundError: - pass - except Exception: - self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True) - - self._logger.info(f"Downloading {model_description}...") - model = download() - self._logger.info(f"Downloaded {model_description}") - - with atomic_writer(path, "w", encoding="UTF-8") as f: - json.dump(to_json(model), f, indent=4) - print(file=f) # add final newline - - return model - def __getitem__(self, sample_index: int): """ Returns the sample with index `sample_index`. diff --git a/tests/python/sdk/test_pytorch.py b/tests/python/sdk/test_pytorch.py index c6e12ba5688..90c22657192 100644 --- a/tests/python/sdk/test_pytorch.py +++ b/tests/python/sdk/test_pytorch.py @@ -11,6 +11,7 @@ import pytest from cvat_sdk import Client, models +from cvat_sdk.api_client.rest import RESTClientObject from cvat_sdk.core.proxies.tasks import ResourceType try: @@ -42,6 +43,13 @@ def _common_setup( api_client.configuration.logger[k] = logger +def _disable_api_requests(monkeypatch: pytest.MonkeyPatch) -> None: + def disabled_request(*args, **kwargs): + raise RuntimeError("Disabled!") + + monkeypatch.setattr(RESTClientObject, "request", disabled_request) + + @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") class TestTaskVisionDataset: @pytest.fixture(autouse=True) @@ -226,6 +234,27 @@ def test_custom_label_mapping(self): assert target.label_id_to_index[label_name_to_id["person"]] == 123 assert target.label_id_to_index[label_name_to_id["car"]] == 456 + def test_offline(self, monkeypatch: pytest.MonkeyPatch): + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + update_policy=cvatpt.UpdatePolicy.IF_MISSING_OR_STALE, + ) + + fresh_samples = list(dataset) + + _disable_api_requests(monkeypatch) + + dataset = cvatpt.TaskVisionDataset( + self.client, + self.task.id, + update_policy=cvatpt.UpdatePolicy.NEVER, + ) + + cached_samples = list(dataset) + + assert fresh_samples == cached_samples + @pytest.mark.skipif(cvatpt is None, reason="PyTorch dependencies are not installed") class TestProjectVisionDataset: @@ -359,3 +388,24 @@ def test_combined_transforms(self): assert isinstance(dataset[0][0], cvatpt.Target) assert isinstance(dataset[0][1], PIL.Image.Image) + + def test_offline(self, monkeypatch: pytest.MonkeyPatch): + dataset = cvatpt.ProjectVisionDataset( + self.client, + self.project.id, + update_policy=cvatpt.UpdatePolicy.IF_MISSING_OR_STALE, + ) + + fresh_samples = list(dataset) + + _disable_api_requests(monkeypatch) + + dataset = cvatpt.ProjectVisionDataset( + self.client, + self.project.id, + update_policy=cvatpt.UpdatePolicy.NEVER, + ) + + cached_samples = list(dataset) + + assert fresh_samples == cached_samples