# Copyright (C) CVAT.ai Corporation # # SPDX-License-Identifier: MIT from __future__ import annotations import zipfile from collections.abc import Iterable, Sequence from concurrent.futures import ThreadPoolExecutor import PIL.Image import cvat_sdk.core import cvat_sdk.core.exceptions import cvat_sdk.models as models from cvat_sdk.datasets.caching import CacheManager, UpdatePolicy, make_cache_manager from cvat_sdk.datasets.common import ( FrameAnnotations, MediaDownloadPolicy, MediaElement, Sample, UnsupportedDatasetError, ) _NUM_DOWNLOAD_THREADS = 4 class TaskDataset: """ Represents a task on a CVAT server as a collection of samples. Each sample corresponds to one frame in the task, and provides access to the corresponding annotations and media data. Deleted frames are omitted. This class caches all data and annotations for the task on the local file system during construction. Limitations: * Only tasks with image (not video) data are supported at the moment. * Track annotations are currently not accessible. """ class _TaskMediaElement(MediaElement): def __init__(self, dataset: TaskDataset, frame_index: int) -> None: self._dataset = dataset self._frame_index = frame_index def load_image(self) -> PIL.Image.Image: return self._dataset._load_frame_image(self._frame_index) def __init__( self, client: cvat_sdk.core.Client, task_id: int, *, update_policy: UpdatePolicy = UpdatePolicy.IF_MISSING_OR_STALE, load_annotations: bool = True, media_download_policy: MediaDownloadPolicy = MediaDownloadPolicy.PRELOAD_ALL, ) -> None: """ Creates a dataset corresponding to the task with ID `task_id` on the server that `client` is connected to. `update_policy` determines when and if the local cache will be updated. `load_annotations` determines whether annotations will be loaded from the server. If set to False, the `annotations` field in the samples will be set to None. `media_download_policy` determines when media data is downloaded. `MediaDownloadPolicy.FETCH_FRAMES_ON_DEMAND` may not be used with with `UpdatePolicy.NEVER`, as it requires network access. """ self._logger = client.logger cache_manager = make_cache_manager(client, update_policy) self._task = cache_manager.retrieve_task(task_id) if not self._task.size or not self._task.data_chunk_size: raise UnsupportedDatasetError("The task has no data") self._logger.info("Fetching labels...") self._labels = tuple(self._task.get_labels()) data_meta = cache_manager.ensure_task_model( self._task.id, "data_meta.json", models.DataMetaRead, self._task.get_meta, "data metadata", ) active_frame_indexes = set(range(self._task.size)) - set(data_meta.deleted_frames) if media_download_policy == MediaDownloadPolicy.PRELOAD_ALL: needed_chunks = {index // self._task.data_chunk_size for index in active_frame_indexes} self._ensure_chunks(task_id, cache_manager, needed_chunks) self._load_frame_image = self._load_frame_image_from_cache elif media_download_policy == MediaDownloadPolicy.FETCH_FRAMES_ON_DEMAND: assert update_policy != UpdatePolicy.NEVER self._load_frame_image = self._load_frame_image_from_server else: assert False, "Unknown media download policy" if load_annotations: self._load_annotations(cache_manager, sorted(active_frame_indexes)) else: self._frame_annotations = { frame_index: None for frame_index in sorted(active_frame_indexes) } # TODO: tracks? is_imageset = self._task.data_original_chunk_type == "imageset" self._samples = [ Sample( frame_index=k, frame_name=data_meta.frames[k if is_imageset else 0].name, annotations=v, media=self._TaskMediaElement(self, k), ) for k, v in self._frame_annotations.items() ] def _ensure_chunks(self, task_id, cache_manager, chunk_indexes): if self._task.data_original_chunk_type != "imageset": raise UnsupportedDatasetError( f"Preloading media data is only supported for tasks with image chunks;" f" current chunk type is {self._task.data_original_chunk_type!r}" ) self._logger.info("Downloading chunks...") self._chunk_dir = cache_manager.chunk_dir(task_id) self._chunk_dir.mkdir(exist_ok=True, parents=True) with ThreadPoolExecutor(_NUM_DOWNLOAD_THREADS) as pool: def ensure_chunk(chunk_index): cache_manager.ensure_chunk(self._task, chunk_index) for _ in pool.map(ensure_chunk, sorted(chunk_indexes)): # just need to loop through all results so that any exceptions are propagated pass self._logger.info("All chunks downloaded") def _load_annotations(self, cache_manager: CacheManager, frame_indexes: Iterable[int]) -> None: annotations = cache_manager.ensure_task_model( self._task.id, "annotations.json", models.LabeledData, self._task.get_annotations, "annotations", ) self._frame_annotations = {frame_index: FrameAnnotations() for frame_index in frame_indexes} for tag in annotations.tags: # Some annotations may belong to deleted frames; skip those. if tag.frame in self._frame_annotations: self._frame_annotations[tag.frame].tags.append(tag) for shape in annotations.shapes: if shape.frame in self._frame_annotations: self._frame_annotations[shape.frame].shapes.append(shape) @property def labels(self) -> Sequence[models.ILabel]: """ Returns the labels configured in the task. Clients must not modify the object returned by this property or its components. """ return self._labels @property def samples(self) -> Sequence[Sample]: """ Returns a sequence of all samples, in order of their frame indices. Note that the frame indices may not be contiguous, as deleted frames will not be included. Clients must not modify the object returned by this property or its components. """ return self._samples def _load_frame_image_from_cache(self, frame_index: int) -> PIL.Image: assert frame_index in self._frame_annotations chunk_index = frame_index // self._task.data_chunk_size member_index = frame_index % self._task.data_chunk_size with zipfile.ZipFile(self._chunk_dir / f"{chunk_index}.zip", "r") as chunk_zip: with chunk_zip.open(chunk_zip.infolist()[member_index]) as chunk_member: image = PIL.Image.open(chunk_member) image.load() return image def _load_frame_image_from_server(self, frame_index: int) -> PIL.Image: assert frame_index in self._frame_annotations return PIL.Image.open(self._task.get_frame(frame_index, quality="original"))