cvat/cvat-sdk/cvat_sdk/datasets/caching.py

339 lines
11 KiB
Python

# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import base64
import json
import shutil
from abc import ABCMeta, abstractmethod
from collections.abc import Mapping
from enum import Enum, auto
from pathlib import Path
from typing import Any, Callable, TypeVar, Union, cast
from attrs import define
import cvat_sdk.models as models
from cvat_sdk.api_client.model_utils import OpenApiModel, to_json
from cvat_sdk.core.client import Client
from cvat_sdk.core.proxies.projects import Project
from cvat_sdk.core.proxies.tasks import Task
from cvat_sdk.core.utils import atomic_writer
class UpdatePolicy(Enum):
"""
Defines policies for when the local cache is updated from the CVAT server.
"""
IF_MISSING_OR_STALE = auto()
"""
Update the cache whenever cached data is missing or the server has a newer version.
"""
NEVER = auto()
"""
Never update the cache. If an operation requires data that is not cached,
it will fail.
No network access will be performed if this policy is used.
"""
_CacheObject = dict[str, Any]
class _CacheObjectModel(metaclass=ABCMeta):
@abstractmethod
def dump(self) -> _CacheObject: ...
@classmethod
@abstractmethod
def load(cls, obj: _CacheObject): ...
_ModelType = TypeVar("_ModelType", bound=Union[OpenApiModel, _CacheObjectModel])
class CacheManager(metaclass=ABCMeta):
def __init__(self, client: Client) -> None:
self._client = client
self._logger = client.logger
self._server_dir = client.config.cache_dir / f"servers/{self.server_dir_name}"
@property
def server_dir_name(self) -> str:
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
return base64.urlsafe_b64encode(self._client.api_map.host.encode()).rstrip(b"=").decode()
def task_dir(self, task_id: int) -> Path:
return self._server_dir / f"tasks/{task_id}"
def task_json_path(self, task_id: int) -> Path:
return self.task_dir(task_id) / "task.json"
def chunk_dir(self, task_id: int) -> Path:
return self.task_dir(task_id) / "chunks"
def project_dir(self, project_id: int) -> Path:
return self._server_dir / f"projects/{project_id}"
def project_json_path(self, project_id: int) -> Path:
return self.project_dir(project_id) / "project.json"
def _load_object(self, path: Path) -> _CacheObject:
with open(path, "rb") as f:
return json.load(f)
def _save_object(self, path: Path, obj: _CacheObject) -> None:
with atomic_writer(path, "w", encoding="UTF-8") as f:
json.dump(obj, f, indent=4)
print(file=f) # add final newline
def _deserialize_model(self, obj: _CacheObject, model_type: _ModelType) -> _ModelType:
if issubclass(model_type, OpenApiModel):
return cast(OpenApiModel, model_type)._new_from_openapi_data(**obj)
elif issubclass(model_type, _CacheObjectModel):
return cast(_CacheObjectModel, model_type).load(obj)
else:
raise NotImplementedError("Unexpected model type")
def _serialize_model(self, model: _ModelType) -> _CacheObject:
if isinstance(model, OpenApiModel):
return to_json(model)
elif isinstance(model, _CacheObjectModel):
return model.dump()
else:
raise NotImplementedError("Unexpected model type")
def load_model(self, path: Path, model_type: type[_ModelType]) -> _ModelType:
return self._deserialize_model(self._load_object(path), model_type)
def save_model(self, path: Path, model: _ModelType) -> None:
return self._save_object(path, self._serialize_model(model))
@abstractmethod
def retrieve_task(self, task_id: int) -> Task: ...
@abstractmethod
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType: ...
@abstractmethod
def ensure_chunk(self, task: Task, chunk_index: int) -> None: ...
@abstractmethod
def retrieve_project(self, project_id: int) -> Project: ...
class _CacheManagerOnline(CacheManager):
def retrieve_task(self, task_id: int) -> Task:
self._logger.info(f"Fetching task {task_id}...")
task = self._client.tasks.retrieve(task_id)
self._initialize_task_dir(task)
return task
def _initialize_task_dir(self, task: Task) -> None:
task_dir = self.task_dir(task.id)
task_json_path = self.task_json_path(task.id)
try:
saved_task = self.load_model(task_json_path, _OfflineTaskModel)
except Exception:
self._logger.info(f"Task {task.id} is not yet cached or the cache is corrupted")
# If the cache was corrupted, the directory might already be there; clear it.
if task_dir.exists():
shutil.rmtree(task_dir)
else:
if saved_task.api_model.updated_date < task.updated_date:
self._logger.info(
f"Task {task.id} has been updated on the server since it was cached; purging the cache"
)
shutil.rmtree(task_dir)
task_dir.mkdir(exist_ok=True, parents=True)
self.save_model(task_json_path, _OfflineTaskModel.from_entity(task))
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
path = self.task_dir(task_id) / filename
try:
model = self.load_model(path, model_type)
self._logger.info(f"Loaded {model_description} from cache")
return model
except FileNotFoundError:
pass
except Exception:
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)
self._logger.info(f"Downloading {model_description}...")
model = downloader()
self._logger.info(f"Downloaded {model_description}")
self.save_model(path, model)
return model
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
if chunk_path.exists():
return # already downloaded previously
self._logger.info(f"Downloading chunk #{chunk_index}...")
with atomic_writer(chunk_path, "wb") as chunk_file:
task.download_chunk(chunk_index, chunk_file, quality="original")
def retrieve_project(self, project_id: int) -> Project:
self._logger.info(f"Fetching project {project_id}...")
project = self._client.projects.retrieve(project_id)
project_dir = self.project_dir(project_id)
project_dir.mkdir(parents=True, exist_ok=True)
project_json_path = self.project_json_path(project_id)
# There are currently no files cached alongside project.json,
# so we don't need to check if we need to purge them.
self.save_model(project_json_path, _OfflineProjectModel.from_entity(project))
return project
class _CacheManagerOffline(CacheManager):
def retrieve_task(self, task_id: int) -> Task:
self._logger.info(f"Retrieving task {task_id} from cache...")
cached_model = self.load_model(self.task_json_path(task_id), _OfflineTaskModel)
return _OfflineTaskProxy(self._client, cached_model, cache_manager=self)
def ensure_task_model(
self,
task_id: int,
filename: str,
model_type: type[_ModelType],
downloader: Callable[[], _ModelType],
model_description: str,
) -> _ModelType:
self._logger.info(f"Loading {model_description} from cache...")
return self.load_model(self.task_dir(task_id) / filename, model_type)
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
if not chunk_path.exists():
raise FileNotFoundError(f"Chunk {chunk_index} of task {task.id} is not cached")
def retrieve_project(self, project_id: int) -> Project:
self._logger.info(f"Retrieving project {project_id} from cache...")
cached_model = self.load_model(self.project_json_path(project_id), _OfflineProjectModel)
return _OfflineProjectProxy(self._client, cached_model, cache_manager=self)
@define
class _OfflineTaskModel(_CacheObjectModel):
api_model: models.ITaskRead
labels: list[models.ILabel]
def dump(self) -> _CacheObject:
return {
"model": to_json(self.api_model),
"labels": to_json(self.labels),
}
@classmethod
def load(cls, obj: _CacheObject):
return cls(
api_model=models.TaskRead._from_openapi_data(**obj["model"]),
labels=[models.Label._from_openapi_data(**label) for label in obj["labels"]],
)
@classmethod
def from_entity(cls, entity: Task):
return cls(
api_model=entity._model,
labels=entity.get_labels(),
)
class _OfflineTaskProxy(Task):
def __init__(
self, client: Client, cached_model: _OfflineTaskModel, *, cache_manager: CacheManager
) -> None:
super().__init__(client, cached_model.api_model)
self._offline_model = cached_model
self._cache_manager = cache_manager
def get_labels(self) -> list[models.ILabel]:
return self._offline_model.labels
@define
class _OfflineProjectModel(_CacheObjectModel):
api_model: models.IProjectRead
task_ids: list[int]
labels: list[models.ILabel]
def dump(self) -> _CacheObject:
return {
"model": to_json(self.api_model),
"tasks": self.task_ids,
"labels": to_json(self.labels),
}
@classmethod
def load(cls, obj: _CacheObject):
return cls(
api_model=models.ProjectRead._from_openapi_data(**obj["model"]),
task_ids=obj["tasks"],
labels=[models.Label._from_openapi_data(**label) for label in obj["labels"]],
)
@classmethod
def from_entity(cls, entity: Project):
return cls(
api_model=entity._model,
task_ids=[t.id for t in entity.get_tasks()],
labels=entity.get_labels(),
)
class _OfflineProjectProxy(Project):
def __init__(
self, client: Client, cached_model: _OfflineProjectModel, *, cache_manager: CacheManager
) -> None:
super().__init__(client, cached_model.api_model)
self._offline_model = cached_model
self._cache_manager = cache_manager
def get_tasks(self) -> list[Task]:
return [self._cache_manager.retrieve_task(t) for t in self._offline_model.task_ids]
def get_labels(self) -> list[models.ILabel]:
return self._offline_model.labels
_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, type[CacheManager]] = {
UpdatePolicy.IF_MISSING_OR_STALE: _CacheManagerOnline,
UpdatePolicy.NEVER: _CacheManagerOffline,
}
def make_cache_manager(client: Client, update_policy: UpdatePolicy) -> CacheManager:
return _CACHE_MANAGER_CLASSES[update_policy](client)