232 lines
7.8 KiB
Python
232 lines
7.8 KiB
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import io
|
|
from logging import Logger
|
|
from pathlib import Path
|
|
|
|
import cvat_sdk.datasets as cvatds
|
|
import PIL.Image
|
|
import pytest
|
|
from cvat_sdk import Client, models
|
|
from cvat_sdk.core.proxies.annotations import AnnotationUpdateAction
|
|
from cvat_sdk.core.proxies.tasks import ResourceType
|
|
|
|
from shared.utils.helpers import generate_image_files
|
|
|
|
from .util import restrict_api_requests
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _common_setup(
|
|
tmp_path: Path,
|
|
fxt_login: tuple[Client, str],
|
|
fxt_logger: tuple[Logger, io.StringIO],
|
|
restore_redis_ondisk_per_function,
|
|
restore_redis_inmem_per_function,
|
|
):
|
|
logger = fxt_logger[0]
|
|
client = fxt_login[0]
|
|
client.logger = logger
|
|
client.config.cache_dir = tmp_path / "cache"
|
|
|
|
api_client = client.api_client
|
|
for k in api_client.configuration.logger:
|
|
api_client.configuration.logger[k] = logger
|
|
|
|
|
|
class TestTaskDataset:
|
|
@pytest.fixture(autouse=True)
|
|
def setup(
|
|
self,
|
|
tmp_path: Path,
|
|
fxt_login: tuple[Client, str],
|
|
):
|
|
self.client = fxt_login[0]
|
|
self.images = generate_image_files(10)
|
|
|
|
image_dir = tmp_path / "images"
|
|
image_dir.mkdir()
|
|
|
|
image_paths = []
|
|
for image in self.images:
|
|
image_path = image_dir / image.name
|
|
image_path.write_bytes(image.getbuffer())
|
|
image_paths.append(image_path)
|
|
|
|
self.task = self.client.tasks.create_from_data(
|
|
models.TaskWriteRequest(
|
|
"Dataset layer test task",
|
|
labels=[
|
|
models.PatchedLabelRequest(name="person"),
|
|
models.PatchedLabelRequest(name="car"),
|
|
],
|
|
),
|
|
resource_type=ResourceType.LOCAL,
|
|
resources=image_paths,
|
|
data_params={"chunk_size": 3},
|
|
)
|
|
|
|
self.expected_labels = sorted(self.task.get_labels(), key=lambda l: l.id)
|
|
|
|
self.task.update_annotations(
|
|
models.PatchedLabeledDataRequest(
|
|
tags=[
|
|
models.LabeledImageRequest(frame=8, label_id=self.expected_labels[0].id),
|
|
models.LabeledImageRequest(frame=8, label_id=self.expected_labels[1].id),
|
|
],
|
|
shapes=[
|
|
models.LabeledShapeRequest(
|
|
frame=6,
|
|
label_id=self.expected_labels[1].id,
|
|
type=models.ShapeType("rectangle"),
|
|
points=[1.0, 2.0, 3.0, 4.0],
|
|
),
|
|
],
|
|
),
|
|
action=AnnotationUpdateAction.CREATE,
|
|
)
|
|
|
|
@pytest.mark.parametrize("media_download_policy", cvatds.MediaDownloadPolicy)
|
|
def test_basic(self, media_download_policy: cvatds.MediaDownloadPolicy):
|
|
dataset = cvatds.TaskDataset(
|
|
self.client, self.task.id, media_download_policy=media_download_policy
|
|
)
|
|
|
|
# verify that the cache is not empty
|
|
assert list(self.client.config.cache_dir.iterdir())
|
|
|
|
for expected_label, actual_label in zip(
|
|
self.expected_labels, sorted(dataset.labels, key=lambda l: l.id)
|
|
):
|
|
assert expected_label.id == actual_label.id
|
|
assert expected_label.name == actual_label.name
|
|
|
|
assert len(dataset.samples) == self.task.size
|
|
|
|
for index, sample in enumerate(dataset.samples):
|
|
assert sample.frame_index == index
|
|
assert sample.frame_name == self.images[index].name
|
|
|
|
actual_image = sample.media.load_image()
|
|
expected_image = PIL.Image.open(self.images[index])
|
|
|
|
assert actual_image == expected_image
|
|
|
|
assert not dataset.samples[0].annotations.tags
|
|
assert not dataset.samples[1].annotations.shapes
|
|
|
|
assert {tag.label_id for tag in dataset.samples[8].annotations.tags} == {
|
|
label.id for label in self.expected_labels
|
|
}
|
|
assert not dataset.samples[8].annotations.shapes
|
|
|
|
assert not dataset.samples[6].annotations.tags
|
|
assert len(dataset.samples[6].annotations.shapes) == 1
|
|
assert dataset.samples[6].annotations.shapes[0].type.value == "rectangle"
|
|
assert dataset.samples[6].annotations.shapes[0].points == [1.0, 2.0, 3.0, 4.0]
|
|
|
|
@pytest.mark.parametrize("media_download_policy", cvatds.MediaDownloadPolicy)
|
|
def test_deleted_frame(self, media_download_policy: cvatds.MediaDownloadPolicy):
|
|
self.task.remove_frames_by_ids([1])
|
|
|
|
dataset = cvatds.TaskDataset(
|
|
self.client, self.task.id, media_download_policy=media_download_policy
|
|
)
|
|
|
|
assert len(dataset.samples) == self.task.size - 1
|
|
|
|
# sample #0 is still frame #0
|
|
assert dataset.samples[0].frame_index == 0
|
|
assert dataset.samples[0].media.load_image() == PIL.Image.open(self.images[0])
|
|
|
|
# sample #1 is now frame #2
|
|
assert dataset.samples[1].frame_index == 2
|
|
assert dataset.samples[1].media.load_image() == PIL.Image.open(self.images[2])
|
|
|
|
# sample #5 is now frame #6
|
|
assert dataset.samples[5].frame_index == 6
|
|
assert dataset.samples[5].media.load_image() == PIL.Image.open(self.images[6])
|
|
assert len(dataset.samples[5].annotations.shapes) == 1
|
|
|
|
def test_offline(self, monkeypatch: pytest.MonkeyPatch):
|
|
dataset = cvatds.TaskDataset(
|
|
self.client,
|
|
self.task.id,
|
|
update_policy=cvatds.UpdatePolicy.IF_MISSING_OR_STALE,
|
|
)
|
|
|
|
fresh_samples = list(dataset.samples)
|
|
|
|
restrict_api_requests(monkeypatch)
|
|
|
|
dataset = cvatds.TaskDataset(
|
|
self.client,
|
|
self.task.id,
|
|
update_policy=cvatds.UpdatePolicy.NEVER,
|
|
)
|
|
|
|
cached_samples = list(dataset.samples)
|
|
|
|
for fresh_sample, cached_sample in zip(fresh_samples, cached_samples):
|
|
assert fresh_sample.frame_index == cached_sample.frame_index
|
|
assert fresh_sample.annotations == cached_sample.annotations
|
|
assert fresh_sample.media.load_image() == cached_sample.media.load_image()
|
|
|
|
def test_update(self, monkeypatch: pytest.MonkeyPatch):
|
|
dataset = cvatds.TaskDataset(
|
|
self.client,
|
|
self.task.id,
|
|
)
|
|
|
|
# Recreating the dataset should only result in minimal requests.
|
|
restrict_api_requests(
|
|
monkeypatch, allow_paths={f"/api/tasks/{self.task.id}", "/api/labels"}
|
|
)
|
|
|
|
dataset = cvatds.TaskDataset(
|
|
self.client,
|
|
self.task.id,
|
|
)
|
|
|
|
assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[1].id
|
|
|
|
# After an update, the annotations should be redownloaded.
|
|
monkeypatch.undo()
|
|
|
|
self.task.update_annotations(
|
|
models.PatchedLabeledDataRequest(
|
|
shapes=[
|
|
models.LabeledShapeRequest(
|
|
id=dataset.samples[6].annotations.shapes[0].id,
|
|
frame=6,
|
|
label_id=self.expected_labels[0].id,
|
|
type=models.ShapeType("rectangle"),
|
|
points=[1.0, 2.0, 3.0, 4.0],
|
|
),
|
|
]
|
|
)
|
|
)
|
|
|
|
dataset = cvatds.TaskDataset(
|
|
self.client,
|
|
self.task.id,
|
|
)
|
|
|
|
assert dataset.samples[6].annotations.shapes[0].label_id == self.expected_labels[0].id
|
|
|
|
def test_no_annotations(self):
|
|
dataset = cvatds.TaskDataset(self.client, self.task.id, load_annotations=False)
|
|
|
|
for index, sample in enumerate(dataset.samples):
|
|
assert sample.frame_index == index
|
|
assert sample.frame_name == self.images[index].name
|
|
|
|
actual_image = sample.media.load_image()
|
|
expected_image = PIL.Image.open(self.images[index])
|
|
|
|
assert actual_image == expected_image
|
|
|
|
assert sample.annotations is None
|