cvat/cvat-sdk/cvat_sdk/datasets/task_dataset.py

209 lines
7.3 KiB
Python
Raw Permalink Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import zipfile
from collections.abc import Iterable, Sequence
from concurrent.futures import ThreadPoolExecutor
import PIL.Image
import cvat_sdk.core
import cvat_sdk.core.exceptions
import cvat_sdk.models as models
from cvat_sdk.datasets.caching import CacheManager, UpdatePolicy, make_cache_manager
from cvat_sdk.datasets.common import (
FrameAnnotations,
MediaDownloadPolicy,
MediaElement,
Sample,
UnsupportedDatasetError,
)
_NUM_DOWNLOAD_THREADS = 4
class TaskDataset:
"""
Represents a task on a CVAT server as a collection of samples.
Each sample corresponds to one frame in the task, and provides access to
the corresponding annotations and media data. Deleted frames are omitted.
This class caches all data and annotations for the task on the local file system
during construction.
Limitations:
* Only tasks with image (not video) data are supported at the moment.
* Track annotations are currently not accessible.
"""
class _TaskMediaElement(MediaElement):
def __init__(self, dataset: TaskDataset, frame_index: int) -> None:
self._dataset = dataset
self._frame_index = frame_index
def load_image(self) -> PIL.Image.Image:
return self._dataset._load_frame_image(self._frame_index)
def __init__(
self,
client: cvat_sdk.core.Client,
task_id: int,
*,
update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE,
load_annotations: bool = True,
media_download_policy: MediaDownloadPolicy = MediaDownloadPolicy.PRELOAD_ALL,
) -> None:
"""
Creates a dataset corresponding to the task with ID `task_id` on the
server that `client` is connected to.
`update_policy` determines when and if the local cache will be updated.
`load_annotations` determines whether annotations will be loaded from
the server. If set to False, the `annotations` field in the samples will
be set to None.
`media_download_policy` determines when media data is downloaded.
`MediaDownloadPolicy.FETCH_FRAMES_ON_DEMAND` may not be used with with `UpdatePolicy.NEVER`,
as it requires network access.
"""
self._logger = client.logger
cache_manager = make_cache_manager(client, update_policy)
self._task = cache_manager.retrieve_task(task_id)
if not self._task.size or not self._task.data_chunk_size:
raise UnsupportedDatasetError("The task has no data")
self._logger.info("Fetching labels...")
self._labels = tuple(self._task.get_labels())
data_meta = cache_manager.ensure_task_model(
self._task.id,
"data_meta.json",
models.DataMetaRead,
self._task.get_meta,
"data metadata",
)
active_frame_indexes = set(range(self._task.size)) - set(data_meta.deleted_frames)
if media_download_policy == MediaDownloadPolicy.PRELOAD_ALL:
needed_chunks = {index // self._task.data_chunk_size for index in active_frame_indexes}
self._ensure_chunks(task_id, cache_manager, needed_chunks)
self._load_frame_image = self._load_frame_image_from_cache
elif media_download_policy == MediaDownloadPolicy.FETCH_FRAMES_ON_DEMAND:
assert update_policy != UpdatePolicy.NEVER
self._load_frame_image = self._load_frame_image_from_server
else:
assert False, "Unknown media download policy"
if load_annotations:
self._load_annotations(cache_manager, sorted(active_frame_indexes))
else:
self._frame_annotations = {
frame_index: None for frame_index in sorted(active_frame_indexes)
}
# TODO: tracks?
is_imageset = self._task.data_original_chunk_type == "imageset"
self._samples = [
Sample(
frame_index=k,
frame_name=data_meta.frames[k if is_imageset else 0].name,
annotations=v,
media=self._TaskMediaElement(self, k),
)
for k, v in self._frame_annotations.items()
]
def _ensure_chunks(self, task_id, cache_manager, chunk_indexes):
if self._task.data_original_chunk_type != "imageset":
raise UnsupportedDatasetError(
f"Preloading media data is only supported for tasks with image chunks;"
f" current chunk type is {self._task.data_original_chunk_type!r}"
)
self._logger.info("Downloading chunks...")
self._chunk_dir = cache_manager.chunk_dir(task_id)
self._chunk_dir.mkdir(exist_ok=True, parents=True)
with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool:
def ensure_chunk(chunk_index):
cache_manager.ensure_chunk(self._task, chunk_index)
for _ in pool.map(ensure_chunk, sorted(chunk_indexes)):
# just need to loop through all results so that any exceptions are propagated
pass
self._logger.info("All chunks downloaded")
def _load_annotations(self, cache_manager: CacheManager, frame_indexes: Iterable[int]) -> None:
annotations = cache_manager.ensure_task_model(
self._task.id,
"annotations.json",
models.LabeledData,
self._task.get_annotations,
"annotations",
)
self._frame_annotations = {frame_index: FrameAnnotations() for frame_index in frame_indexes}
for tag in annotations.tags:
# Some annotations may belong to deleted frames; skip those.
if tag.frame in self._frame_annotations:
self._frame_annotations[tag.frame].tags.append(tag)
for shape in annotations.shapes:
if shape.frame in self._frame_annotations:
self._frame_annotations[shape.frame].shapes.append(shape)
@property
def labels(self) -> Sequence[models.ILabel]:
"""
Returns the labels configured in the task.
Clients must not modify the object returned by this property or its components.
"""
return self._labels
@property
def samples(self) -> Sequence[Sample]:
"""
Returns a sequence of all samples, in order of their frame indices.
Note that the frame indices may not be contiguous, as deleted frames will not be included.
Clients must not modify the object returned by this property or its components.
"""
return self._samples
def _load_frame_image_from_cache(self, frame_index: int) -> PIL.Image:
assert frame_index in self._frame_annotations
chunk_index = frame_index // self._task.data_chunk_size
member_index = frame_index % self._task.data_chunk_size
with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip:
with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member:
image = PIL.Image.open(chunk_member)
image.load()
return image
def _load_frame_image_from_server(self, frame_index: int) -> PIL.Image:
assert frame_index in self._frame_annotations
return PIL.Image.open(self._task.get_frame(frame_index, quality="original"))