forked from cvat-ai/cvat
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
PyTorch adapter: add a way to disable cache updates (cvat-ai#5549)
This will let users to run their PyTorch code without network access, provided that they have already cached the data. ### How has this been tested? <!-- Please describe in detail how you tested your changes. Include details of your testing environment, and the tests you ran to see how your change affects other areas of the code, etc. --> Unit tests.
- Loading branch information
Showing
7 changed files
with
317 additions
and
111 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
# Copyright (C) 2023 CVAT.ai Corporation | ||
# | ||
# SPDX-License-Identifier: MIT | ||
|
||
import base64 | ||
import json | ||
import shutil | ||
from abc import ABCMeta, abstractmethod | ||
from enum import Enum, auto | ||
from pathlib import Path | ||
from typing import Callable, Mapping, Type, TypeVar | ||
|
||
import cvat_sdk.models as models | ||
from cvat_sdk.api_client.model_utils import OpenApiModel, to_json | ||
from cvat_sdk.core.client import Client | ||
from cvat_sdk.core.proxies.projects import Project | ||
from cvat_sdk.core.proxies.tasks import Task | ||
from cvat_sdk.core.utils import atomic_writer | ||
|
||
|
||
class UpdatePolicy(Enum): | ||
""" | ||
Defines policies for when the local cache is updated from the CVAT server. | ||
""" | ||
|
||
IF_MISSING_OR_STALE = auto() | ||
""" | ||
Update the cache whenever cached data is missing or the server has a newer version. | ||
""" | ||
|
||
NEVER = auto() | ||
""" | ||
Never update the cache. If an operation requires data that is not cached, | ||
it will fail. | ||
No network access will be performed if this policy is used. | ||
""" | ||
|
||
|
||
_ModelType = TypeVar("_ModelType", bound=OpenApiModel) | ||
|
||
|
||
class CacheManager(metaclass=ABCMeta): | ||
def __init__(self, client: Client) -> None: | ||
self._client = client | ||
self._logger = client.logger | ||
|
||
self._server_dir = client.config.cache_dir / f"servers/{self.server_dir_name}" | ||
|
||
@property | ||
def server_dir_name(self) -> str: | ||
# Base64-encode the name to avoid FS-unsafe characters (like slashes) | ||
return base64.urlsafe_b64encode(self._client.api_map.host.encode()).rstrip(b"=").decode() | ||
|
||
def task_dir(self, task_id: int) -> Path: | ||
return self._server_dir / f"tasks/{task_id}" | ||
|
||
def task_json_path(self, task_id: int) -> Path: | ||
return self.task_dir(task_id) / "task.json" | ||
|
||
def chunk_dir(self, task_id: int) -> Path: | ||
return self.task_dir(task_id) / "chunks" | ||
|
||
def project_dir(self, project_id: int) -> Path: | ||
return self._server_dir / f"projects/{project_id}" | ||
|
||
def project_json_path(self, project_id: int) -> Path: | ||
return self.project_dir(project_id) / "project.json" | ||
|
||
def load_model(self, path: Path, model_type: Type[_ModelType]) -> _ModelType: | ||
with open(path, "rb") as f: | ||
return model_type._new_from_openapi_data(**json.load(f)) | ||
|
||
def save_model(self, path: Path, model: OpenApiModel) -> None: | ||
with atomic_writer(path, "w", encoding="UTF-8") as f: | ||
json.dump(to_json(model), f, indent=4) | ||
print(file=f) # add final newline | ||
|
||
@abstractmethod | ||
def retrieve_task(self, task_id: int) -> Task: | ||
... | ||
|
||
@abstractmethod | ||
def ensure_task_model( | ||
self, | ||
task_id: int, | ||
filename: str, | ||
model_type: Type[_ModelType], | ||
downloader: Callable[[], _ModelType], | ||
model_description: str, | ||
) -> _ModelType: | ||
... | ||
|
||
@abstractmethod | ||
def ensure_chunk(self, task: Task, chunk_index: int) -> None: | ||
... | ||
|
||
@abstractmethod | ||
def retrieve_project(self, project_id: int) -> Project: | ||
... | ||
|
||
|
||
class _CacheManagerOnline(CacheManager): | ||
def retrieve_task(self, task_id: int) -> Task: | ||
self._logger.info(f"Fetching task {task_id}...") | ||
task = self._client.tasks.retrieve(task_id) | ||
|
||
self._initialize_task_dir(task) | ||
return task | ||
|
||
def _initialize_task_dir(self, task: Task) -> None: | ||
task_dir = self.task_dir(task.id) | ||
task_json_path = self.task_json_path(task.id) | ||
|
||
try: | ||
saved_task = self.load_model(task_json_path, models.TaskRead) | ||
except Exception: | ||
self._logger.info(f"Task {task.id} is not yet cached or the cache is corrupted") | ||
|
||
# If the cache was corrupted, the directory might already be there; clear it. | ||
if task_dir.exists(): | ||
shutil.rmtree(task_dir) | ||
else: | ||
if saved_task.updated_date < task.updated_date: | ||
self._logger.info( | ||
f"Task {task.id} has been updated on the server since it was cached; purging the cache" | ||
) | ||
shutil.rmtree(task_dir) | ||
|
||
task_dir.mkdir(exist_ok=True, parents=True) | ||
self.save_model(task_json_path, task._model) | ||
|
||
def ensure_task_model( | ||
self, | ||
task_id: int, | ||
filename: str, | ||
model_type: Type[_ModelType], | ||
downloader: Callable[[], _ModelType], | ||
model_description: str, | ||
) -> _ModelType: | ||
path = self.task_dir(task_id) / filename | ||
|
||
try: | ||
model = self.load_model(path, model_type) | ||
self._logger.info(f"Loaded {model_description} from cache") | ||
return model | ||
except FileNotFoundError: | ||
pass | ||
except Exception: | ||
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True) | ||
|
||
self._logger.info(f"Downloading {model_description}...") | ||
model = downloader() | ||
self._logger.info(f"Downloaded {model_description}") | ||
|
||
self.save_model(path, model) | ||
|
||
return model | ||
|
||
def ensure_chunk(self, task: Task, chunk_index: int) -> None: | ||
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip" | ||
if chunk_path.exists(): | ||
return # already downloaded previously | ||
|
||
self._logger.info(f"Downloading chunk #{chunk_index}...") | ||
|
||
with atomic_writer(chunk_path, "wb") as chunk_file: | ||
task.download_chunk(chunk_index, chunk_file, quality="original") | ||
|
||
def retrieve_project(self, project_id: int) -> Project: | ||
self._logger.info(f"Fetching project {project_id}...") | ||
project = self._client.projects.retrieve(project_id) | ||
|
||
project_dir = self.project_dir(project_id) | ||
project_dir.mkdir(parents=True, exist_ok=True) | ||
project_json_path = self.project_json_path(project_id) | ||
|
||
# There are currently no files cached alongside project.json, | ||
# so we don't need to check if we need to purge them. | ||
|
||
self.save_model(project_json_path, project._model) | ||
|
||
return project | ||
|
||
|
||
class _CacheManagerOffline(CacheManager): | ||
def retrieve_task(self, task_id: int) -> Task: | ||
self._logger.info(f"Retrieving task {task_id} from cache...") | ||
return Task(self._client, self.load_model(self.task_json_path(task_id), models.TaskRead)) | ||
|
||
def ensure_task_model( | ||
self, | ||
task_id: int, | ||
filename: str, | ||
model_type: Type[_ModelType], | ||
downloader: Callable[[], _ModelType], | ||
model_description: str, | ||
) -> _ModelType: | ||
self._logger.info(f"Loading {model_description} from cache...") | ||
return self.load_model(self.task_dir(task_id) / filename, model_type) | ||
|
||
def ensure_chunk(self, task: Task, chunk_index: int) -> None: | ||
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip" | ||
|
||
if not chunk_path.exists(): | ||
raise FileNotFoundError(f"Chunk {chunk_index} of task {task.id} is not cached") | ||
|
||
def retrieve_project(self, project_id: int) -> Project: | ||
self._logger.info(f"Retrieving project {project_id} from cache...") | ||
return Project( | ||
self._client, self.load_model(self.project_json_path(project_id), models.ProjectRead) | ||
) | ||
|
||
|
||
_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, Type[CacheManager]] = { | ||
UpdatePolicy.IF_MISSING_OR_STALE: _CacheManagerOnline, | ||
UpdatePolicy.NEVER: _CacheManagerOffline, | ||
} | ||
|
||
|
||
def make_cache_manager(client: Client, update_policy: UpdatePolicy) -> CacheManager: | ||
return _CACHE_MANAGER_CLASSES[update_policy](client) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.