209 lines
7.3 KiB
Python
209 lines
7.3 KiB
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from __future__ import annotations
|
|
|
|
import zipfile
|
|
from collections.abc import Iterable, Sequence
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
|
|
import PIL.Image
|
|
|
|
import cvat_sdk.core
|
|
import cvat_sdk.core.exceptions
|
|
import cvat_sdk.models as models
|
|
from cvat_sdk.datasets.caching import CacheManager, UpdatePolicy, make_cache_manager
|
|
from cvat_sdk.datasets.common import (
|
|
FrameAnnotations,
|
|
MediaDownloadPolicy,
|
|
MediaElement,
|
|
Sample,
|
|
UnsupportedDatasetError,
|
|
)
|
|
|
|
_NUM_DOWNLOAD_THREADS = 4
|
|
|
|
|
|
class TaskDataset:
|
|
"""
|
|
Represents a task on a CVAT server as a collection of samples.
|
|
|
|
Each sample corresponds to one frame in the task, and provides access to
|
|
the corresponding annotations and media data. Deleted frames are omitted.
|
|
|
|
This class caches all data and annotations for the task on the local file system
|
|
during construction.
|
|
|
|
Limitations:
|
|
|
|
* Only tasks with image (not video) data are supported at the moment.
|
|
* Track annotations are currently not accessible.
|
|
"""
|
|
|
|
class _TaskMediaElement(MediaElement):
|
|
def __init__(self, dataset: TaskDataset, frame_index: int) -> None:
|
|
self._dataset = dataset
|
|
self._frame_index = frame_index
|
|
|
|
def load_image(self) -> PIL.Image.Image:
|
|
return self._dataset._load_frame_image(self._frame_index)
|
|
|
|
def __init__(
|
|
self,
|
|
client: cvat_sdk.core.Client,
|
|
task_id: int,
|
|
*,
|
|
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
|
|
load_annotations: bool = True,
|
|
media_download_policy: MediaDownloadPolicy = MediaDownloadPolicy.PRELOAD_ALL,
|
|
) -> None:
|
|
"""
|
|
Creates a dataset corresponding to the task with ID `task_id` on the
|
|
server that `client` is connected to.
|
|
|
|
`update_policy` determines when and if the local cache will be updated.
|
|
|
|
`load_annotations` determines whether annotations will be loaded from
|
|
the server. If set to False, the `annotations` field in the samples will
|
|
be set to None.
|
|
|
|
`media_download_policy` determines when media data is downloaded.
|
|
|
|
`MediaDownloadPolicy.FETCH_FRAMES_ON_DEMAND` may not be used with with `UpdatePolicy.NEVER`,
|
|
as it requires network access.
|
|
"""
|
|
|
|
self._logger = client.logger
|
|
|
|
cache_manager = make_cache_manager(client, update_policy)
|
|
self._task = cache_manager.retrieve_task(task_id)
|
|
|
|
if not self._task.size or not self._task.data_chunk_size:
|
|
raise UnsupportedDatasetError("The task has no data")
|
|
|
|
self._logger.info("Fetching labels...")
|
|
self._labels = tuple(self._task.get_labels())
|
|
|
|
data_meta = cache_manager.ensure_task_model(
|
|
self._task.id,
|
|
"data_meta.json",
|
|
models.DataMetaRead,
|
|
self._task.get_meta,
|
|
"data metadata",
|
|
)
|
|
|
|
active_frame_indexes = set(range(self._task.size)) - set(data_meta.deleted_frames)
|
|
|
|
if media_download_policy == MediaDownloadPolicy.PRELOAD_ALL:
|
|
needed_chunks = {index // self._task.data_chunk_size for index in active_frame_indexes}
|
|
self._ensure_chunks(task_id, cache_manager, needed_chunks)
|
|
self._load_frame_image = self._load_frame_image_from_cache
|
|
elif media_download_policy == MediaDownloadPolicy.FETCH_FRAMES_ON_DEMAND:
|
|
assert update_policy != UpdatePolicy.NEVER
|
|
self._load_frame_image = self._load_frame_image_from_server
|
|
else:
|
|
assert False, "Unknown media download policy"
|
|
|
|
if load_annotations:
|
|
self._load_annotations(cache_manager, sorted(active_frame_indexes))
|
|
else:
|
|
self._frame_annotations = {
|
|
frame_index: None for frame_index in sorted(active_frame_indexes)
|
|
}
|
|
|
|
# TODO: tracks?
|
|
|
|
is_imageset = self._task.data_original_chunk_type == "imageset"
|
|
|
|
self._samples = [
|
|
Sample(
|
|
frame_index=k,
|
|
frame_name=data_meta.frames[k if is_imageset else 0].name,
|
|
annotations=v,
|
|
media=self._TaskMediaElement(self, k),
|
|
)
|
|
for k, v in self._frame_annotations.items()
|
|
]
|
|
|
|
def _ensure_chunks(self, task_id, cache_manager, chunk_indexes):
|
|
if self._task.data_original_chunk_type != "imageset":
|
|
raise UnsupportedDatasetError(
|
|
f"Preloading media data is only supported for tasks with image chunks;"
|
|
f" current chunk type is {self._task.data_original_chunk_type!r}"
|
|
)
|
|
|
|
self._logger.info("Downloading chunks...")
|
|
|
|
self._chunk_dir = cache_manager.chunk_dir(task_id)
|
|
self._chunk_dir.mkdir(exist_ok=True, parents=True)
|
|
|
|
with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:
|
|
|
|
def ensure_chunk(chunk_index):
|
|
cache_manager.ensure_chunk(self._task, chunk_index)
|
|
|
|
for _ in pool.map(ensure_chunk, sorted(chunk_indexes)):
|
|
# just need to loop through all results so that any exceptions are propagated
|
|
pass
|
|
|
|
self._logger.info("All chunks downloaded")
|
|
|
|
def _load_annotations(self, cache_manager: CacheManager, frame_indexes: Iterable[int]) -> None:
|
|
annotations = cache_manager.ensure_task_model(
|
|
self._task.id,
|
|
"annotations.json",
|
|
models.LabeledData,
|
|
self._task.get_annotations,
|
|
"annotations",
|
|
)
|
|
|
|
self._frame_annotations = {frame_index: FrameAnnotations() for frame_index in frame_indexes}
|
|
|
|
for tag in annotations.tags:
|
|
# Some annotations may belong to deleted frames; skip those.
|
|
if tag.frame in self._frame_annotations:
|
|
self._frame_annotations[tag.frame].tags.append(tag)
|
|
|
|
for shape in annotations.shapes:
|
|
if shape.frame in self._frame_annotations:
|
|
self._frame_annotations[shape.frame].shapes.append(shape)
|
|
|
|
@property
|
|
def labels(self) -> Sequence[models.ILabel]:
|
|
"""
|
|
Returns the labels configured in the task.
|
|
|
|
Clients must not modify the object returned by this property or its components.
|
|
"""
|
|
return self._labels
|
|
|
|
@property
|
|
def samples(self) -> Sequence[Sample]:
|
|
"""
|
|
Returns a sequence of all samples, in order of their frame indices.
|
|
|
|
Note that the frame indices may not be contiguous, as deleted frames will not be included.
|
|
|
|
Clients must not modify the object returned by this property or its components.
|
|
"""
|
|
return self._samples
|
|
|
|
def _load_frame_image_from_cache(self, frame_index: int) -> PIL.Image:
|
|
assert frame_index in self._frame_annotations
|
|
|
|
chunk_index = frame_index // self._task.data_chunk_size
|
|
member_index = frame_index % self._task.data_chunk_size
|
|
|
|
with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
|
|
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
|
|
image = PIL.Image.open(chunk_member)
|
|
image.load()
|
|
|
|
return image
|
|
|
|
def _load_frame_image_from_server(self, frame_index: int) -> PIL.Image:
|
|
assert frame_index in self._frame_annotations
|
|
|
|
return PIL.Image.open(self._task.get_frame(frame_index, quality="original"))
|