cvat/tests/python/sdk/test_datasets.py

232 lines
7.8 KiB
Python
Raw Permalink Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
from logging import Logger
from pathlib import Path
import cvat_sdk.datasets as cvatds
import PIL.Image
import pytest
from cvat_sdk import Client, models
from cvat_sdk.core.proxies.annotations import AnnotationUpdateAction
from cvat_sdk.core.proxies.tasks import ResourceType
from shared.utils.helpers import generate_image_files
from .util import restrict_api_requests
@pytest.fixture(autouse=True)
def _common_setup(
tmp_path: Path,
fxt_login: tuple[Client, str],
fxt_logger: tuple[Logger, io.StringIO],
restore_redis_ondisk_per_function,
restore_redis_inmem_per_function,
):
logger = fxt_logger[0]
client = fxt_login[0]
client.logger = logger
client.config.cache_dir = tmp_path / "cache"
api_client = client.api_client
for k in api_client.configuration.logger:
api_client.configuration.logger[k] = logger
class TestTaskDataset:
@pytest.fixture(autouse=True)
def setup(
self,
tmp_path: Path,
fxt_login: tuple[Client, str],
):
self.client = fxt_login[0]
self.images = generate_image_files(10)
image_dir = tmp_path / "images"
image_dir.mkdir()
image_paths = []
for image in self.images:
image_path = image_dir / image.name
image_path.write_bytes(image.getbuffer())
image_paths.append(image_path)
self.task = self.client.tasks.create_from_data(
models.TaskWriteRequest(
"Dataset layer test task",
labels=[
models.PatchedLabelRequest(name="person"),
models.PatchedLabelRequest(name="car"),
],
),
resource_type=ResourceType.LOCAL,
resources=image_paths,
data_params={"chunk_size": 3},
)
self.expected_labels = sorted(self.task.get_labels(), key=lambda l: l.id)
self.task.update_annotations(
models.PatchedLabeledDataRequest(
tags=[
models.LabeledImageRequest(frame=8, label_id=self.expected_labels[0].id),
models.LabeledImageRequest(frame=8, label_id=self.expected_labels[1].id),
],
shapes=[
models.LabeledShapeRequest(
frame=6,
label_id=self.expected_labels[1].id,
type=models.ShapeType("rectangle"),
points=[1.0, 2.0, 3.0, 4.0],
),
],
),
action=AnnotationUpdateAction.CREATE,
)
@pytest.mark.parametrize("media_download_policy", cvatds.MediaDownloadPolicy)
def test_basic(self, media_download_policy: cvatds.MediaDownloadPolicy):
dataset = cvatds.TaskDataset(
self.client, self.task.id, media_download_policy=media_download_policy
)
# verify that the cache is not empty
assert list(self.client.config.cache_dir.iterdir())
for expected_label, actual_label in zip(
self.expected_labels, sorted(dataset.labels, key=lambda l: l.id)
):
assert expected_label.id == actual_label.id
assert expected_label.name == actual_label.name
assert len(dataset.samples) == self.task.size
for index, sample in enumerate(dataset.samples):
assert sample.frame_index == index
assert sample.frame_name == self.images[index].name
actual_image = sample.media.load_image()
expected_image = PIL.Image.open(self.images[index])
assert actual_image == expected_image
assert not dataset.samples[0].annotations.tags
assert not dataset.samples[1].annotations.shapes
assert {tag.label_id for tag in dataset.samples[8].annotations.tags} == {
label.id for label in self.expected_labels
}
assert not dataset.samples[8].annotations.shapes
assert not dataset.samples[6].annotations.tags
assert len(dataset.samples[6].annotations.shapes) == 1
assert dataset.samples[6].annotations.shapes[0].type.value == "rectangle"
assert dataset.samples[6].annotations.shapes[0].points == [1.0, 2.0, 3.0, 4.0]
@pytest.mark.parametrize("media_download_policy", cvatds.MediaDownloadPolicy)
def test_deleted_frame(self, media_download_policy: cvatds.MediaDownloadPolicy):
self.task.remove_frames_by_ids([1])
dataset = cvatds.TaskDataset(
self.client, self.task.id, media_download_policy=media_download_policy
)
assert len(dataset.samples) == self.task.size - 1
# sample #0 is still frame #0
assert dataset.samples[0].frame_index == 0
assert dataset.samples[0].media.load_image() == PIL.Image.open(self.images[0])
# sample #1 is now frame #2
assert dataset.samples[1].frame_index == 2
assert dataset.samples[1].media.load_image() == PIL.Image.open(self.images[2])
# sample #5 is now frame #6
assert dataset.samples[5].frame_index == 6
assert dataset.samples[5].media.load_image() == PIL.Image.open(self.images[6])
assert len(dataset.samples[5].annotations.shapes) == 1
def test_offline(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
update_policy=cvatds.UpdatePolicy.IF_MISSING_OR_STALE,
)
fresh_samples = list(dataset.samples)
restrict_api_requests(monkeypatch)
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
update_policy=cvatds.UpdatePolicy.NEVER,
)
cached_samples = list(dataset.samples)
for fresh_sample, cached_sample in zip(fresh_samples, cached_samples):
assert fresh_sample.frame_index == cached_sample.frame_index
assert fresh_sample.annotations == cached_sample.annotations
assert fresh_sample.media.load_image() == cached_sample.media.load_image()
def test_update(self, monkeypatch: pytest.MonkeyPatch):
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
)
# Recreating the dataset should only result in minimal requests.
restrict_api_requests(
monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"}
)
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
)
assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[1].id
# After an update, the annotations should be redownloaded.
monkeypatch.undo()
self.task.update_annotations(
models.PatchedLabeledDataRequest(
shapes=[
models.LabeledShapeRequest(
id=dataset.samples[6].annotations.shapes[0].id,
frame=6,
label_id=self.expected_labels[0].id,
type=models.ShapeType("rectangle"),
points=[1.0, 2.0, 3.0, 4.0],
),
]
)
)
dataset = cvatds.TaskDataset(
self.client,
self.task.id,
)
assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[0].id
def test_no_annotations(self):
dataset = cvatds.TaskDataset(self.client, self.task.id, load_annotations=False)
for index, sample in enumerate(dataset.samples):
assert sample.frame_index == index
assert sample.frame_name == self.images[index].name
actual_image = sample.media.load_image()
expected_image = PIL.Image.open(self.images[index])
assert actual_image == expected_image
assert sample.annotations is None