339 lines
11 KiB
Python
339 lines
11 KiB
Python
|
|
# Copyright (C) CVAT.ai Corporation
|
||
|
|
#
|
||
|
|
# SPDX-License-Identifier: MIT
|
||
|
|
|
||
|
|
import base64
|
||
|
|
import json
|
||
|
|
import shutil
|
||
|
|
from abc import ABCMeta, abstractmethod
|
||
|
|
from collections.abc import Mapping
|
||
|
|
from enum import Enum, auto
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Callable, TypeVar, Union, cast
|
||
|
|
|
||
|
|
from attrs import define
|
||
|
|
|
||
|
|
import cvat_sdk.models as models
|
||
|
|
from cvat_sdk.api_client.model_utils import OpenApiModel, to_json
|
||
|
|
from cvat_sdk.core.client import Client
|
||
|
|
from cvat_sdk.core.proxies.projects import Project
|
||
|
|
from cvat_sdk.core.proxies.tasks import Task
|
||
|
|
from cvat_sdk.core.utils import atomic_writer
|
||
|
|
|
||
|
|
|
||
|
|
class UpdatePolicy(Enum):
|
||
|
|
"""
|
||
|
|
Defines policies for when the local cache is updated from the CVAT server.
|
||
|
|
"""
|
||
|
|
|
||
|
|
IF_MISSING_OR_STALE = auto()
|
||
|
|
"""
|
||
|
|
Update the cache whenever cached data is missing or the server has a newer version.
|
||
|
|
"""
|
||
|
|
|
||
|
|
NEVER = auto()
|
||
|
|
"""
|
||
|
|
Never update the cache. If an operation requires data that is not cached,
|
||
|
|
it will fail.
|
||
|
|
|
||
|
|
No network access will be performed if this policy is used.
|
||
|
|
"""
|
||
|
|
|
||
|
|
|
||
|
|
_CacheObject = dict[str, Any]
|
||
|
|
|
||
|
|
|
||
|
|
class _CacheObjectModel(metaclass=ABCMeta):
|
||
|
|
@abstractmethod
|
||
|
|
def dump(self) -> _CacheObject: ...
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
@abstractmethod
|
||
|
|
def load(cls, obj: _CacheObject): ...
|
||
|
|
|
||
|
|
|
||
|
|
_ModelType = TypeVar("_ModelType", bound=Union[OpenApiModel, _CacheObjectModel])
|
||
|
|
|
||
|
|
|
||
|
|
class CacheManager(metaclass=ABCMeta):
|
||
|
|
def __init__(self, client: Client) -> None:
|
||
|
|
self._client = client
|
||
|
|
self._logger = client.logger
|
||
|
|
|
||
|
|
self._server_dir = client.config.cache_dir / f"servers/{self.server_dir_name}"
|
||
|
|
|
||
|
|
@property
|
||
|
|
def server_dir_name(self) -> str:
|
||
|
|
# Base64-encode the name to avoid FS-unsafe characters (like slashes)
|
||
|
|
return base64.urlsafe_b64encode(self._client.api_map.host.encode()).rstrip(b"=").decode()
|
||
|
|
|
||
|
|
def task_dir(self, task_id: int) -> Path:
|
||
|
|
return self._server_dir / f"tasks/{task_id}"
|
||
|
|
|
||
|
|
def task_json_path(self, task_id: int) -> Path:
|
||
|
|
return self.task_dir(task_id) / "task.json"
|
||
|
|
|
||
|
|
def chunk_dir(self, task_id: int) -> Path:
|
||
|
|
return self.task_dir(task_id) / "chunks"
|
||
|
|
|
||
|
|
def project_dir(self, project_id: int) -> Path:
|
||
|
|
return self._server_dir / f"projects/{project_id}"
|
||
|
|
|
||
|
|
def project_json_path(self, project_id: int) -> Path:
|
||
|
|
return self.project_dir(project_id) / "project.json"
|
||
|
|
|
||
|
|
def _load_object(self, path: Path) -> _CacheObject:
|
||
|
|
with open(path, "rb") as f:
|
||
|
|
return json.load(f)
|
||
|
|
|
||
|
|
def _save_object(self, path: Path, obj: _CacheObject) -> None:
|
||
|
|
with atomic_writer(path, "w", encoding="UTF-8") as f:
|
||
|
|
json.dump(obj, f, indent=4)
|
||
|
|
print(file=f) # add final newline
|
||
|
|
|
||
|
|
def _deserialize_model(self, obj: _CacheObject, model_type: _ModelType) -> _ModelType:
|
||
|
|
if issubclass(model_type, OpenApiModel):
|
||
|
|
return cast(OpenApiModel, model_type)._new_from_openapi_data(**obj)
|
||
|
|
elif issubclass(model_type, _CacheObjectModel):
|
||
|
|
return cast(_CacheObjectModel, model_type).load(obj)
|
||
|
|
else:
|
||
|
|
raise NotImplementedError("Unexpected model type")
|
||
|
|
|
||
|
|
def _serialize_model(self, model: _ModelType) -> _CacheObject:
|
||
|
|
if isinstance(model, OpenApiModel):
|
||
|
|
return to_json(model)
|
||
|
|
elif isinstance(model, _CacheObjectModel):
|
||
|
|
return model.dump()
|
||
|
|
else:
|
||
|
|
raise NotImplementedError("Unexpected model type")
|
||
|
|
|
||
|
|
def load_model(self, path: Path, model_type: type[_ModelType]) -> _ModelType:
|
||
|
|
return self._deserialize_model(self._load_object(path), model_type)
|
||
|
|
|
||
|
|
def save_model(self, path: Path, model: _ModelType) -> None:
|
||
|
|
return self._save_object(path, self._serialize_model(model))
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def retrieve_task(self, task_id: int) -> Task: ...
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def ensure_task_model(
|
||
|
|
self,
|
||
|
|
task_id: int,
|
||
|
|
filename: str,
|
||
|
|
model_type: type[_ModelType],
|
||
|
|
downloader: Callable[[], _ModelType],
|
||
|
|
model_description: str,
|
||
|
|
) -> _ModelType: ...
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def ensure_chunk(self, task: Task, chunk_index: int) -> None: ...
|
||
|
|
|
||
|
|
@abstractmethod
|
||
|
|
def retrieve_project(self, project_id: int) -> Project: ...
|
||
|
|
|
||
|
|
|
||
|
|
class _CacheManagerOnline(CacheManager):
|
||
|
|
def retrieve_task(self, task_id: int) -> Task:
|
||
|
|
self._logger.info(f"Fetching task {task_id}...")
|
||
|
|
task = self._client.tasks.retrieve(task_id)
|
||
|
|
|
||
|
|
self._initialize_task_dir(task)
|
||
|
|
return task
|
||
|
|
|
||
|
|
def _initialize_task_dir(self, task: Task) -> None:
|
||
|
|
task_dir = self.task_dir(task.id)
|
||
|
|
task_json_path = self.task_json_path(task.id)
|
||
|
|
|
||
|
|
try:
|
||
|
|
saved_task = self.load_model(task_json_path, _OfflineTaskModel)
|
||
|
|
except Exception:
|
||
|
|
self._logger.info(f"Task {task.id} is not yet cached or the cache is corrupted")
|
||
|
|
|
||
|
|
# If the cache was corrupted, the directory might already be there; clear it.
|
||
|
|
if task_dir.exists():
|
||
|
|
shutil.rmtree(task_dir)
|
||
|
|
else:
|
||
|
|
if saved_task.api_model.updated_date < task.updated_date:
|
||
|
|
self._logger.info(
|
||
|
|
f"Task {task.id} has been updated on the server since it was cached; purging the cache"
|
||
|
|
)
|
||
|
|
shutil.rmtree(task_dir)
|
||
|
|
|
||
|
|
task_dir.mkdir(exist_ok=True, parents=True)
|
||
|
|
self.save_model(task_json_path, _OfflineTaskModel.from_entity(task))
|
||
|
|
|
||
|
|
def ensure_task_model(
|
||
|
|
self,
|
||
|
|
task_id: int,
|
||
|
|
filename: str,
|
||
|
|
model_type: type[_ModelType],
|
||
|
|
downloader: Callable[[], _ModelType],
|
||
|
|
model_description: str,
|
||
|
|
) -> _ModelType:
|
||
|
|
path = self.task_dir(task_id) / filename
|
||
|
|
|
||
|
|
try:
|
||
|
|
model = self.load_model(path, model_type)
|
||
|
|
self._logger.info(f"Loaded {model_description} from cache")
|
||
|
|
return model
|
||
|
|
except FileNotFoundError:
|
||
|
|
pass
|
||
|
|
except Exception:
|
||
|
|
self._logger.warning(f"Failed to load {model_description} from cache", exc_info=True)
|
||
|
|
|
||
|
|
self._logger.info(f"Downloading {model_description}...")
|
||
|
|
model = downloader()
|
||
|
|
self._logger.info(f"Downloaded {model_description}")
|
||
|
|
|
||
|
|
self.save_model(path, model)
|
||
|
|
|
||
|
|
return model
|
||
|
|
|
||
|
|
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
|
||
|
|
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
|
||
|
|
if chunk_path.exists():
|
||
|
|
return # already downloaded previously
|
||
|
|
|
||
|
|
self._logger.info(f"Downloading chunk #{chunk_index}...")
|
||
|
|
|
||
|
|
with atomic_writer(chunk_path, "wb") as chunk_file:
|
||
|
|
task.download_chunk(chunk_index, chunk_file, quality="original")
|
||
|
|
|
||
|
|
def retrieve_project(self, project_id: int) -> Project:
|
||
|
|
self._logger.info(f"Fetching project {project_id}...")
|
||
|
|
project = self._client.projects.retrieve(project_id)
|
||
|
|
|
||
|
|
project_dir = self.project_dir(project_id)
|
||
|
|
project_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
project_json_path = self.project_json_path(project_id)
|
||
|
|
|
||
|
|
# There are currently no files cached alongside project.json,
|
||
|
|
# so we don't need to check if we need to purge them.
|
||
|
|
|
||
|
|
self.save_model(project_json_path, _OfflineProjectModel.from_entity(project))
|
||
|
|
|
||
|
|
return project
|
||
|
|
|
||
|
|
|
||
|
|
class _CacheManagerOffline(CacheManager):
|
||
|
|
def retrieve_task(self, task_id: int) -> Task:
|
||
|
|
self._logger.info(f"Retrieving task {task_id} from cache...")
|
||
|
|
cached_model = self.load_model(self.task_json_path(task_id), _OfflineTaskModel)
|
||
|
|
return _OfflineTaskProxy(self._client, cached_model, cache_manager=self)
|
||
|
|
|
||
|
|
def ensure_task_model(
|
||
|
|
self,
|
||
|
|
task_id: int,
|
||
|
|
filename: str,
|
||
|
|
model_type: type[_ModelType],
|
||
|
|
downloader: Callable[[], _ModelType],
|
||
|
|
model_description: str,
|
||
|
|
) -> _ModelType:
|
||
|
|
self._logger.info(f"Loading {model_description} from cache...")
|
||
|
|
return self.load_model(self.task_dir(task_id) / filename, model_type)
|
||
|
|
|
||
|
|
def ensure_chunk(self, task: Task, chunk_index: int) -> None:
|
||
|
|
chunk_path = self.chunk_dir(task.id) / f"{chunk_index}.zip"
|
||
|
|
|
||
|
|
if not chunk_path.exists():
|
||
|
|
raise FileNotFoundError(f"Chunk {chunk_index} of task {task.id} is not cached")
|
||
|
|
|
||
|
|
def retrieve_project(self, project_id: int) -> Project:
|
||
|
|
self._logger.info(f"Retrieving project {project_id} from cache...")
|
||
|
|
cached_model = self.load_model(self.project_json_path(project_id), _OfflineProjectModel)
|
||
|
|
return _OfflineProjectProxy(self._client, cached_model, cache_manager=self)
|
||
|
|
|
||
|
|
|
||
|
|
@define
|
||
|
|
class _OfflineTaskModel(_CacheObjectModel):
|
||
|
|
api_model: models.ITaskRead
|
||
|
|
labels: list[models.ILabel]
|
||
|
|
|
||
|
|
def dump(self) -> _CacheObject:
|
||
|
|
return {
|
||
|
|
"model": to_json(self.api_model),
|
||
|
|
"labels": to_json(self.labels),
|
||
|
|
}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def load(cls, obj: _CacheObject):
|
||
|
|
return cls(
|
||
|
|
api_model=models.TaskRead._from_openapi_data(**obj["model"]),
|
||
|
|
labels=[models.Label._from_openapi_data(**label) for label in obj["labels"]],
|
||
|
|
)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_entity(cls, entity: Task):
|
||
|
|
return cls(
|
||
|
|
api_model=entity._model,
|
||
|
|
labels=entity.get_labels(),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class _OfflineTaskProxy(Task):
|
||
|
|
def __init__(
|
||
|
|
self, client: Client, cached_model: _OfflineTaskModel, *, cache_manager: CacheManager
|
||
|
|
) -> None:
|
||
|
|
super().__init__(client, cached_model.api_model)
|
||
|
|
self._offline_model = cached_model
|
||
|
|
self._cache_manager = cache_manager
|
||
|
|
|
||
|
|
def get_labels(self) -> list[models.ILabel]:
|
||
|
|
return self._offline_model.labels
|
||
|
|
|
||
|
|
|
||
|
|
@define
|
||
|
|
class _OfflineProjectModel(_CacheObjectModel):
|
||
|
|
api_model: models.IProjectRead
|
||
|
|
task_ids: list[int]
|
||
|
|
labels: list[models.ILabel]
|
||
|
|
|
||
|
|
def dump(self) -> _CacheObject:
|
||
|
|
return {
|
||
|
|
"model": to_json(self.api_model),
|
||
|
|
"tasks": self.task_ids,
|
||
|
|
"labels": to_json(self.labels),
|
||
|
|
}
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def load(cls, obj: _CacheObject):
|
||
|
|
return cls(
|
||
|
|
api_model=models.ProjectRead._from_openapi_data(**obj["model"]),
|
||
|
|
task_ids=obj["tasks"],
|
||
|
|
labels=[models.Label._from_openapi_data(**label) for label in obj["labels"]],
|
||
|
|
)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_entity(cls, entity: Project):
|
||
|
|
return cls(
|
||
|
|
api_model=entity._model,
|
||
|
|
task_ids=[t.id for t in entity.get_tasks()],
|
||
|
|
labels=entity.get_labels(),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class _OfflineProjectProxy(Project):
|
||
|
|
def __init__(
|
||
|
|
self, client: Client, cached_model: _OfflineProjectModel, *, cache_manager: CacheManager
|
||
|
|
) -> None:
|
||
|
|
super().__init__(client, cached_model.api_model)
|
||
|
|
self._offline_model = cached_model
|
||
|
|
self._cache_manager = cache_manager
|
||
|
|
|
||
|
|
def get_tasks(self) -> list[Task]:
|
||
|
|
return [self._cache_manager.retrieve_task(t) for t in self._offline_model.task_ids]
|
||
|
|
|
||
|
|
def get_labels(self) -> list[models.ILabel]:
|
||
|
|
return self._offline_model.labels
|
||
|
|
|
||
|
|
|
||
|
|
_CACHE_MANAGER_CLASSES: Mapping[UpdatePolicy, type[CacheManager]] = {
|
||
|
|
UpdatePolicy.IF_MISSING_OR_STALE: _CacheManagerOnline,
|
||
|
|
UpdatePolicy.NEVER: _CacheManagerOffline,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
def make_cache_manager(client: Client, update_policy: UpdatePolicy) -> CacheManager:
|
||
|
|
return _CACHE_MANAGER_CLASSES[update_policy](client)
|