cvat/tests/python/rest_api/test_labels.py

966 lines
35 KiB
Python
Raw Permalink Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import itertools
import json
from copy import deepcopy
from http import HTTPStatus
from types import SimpleNamespace
from typing import Any, Optional
import pytest
from cvat_sdk import exceptions, models
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
from cvat_sdk.core.helpers import get_paginated_collection
from dateutil.parser import isoparse as parse_datetime
from deepdiff import DeepDiff
from pytest_cases import fixture, fixture_ref, parametrize
from shared.utils.config import delete_method, get_method, make_api_client, patch_method
from .utils import CollectionSimpleFilterTestBase, build_exclude_paths_expr, get_attrs
class _TestLabelsPermissionsBase:
@pytest.fixture
def _base_setup(
self,
users,
labels,
jobs,
tasks,
projects,
is_task_staff,
is_project_staff,
users_by_name,
tasks_by_org,
projects_by_org,
memberships,
org_staff,
):
self.users = users
self.labels = labels
self.jobs = jobs
self.tasks = tasks
self.projects = projects
self.is_task_staff = is_task_staff
self.is_project_staff = is_project_staff
self.users_by_name = users_by_name
self.tasks_by_org = tasks_by_org
self.projects_by_org = projects_by_org
self.memberships = memberships
self.org_staff = org_staff
@pytest.fixture(autouse=True)
def setup(self, _base_setup):
"""
This function only calls the _base_setup() fixture.
It can be overridden in derived classes.
"""
@staticmethod
def _labels_by_source(labels: list[dict], *, source_key: str) -> dict[int, list[dict]]:
labels_by_source = {}
for label in labels:
label_source = label.get(source_key)
if label_source:
labels_by_source.setdefault(label_source, []).append(label)
return labels_by_source
def _get_source_info(self, source: str, *, org_id: Optional[int] = None):
if source == "task":
sources = self.tasks_by_org
is_source_staff = self.is_task_staff
label_source_key = "task_id"
elif source == "project":
sources = self.projects_by_org
is_source_staff = self.is_project_staff
label_source_key = "project_id"
else:
assert False
sources = sources[org_id or ""]
return SimpleNamespace(
sources=sources, is_source_staff=is_source_staff, label_source_key=label_source_key
)
source_types = ["task", "project"]
org_roles = ["worker", "supervisor", "maintainer", "owner"]
@fixture
@parametrize("source", source_types)
@parametrize("user", ["admin1"])
@parametrize("is_staff", [True, False])
def admin_sandbox_case(self, user, source, is_staff):
sources, is_source_staff, label_source_key = get_attrs(
self._get_source_info(source),
["sources", "is_source_staff", "label_source_key"],
)
labels_by_source = self._labels_by_source(self.labels, source_key=label_source_key)
sources_with_labels = [s for s in sources if labels_by_source.get(s["id"])]
user_id = self.users_by_name[user]["id"]
source_obj = next(
filter(lambda s: is_source_staff(user_id, s["id"]) == is_staff, sources_with_labels)
)
label = labels_by_source[source_obj["id"]][0]
yield SimpleNamespace(label=label, user=user, source=source, is_staff=is_staff)
@fixture
@parametrize("source", source_types)
@parametrize("org_id", [2])
@parametrize("user", ["admin2"])
@parametrize("is_staff", [False])
def admin_org_case(self, user, source, org_id, is_staff):
sources, is_source_staff, label_source_key = get_attrs(
self._get_source_info(source, org_id=org_id),
["sources", "is_source_staff", "label_source_key"],
)
labels_by_source = self._labels_by_source(self.labels, source_key=label_source_key)
sources_with_labels = [s for s in sources if labels_by_source.get(s["id"])]
user_id = self.users_by_name[user]["id"]
source_obj = next(
filter(lambda s: is_source_staff(user_id, s["id"]) == is_staff, sources_with_labels)
)
label = labels_by_source[source_obj["id"]][0]
yield SimpleNamespace(
label=label, user=user, source=source, org_id=org_id, is_staff=is_staff
)
@fixture
@parametrize("source", source_types)
@parametrize("is_staff", [True, False])
def user_sandbox_case(self, source, is_staff):
sources, label_source_key = get_attrs(
self._get_source_info(source),
["sources", "label_source_key"],
)
users = {u["id"]: u for u in self.users if not u["is_superuser"]}
regular_users_sources = [
s for s in sources if s["owner"]["id"] in users and s["organization"] is None
]
labels_by_source = self._labels_by_source(self.labels, source_key=label_source_key)
source_obj = next(s for s in regular_users_sources if labels_by_source.get(s["id"]))
label = labels_by_source[source_obj["id"]][0]
user = next(u for u in users.values() if (u["id"] == source_obj["owner"]["id"]) == is_staff)
yield SimpleNamespace(label=label, user=user, is_staff=is_staff)
@fixture
@parametrize("source", source_types)
@parametrize("org_id", [2])
@parametrize(
"role, src_staff", list(itertools.product(org_roles, [True, False])) + [(None, False)]
)
def user_org_case(self, source, src_staff, role, org_id):
sources, is_source_staff, label_source_key = get_attrs(
self._get_source_info(source, org_id=org_id),
["sources", "is_source_staff", "label_source_key"],
)
labels_by_source = self._labels_by_source(self.labels, source_key=label_source_key)
users = {u["id"]: u for u in self.users_by_name.values() if not u["is_superuser"]}
staff_by_role = {}
for m in self.memberships:
if m["organization"] == org_id:
staff_by_role.setdefault(m["role"], []).append(
self.users_by_name[m["user"]["username"]]
)
for source_obj in (s for s in sources if labels_by_source.get(s["id"])):
user = next(
(
u
for u in users.values()
if is_source_staff(u["id"], source_obj["id"]) == src_staff
or not role
or u["id"] in self.org_staff(org_id)
if not role or u in staff_by_role[role]
),
None,
)
if user:
break
assert user
label = labels_by_source[source_obj["id"]][0]
yield SimpleNamespace(
label=label,
user=user,
org_id=org_id,
is_staff=src_staff or user["id"] in self.org_staff(org_id),
)
class TestLabelsListFilters(CollectionSimpleFilterTestBase):
@pytest.fixture(autouse=True)
def setup(self, restore_db_per_class, admin_user, labels, jobs_wlc, tasks_wlc, projects_wlc):
self.user = admin_user
self.samples = labels
self.job_samples = jobs_wlc
self.task_samples = tasks_wlc
self.project_samples = projects_wlc
def _get_endpoint(self, api_client: ApiClient) -> Endpoint:
return api_client.labels_api.list_endpoint
def _get_field_samples(self, field: str) -> tuple[Any, list[dict[str, Any]]]:
if field == "parent":
parent_id, gt_objects = self._get_field_samples("parent_id")
parent_name = self._get_field(
next(
filter(
lambda p: parent_id == self._get_field(p, self._map_field("id")),
self.samples,
)
),
self._map_field("name"),
)
return parent_name, gt_objects
elif field == "job_id":
field_path = ["id"]
field_value = self._find_valid_field_value(self.job_samples, field_path)
job_sample = next(
filter(lambda p: field_value == self._get_field(p, field_path), self.job_samples)
)
task_id = job_sample["task_id"]
project_id = job_sample["project_id"]
label_samples = filter(
lambda p: (task_id and task_id == p.get("task_id"))
or (project_id and project_id == p.get("project_id")),
self.samples,
)
return field_value, label_samples
else:
return super()._get_field_samples(field)
@pytest.mark.parametrize(
"field",
("name", "job_id", "task_id", "project_id", "type", "color"),
)
def test_can_use_simple_filter_for_object_list(self, field):
return super()._test_can_use_simple_filter_for_object_list(field)
@pytest.mark.parametrize(
"key1, key2", itertools.combinations(["job_id", "task_id", "project_id"], 2)
)
def test_cant_mix_job_task_project_filters(self, key1, key2):
kwargs = {}
for key in [key1, key2]:
if key == "job_id":
v = self._find_valid_field_value(self.job_samples, ["id"])
elif key == "task_id":
v = self._find_valid_field_value(self.task_samples, ["id"])
elif key == "project_id":
v = self._find_valid_field_value(self.project_samples, ["id"])
else:
assert False
kwargs[key] = v
with pytest.raises(exceptions.ApiException) as capture:
self._retrieve_collection(**kwargs)
assert capture.value.status == 400
assert "cannot be used together" in capture.value.body
@pytest.mark.parametrize("org_id", [None, 2])
@pytest.mark.parametrize("dst, src", itertools.combinations(["job", "task", "project"], 2))
def test_can_list_inherited_labels(self, org_id, dst, src):
kwargs = {}
if src == "project":
src_with_labels = next(
p
for p in self.project_samples
if p["labels"]["count"] > 0
and p["organization"] == org_id
and p["tasks"]["count"] > 0
)
elif src == "task":
src_with_labels = next(
t
for t in self.task_samples
if t["labels"]["count"] > 0
and t["organization"] == org_id
and t["jobs"]["count"] > 0
and not t.get("project_id")
)
else:
assert False
labels = [l for l in self.samples if l.get(f"{src}_id") == src_with_labels["id"]]
if dst == "task":
dst_obj = next(
t for t in self.task_samples if t.get(f"{src}_id") == src_with_labels["id"]
)
kwargs["task_id"] = dst_obj["id"]
elif dst == "job":
dst_obj = next(
j for j in self.job_samples if j.get(f"{src}_id") == src_with_labels["id"]
)
kwargs["job_id"] = dst_obj["id"]
else:
assert False
if org_id:
kwargs["org_id"] = org_id
retrieved_data = self._retrieve_collection(**kwargs)
self._compare_results(labels, retrieved_data)
@pytest.mark.usefixtures("restore_db_per_class")
class TestListLabels(_TestLabelsPermissionsBase):
def _test_list_ok(self, user, data, **kwargs):
with make_api_client(user) as client:
results = get_paginated_collection(
client.labels_api.list_endpoint, **kwargs, return_json=True
)
assert (
DeepDiff(
data,
results,
exclude_paths="root['updated_date']",
ignore_order=True,
)
== {}
)
def _test_list_denied(self, user, **kwargs):
with make_api_client(user) as client:
(_, response) = client.labels_api.list(
**kwargs, _parse_response=False, _check_status=False
)
assert response.status == HTTPStatus.FORBIDDEN
@pytest.mark.parametrize("org_id", [2])
@pytest.mark.parametrize("source_type", ["job", "task", "project"])
@pytest.mark.parametrize("role", ["worker", "supervisor"])
@pytest.mark.parametrize("staff", [True, False])
def test_staff_can_list_labels_in_org(
self,
org_id,
source_type,
role,
staff,
labels,
jobs_wlc,
tasks_wlc,
projects_wlc,
users,
is_project_staff,
is_task_staff,
is_job_staff,
memberships,
users_by_name,
):
labels_by_project = self._labels_by_source(labels, source_key="project_id")
labels_by_task = self._labels_by_source(labels, source_key="task_id")
if source_type == "project":
sources = [
p for p in projects_wlc if p["labels"]["count"] > 0 and p["organization"] == org_id
]
labels_by_source = labels_by_project
is_staff = is_project_staff
elif source_type == "task":
sources = [
t for t in tasks_wlc if t["labels"]["count"] > 0 and t["organization"] == org_id
]
labels_by_source = {
task["id"]: (
labels_by_task.get(task["id"]) or labels_by_project.get(task.get("project_id"))
)
for task in sources
}
is_staff = is_task_staff
elif source_type == "job":
sources = [
j
for j in jobs_wlc
if j["labels"]["count"] > 0
if next(t for t in tasks_wlc if t["id"] == j["task_id"])["organization"] == org_id
]
labels_by_source = {
job["id"]: (
labels_by_task.get(job["task_id"]) or labels_by_project.get(job["project_id"])
)
for job in sources
}
is_staff = is_job_staff
else:
assert False
staff_by_role = {}
for m in memberships:
if m["organization"] == org_id:
staff_by_role.setdefault(m["role"], []).append(users_by_name[m["user"]["username"]])
for source in sources:
user = next(
(
u
for u in users
if not u["is_superuser"]
if is_staff(u["id"], source["id"]) == staff
if u in staff_by_role[role]
),
None,
)
if user:
break
assert source
assert user
labels = labels_by_source[source["id"]]
kwargs = {
"org_id": org_id,
f"{source_type}_id": source["id"],
}
if staff:
self._test_list_ok(user["username"], labels, **kwargs)
else:
self._test_list_denied(user["username"], **kwargs)
@pytest.mark.parametrize("org_id", [2])
@pytest.mark.parametrize("source_type", ["job", "task", "project"])
def test_only_1st_level_labels_included(
self, projects_wlc, tasks_wlc, jobs_wlc, labels, admin_user, source_type, org_id
):
labels_by_project = self._labels_by_source(labels, source_key="project_id")
labels_by_task = self._labels_by_source(labels, source_key="task_id")
if source_type == "project":
sources = [
p for p in projects_wlc if p["labels"]["count"] > 0 and p["organization"] == org_id
]
labels_by_source = labels_by_project
elif source_type == "task":
sources = [
t for t in tasks_wlc if t["labels"]["count"] > 0 and t["organization"] == org_id
]
labels_by_source = {
task["id"]: (
labels_by_task.get(task["id"]) or labels_by_project.get(task.get("project_id"))
)
for task in sources
}
elif source_type == "job":
sources = [
j
for j in jobs_wlc
if j["labels"]["count"] > 0
if next(t for t in tasks_wlc if t["id"] == j["task_id"])["organization"] == org_id
]
labels_by_source = {
job["id"]: (
labels_by_task.get(job["task_id"]) or labels_by_project.get(job["project_id"])
)
for job in sources
}
else:
assert False
source = next(
s for s in sources if any(label["sublabels"] for label in labels_by_source[s["id"]])
)
source_labels = labels_by_source[source["id"]]
assert not any(label["has_parent"] for label in source_labels)
kwargs = {
"org_id": org_id,
f"{source_type}_id": source["id"],
}
self._test_list_ok(admin_user, source_labels, **kwargs)
class TestGetLabels(_TestLabelsPermissionsBase):
@pytest.fixture(autouse=True)
def setup(self, restore_db_per_class, _base_setup): # pylint: disable=arguments-differ
pass
def _test_get_ok(self, user, lid, data):
with make_api_client(user) as client:
(_, response) = client.labels_api.retrieve(lid)
assert response.status == HTTPStatus.OK
assert (
DeepDiff(
data,
json.loads(response.data),
exclude_paths="root['updated_date']",
ignore_order=True,
)
== {}
)
def _test_get_denied(self, user, lid):
with make_api_client(user) as client:
(_, response) = client.labels_api.retrieve(
lid, _check_status=False, _parse_response=False
)
assert response.status == HTTPStatus.FORBIDDEN
def test_admin_get_sandbox_label(self, admin_sandbox_case):
label, user = get_attrs(admin_sandbox_case, ["label", "user"])
self._test_get_ok(user, label["id"], label)
def test_admin_get_org_label(self, admin_org_case):
label, user = get_attrs(admin_org_case, ["label", "user"])
self._test_get_ok(user, label["id"], label)
def test_regular_user_get_sandbox_label(self, user_sandbox_case):
label, user, is_staff = get_attrs(user_sandbox_case, ["label", "user", "is_staff"])
if is_staff:
self._test_get_ok(user["username"], label["id"], label)
else:
self._test_get_denied(user["username"], label["id"])
def test_regular_user_get_org_label(self, user_org_case):
label, user, is_staff = get_attrs(user_org_case, ["label", "user", "is_staff"])
if is_staff:
self._test_get_ok(user["username"], label["id"], label)
else:
self._test_get_denied(user["username"], label["id"])
class TestPatchLabels(_TestLabelsPermissionsBase):
@pytest.fixture(autouse=True)
def setup(self, restore_db_per_function, _base_setup): # pylint: disable=arguments-differ
self.ignore_fields = ["updated_date"]
def _build_exclude_paths_expr(self, ignore_fields=None):
if ignore_fields is None:
ignore_fields = self.ignore_fields
return build_exclude_paths_expr(ignore_fields)
def _test_update_ok(self, user, lid, data, *, expected_data=None, ignore_fields=None, **kwargs):
with make_api_client(user) as client:
(_, response) = client.labels_api.partial_update(
lid, patched_label_request=models.PatchedLabelRequest(**deepcopy(data)), **kwargs
)
assert response.status == HTTPStatus.OK
assert (
DeepDiff(
expected_data if expected_data is not None else data,
json.loads(response.data),
exclude_regex_paths=self._build_exclude_paths_expr(ignore_fields),
ignore_order=True,
)
== {}
)
return response
def _test_update_denied(self, user, lid, data, expected_status=HTTPStatus.FORBIDDEN, **kwargs):
with make_api_client(user) as client:
(_, response) = client.labels_api.partial_update(
lid,
patched_label_request=models.PatchedLabelRequest(**deepcopy(data)),
**kwargs,
_check_status=False,
_parse_response=False,
)
assert response.status == expected_status
return response
def _get_patch_data(
self, original_data: dict[str, Any], **overrides
) -> tuple[dict[str, Any], dict[str, Any]]:
result = deepcopy(original_data)
result.update(overrides)
ignore_fields = self.ignore_fields.copy()
if overrides:
payload = deepcopy(overrides)
if overridden_attributes := deepcopy(overrides.get("attributes", [])):
combined_attributes = deepcopy(original_data.get("attributes", []))
mapping = {attr["id"]: attr for attr in overridden_attributes if "id" in attr}
# no attributes to update
if not mapping:
ignore_fields.append("attributes.id")
for attr in combined_attributes:
if attr["id"] in mapping:
attr.update(mapping[attr["id"]])
for attr in overridden_attributes:
if attr not in combined_attributes:
combined_attributes.append(attr)
payload["attributes"] = deepcopy(combined_attributes)
result["attributes"] = deepcopy(combined_attributes)
# Changing skeletons is not supported
if overrides.get("type") == "skeleton":
result["type"] = original_data["type"]
if "name" in overrides:
ignore_fields.append("color")
else:
payload = deepcopy(original_data)
return result, payload, ignore_fields
@parametrize(
"param, newvalue",
list(
itertools.chain.from_iterable(
itertools.product([k], values)
for k, values in {
"attributes": [
[
{
"default_value": "mazda_new",
"input_type": "select",
"mutable": True,
"name": "model_new",
"values": ["mazda_new", "bmw"],
}
],
],
"color": ["#2000c0"],
"name": ["modified"],
"type": [
"any",
"cuboid",
"ellipse",
"mask",
"points",
"polygon",
"polyline",
"rectangle",
"skeleton",
"tag",
],
}.items()
)
),
)
@parametrize("source", _TestLabelsPermissionsBase.source_types)
def test_can_patch_label_field(self, source, admin_user, param, newvalue):
user = admin_user
label = next(
iter(
self._labels_by_source(
self.labels, source_key=self._get_source_info(source).label_source_key
).values()
)
)[0]
expected_data, patch_data, ignore_fields = self._get_patch_data(label, **{param: newvalue})
self._test_update_ok(
user, label["id"], patch_data, expected_data=expected_data, ignore_fields=ignore_fields
)
@parametrize("source", _TestLabelsPermissionsBase.source_types)
def test_can_patch_attribute_name(self, source: str, admin_user: str):
source_key = self._get_source_info(source).label_source_key
label = next(
l
for l in self.labels
if l.get(source_key) and not l["has_parent"] and l.get("attributes")
)
attributes = deepcopy(label["attributes"])
for attribute in attributes:
attribute["name"] += "_updated"
expected_data, patch_data, ignore_fields = self._get_patch_data(
label, attributes=attributes
)
self._test_update_ok(
admin_user,
label["id"],
patch_data,
expected_data=expected_data,
ignore_fields=ignore_fields,
)
@parametrize("source", _TestLabelsPermissionsBase.source_types)
def test_cannot_patch_sublabel_directly(self, admin_user, source):
user = admin_user
label = next(
sublabel
for source_labels in self._labels_by_source(
self.labels, source_key=self._get_source_info(source).label_source_key
).values()
for label in source_labels
for sublabel in label["sublabels"]
)
with make_api_client(user) as client:
(_, response) = client.labels_api.partial_update(
label["id"],
patched_label_request=models.PatchedLabelRequest(**label),
_parse_response=False,
_check_status=False,
)
assert response.status == HTTPStatus.BAD_REQUEST
assert "Sublabels cannot be modified this way." in response.data.decode()
@parametrize("user", [fixture_ref("admin_user")])
@parametrize("source_type", _TestLabelsPermissionsBase.source_types)
def test_cannot_rename_label_to_duplicate_name(self, source_type, user):
source_info = self._get_source_info(source_type)
labels_by_source = self._labels_by_source(
self.labels, source_key=source_info.label_source_key
)
labels = next(ls for ls in labels_by_source.values() if len(ls) >= 2)
payload = {"name": labels[1]["name"]}
response = self._test_update_denied(
user, lid=labels[0]["id"], data=payload, expected_status=HTTPStatus.BAD_REQUEST
)
assert "All label names must be unique" in response.data.decode()
def test_admin_patch_sandbox_label(self, admin_sandbox_case):
label, user = get_attrs(admin_sandbox_case, ["label", "user"])
expected_data, patch_data, *_ = self._get_patch_data(label)
self._test_update_ok(user, label["id"], patch_data, expected_data=expected_data)
def test_admin_patch_org_label(self, admin_org_case):
label, user = get_attrs(admin_org_case, ["label", "user"])
expected_data, patch_data, *_ = self._get_patch_data(label)
self._test_update_ok(user, label["id"], patch_data, expected_data=expected_data)
def test_regular_user_patch_sandbox_label(self, user_sandbox_case):
label, user, is_staff = get_attrs(user_sandbox_case, ["label", "user", "is_staff"])
expected_data, patch_data, *_ = self._get_patch_data(label)
if is_staff:
self._test_update_ok(
user["username"], label["id"], patch_data, expected_data=expected_data
)
else:
self._test_update_denied(user["username"], label["id"], patch_data)
def test_regular_user_patch_org_label(self, user_org_case):
label, user, is_staff = get_attrs(user_org_case, ["label", "user", "is_staff"])
expected_data, patch_data, *_ = self._get_patch_data(label)
if is_staff:
self._test_update_ok(
user["username"], label["id"], patch_data, expected_data=expected_data
)
else:
self._test_update_denied(user["username"], label["id"], patch_data)
class TestDeleteLabels(_TestLabelsPermissionsBase):
@pytest.fixture(autouse=True)
def setup(self, restore_db_per_function, _base_setup): # pylint: disable=arguments-differ
pass
def _test_delete_ok(self, user, lid, **kwargs):
with make_api_client(user) as client:
(_, response) = client.labels_api.destroy(lid, **kwargs)
assert response.status == HTTPStatus.NO_CONTENT
def _test_delete_denied(self, user, lid, **kwargs):
with make_api_client(user) as client:
(_, response) = client.labels_api.partial_update(
lid,
**kwargs,
_check_status=False,
_parse_response=False,
)
assert response.status == HTTPStatus.FORBIDDEN
@parametrize("source", _TestLabelsPermissionsBase.source_types)
def test_can_delete_label(self, admin_user, source):
user = admin_user
label = next(
iter(
self._labels_by_source(
self.labels, source_key=self._get_source_info(source).label_source_key
).values()
)
)[0]
with make_api_client(user) as client:
(_, response) = client.labels_api.destroy(label["id"])
assert response.status == HTTPStatus.NO_CONTENT
(_, response) = client.labels_api.retrieve(
label["id"], _check_status=False, _parse_response=False
)
assert response.status == HTTPStatus.NOT_FOUND
@parametrize("source", _TestLabelsPermissionsBase.source_types)
def test_cannot_delete_sublabel_directly(self, admin_user, source):
user = admin_user
label = next(
sublabel
for source_labels in self._labels_by_source(
self.labels, source_key=self._get_source_info(source).label_source_key
).values()
for label in source_labels
for sublabel in label["sublabels"]
)
with make_api_client(user) as client:
(_, response) = client.labels_api.destroy(label["id"], _check_status=False)
assert response.status == HTTPStatus.BAD_REQUEST
assert "Sublabels cannot be deleted this way." in response.data.decode()
def test_admin_delete_sandbox_label(self, admin_sandbox_case):
label, user = get_attrs(admin_sandbox_case, ["label", "user"])
self._test_delete_ok(user, label["id"])
def test_admin_delete_org_label(self, admin_org_case):
label, user = get_attrs(admin_org_case, ["label", "user"])
self._test_delete_ok(user, label["id"])
def test_regular_user_delete_sandbox_label(self, user_sandbox_case):
label, user, is_staff = get_attrs(user_sandbox_case, ["label", "user", "is_staff"])
if is_staff:
self._test_delete_ok(user["username"], label["id"])
else:
self._test_delete_denied(user["username"], label["id"])
def test_regular_user_delete_org_label(self, user_org_case):
label, user, is_staff = get_attrs(user_org_case, ["label", "user", "is_staff"])
if is_staff:
self._test_delete_ok(user["username"], label["id"])
else:
self._test_delete_denied(user["username"], label["id"])
@pytest.mark.usefixtures("restore_db_per_function")
class TestLabelUpdates:
@pytest.mark.parametrize("update_kind", ["addition", "removal", "modification"])
def test_project_label_update_triggers_nested_task_and_job_update(
self, update_kind, admin_user, labels, projects_wlc, tasks, jobs
):
# Checks for regressions against the issue https://github.com/cvat-ai/cvat/issues/6871
project = next(p for p in projects_wlc if p["tasks"]["count"] and p["labels"]["count"])
project_labels = [l for l in labels if l.get("project_id") == project["id"]]
nested_tasks = [t for t in tasks if t["project_id"] == project["id"]]
nested_task_ids = set(t["id"] for t in nested_tasks)
nested_jobs = [j for j in jobs if j["task_id"] in nested_task_ids]
if update_kind == "addition":
response = patch_method(
admin_user, f'projects/{project["id"]}', {"labels": [{"name": "dog2"}]}
)
updated_project = response.json()
elif update_kind == "modification":
label = project_labels[0]
patch_method(admin_user, f'labels/{label["id"]}', {"name": label["name"] + "-updated"})
response = get_method(admin_user, f'projects/{project["id"]}')
updated_project = response.json()
elif update_kind == "removal":
label = project_labels[0]
delete_method(admin_user, f'labels/{label["id"]}')
response = get_method(admin_user, f'projects/{project["id"]}')
updated_project = response.json()
else:
assert False
with make_api_client(admin_user) as api_client:
updated_tasks = get_paginated_collection(
api_client.tasks_api.list_endpoint, project_id=project["id"], return_json=True
)
updated_jobs = [
j
for j in get_paginated_collection(
api_client.jobs_api.list_endpoint, return_json=True
)
if j["task_id"] in nested_task_ids
]
assert parse_datetime(project["updated_date"]) < parse_datetime(
updated_project["updated_date"]
)
assert len(updated_tasks) == len(nested_tasks)
assert len(updated_jobs) == len(nested_jobs)
for entity in updated_tasks + updated_jobs:
assert updated_project["updated_date"] == entity["updated_date"]
@pytest.mark.parametrize("update_kind", ["addition", "removal", "modification"])
def test_task_label_update_triggers_nested_task_and_job_update(
self, update_kind, admin_user, labels, tasks_wlc, jobs
):
# Checks for regressions against the issue https://github.com/cvat-ai/cvat/issues/6871
task = next(
t
for t in tasks_wlc
if t["jobs"]["count"] and t["labels"]["count"] and not t["project_id"]
)
task_labels = [l for l in labels if l.get("task_id") == task["id"]]
nested_jobs = [j for j in jobs if j["task_id"] == task["id"]]
if update_kind == "addition":
response = patch_method(
admin_user, f'tasks/{task["id"]}', {"labels": [{"name": "dog2"}]}
)
updated_task = response.json()
elif update_kind == "modification":
label = task_labels[0]
patch_method(admin_user, f'labels/{label["id"]}', {"name": label["name"] + "-updated"})
response = get_method(admin_user, f'tasks/{task["id"]}')
updated_task = response.json()
elif update_kind == "removal":
label = task_labels[0]
delete_method(admin_user, f'labels/{label["id"]}')
response = get_method(admin_user, f'tasks/{task["id"]}')
updated_task = response.json()
else:
assert False
with make_api_client(admin_user) as api_client:
updated_jobs = get_paginated_collection(
api_client.jobs_api.list_endpoint, task_id=task["id"], return_json=True
)
assert parse_datetime(task["updated_date"]) < parse_datetime(updated_task["updated_date"])
assert len(updated_jobs) == len(nested_jobs)
for job in updated_jobs:
assert updated_task["updated_date"] == job["updated_date"]