2257 lines
86 KiB
Python
2257 lines
86 KiB
Python
|
|
# Copyright (C) CVAT.ai Corporation
|
||
|
|
#
|
||
|
|
# SPDX-License-Identifier: MIT
|
||
|
|
|
||
|
|
import json
|
||
|
|
import math
|
||
|
|
from collections.abc import Collection, Iterable
|
||
|
|
from copy import deepcopy
|
||
|
|
from functools import partial
|
||
|
|
from http import HTTPStatus
|
||
|
|
from itertools import groupby, product
|
||
|
|
from typing import Any, Callable, Optional
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from cvat_sdk.api_client import exceptions, models
|
||
|
|
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
|
||
|
|
from cvat_sdk.core.helpers import get_paginated_collection
|
||
|
|
from deepdiff import DeepDiff
|
||
|
|
|
||
|
|
from shared.tasks.utils import parse_frame_step
|
||
|
|
from shared.utils.config import make_api_client
|
||
|
|
|
||
|
|
from .utils import (
|
||
|
|
CollectionSimpleFilterTestBase,
|
||
|
|
invite_user_to_org,
|
||
|
|
register_new_user,
|
||
|
|
wait_background_request,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class _PermissionTestBase:
|
||
|
|
def create_quality_report(
|
||
|
|
self, *, user: str, task_id: Optional[int] = None, project_id: Optional[int] = None
|
||
|
|
) -> dict:
|
||
|
|
assert task_id is not None or project_id is not None
|
||
|
|
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.create_report(
|
||
|
|
quality_report_create_request=models.QualityReportCreateRequest(
|
||
|
|
**{"task_id": task_id} if task_id else {},
|
||
|
|
**{"project_id": project_id} if project_id else {},
|
||
|
|
),
|
||
|
|
_parse_response=False,
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.ACCEPTED
|
||
|
|
rq_id = json.loads(response.data)["rq_id"]
|
||
|
|
|
||
|
|
background_request, _ = wait_background_request(api_client, rq_id)
|
||
|
|
assert (
|
||
|
|
background_request.status.value
|
||
|
|
== models.RequestStatus.allowed_values[("value",)]["FINISHED"]
|
||
|
|
)
|
||
|
|
report_id = background_request.result_id
|
||
|
|
|
||
|
|
_, response = api_client.quality_api.retrieve_report(report_id, _parse_response=False)
|
||
|
|
|
||
|
|
return json.loads(response.data)
|
||
|
|
|
||
|
|
def create_gt_job(self, user: str, task_id: int, *, complete: bool = True) -> models.IJobRead:
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(meta, _) = api_client.tasks_api.retrieve_data_meta(task_id)
|
||
|
|
start_frame = meta.start_frame
|
||
|
|
|
||
|
|
(job, _) = api_client.jobs_api.create(
|
||
|
|
models.JobWriteRequest(
|
||
|
|
type="ground_truth",
|
||
|
|
task_id=task_id,
|
||
|
|
frame_selection_method="manual",
|
||
|
|
frames=[start_frame],
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
if complete:
|
||
|
|
(labels, _) = api_client.labels_api.list(
|
||
|
|
**({"project_id": job.project_id} if job.project_id else {"task_id": task_id})
|
||
|
|
)
|
||
|
|
|
||
|
|
api_client.jobs_api.update_annotations(
|
||
|
|
job.id,
|
||
|
|
labeled_data_request=dict(
|
||
|
|
shapes=[
|
||
|
|
dict(
|
||
|
|
frame=start_frame,
|
||
|
|
label_id=labels.results[0].id,
|
||
|
|
type="rectangle",
|
||
|
|
points=[1, 1, 2, 2],
|
||
|
|
),
|
||
|
|
],
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
api_client.jobs_api.partial_update(
|
||
|
|
job.id,
|
||
|
|
patched_job_write_request={
|
||
|
|
"stage": "acceptance",
|
||
|
|
"state": "completed",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
return job
|
||
|
|
|
||
|
|
@pytest.fixture(scope="class")
|
||
|
|
def find_sandbox_task(self, tasks, jobs, users, is_task_staff):
|
||
|
|
def _find(
|
||
|
|
is_staff: bool, *, has_gt_jobs: Optional[bool] = None
|
||
|
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||
|
|
task = next(
|
||
|
|
t
|
||
|
|
for t in tasks
|
||
|
|
if t["organization"] is None
|
||
|
|
and not users[t["owner"]["id"]]["is_superuser"]
|
||
|
|
and (
|
||
|
|
has_gt_jobs is None
|
||
|
|
or has_gt_jobs
|
||
|
|
== any(
|
||
|
|
j for j in jobs if j["task_id"] == t["id"] and j["type"] == "ground_truth"
|
||
|
|
)
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
if is_staff:
|
||
|
|
user = task["owner"]
|
||
|
|
else:
|
||
|
|
user = next(u for u in users if not is_task_staff(u["id"], task["id"]))
|
||
|
|
|
||
|
|
return task, user
|
||
|
|
|
||
|
|
return _find
|
||
|
|
|
||
|
|
@pytest.fixture(scope="class")
|
||
|
|
def find_sandbox_task_without_gt(self, find_sandbox_task):
|
||
|
|
return partial(find_sandbox_task, has_gt_jobs=False)
|
||
|
|
|
||
|
|
@pytest.fixture(scope="class")
|
||
|
|
def find_org_task(self, tasks, jobs, users, is_org_member, is_task_staff):
|
||
|
|
def _find(
|
||
|
|
is_staff: bool, user_org_role: str, *, has_gt_jobs: Optional[bool] = None
|
||
|
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||
|
|
for user in users:
|
||
|
|
if user["is_superuser"]:
|
||
|
|
continue
|
||
|
|
|
||
|
|
task = next(
|
||
|
|
(
|
||
|
|
t
|
||
|
|
for t in tasks
|
||
|
|
if t["organization"] is not None
|
||
|
|
and is_task_staff(user["id"], t["id"]) == is_staff
|
||
|
|
and is_org_member(user["id"], t["organization"], role=user_org_role)
|
||
|
|
and (
|
||
|
|
has_gt_jobs is None
|
||
|
|
or has_gt_jobs
|
||
|
|
== any(
|
||
|
|
j
|
||
|
|
for j in jobs
|
||
|
|
if j["task_id"] == t["id"] and j["type"] == "ground_truth"
|
||
|
|
)
|
||
|
|
)
|
||
|
|
),
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
if task is not None:
|
||
|
|
break
|
||
|
|
|
||
|
|
assert task
|
||
|
|
|
||
|
|
return task, user
|
||
|
|
|
||
|
|
return _find
|
||
|
|
|
||
|
|
@pytest.fixture(scope="class")
|
||
|
|
def find_org_task_without_gt(self, find_org_task):
|
||
|
|
return partial(find_org_task, has_gt_jobs=False)
|
||
|
|
|
||
|
|
@pytest.fixture(scope="class")
|
||
|
|
def find_sandbox_project(self, projects, tasks, users, labels, is_project_staff):
|
||
|
|
def _find(
|
||
|
|
is_staff: bool, has_gt_jobs: Optional[bool] = False
|
||
|
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||
|
|
project = next(
|
||
|
|
p
|
||
|
|
for p in projects
|
||
|
|
if p["organization"] is None
|
||
|
|
if not users[p["owner"]["id"]]["is_superuser"]
|
||
|
|
if any(l for l in labels if l.get("project_id") == p["id"])
|
||
|
|
if any(t for t in tasks if t["project_id"] == p["id"] and t["size"])
|
||
|
|
if (
|
||
|
|
has_gt_jobs is None
|
||
|
|
or has_gt_jobs
|
||
|
|
== any(t["validation_mode"] for t in tasks if t["project_id"] == p["id"])
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
if is_staff:
|
||
|
|
user = project["owner"]
|
||
|
|
else:
|
||
|
|
user = next(u for u in users if not is_project_staff(u["id"], project["id"]))
|
||
|
|
|
||
|
|
return project, user
|
||
|
|
|
||
|
|
return _find
|
||
|
|
|
||
|
|
@pytest.fixture(scope="class")
|
||
|
|
def find_sandbox_project_without_validation(self, find_sandbox_project):
|
||
|
|
return partial(find_sandbox_project, has_gt_jobs=False)
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def find_org_project(
|
||
|
|
self,
|
||
|
|
restore_db_per_function,
|
||
|
|
admin_user,
|
||
|
|
projects,
|
||
|
|
tasks,
|
||
|
|
users,
|
||
|
|
labels,
|
||
|
|
is_org_member,
|
||
|
|
is_project_staff,
|
||
|
|
):
|
||
|
|
def _find(
|
||
|
|
is_staff: bool, user_org_role: str, has_gt_jobs: Optional[bool] = False
|
||
|
|
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||
|
|
project = None
|
||
|
|
|
||
|
|
for user in users:
|
||
|
|
if user["is_superuser"]:
|
||
|
|
continue
|
||
|
|
|
||
|
|
project = next(
|
||
|
|
(
|
||
|
|
p
|
||
|
|
for p in projects
|
||
|
|
if p["organization"] is not None
|
||
|
|
if any(l for l in labels if l.get("project_id") == p["id"])
|
||
|
|
if (
|
||
|
|
has_gt_jobs is None
|
||
|
|
or has_gt_jobs
|
||
|
|
== any(
|
||
|
|
t["validation_mode"] for t in tasks if t["project_id"] == p["id"]
|
||
|
|
)
|
||
|
|
)
|
||
|
|
if any(t for t in tasks if t["project_id"] == p["id"] and t["size"])
|
||
|
|
if is_project_staff(user["id"], p["id"]) == is_staff
|
||
|
|
if is_org_member(user["id"], p["organization"], role=user_org_role)
|
||
|
|
),
|
||
|
|
None,
|
||
|
|
)
|
||
|
|
if project is not None:
|
||
|
|
break
|
||
|
|
|
||
|
|
if not project:
|
||
|
|
project = next(
|
||
|
|
p
|
||
|
|
for p in projects
|
||
|
|
if p["organization"] is not None
|
||
|
|
if any(l for l in labels if l.get("project_id") == p["id"])
|
||
|
|
if any(t for t in tasks if t["project_id"] == p["id"] and t["size"])
|
||
|
|
if (
|
||
|
|
has_gt_jobs is None
|
||
|
|
or has_gt_jobs
|
||
|
|
== any(t["validation_mode"] for t in tasks if t["project_id"] == p["id"])
|
||
|
|
)
|
||
|
|
)
|
||
|
|
user = next(
|
||
|
|
u
|
||
|
|
for u in users
|
||
|
|
if is_org_member(u["id"], project["organization"], role=user_org_role)
|
||
|
|
)
|
||
|
|
|
||
|
|
if is_staff:
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.projects_api.partial_update(
|
||
|
|
project["id"],
|
||
|
|
patched_project_write_request=models.PatchedProjectWriteRequest(
|
||
|
|
assignee_id=user["id"]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
return project, user
|
||
|
|
|
||
|
|
return _find
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def find_org_project_without_validation(self, find_org_project):
|
||
|
|
return partial(find_org_project, has_gt_jobs=False)
|
||
|
|
|
||
|
|
_default_sandbox_cases = ("is_staff, allow", [(True, True), (False, False)])
|
||
|
|
|
||
|
|
_default_org_cases = (
|
||
|
|
"org_role, is_staff, allow",
|
||
|
|
[
|
||
|
|
("owner", True, True),
|
||
|
|
("owner", False, True),
|
||
|
|
("maintainer", True, True),
|
||
|
|
("maintainer", False, True),
|
||
|
|
("supervisor", True, True),
|
||
|
|
("supervisor", False, False),
|
||
|
|
("worker", True, True),
|
||
|
|
("worker", False, False),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
_default_org_roles = ("owner", "maintainer", "supervisor", "worker")
|
||
|
|
|
||
|
|
key_field_for_target = {
|
||
|
|
"project": "project_id",
|
||
|
|
"task": "task_id",
|
||
|
|
"job": "job_id",
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_class")
|
||
|
|
class TestListQualityReports(_PermissionTestBase):
|
||
|
|
def _test_list_reports_200(self, user, *, expected_data=None, **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
results = get_paginated_collection(
|
||
|
|
api_client.quality_api.list_reports_endpoint,
|
||
|
|
return_json=True,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
if expected_data is not None:
|
||
|
|
assert DeepDiff(expected_data, results) == {}
|
||
|
|
|
||
|
|
def _test_list_reports_403(self, user, **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.list_reports(
|
||
|
|
**kwargs, _parse_response=False, _check_status=False
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
def test_can_list_quality_reports(self, admin_user, quality_reports):
|
||
|
|
reports = sorted(quality_reports, key=lambda r: -r["id"])
|
||
|
|
|
||
|
|
self._test_list_reports_200(admin_user, sort="-id", expected_data=reports)
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize("target", ["project", "task", "job"])
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_list_reports_in_sandbox(
|
||
|
|
self,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
admin_user,
|
||
|
|
tasks,
|
||
|
|
find_sandbox_task_without_gt,
|
||
|
|
find_sandbox_project_without_validation,
|
||
|
|
target,
|
||
|
|
):
|
||
|
|
if target == "project":
|
||
|
|
project, user = find_sandbox_project_without_validation(is_staff)
|
||
|
|
task = next(t for t in tasks if t["project_id"] == project["id"] and t["size"])
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, project_id=project["id"])
|
||
|
|
target_id = project["id"]
|
||
|
|
else:
|
||
|
|
task, user = find_sandbox_task_without_gt(is_staff)
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task["id"])
|
||
|
|
target_id = task["id"]
|
||
|
|
|
||
|
|
if target == "job":
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
report = json.loads(
|
||
|
|
api_client.quality_api.list_reports(target="job", parent_id=report["id"])[
|
||
|
|
1
|
||
|
|
].data
|
||
|
|
)["results"][0]
|
||
|
|
target_id = report["job_id"]
|
||
|
|
|
||
|
|
list_kwargs = {
|
||
|
|
"user": user["username"],
|
||
|
|
"target": target,
|
||
|
|
self.key_field_for_target[target]: target_id,
|
||
|
|
}
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_list_reports_200(expected_data=[report], **list_kwargs)
|
||
|
|
else:
|
||
|
|
self._test_list_reports_403(**list_kwargs)
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize("target", ["project", "task", "job"])
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_list_reports_in_org(
|
||
|
|
self,
|
||
|
|
find_org_task_without_gt,
|
||
|
|
find_org_project_without_validation,
|
||
|
|
tasks,
|
||
|
|
org_role,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
admin_user,
|
||
|
|
target,
|
||
|
|
):
|
||
|
|
if target == "project":
|
||
|
|
project, user = find_org_project_without_validation(is_staff, org_role)
|
||
|
|
task = next(t for t in tasks if t["project_id"] == project["id"] and t["size"])
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, project_id=project["id"])
|
||
|
|
target_id = project["id"]
|
||
|
|
else:
|
||
|
|
task, user = find_org_task_without_gt(is_staff, org_role)
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task["id"])
|
||
|
|
target_id = task["id"]
|
||
|
|
|
||
|
|
if target == "job":
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
report = json.loads(
|
||
|
|
api_client.quality_api.list_reports(target="job", parent_id=report["id"])[
|
||
|
|
1
|
||
|
|
].data
|
||
|
|
)["results"][0]
|
||
|
|
target_id = report["job_id"]
|
||
|
|
|
||
|
|
list_kwargs = {
|
||
|
|
"user": user["username"],
|
||
|
|
"target": target,
|
||
|
|
self.key_field_for_target[target]: target_id,
|
||
|
|
}
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_list_reports_200(expected_data=[report], **list_kwargs)
|
||
|
|
else:
|
||
|
|
self._test_list_reports_403(**list_kwargs)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_class")
|
||
|
|
class TestGetQualityReports(_PermissionTestBase):
|
||
|
|
def _test_get_report_200(
|
||
|
|
self, user: str, obj_id: int, *, expected_data: Optional[dict[str, Any]] = None, **kwargs
|
||
|
|
):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report(obj_id, **kwargs)
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
|
||
|
|
if expected_data is not None:
|
||
|
|
assert DeepDiff(expected_data, json.loads(response.data), ignore_order=True) == {}
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
def _test_get_report_403(self, user: str, obj_id: int, **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report(
|
||
|
|
obj_id, **kwargs, _parse_response=False, _check_status=False
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_get_report_in_sandbox_task(
|
||
|
|
self, is_staff, allow, admin_user, find_sandbox_task_without_gt
|
||
|
|
):
|
||
|
|
task, user = find_sandbox_task_without_gt(is_staff)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task["id"])
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_get_report_200(user["username"], report["id"], expected_data=report)
|
||
|
|
else:
|
||
|
|
self._test_get_report_403(user["username"], report["id"])
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_get_report_in_org_task(
|
||
|
|
self,
|
||
|
|
find_org_task_without_gt,
|
||
|
|
org_role,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
admin_user,
|
||
|
|
):
|
||
|
|
task, user = find_org_task_without_gt(is_staff, org_role)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task["id"])
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_get_report_200(user["username"], report["id"], expected_data=report)
|
||
|
|
else:
|
||
|
|
self._test_get_report_403(user["username"], report["id"])
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_class")
|
||
|
|
class TestGetQualityReportData(_PermissionTestBase):
|
||
|
|
def _test_get_report_data_200(
|
||
|
|
self, user: str, obj_id: int, *, expected_data: Optional[dict[str, Any]] = None, **kwargs
|
||
|
|
):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report_data(obj_id, **kwargs)
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
|
||
|
|
if expected_data is not None:
|
||
|
|
assert DeepDiff(expected_data, json.loads(response.data), ignore_order=True) == {}
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
def _test_get_report_data_403(self, user: str, obj_id: int, **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report_data(
|
||
|
|
obj_id, **kwargs, _parse_response=False, _check_status=False
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("target", ["project", "task", "job"])
|
||
|
|
def test_can_get_full_report_data(self, admin_user, target, quality_reports):
|
||
|
|
report = next(
|
||
|
|
r for r in quality_reports if r[self.key_field_for_target[target]] is not None
|
||
|
|
)
|
||
|
|
report_id = report["id"]
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
(report_data, response) = api_client.quality_api.retrieve_report_data(report_id)
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
|
||
|
|
# Just check several keys exist
|
||
|
|
for key in ["parameters", "comparison_summary"] + (
|
||
|
|
["frame_results"] if target != "project" else []
|
||
|
|
):
|
||
|
|
assert key in report_data.keys(), key
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_get_report_data_in_sandbox_task(
|
||
|
|
self, is_staff, allow, admin_user, find_sandbox_task_without_gt
|
||
|
|
):
|
||
|
|
task, user = find_sandbox_task_without_gt(is_staff)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task["id"])
|
||
|
|
report_data = json.loads(self._test_get_report_data_200(admin_user, report["id"]).data)
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_get_report_data_200(
|
||
|
|
user["username"], report["id"], expected_data=report_data
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self._test_get_report_data_403(user["username"], report["id"])
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_get_report_data_in_org_task(
|
||
|
|
self,
|
||
|
|
find_org_task_without_gt,
|
||
|
|
org_role,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
admin_user,
|
||
|
|
):
|
||
|
|
task, user = find_org_task_without_gt(is_staff, org_role)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task["id"])
|
||
|
|
report_data = json.loads(self._test_get_report_data_200(admin_user, report["id"]).data)
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_get_report_data_200(
|
||
|
|
user["username"], report["id"], expected_data=report_data
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self._test_get_report_data_403(user["username"], report["id"])
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize("has_assignee", [False, True])
|
||
|
|
def test_can_get_report_data_with_job_assignees(
|
||
|
|
self, admin_user, jobs, users_by_name, has_assignee
|
||
|
|
):
|
||
|
|
gt_job = next(
|
||
|
|
j
|
||
|
|
for j in jobs
|
||
|
|
if j["type"] == "ground_truth"
|
||
|
|
and j["stage"] == "acceptance"
|
||
|
|
and j["state"] == "completed"
|
||
|
|
)
|
||
|
|
task_id = gt_job["task_id"]
|
||
|
|
|
||
|
|
normal_job = next(j for j in jobs if j["type"] == "annotation" and j["task_id"] == task_id)
|
||
|
|
if has_assignee:
|
||
|
|
new_assignee = users_by_name[admin_user]
|
||
|
|
else:
|
||
|
|
new_assignee = None
|
||
|
|
|
||
|
|
if bool(normal_job["assignee"]) != has_assignee:
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.jobs_api.partial_update(
|
||
|
|
normal_job["id"],
|
||
|
|
patched_job_write_request={
|
||
|
|
"assignee": new_assignee["id"] if new_assignee else None
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
task_report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
job_report = api_client.quality_api.list_reports(
|
||
|
|
job_id=normal_job["id"], parent_id=task_report["id"]
|
||
|
|
)[0].results[0]
|
||
|
|
|
||
|
|
report_data = json.loads(self._test_get_report_data_200(admin_user, job_report["id"]).data)
|
||
|
|
assert (
|
||
|
|
DeepDiff(
|
||
|
|
(
|
||
|
|
{
|
||
|
|
k: v
|
||
|
|
for k, v in new_assignee.items()
|
||
|
|
if k in ["id", "username", "first_name", "last_name"]
|
||
|
|
}
|
||
|
|
if new_assignee
|
||
|
|
else None
|
||
|
|
),
|
||
|
|
report_data["assignee"],
|
||
|
|
)
|
||
|
|
== {}
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
class TestPostQualityReports(_PermissionTestBase):
|
||
|
|
def test_can_create_report(self, admin_user, jobs):
|
||
|
|
gt_job = next(
|
||
|
|
j
|
||
|
|
for j in jobs
|
||
|
|
if j["type"] == "ground_truth"
|
||
|
|
and j["stage"] == "acceptance"
|
||
|
|
and j["state"] == "completed"
|
||
|
|
)
|
||
|
|
task_id = gt_job["task_id"]
|
||
|
|
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
assert models.QualityReport._from_openapi_data(**report)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("has_assignee", [False, True])
|
||
|
|
def test_can_create_report_with_job_assignees(
|
||
|
|
self, admin_user, jobs, users_by_name, has_assignee
|
||
|
|
):
|
||
|
|
gt_job = next(
|
||
|
|
j
|
||
|
|
for j in jobs
|
||
|
|
if j["type"] == "ground_truth"
|
||
|
|
and j["stage"] == "acceptance"
|
||
|
|
and j["state"] == "completed"
|
||
|
|
)
|
||
|
|
task_id = gt_job["task_id"]
|
||
|
|
|
||
|
|
normal_job = next(j for j in jobs if j["type"] == "annotation")
|
||
|
|
if bool(normal_job["assignee"]) != has_assignee:
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.jobs_api.partial_update(
|
||
|
|
normal_job["id"],
|
||
|
|
patched_job_write_request={
|
||
|
|
"assignee": users_by_name[admin_user]["id"] if has_assignee else None
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
assert models.QualityReport._from_openapi_data(**report)
|
||
|
|
|
||
|
|
def test_cannot_create_report_without_gt_job(self, admin_user, tasks):
|
||
|
|
task_id = next(t["id"] for t in tasks if t["jobs"]["count"] == 1)
|
||
|
|
|
||
|
|
with pytest.raises(exceptions.ApiException) as capture:
|
||
|
|
self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
assert (
|
||
|
|
"Quality reports require a Ground Truth job in the task at the acceptance "
|
||
|
|
"stage and in the completed state"
|
||
|
|
) in capture.value.body
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(
|
||
|
|
"field_name, field_value",
|
||
|
|
[
|
||
|
|
("stage", "annotation"),
|
||
|
|
("stage", "validation"),
|
||
|
|
("state", "new"),
|
||
|
|
("state", "in progress"),
|
||
|
|
("state", "rejected"),
|
||
|
|
],
|
||
|
|
)
|
||
|
|
def test_cannot_create_report_with_incomplete_gt_job(
|
||
|
|
self, admin_user, jobs, field_name, field_value
|
||
|
|
):
|
||
|
|
gt_job = next(
|
||
|
|
j
|
||
|
|
for j in jobs
|
||
|
|
if j["type"] == "ground_truth"
|
||
|
|
and j["stage"] == "acceptance"
|
||
|
|
and j["state"] == "completed"
|
||
|
|
)
|
||
|
|
task_id = gt_job["task_id"]
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.jobs_api.partial_update(
|
||
|
|
gt_job["id"], patched_job_write_request={field_name: field_value}
|
||
|
|
)
|
||
|
|
|
||
|
|
with pytest.raises(exceptions.ApiException) as capture:
|
||
|
|
self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
assert (
|
||
|
|
"Quality reports require a Ground Truth job in the task at the acceptance "
|
||
|
|
"stage and in the completed state"
|
||
|
|
) in capture.value.body
|
||
|
|
|
||
|
|
def _test_create_report_200(self, user: str, task_id: int):
|
||
|
|
return self.create_quality_report(user=user, task_id=task_id)
|
||
|
|
|
||
|
|
def _test_create_report_403(self, user: str, task_id: int):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.create_report(
|
||
|
|
quality_report_create_request=models.QualityReportCreateRequest(task_id=task_id),
|
||
|
|
_parse_response=False,
|
||
|
|
_check_status=False,
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_create_report_in_sandbox_task(
|
||
|
|
self, is_staff, allow, admin_user, find_sandbox_task_without_gt
|
||
|
|
):
|
||
|
|
task, user = find_sandbox_task_without_gt(is_staff)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_create_report_200(user["username"], task["id"])
|
||
|
|
else:
|
||
|
|
self._test_create_report_403(user["username"], task["id"])
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_create_report_in_org_task(
|
||
|
|
self,
|
||
|
|
find_org_task_without_gt,
|
||
|
|
org_role,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
admin_user,
|
||
|
|
):
|
||
|
|
task, user = find_org_task_without_gt(is_staff, org_role)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_create_report_200(user["username"], task["id"])
|
||
|
|
else:
|
||
|
|
self._test_create_report_403(user["username"], task["id"])
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _initialize_report_creation(task_id: int, user: str) -> str:
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.create_report(
|
||
|
|
quality_report_create_request=models.QualityReportCreateRequest(task_id=task_id),
|
||
|
|
_parse_response=False,
|
||
|
|
)
|
||
|
|
rq_id = json.loads(response.data)["rq_id"]
|
||
|
|
assert rq_id
|
||
|
|
|
||
|
|
return rq_id
|
||
|
|
|
||
|
|
# users with task:view rights can check status of report creation
|
||
|
|
def _test_check_status_of_report_creation(
|
||
|
|
self,
|
||
|
|
rq_id: str,
|
||
|
|
*,
|
||
|
|
task_staff: str,
|
||
|
|
another_user: str,
|
||
|
|
another_user_status: int = HTTPStatus.FORBIDDEN,
|
||
|
|
):
|
||
|
|
with make_api_client(another_user) as api_client:
|
||
|
|
(_, response) = api_client.requests_api.retrieve(
|
||
|
|
rq_id, _parse_response=False, _check_status=False
|
||
|
|
)
|
||
|
|
assert response.status == another_user_status
|
||
|
|
|
||
|
|
with make_api_client(task_staff) as api_client:
|
||
|
|
wait_background_request(api_client, rq_id)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(
|
||
|
|
"role",
|
||
|
|
# owner and maintainer have rights even without being assigned to a task
|
||
|
|
("supervisor", "worker"),
|
||
|
|
)
|
||
|
|
def test_task_assignee_can_check_status_of_report_creation_in_org(
|
||
|
|
self,
|
||
|
|
find_org_task_without_gt: Callable[[bool, str], tuple[dict[str, Any], dict[str, Any]]],
|
||
|
|
role: str,
|
||
|
|
admin_user: str,
|
||
|
|
):
|
||
|
|
task, another_user = find_org_task_without_gt(is_staff=False, user_org_role=role)
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
|
||
|
|
task_owner = task["owner"]
|
||
|
|
|
||
|
|
rq_id = self._initialize_report_creation(task_id=task["id"], user=task_owner["username"])
|
||
|
|
self._test_check_status_of_report_creation(
|
||
|
|
rq_id,
|
||
|
|
task_staff=task_owner["username"],
|
||
|
|
another_user=another_user["username"],
|
||
|
|
)
|
||
|
|
|
||
|
|
with make_api_client(task_owner["username"]) as api_client:
|
||
|
|
api_client.tasks_api.partial_update(
|
||
|
|
task["id"],
|
||
|
|
patched_task_write_request=models.PatchedTaskWriteRequest(
|
||
|
|
assignee_id=another_user["id"]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
self._test_check_status_of_report_creation(
|
||
|
|
rq_id,
|
||
|
|
task_staff=task_owner["username"],
|
||
|
|
another_user=another_user["username"],
|
||
|
|
another_user_status=HTTPStatus.OK,
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_user_without_rights_cannot_check_status_of_report_creation_in_sandbox(
|
||
|
|
self,
|
||
|
|
find_sandbox_task_without_gt: Callable[[bool], tuple[dict[str, Any], dict[str, Any]]],
|
||
|
|
admin_user: str,
|
||
|
|
users: Iterable,
|
||
|
|
):
|
||
|
|
task, task_staff = find_sandbox_task_without_gt(is_staff=True)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
|
||
|
|
another_user = next(
|
||
|
|
u
|
||
|
|
for u in users
|
||
|
|
if (
|
||
|
|
u["id"] != task_staff["id"]
|
||
|
|
and not u["is_superuser"]
|
||
|
|
and u["id"] != task["owner"]["id"]
|
||
|
|
and u["id"] != (task["assignee"] or {}).get("id")
|
||
|
|
)
|
||
|
|
)
|
||
|
|
rq_id = self._initialize_report_creation(task["id"], task_staff["username"])
|
||
|
|
self._test_check_status_of_report_creation(
|
||
|
|
rq_id, task_staff=task_staff["username"], another_user=another_user["username"]
|
||
|
|
)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(
|
||
|
|
"same_org, role",
|
||
|
|
[
|
||
|
|
pair
|
||
|
|
for pair in product([True, False], _PermissionTestBase._default_org_roles)
|
||
|
|
if not (pair[0] and pair[1] in ["owner", "maintainer"])
|
||
|
|
],
|
||
|
|
)
|
||
|
|
def test_user_without_rights_cannot_check_status_of_report_creation_in_org(
|
||
|
|
self,
|
||
|
|
same_org: bool,
|
||
|
|
role: str,
|
||
|
|
admin_user: str,
|
||
|
|
find_org_task_without_gt: Callable[[bool, str], tuple[dict[str, Any], dict[str, Any]]],
|
||
|
|
organizations,
|
||
|
|
):
|
||
|
|
task, task_staff = find_org_task_without_gt(is_staff=True, user_org_role="supervisor")
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
|
||
|
|
# create another user that passes the requirements
|
||
|
|
another_user = register_new_user(f"{same_org}{role}")
|
||
|
|
org_id = (
|
||
|
|
task["organization"]
|
||
|
|
if same_org
|
||
|
|
else next(o for o in organizations if o["id"] != task["organization"])["id"]
|
||
|
|
)
|
||
|
|
invite_user_to_org(another_user["email"], org_id, role)
|
||
|
|
|
||
|
|
rq_id = self._initialize_report_creation(task["id"], task_staff["username"])
|
||
|
|
self._test_check_status_of_report_creation(
|
||
|
|
rq_id, task_staff=task_staff["username"], another_user=another_user["username"]
|
||
|
|
)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("is_sandbox", (True, False))
|
||
|
|
def test_admin_can_check_status_of_report_creation(
|
||
|
|
self,
|
||
|
|
is_sandbox: bool,
|
||
|
|
users: Iterable,
|
||
|
|
admin_user: str,
|
||
|
|
find_org_task_without_gt: Callable[[bool, str], tuple[dict[str, Any], dict[str, Any]]],
|
||
|
|
find_sandbox_task_without_gt: Callable[[bool], tuple[dict[str, Any], dict[str, Any]]],
|
||
|
|
):
|
||
|
|
if is_sandbox:
|
||
|
|
task, task_staff = find_sandbox_task_without_gt(is_staff=True)
|
||
|
|
else:
|
||
|
|
task, task_staff = find_org_task_without_gt(is_staff=True, user_org_role="owner")
|
||
|
|
|
||
|
|
admin = next(
|
||
|
|
u
|
||
|
|
for u in users
|
||
|
|
if (
|
||
|
|
u["is_superuser"]
|
||
|
|
and u["id"] != task_staff["id"]
|
||
|
|
and u["id"] != task["owner"]["id"]
|
||
|
|
and u["id"] != (task["assignee"] or {}).get("id")
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
|
||
|
|
rq_id = self._initialize_report_creation(task["id"], task_staff["username"])
|
||
|
|
|
||
|
|
with make_api_client(admin["username"]) as api_client:
|
||
|
|
wait_background_request(api_client, rq_id)
|
||
|
|
|
||
|
|
|
||
|
|
class TestSimpleQualityReportsFilters(CollectionSimpleFilterTestBase):
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def setup(self, restore_db_per_class, admin_user, quality_reports, jobs, tasks, projects):
|
||
|
|
self.user = admin_user
|
||
|
|
self.samples = quality_reports
|
||
|
|
self.job_samples = jobs
|
||
|
|
self.task_samples = tasks
|
||
|
|
self.project_samples = projects
|
||
|
|
|
||
|
|
def _get_endpoint(self, api_client: ApiClient) -> Endpoint:
|
||
|
|
return api_client.quality_api.list_reports_endpoint
|
||
|
|
|
||
|
|
def _get_field_samples(self, field: str) -> tuple[Any, list[dict[str, Any]]]:
|
||
|
|
def _get_job_reports(task_ids: Collection[int]) -> list[dict[str, Any]]:
|
||
|
|
job_ids = set(j["id"] for j in self.job_samples if j["task_id"] in task_ids)
|
||
|
|
job_reports = [
|
||
|
|
r for r in self.samples if self._get_field(r, self._map_field("job_id")) in job_ids
|
||
|
|
]
|
||
|
|
return job_reports
|
||
|
|
|
||
|
|
if field == "project_id":
|
||
|
|
# This filter includes both the project, task and nested job reports
|
||
|
|
project_id, project_reports = super()._get_field_samples(field)
|
||
|
|
return project_id, list(project_reports) + _get_job_reports(
|
||
|
|
[r["task_id"] for r in project_reports]
|
||
|
|
)
|
||
|
|
elif field == "task_id":
|
||
|
|
# This filter includes both the task and nested job reports
|
||
|
|
task_id, task_reports = super()._get_field_samples(field)
|
||
|
|
task_reports = list(task_reports) + _get_job_reports([task_id])
|
||
|
|
return task_id, task_reports
|
||
|
|
elif field == "org_id":
|
||
|
|
org_id = self.task_samples[
|
||
|
|
next(
|
||
|
|
s
|
||
|
|
for s in self.samples
|
||
|
|
if s["task_id"] and self.task_samples[s["task_id"]]["organization"]
|
||
|
|
)["task_id"]
|
||
|
|
]["organization"]
|
||
|
|
return org_id, [
|
||
|
|
s
|
||
|
|
for s in self.samples
|
||
|
|
if s["job_id"]
|
||
|
|
and self.job_samples[s["job_id"]]["organization"] == org_id
|
||
|
|
or s["task_id"]
|
||
|
|
and self.task_samples[s["task_id"]]["organization"] == org_id
|
||
|
|
or s["project_id"]
|
||
|
|
and self.project_samples[s["project_id"]]["organization"] == org_id
|
||
|
|
]
|
||
|
|
else:
|
||
|
|
return super()._get_field_samples(field)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(
|
||
|
|
"field",
|
||
|
|
("project_id", "task_id", "job_id", "parent_id", "target", "org_id"),
|
||
|
|
)
|
||
|
|
def test_can_use_simple_filter_for_object_list(self, field):
|
||
|
|
return super()._test_can_use_simple_filter_for_object_list(field)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_class")
|
||
|
|
class TestListQualityConflicts(_PermissionTestBase):
|
||
|
|
def _test_list_conflicts_200(self, user, report_id, *, expected_data=None, **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
results = get_paginated_collection(
|
||
|
|
api_client.quality_api.list_conflicts_endpoint,
|
||
|
|
return_json=True,
|
||
|
|
report_id=report_id,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
|
||
|
|
if expected_data is not None:
|
||
|
|
assert DeepDiff(expected_data, results) == {}
|
||
|
|
|
||
|
|
return results
|
||
|
|
|
||
|
|
def _test_list_conflicts_403(self, user, report_id, **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.list_conflicts(
|
||
|
|
report_id=report_id, **kwargs, _parse_response=False, _check_status=False
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
def test_can_list_job_report_conflicts(self, admin_user, quality_reports, quality_conflicts):
|
||
|
|
report = next(r for r in quality_reports if r["job_id"])
|
||
|
|
conflicts = [c for c in quality_conflicts if c["report_id"] == report["id"]]
|
||
|
|
|
||
|
|
self._test_list_conflicts_200(admin_user, report["id"], expected_data=conflicts)
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_list_conflicts_in_sandbox_task(
|
||
|
|
self, is_staff, allow, admin_user, find_sandbox_task_without_gt
|
||
|
|
):
|
||
|
|
task, user = find_sandbox_task_without_gt(is_staff)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task["id"])
|
||
|
|
conflicts = self._test_list_conflicts_200(admin_user, report_id=report["id"])
|
||
|
|
assert conflicts
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_list_conflicts_200(user["username"], report["id"], expected_data=conflicts)
|
||
|
|
else:
|
||
|
|
self._test_list_conflicts_403(user["username"], report["id"])
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_list_conflicts_in_org_task(
|
||
|
|
self,
|
||
|
|
find_org_task_without_gt,
|
||
|
|
org_role,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
admin_user,
|
||
|
|
):
|
||
|
|
task, user = find_org_task_without_gt(is_staff, org_role)
|
||
|
|
user = user["username"]
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task["id"])
|
||
|
|
conflicts = self._test_list_conflicts_200(admin_user, report_id=report["id"])
|
||
|
|
assert conflicts
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_list_conflicts_200(user, report["id"], expected_data=conflicts)
|
||
|
|
else:
|
||
|
|
self._test_list_conflicts_403(user, report["id"])
|
||
|
|
|
||
|
|
|
||
|
|
class TestSimpleQualityConflictsFilters(CollectionSimpleFilterTestBase):
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def setup(
|
||
|
|
self,
|
||
|
|
restore_db_per_class,
|
||
|
|
admin_user,
|
||
|
|
quality_conflicts,
|
||
|
|
quality_reports,
|
||
|
|
jobs,
|
||
|
|
tasks,
|
||
|
|
projects,
|
||
|
|
):
|
||
|
|
self.user = admin_user
|
||
|
|
self.samples = quality_conflicts
|
||
|
|
self.report_samples = quality_reports
|
||
|
|
self.job_samples = jobs
|
||
|
|
self.task_samples = tasks
|
||
|
|
self.project_samples = projects
|
||
|
|
|
||
|
|
def _get_endpoint(self, api_client: ApiClient) -> Endpoint:
|
||
|
|
return api_client.quality_api.list_conflicts_endpoint
|
||
|
|
|
||
|
|
def _get_field_samples(self, field: str) -> tuple[Any, list[dict[str, Any]]]:
|
||
|
|
def _get_job_reports(task_ids: Collection[int]) -> list[dict[str, Any]]:
|
||
|
|
job_ids = set(j["id"] for j in self.job_samples if j["task_id"] in task_ids)
|
||
|
|
job_reports = [
|
||
|
|
r
|
||
|
|
for r in self.report_samples
|
||
|
|
if self._get_field(r, self._map_field("job_id")) in job_ids
|
||
|
|
]
|
||
|
|
return job_reports
|
||
|
|
|
||
|
|
if field == "job_id":
|
||
|
|
# This field is not included in the response
|
||
|
|
job_id = self._find_valid_field_value(self.report_samples, field_path=["job_id"])
|
||
|
|
job_reports = set(r["id"] for r in self.report_samples if r["job_id"] == job_id)
|
||
|
|
job_conflicts = [
|
||
|
|
c
|
||
|
|
for c in self.samples
|
||
|
|
if self._get_field(c, self._map_field("report_id")) in job_reports
|
||
|
|
]
|
||
|
|
return job_id, job_conflicts
|
||
|
|
elif field == "task_id":
|
||
|
|
# This field is not included in the response
|
||
|
|
task_id = self._find_valid_field_value(self.report_samples, field_path=["task_id"])
|
||
|
|
task_reports = [r for r in self.report_samples if r["task_id"] == task_id]
|
||
|
|
task_report_ids = {r["id"] for r in task_reports}
|
||
|
|
task_report_ids |= {r["id"] for r in _get_job_reports([task_id])}
|
||
|
|
task_conflicts = [
|
||
|
|
c
|
||
|
|
for c in self.samples
|
||
|
|
if self._get_field(c, self._map_field("report_id")) in task_report_ids
|
||
|
|
]
|
||
|
|
return task_id, task_conflicts
|
||
|
|
elif field == "project_id":
|
||
|
|
# This field is not included in the response
|
||
|
|
project_id = self._find_valid_field_value(
|
||
|
|
self.report_samples, field_path=["project_id"]
|
||
|
|
)
|
||
|
|
project_reports = [r for r in self.report_samples if r["project_id"] == project_id]
|
||
|
|
project_report_ids = {r["id"] for r in project_reports}
|
||
|
|
project_report_ids |= {
|
||
|
|
r["id"] for r in _get_job_reports([r["task_id"] for r in project_reports])
|
||
|
|
}
|
||
|
|
project_conflicts = [
|
||
|
|
c
|
||
|
|
for c in self.samples
|
||
|
|
if self._get_field(c, self._map_field("report_id")) in project_report_ids
|
||
|
|
]
|
||
|
|
return project_id, project_conflicts
|
||
|
|
elif field == "org_id":
|
||
|
|
org_id = self.task_samples[
|
||
|
|
next(
|
||
|
|
s
|
||
|
|
for s in self.report_samples
|
||
|
|
if s["task_id"] and self.task_samples[s["task_id"]]["organization"]
|
||
|
|
)["task_id"]
|
||
|
|
]["organization"]
|
||
|
|
report_ids = set(
|
||
|
|
s["id"]
|
||
|
|
for s in self.report_samples
|
||
|
|
if s["job_id"]
|
||
|
|
and self.job_samples[s["job_id"]]["organization"] == org_id
|
||
|
|
or s["task_id"]
|
||
|
|
and self.task_samples[s["task_id"]]["organization"] == org_id
|
||
|
|
)
|
||
|
|
return org_id, [c for c in self.samples if c["report_id"] in report_ids]
|
||
|
|
else:
|
||
|
|
return super()._get_field_samples(field)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(
|
||
|
|
"field",
|
||
|
|
("report_id", "severity", "type", "frame", "job_id", "task_id", "project_id", "org_id"),
|
||
|
|
)
|
||
|
|
def test_can_use_simple_filter_for_object_list(self, field):
|
||
|
|
return super()._test_can_use_simple_filter_for_object_list(field)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("filter_name", ["project_id", "task_id", "job_id"])
|
||
|
|
def test_cannot_use_object_id_filters_without_permissions(
|
||
|
|
self, is_project_staff, is_task_staff, is_job_staff, users, filter_name
|
||
|
|
):
|
||
|
|
# Find a project where the user doesn't have permissions
|
||
|
|
non_admin_user = next(
|
||
|
|
u["username"] for u in users if not u["is_superuser"] and u["username"] != self.user
|
||
|
|
)
|
||
|
|
|
||
|
|
if filter_name == "project_id":
|
||
|
|
samples = self.project_samples
|
||
|
|
is_staff = is_project_staff
|
||
|
|
elif filter_name == "task_id":
|
||
|
|
samples = self.task_samples
|
||
|
|
is_staff = is_task_staff
|
||
|
|
elif filter_name == "job_id":
|
||
|
|
samples = self.job_samples
|
||
|
|
is_staff = is_job_staff
|
||
|
|
else:
|
||
|
|
assert False
|
||
|
|
|
||
|
|
obj_id = next(obj["id"] for obj in samples if not is_staff(non_admin_user, obj["id"]))
|
||
|
|
|
||
|
|
with make_api_client(non_admin_user) as api_client:
|
||
|
|
response = api_client.quality_api.list_reports(
|
||
|
|
**{filter_name: obj_id}, _parse_response=False, _check_status=False
|
||
|
|
)[1]
|
||
|
|
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
|
||
|
|
class TestSimpleQualitySettingsFilters(CollectionSimpleFilterTestBase):
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def setup(self, restore_db_per_class, admin_user, quality_settings, tasks, projects):
|
||
|
|
self.user = admin_user
|
||
|
|
self.samples = quality_settings
|
||
|
|
self.task_samples = tasks
|
||
|
|
self.project_samples = projects
|
||
|
|
|
||
|
|
def _get_endpoint(self, api_client: ApiClient) -> Endpoint:
|
||
|
|
return api_client.quality_api.list_settings_endpoint
|
||
|
|
|
||
|
|
def _get_field_samples(self, field):
|
||
|
|
if field == "parent_type":
|
||
|
|
# This field is not included in the response
|
||
|
|
parent_type = "project"
|
||
|
|
parent_type_reports = [s for s in self.samples if s["project_id"]]
|
||
|
|
return parent_type, parent_type_reports
|
||
|
|
elif field == "project_id":
|
||
|
|
# Nested task settings are also included
|
||
|
|
project_id = self._find_valid_field_value(self.samples, field_path=["project_id"])
|
||
|
|
project_task_ids = set(
|
||
|
|
t["id"] for t in self.task_samples if t["project_id"] == project_id
|
||
|
|
)
|
||
|
|
return project_id, [
|
||
|
|
s
|
||
|
|
for s in self.samples
|
||
|
|
if s["project_id"] == project_id or s["task_id"] in project_task_ids
|
||
|
|
]
|
||
|
|
elif field == "org_id":
|
||
|
|
# This field is not included in the response
|
||
|
|
org_id = self.task_samples[
|
||
|
|
next(
|
||
|
|
s
|
||
|
|
for s in self.samples
|
||
|
|
if s["task_id"] and self.task_samples[s["task_id"]]["organization"]
|
||
|
|
)["task_id"]
|
||
|
|
]["organization"]
|
||
|
|
return org_id, [
|
||
|
|
s
|
||
|
|
for s in self.samples
|
||
|
|
if s["task_id"]
|
||
|
|
and self.task_samples[s["task_id"]]["organization"] == org_id
|
||
|
|
or s["project_id"]
|
||
|
|
and self.project_samples[s["project_id"]]["organization"] == org_id
|
||
|
|
]
|
||
|
|
else:
|
||
|
|
return super()._get_field_samples(field)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(
|
||
|
|
"field",
|
||
|
|
(
|
||
|
|
"task_id",
|
||
|
|
"project_id",
|
||
|
|
"parent_type",
|
||
|
|
"inherit",
|
||
|
|
"org_id",
|
||
|
|
),
|
||
|
|
)
|
||
|
|
def test_can_use_simple_filter_for_object_list(self, field):
|
||
|
|
return super()._test_can_use_simple_filter_for_object_list(field)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("filter_name", ["project_id", "task_id"])
|
||
|
|
def test_cannot_use_object_id_filters_without_permissions(
|
||
|
|
self, is_project_staff, is_task_staff, projects, tasks, users, filter_name
|
||
|
|
):
|
||
|
|
# Find a project where the user doesn't have permissions
|
||
|
|
non_admin_user = next(
|
||
|
|
u["username"] for u in users if not u["is_superuser"] and u["username"] != self.user
|
||
|
|
)
|
||
|
|
|
||
|
|
if filter_name == "project_id":
|
||
|
|
samples = projects
|
||
|
|
is_staff = is_project_staff
|
||
|
|
elif filter_name == "task_id":
|
||
|
|
samples = tasks
|
||
|
|
is_staff = is_task_staff
|
||
|
|
else:
|
||
|
|
assert False
|
||
|
|
|
||
|
|
obj_id = next(obj["id"] for obj in samples if not is_staff(non_admin_user, obj["id"]))
|
||
|
|
|
||
|
|
with make_api_client(non_admin_user) as api_client:
|
||
|
|
response = api_client.quality_api.list_reports(
|
||
|
|
**{filter_name: obj_id}, _parse_response=False, _check_status=False
|
||
|
|
)[1]
|
||
|
|
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_class")
|
||
|
|
class TestListSettings(_PermissionTestBase):
|
||
|
|
def _test_list_settings_200(
|
||
|
|
self, user: str, task_id: int, *, expected_data: Optional[dict[str, Any]] = None, **kwargs
|
||
|
|
):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
actual = get_paginated_collection(
|
||
|
|
api_client.quality_api.list_settings_endpoint,
|
||
|
|
task_id=task_id,
|
||
|
|
**kwargs,
|
||
|
|
return_json=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
if expected_data is not None:
|
||
|
|
assert DeepDiff(expected_data, actual, ignore_order=True) == {}
|
||
|
|
|
||
|
|
def _test_list_settings_403(self, user: str, task_id: int, **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.list_settings(
|
||
|
|
task_id=task_id, **kwargs, _parse_response=False, _check_status=False
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_list_settings_in_sandbox(
|
||
|
|
self, quality_settings, find_sandbox_task, is_staff, allow
|
||
|
|
):
|
||
|
|
task, user = find_sandbox_task(is_staff)
|
||
|
|
|
||
|
|
settings = [s for s in quality_settings if s["task_id"] == task["id"]]
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_list_settings_200(
|
||
|
|
user["username"], task_id=task["id"], expected_data=settings
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self._test_list_settings_403(user["username"], task_id=task["id"])
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_list_settings_in_org_task(
|
||
|
|
self,
|
||
|
|
find_org_task,
|
||
|
|
org_role,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
quality_settings,
|
||
|
|
):
|
||
|
|
task, user = find_org_task(is_staff, org_role)
|
||
|
|
|
||
|
|
settings = [s for s in quality_settings if s["task_id"] == task["id"]]
|
||
|
|
org_id = task["organization"]
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_list_settings_200(
|
||
|
|
user["username"], task_id=task["id"], expected_data=settings, org_id=org_id
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self._test_list_settings_403(user["username"], task_id=task["id"], org_id=org_id)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_class")
|
||
|
|
class TestGetSettings(_PermissionTestBase):
|
||
|
|
def _test_get_settings_200(
|
||
|
|
self, user: str, obj_id: int, *, expected_data: Optional[dict[str, Any]] = None, **kwargs
|
||
|
|
):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.retrieve_settings(obj_id, **kwargs)
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
|
||
|
|
if expected_data is not None:
|
||
|
|
assert DeepDiff(expected_data, json.loads(response.data), ignore_order=True) == {}
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
def _test_get_settings_403(self, user: str, obj_id: int, **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.retrieve_settings(
|
||
|
|
obj_id, **kwargs, _parse_response=False, _check_status=False
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
def test_can_get_settings(self, admin_user, quality_settings):
|
||
|
|
settings = next(iter(quality_settings))
|
||
|
|
settings_id = settings["id"]
|
||
|
|
self._test_get_settings_200(admin_user, settings_id, expected_data=settings)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_get_settings_in_sandbox_task(
|
||
|
|
self, quality_settings, find_sandbox_task, is_staff, allow
|
||
|
|
):
|
||
|
|
task, user = find_sandbox_task(is_staff)
|
||
|
|
|
||
|
|
settings = next(s for s in quality_settings if s["task_id"] == task["id"])
|
||
|
|
settings_id = settings["id"]
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_get_settings_200(user["username"], settings_id, expected_data=settings)
|
||
|
|
else:
|
||
|
|
self._test_get_settings_403(user["username"], settings_id)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_get_settings_in_org_task(
|
||
|
|
self,
|
||
|
|
find_org_task,
|
||
|
|
org_role,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
quality_settings,
|
||
|
|
):
|
||
|
|
task, user = find_org_task(is_staff, org_role)
|
||
|
|
|
||
|
|
settings = next(s for s in quality_settings if s["task_id"] == task["id"])
|
||
|
|
settings_id = settings["id"]
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_get_settings_200(user["username"], settings_id, expected_data=settings)
|
||
|
|
else:
|
||
|
|
self._test_get_settings_403(user["username"], settings_id)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
class TestPatchSettings(_PermissionTestBase):
|
||
|
|
def _test_patch_settings_200(
|
||
|
|
self,
|
||
|
|
user: str,
|
||
|
|
obj_id: int,
|
||
|
|
data: dict[str, Any],
|
||
|
|
*,
|
||
|
|
expected_data: Optional[dict[str, Any]] = None,
|
||
|
|
**kwargs,
|
||
|
|
):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.partial_update_settings(
|
||
|
|
obj_id, patched_quality_settings_request=data, **kwargs
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
|
||
|
|
if expected_data is not None:
|
||
|
|
assert (
|
||
|
|
DeepDiff(
|
||
|
|
expected_data,
|
||
|
|
json.loads(response.data),
|
||
|
|
exclude_paths=["root['updated_date']"],
|
||
|
|
ignore_order=True,
|
||
|
|
)
|
||
|
|
== {}
|
||
|
|
)
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
def _test_patch_settings_403(self, user: str, obj_id: int, data: dict[str, Any], **kwargs):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.partial_update_settings(
|
||
|
|
obj_id,
|
||
|
|
patched_quality_settings_request=data,
|
||
|
|
**kwargs,
|
||
|
|
_parse_response=False,
|
||
|
|
_check_status=False,
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
def _get_request_data(self, data: dict[str, Any]) -> tuple[dict[str, Any], dict[str, Any]]:
|
||
|
|
patched_data = deepcopy(data)
|
||
|
|
|
||
|
|
for field, value in data.items():
|
||
|
|
if isinstance(value, bool):
|
||
|
|
patched_data[field] = not value
|
||
|
|
elif isinstance(value, float):
|
||
|
|
patched_data[field] = 1 - value
|
||
|
|
|
||
|
|
expected_data = deepcopy(patched_data)
|
||
|
|
|
||
|
|
return patched_data, expected_data
|
||
|
|
|
||
|
|
def test_can_patch_settings(self, admin_user, quality_settings):
|
||
|
|
settings = next(iter(quality_settings))
|
||
|
|
settings_id = settings["id"]
|
||
|
|
data, expected_data = self._get_request_data(settings)
|
||
|
|
self._test_patch_settings_200(admin_user, settings_id, data, expected_data=expected_data)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_patch_settings_in_sandbox_task(
|
||
|
|
self, quality_settings, find_sandbox_task, is_staff, allow
|
||
|
|
):
|
||
|
|
task, user = find_sandbox_task(is_staff)
|
||
|
|
|
||
|
|
settings = next(s for s in quality_settings if s["task_id"] == task["id"])
|
||
|
|
settings_id = settings["id"]
|
||
|
|
data, expected_data = self._get_request_data(settings)
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_patch_settings_200(
|
||
|
|
user["username"], settings_id, data, expected_data=expected_data
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self._test_patch_settings_403(user["username"], settings_id, data)
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_patch_settings_in_org_task(
|
||
|
|
self,
|
||
|
|
find_org_task,
|
||
|
|
org_role,
|
||
|
|
is_staff,
|
||
|
|
allow,
|
||
|
|
quality_settings,
|
||
|
|
):
|
||
|
|
task, user = find_org_task(is_staff, org_role)
|
||
|
|
|
||
|
|
settings = next(s for s in quality_settings if s["task_id"] == task["id"])
|
||
|
|
settings_id = settings["id"]
|
||
|
|
data, expected_data = self._get_request_data(settings)
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_patch_settings_200(
|
||
|
|
user["username"], settings_id, data, expected_data=expected_data
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self._test_patch_settings_403(user["username"], settings_id, data)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
class TestQualityReportMetrics(_PermissionTestBase):
|
||
|
|
demo_task_id = 22 # this task reproduces all the checkable cases
|
||
|
|
demo_task_id_multiple_jobs = 23 # this task reproduces cases for multiple jobs
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("task_id", [demo_task_id])
|
||
|
|
def test_report_summary(self, task_id, tasks, jobs, quality_reports):
|
||
|
|
gt_job = next(j for j in jobs if j["task_id"] == task_id and j["type"] == "ground_truth")
|
||
|
|
task = tasks[task_id]
|
||
|
|
report = next(r for r in quality_reports if r["task_id"] == task_id)
|
||
|
|
|
||
|
|
summary = report["summary"]
|
||
|
|
assert 0 < summary["conflict_count"]
|
||
|
|
assert all(summary["conflicts_by_type"].values())
|
||
|
|
assert summary["conflict_count"] == sum(summary["conflicts_by_type"].values())
|
||
|
|
assert summary["conflict_count"] == summary["warning_count"] + summary["error_count"]
|
||
|
|
assert 0 < summary["valid_count"]
|
||
|
|
assert summary["valid_count"] < summary["ds_count"]
|
||
|
|
assert summary["valid_count"] < summary["gt_count"]
|
||
|
|
assert summary["frame_count"] == gt_job["frame_count"]
|
||
|
|
assert summary["frame_share"] == summary["frame_count"] / task["size"]
|
||
|
|
|
||
|
|
def test_unmodified_task_produces_the_same_metrics(self, admin_user, quality_reports):
|
||
|
|
old_report = max(
|
||
|
|
(
|
||
|
|
r
|
||
|
|
for r in quality_reports
|
||
|
|
if r["task_id"] == self.demo_task_id
|
||
|
|
if r["target"] == "task"
|
||
|
|
),
|
||
|
|
key=lambda r: r["id"],
|
||
|
|
)
|
||
|
|
task_id = old_report["task_id"]
|
||
|
|
|
||
|
|
new_report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
(old_report_data, _) = api_client.quality_api.retrieve_report_data(old_report["id"])
|
||
|
|
(new_report_data, _) = api_client.quality_api.retrieve_report_data(new_report["id"])
|
||
|
|
|
||
|
|
assert (
|
||
|
|
DeepDiff(
|
||
|
|
new_report,
|
||
|
|
old_report,
|
||
|
|
ignore_order=True,
|
||
|
|
exclude_paths=["root['created_date']", "root['id']"],
|
||
|
|
)
|
||
|
|
== {}
|
||
|
|
)
|
||
|
|
assert (
|
||
|
|
DeepDiff(
|
||
|
|
new_report_data,
|
||
|
|
old_report_data,
|
||
|
|
ignore_order=True,
|
||
|
|
exclude_paths=[
|
||
|
|
"root['created_date']",
|
||
|
|
"root['id']",
|
||
|
|
"root['parameters']['included_annotation_types']",
|
||
|
|
],
|
||
|
|
)
|
||
|
|
== {}
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_modified_task_produces_different_metrics(
|
||
|
|
self, admin_user, quality_reports, jobs, labels
|
||
|
|
):
|
||
|
|
gt_job = next(
|
||
|
|
j for j in jobs if j["type"] == "ground_truth" and j["task_id"] == self.demo_task_id
|
||
|
|
)
|
||
|
|
task_id = gt_job["task_id"]
|
||
|
|
old_report = max(
|
||
|
|
(r for r in quality_reports if r["task_id"] == task_id), key=lambda r: r["id"]
|
||
|
|
)
|
||
|
|
job_labels = [
|
||
|
|
l
|
||
|
|
for l in labels
|
||
|
|
if l.get("task_id") == task_id
|
||
|
|
or gt_job.get("project_id")
|
||
|
|
and l.get("project_id") == gt_job.get("project_id")
|
||
|
|
if not l["parent_id"]
|
||
|
|
]
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.jobs_api.partial_update_annotations(
|
||
|
|
"create",
|
||
|
|
gt_job["id"],
|
||
|
|
patched_labeled_data_request=dict(
|
||
|
|
shapes=[
|
||
|
|
dict(
|
||
|
|
frame=gt_job["start_frame"],
|
||
|
|
label_id=job_labels[0]["id"],
|
||
|
|
type="rectangle",
|
||
|
|
points=[1, 1, 2, 2],
|
||
|
|
),
|
||
|
|
],
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
new_report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
assert new_report["summary"]["conflict_count"] > old_report["summary"]["conflict_count"]
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("task_id", [demo_task_id])
|
||
|
|
@pytest.mark.parametrize(
|
||
|
|
"parameter",
|
||
|
|
[
|
||
|
|
"check_covered_annotations",
|
||
|
|
"compare_attributes",
|
||
|
|
"compare_groups",
|
||
|
|
"group_match_threshold",
|
||
|
|
"iou_threshold",
|
||
|
|
"line_orientation_threshold",
|
||
|
|
"line_thickness",
|
||
|
|
"low_overlap_threshold",
|
||
|
|
"object_visibility_threshold",
|
||
|
|
"oks_sigma",
|
||
|
|
"compare_line_orientation",
|
||
|
|
"panoptic_comparison",
|
||
|
|
"point_size_base",
|
||
|
|
"empty_is_annotated",
|
||
|
|
],
|
||
|
|
)
|
||
|
|
def test_settings_affect_metrics(
|
||
|
|
self, admin_user, quality_reports, quality_settings, task_id, parameter
|
||
|
|
):
|
||
|
|
old_report = max(
|
||
|
|
(r for r in quality_reports if r["task_id"] == task_id), key=lambda r: r["id"]
|
||
|
|
)
|
||
|
|
task_id = old_report["task_id"]
|
||
|
|
|
||
|
|
settings = deepcopy(next(s for s in quality_settings if s["task_id"] == task_id))
|
||
|
|
if isinstance(settings[parameter], bool):
|
||
|
|
settings[parameter] = not settings[parameter]
|
||
|
|
elif isinstance(settings[parameter], float):
|
||
|
|
settings[parameter] = 1 - settings[parameter]
|
||
|
|
if parameter == "group_match_threshold":
|
||
|
|
settings[parameter] = 0.9
|
||
|
|
elif parameter == "point_size_base":
|
||
|
|
settings[parameter] = next(
|
||
|
|
v
|
||
|
|
for v in models.QualityPointSizeBase.allowed_values[("value",)].values()
|
||
|
|
if v != settings[parameter]
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
assert False
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.quality_api.partial_update_settings(
|
||
|
|
settings["id"], patched_quality_settings_request=settings
|
||
|
|
)
|
||
|
|
|
||
|
|
new_report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
if parameter == "empty_is_annotated":
|
||
|
|
assert new_report["summary"]["valid_count"] != old_report["summary"]["valid_count"]
|
||
|
|
assert new_report["summary"]["total_count"] != old_report["summary"]["total_count"]
|
||
|
|
assert new_report["summary"]["ds_count"] != old_report["summary"]["ds_count"]
|
||
|
|
assert new_report["summary"]["gt_count"] != old_report["summary"]["gt_count"]
|
||
|
|
else:
|
||
|
|
assert (
|
||
|
|
new_report["summary"]["conflict_count"] != old_report["summary"]["conflict_count"]
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_old_report_can_be_loaded(self, admin_user, quality_reports):
|
||
|
|
report = min((r for r in quality_reports if r["task_id"]), key=lambda r: r["id"])
|
||
|
|
assert report["created_date"] < "2024"
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
(report_data, _) = api_client.quality_api.retrieve_report_data(report["id"])
|
||
|
|
|
||
|
|
# This report should have been created before the Jaccard index was included.
|
||
|
|
for d in [report_data["comparison_summary"], *report_data["frame_results"].values()]:
|
||
|
|
assert d["annotations"]["confusion_matrix"]["jaccard_index"] is None
|
||
|
|
|
||
|
|
def test_accumulation_annotation_conflicts_multiple_jobs(self, admin_user):
|
||
|
|
report = self.create_quality_report(
|
||
|
|
user=admin_user, task_id=self.demo_task_id_multiple_jobs
|
||
|
|
)
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report_data(report["id"])
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
report_data = json.loads(response.data)
|
||
|
|
task_confusion_matrix = report_data["comparison_summary"]["annotations"][
|
||
|
|
"confusion_matrix"
|
||
|
|
]["rows"]
|
||
|
|
|
||
|
|
expected_frame_confusion_matrix = {
|
||
|
|
"5": [[1, 0, 0], [0, 0, 0], [0, 0, 0]],
|
||
|
|
"7": [[1, 0, 0], [0, 0, 0], [0, 0, 0]],
|
||
|
|
"4": [[0, 0, 1], [0, 0, 0], [1, 0, 0]],
|
||
|
|
}
|
||
|
|
for frame_id in report_data["frame_results"].keys():
|
||
|
|
assert (
|
||
|
|
report_data["frame_results"][frame_id]["annotations"]["confusion_matrix"]["rows"]
|
||
|
|
== expected_frame_confusion_matrix[frame_id]
|
||
|
|
)
|
||
|
|
|
||
|
|
assert task_confusion_matrix == [[2, 0, 1], [0, 0, 0], [1, 0, 0]]
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("task_id", [8])
|
||
|
|
def test_can_compute_quality_if_non_skeleton_label_follows_skeleton_label(
|
||
|
|
self, admin_user, labels, task_id
|
||
|
|
):
|
||
|
|
new_label_name = "non_skeleton"
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
task_labels = [label for label in labels if label.get("task_id") == task_id]
|
||
|
|
assert any(label["type"] == "skeleton" for label in task_labels)
|
||
|
|
task_labels += [{"name": new_label_name, "type": "any"}]
|
||
|
|
api_client.tasks_api.partial_update(
|
||
|
|
task_id,
|
||
|
|
patched_task_write_request=models.PatchedTaskWriteRequest(labels=task_labels),
|
||
|
|
)
|
||
|
|
|
||
|
|
new_label_obj, _ = api_client.labels_api.list(task_id=task_id, name=new_label_name)
|
||
|
|
new_label_id = new_label_obj.results[0].id
|
||
|
|
api_client.tasks_api.update_annotations(
|
||
|
|
task_id,
|
||
|
|
labeled_data_request={
|
||
|
|
"shapes": [
|
||
|
|
models.LabeledShapeRequest(
|
||
|
|
type="rectangle",
|
||
|
|
frame=0,
|
||
|
|
label_id=new_label_id,
|
||
|
|
points=[0, 0, 1, 1],
|
||
|
|
)
|
||
|
|
]
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task_id)
|
||
|
|
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report_data(report["id"])
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
|
||
|
|
def test_excluded_gt_job_frames_are_not_included_in_honeypot_task_quality_report(
|
||
|
|
self, admin_user, tasks, jobs
|
||
|
|
):
|
||
|
|
task_id = next(t["id"] for t in tasks if t["validation_mode"] == "gt_pool")
|
||
|
|
gt_job = next(j for j in jobs if j["task_id"] == task_id if j["type"] == "ground_truth")
|
||
|
|
gt_job_frames = range(gt_job["start_frame"], gt_job["stop_frame"] + 1)
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
gt_job_meta, _ = api_client.jobs_api.retrieve_data_meta(gt_job["id"])
|
||
|
|
gt_frame_names = [f.name for f in gt_job_meta.frames]
|
||
|
|
|
||
|
|
task_meta, _ = api_client.tasks_api.retrieve_data_meta(task_id)
|
||
|
|
honeypot_frames = [
|
||
|
|
i
|
||
|
|
for i, f in enumerate(task_meta.frames)
|
||
|
|
if f.name in gt_frame_names and i not in gt_job_frames
|
||
|
|
]
|
||
|
|
gt_frame_uses = {
|
||
|
|
name: (gt_job["start_frame"] + gt_frame_names.index(name), list(ids))
|
||
|
|
for name, ids in groupby(
|
||
|
|
sorted(
|
||
|
|
[
|
||
|
|
i
|
||
|
|
for i in range(task_meta.size)
|
||
|
|
if task_meta.frames[i].name in gt_frame_names
|
||
|
|
],
|
||
|
|
key=lambda i: task_meta.frames[i].name,
|
||
|
|
),
|
||
|
|
key=lambda i: task_meta.frames[i].name,
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
api_client.jobs_api.partial_update(
|
||
|
|
gt_job["id"],
|
||
|
|
patched_job_write_request=models.PatchedJobWriteRequest(
|
||
|
|
stage="acceptance", state="completed"
|
||
|
|
),
|
||
|
|
)
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report_data(report["id"])
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
assert honeypot_frames == json.loads(response.data)["comparison_summary"]["frames"]
|
||
|
|
|
||
|
|
excluded_gt_frame, excluded_gt_frame_honeypots = next(
|
||
|
|
(i, honeypots) for i, honeypots in gt_frame_uses.values() if len(honeypots) > 1
|
||
|
|
)
|
||
|
|
api_client.jobs_api.partial_update_data_meta(
|
||
|
|
gt_job["id"],
|
||
|
|
patched_job_data_meta_write_request=models.PatchedJobDataMetaWriteRequest(
|
||
|
|
deleted_frames=[excluded_gt_frame]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report_data(report["id"])
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
assert [
|
||
|
|
v for v in honeypot_frames if v not in excluded_gt_frame_honeypots
|
||
|
|
] == json.loads(response.data)["comparison_summary"]["frames"]
|
||
|
|
|
||
|
|
@pytest.mark.parametrize("task_id", [23])
|
||
|
|
def test_excluded_gt_job_frames_are_not_included_in_simple_gt_job_task_quality_report(
|
||
|
|
self, admin_user, task_id: int, jobs
|
||
|
|
):
|
||
|
|
gt_job = next(j for j in jobs if j["task_id"] == task_id if j["type"] == "ground_truth")
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
gt_job_meta, _ = api_client.jobs_api.retrieve_data_meta(gt_job["id"])
|
||
|
|
gt_frames = [
|
||
|
|
(f - gt_job_meta.start_frame) // parse_frame_step(gt_job_meta.frame_filter)
|
||
|
|
for f in gt_job_meta.included_frames
|
||
|
|
]
|
||
|
|
|
||
|
|
api_client.jobs_api.partial_update(
|
||
|
|
gt_job["id"],
|
||
|
|
patched_job_write_request=models.PatchedJobWriteRequest(
|
||
|
|
stage="acceptance", state="completed"
|
||
|
|
),
|
||
|
|
)
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report_data(report["id"])
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
assert gt_frames == json.loads(response.data)["comparison_summary"]["frames"]
|
||
|
|
|
||
|
|
excluded_gt_frame = gt_frames[0]
|
||
|
|
api_client.jobs_api.partial_update_data_meta(
|
||
|
|
gt_job["id"],
|
||
|
|
patched_job_data_meta_write_request=models.PatchedJobDataMetaWriteRequest(
|
||
|
|
deleted_frames=[excluded_gt_frame]
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
(_, response) = api_client.quality_api.retrieve_report_data(report["id"])
|
||
|
|
assert response.status == HTTPStatus.OK
|
||
|
|
assert [f for f in gt_frames if f != excluded_gt_frame] == json.loads(response.data)[
|
||
|
|
"comparison_summary"
|
||
|
|
]["frames"]
|
||
|
|
|
||
|
|
def test_quality_metrics_in_task_with_gt_and_tracks(
|
||
|
|
self,
|
||
|
|
admin_user,
|
||
|
|
tasks,
|
||
|
|
labels,
|
||
|
|
):
|
||
|
|
task_id = next(
|
||
|
|
t["id"]
|
||
|
|
for t in tasks
|
||
|
|
if not t["validation_mode"] and t["size"] >= 5 and not t["project_id"]
|
||
|
|
)
|
||
|
|
label_id = next(l["id"] for l in labels if l.get("task_id") == task_id)
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
gt_frames = [1, 3]
|
||
|
|
gt_job = api_client.jobs_api.create(
|
||
|
|
job_write_request=models.JobWriteRequest(
|
||
|
|
type="ground_truth",
|
||
|
|
task_id=task_id,
|
||
|
|
frame_selection_method="manual",
|
||
|
|
frames=gt_frames,
|
||
|
|
)
|
||
|
|
)[0]
|
||
|
|
|
||
|
|
gt_annotations = {
|
||
|
|
"shapes": [
|
||
|
|
{
|
||
|
|
"frame": 1,
|
||
|
|
"label_id": label_id,
|
||
|
|
"points": [0.5, 1.5, 2.5, 3.5],
|
||
|
|
"rotation": 0,
|
||
|
|
"type": "rectangle",
|
||
|
|
"occluded": False,
|
||
|
|
"outside": False,
|
||
|
|
"attributes": [],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"frame": 3,
|
||
|
|
"label_id": label_id,
|
||
|
|
"points": [3.0, 4.0, 5.0, 6.0],
|
||
|
|
"rotation": 0,
|
||
|
|
"type": "rectangle",
|
||
|
|
"occluded": False,
|
||
|
|
"outside": False,
|
||
|
|
"attributes": [],
|
||
|
|
},
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
normal_annotations = {
|
||
|
|
"tracks": [
|
||
|
|
{
|
||
|
|
"type": "rectangle",
|
||
|
|
"frame": 0,
|
||
|
|
"label_id": label_id,
|
||
|
|
"shapes": [
|
||
|
|
{
|
||
|
|
"frame": 0,
|
||
|
|
"points": [1.0, 2.0, 3.0, 4.0],
|
||
|
|
"rotation": 0,
|
||
|
|
"type": "rectangle",
|
||
|
|
"occluded": False,
|
||
|
|
"outside": False,
|
||
|
|
"attributes": [],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"frame": 2, # not included, but must affect interpolation
|
||
|
|
"points": [0.0, 1.0, 2.0, 3.0],
|
||
|
|
"rotation": 0,
|
||
|
|
"type": "rectangle",
|
||
|
|
"occluded": False,
|
||
|
|
"outside": False,
|
||
|
|
"attributes": [],
|
||
|
|
},
|
||
|
|
{
|
||
|
|
"frame": 4,
|
||
|
|
"points": [6.0, 7.0, 8.0, 9.0],
|
||
|
|
"rotation": 0,
|
||
|
|
"type": "rectangle",
|
||
|
|
"occluded": False,
|
||
|
|
"outside": False,
|
||
|
|
"attributes": [],
|
||
|
|
},
|
||
|
|
],
|
||
|
|
}
|
||
|
|
]
|
||
|
|
}
|
||
|
|
|
||
|
|
api_client.jobs_api.update_annotations(gt_job.id, labeled_data_request=gt_annotations)
|
||
|
|
|
||
|
|
api_client.tasks_api.update_annotations(
|
||
|
|
task_id, labeled_data_request=normal_annotations
|
||
|
|
)
|
||
|
|
|
||
|
|
api_client.jobs_api.partial_update(
|
||
|
|
gt_job.id,
|
||
|
|
patched_job_write_request=models.PatchedJobWriteRequest(
|
||
|
|
stage="acceptance", state="completed"
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
assert report["summary"]["conflict_count"] == 0
|
||
|
|
assert report["summary"]["valid_count"] == 2
|
||
|
|
assert report["summary"]["total_count"] == 2
|
||
|
|
|
||
|
|
def test_project_report_aggregates_nested_task_reports(self, quality_reports, tasks, jobs):
|
||
|
|
project_report = max(
|
||
|
|
(r for r in quality_reports if r["target"] == "project"), key=lambda r: r["id"]
|
||
|
|
)
|
||
|
|
task_reports = [r for r in quality_reports if r["parent_id"] == project_report["id"]]
|
||
|
|
|
||
|
|
tasks_with_configured_validation = [
|
||
|
|
t
|
||
|
|
for t in tasks
|
||
|
|
if t["project_id"] == project_report["project_id"]
|
||
|
|
if any(
|
||
|
|
j
|
||
|
|
for j in jobs
|
||
|
|
if j["task_id"] == t["id"]
|
||
|
|
and j["type"] == "ground_truth"
|
||
|
|
and j["state"] == "completed"
|
||
|
|
and j["stage"] == "acceptance"
|
||
|
|
)
|
||
|
|
]
|
||
|
|
assert len(task_reports) == len(tasks_with_configured_validation)
|
||
|
|
|
||
|
|
# Base fields of confusion matrix are scaled by inverse task validation frame share.
|
||
|
|
# This is needed to make the derived averaged metrics (accuracy etc.)
|
||
|
|
# for projects consistent with the averaged metrics in tasks in cases
|
||
|
|
# with different validation frame shares in project tasks
|
||
|
|
task_weights = {
|
||
|
|
r["task_id"]: 1 / r["summary"]["validation_frame_share"] for r in task_reports
|
||
|
|
}
|
||
|
|
summary = project_report["summary"]
|
||
|
|
for confusion_field in ["valid_count", "total_count", "ds_count", "gt_count"]:
|
||
|
|
assert summary[confusion_field] == sum(
|
||
|
|
math.ceil(r["summary"][confusion_field] * task_weights[r["task_id"]])
|
||
|
|
for r in task_reports
|
||
|
|
)
|
||
|
|
|
||
|
|
assert summary["accuracy"] == summary["valid_count"] / summary["total_count"]
|
||
|
|
assert summary["precision"] == summary["valid_count"] / summary["ds_count"]
|
||
|
|
assert summary["recall"] == summary["valid_count"] / summary["gt_count"]
|
||
|
|
|
||
|
|
# other summary fields are simply aggregated
|
||
|
|
for summary_field in [
|
||
|
|
"conflict_count",
|
||
|
|
"error_count",
|
||
|
|
"warning_count",
|
||
|
|
"total_frames",
|
||
|
|
"validation_frames",
|
||
|
|
]:
|
||
|
|
assert summary[summary_field] == sum(r["summary"][summary_field] for r in task_reports)
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
class TestPostProjectQualityReports(_PermissionTestBase):
|
||
|
|
def _test_create_report_200(self, user: str, project_id: int):
|
||
|
|
return self.create_quality_report(user=user, project_id=project_id)
|
||
|
|
|
||
|
|
def _test_create_report_403(self, user: str, project_id: int):
|
||
|
|
with make_api_client(user) as api_client:
|
||
|
|
(_, response) = api_client.quality_api.create_report(
|
||
|
|
quality_report_create_request=models.QualityReportCreateRequest(
|
||
|
|
project_id=project_id
|
||
|
|
),
|
||
|
|
_parse_response=False,
|
||
|
|
_check_status=False,
|
||
|
|
)
|
||
|
|
assert response.status == HTTPStatus.FORBIDDEN
|
||
|
|
|
||
|
|
return response
|
||
|
|
|
||
|
|
def test_can_create_project_report(self, admin_user, projects, tasks, labels):
|
||
|
|
project = next(
|
||
|
|
p
|
||
|
|
for p in projects
|
||
|
|
if p["tasks"]["count"] > 0
|
||
|
|
and any(l for l in labels if l.get("project_id") == p["id"])
|
||
|
|
and not any(t["validation_mode"] for t in tasks if t.get("project_id") == p["id"])
|
||
|
|
)
|
||
|
|
project_id = project["id"]
|
||
|
|
|
||
|
|
# Create GT jobs for all tasks in the project
|
||
|
|
tasks = [t for t in tasks if t.get("project_id") == project_id]
|
||
|
|
for task in tasks:
|
||
|
|
self.create_gt_job(admin_user, task["id"])
|
||
|
|
|
||
|
|
# Create project report
|
||
|
|
report = self.create_quality_report(user=admin_user, project_id=project_id)
|
||
|
|
|
||
|
|
# Check report data
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
report_data, _ = api_client.quality_api.retrieve_report_data(report["id"])
|
||
|
|
|
||
|
|
for r in [report, report_data]:
|
||
|
|
# Verify report was created
|
||
|
|
assert r["project_id"] == project_id
|
||
|
|
assert r.get("task_id") is None
|
||
|
|
assert r.get("job_id") is None
|
||
|
|
assert r.get("parent_id") is None
|
||
|
|
|
||
|
|
# Verify child reports were created
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
child_reports = get_paginated_collection(
|
||
|
|
api_client.quality_api.list_reports_endpoint,
|
||
|
|
parent_id=report["id"],
|
||
|
|
target="task",
|
||
|
|
return_json=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert len(child_reports) == len(tasks)
|
||
|
|
for child in child_reports:
|
||
|
|
assert child["parent_id"] == report["id"]
|
||
|
|
|
||
|
|
def test_can_create_project_report_in_empty_project(self, admin_user, projects):
|
||
|
|
project = next(p for p in projects if p["tasks"]["count"] == 0)
|
||
|
|
project_id = project["id"]
|
||
|
|
|
||
|
|
report = self.create_quality_report(user=admin_user, project_id=project_id)
|
||
|
|
|
||
|
|
assert report["project_id"] == project_id
|
||
|
|
assert report["summary"]["total_count"] == 0
|
||
|
|
assert report["summary"]["tasks"]["total"] == 0
|
||
|
|
|
||
|
|
def test_can_create_project_report_when_there_are_tasks_without_validation(
|
||
|
|
self, admin_user, projects, tasks, labels
|
||
|
|
):
|
||
|
|
project = next(
|
||
|
|
p
|
||
|
|
for p in projects
|
||
|
|
if p["tasks"]["count"] > 0
|
||
|
|
and any(l for l in labels if l.get("project_id") == p["id"])
|
||
|
|
and not any(t["validation_mode"] for t in tasks if t.get("project_id") == p["id"])
|
||
|
|
)
|
||
|
|
project_id = project["id"]
|
||
|
|
|
||
|
|
self.create_quality_report(user=admin_user, project_id=project_id)
|
||
|
|
|
||
|
|
def test_can_create_project_report_when_there_are_tasks_without_configured_gt(
|
||
|
|
self, admin_user, projects, tasks, jobs, labels
|
||
|
|
):
|
||
|
|
project = next(
|
||
|
|
p
|
||
|
|
for p in projects
|
||
|
|
if p["tasks"]["count"] > 1
|
||
|
|
and any(l for l in labels if l.get("project_id") == p["id"])
|
||
|
|
and any(t["validation_mode"] for t in tasks if t.get("project_id") == p["id"])
|
||
|
|
)
|
||
|
|
project_id = project["id"]
|
||
|
|
|
||
|
|
# Create GT jobs for 1 task in the project
|
||
|
|
task = next(t for t in tasks if t.get("project_id") == project_id and t["validation_mode"])
|
||
|
|
gt_job = next(j for j in jobs if j["type"] == "ground_truth" and j["task_id"] == task["id"])
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.jobs_api.partial_update(
|
||
|
|
gt_job["id"],
|
||
|
|
patched_job_write_request=models.PatchedJobWriteRequest(
|
||
|
|
stage="annotation", state="new"
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create project report
|
||
|
|
self.create_quality_report(user=admin_user, project_id=project_id)
|
||
|
|
|
||
|
|
def test_can_reuse_relevant_task_reports_in_project_report(
|
||
|
|
self, admin_user, projects, tasks, labels, quality_settings, quality_reports
|
||
|
|
):
|
||
|
|
project = next(
|
||
|
|
p
|
||
|
|
for p in projects
|
||
|
|
if any(r["project_id"] == p["id"] for r in quality_reports)
|
||
|
|
if p["tasks"]["count"] >= 2
|
||
|
|
if all(t["size"] > 0 for t in tasks if t["project_id"] == p["id"])
|
||
|
|
if any(l for l in labels if l.get("project_id") == p["id"])
|
||
|
|
if any(t["validation_mode"] for t in tasks if t.get("project_id") == p["id"])
|
||
|
|
if all(
|
||
|
|
s["inherit"]
|
||
|
|
for s in quality_settings
|
||
|
|
if s["task_id"]
|
||
|
|
if tasks[s["task_id"]]["project_id"] == p["id"]
|
||
|
|
)
|
||
|
|
)
|
||
|
|
project_id = project["id"]
|
||
|
|
|
||
|
|
project_tasks = sorted(
|
||
|
|
[t for t in tasks if t.get("project_id") == project_id], key=lambda t: t["id"]
|
||
|
|
)
|
||
|
|
|
||
|
|
latest_project_reports = sorted(
|
||
|
|
[r for r in quality_reports if r["project_id"] == project_id], key=lambda r: -r["id"]
|
||
|
|
)
|
||
|
|
latest_project_report = next(r for r in latest_project_reports if r["target"] == "project")
|
||
|
|
latest_project_reports = [
|
||
|
|
r for r in latest_project_reports if r["parent_id"] == latest_project_report["id"]
|
||
|
|
]
|
||
|
|
latest_task_reports = {
|
||
|
|
task_id: next(task_reports)
|
||
|
|
for task_id, task_reports in groupby(
|
||
|
|
sorted(latest_project_reports, key=lambda r: (r["task_id"], -r["id"])),
|
||
|
|
key=lambda r: r["task_id"],
|
||
|
|
)
|
||
|
|
}
|
||
|
|
|
||
|
|
# Create project report before task changes
|
||
|
|
new_report_before_task_changes = self.create_quality_report(
|
||
|
|
user=admin_user, project_id=project_id
|
||
|
|
)
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
task_reports_in_new_report_before_task_changes = {
|
||
|
|
r["id"]
|
||
|
|
for r in get_paginated_collection(
|
||
|
|
api_client.quality_api.list_reports_endpoint,
|
||
|
|
parent_id=new_report_before_task_changes["id"],
|
||
|
|
target="task",
|
||
|
|
)
|
||
|
|
}
|
||
|
|
assert task_reports_in_new_report_before_task_changes == set(
|
||
|
|
r["id"] for r in latest_task_reports.values()
|
||
|
|
)
|
||
|
|
|
||
|
|
# Modify one of the tasks
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
modified_task_id = project_tasks[0]["id"]
|
||
|
|
api_client.tasks_api.update_annotations(
|
||
|
|
modified_task_id, labeled_data_request={"shapes": []}
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create new project report after task changes
|
||
|
|
new_report_after_task_changes = self.create_quality_report(
|
||
|
|
user=admin_user, project_id=project_id
|
||
|
|
)
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
task_reports_in_new_report_after_task_changes = {
|
||
|
|
(r["id"], r["task_id"])
|
||
|
|
for r in get_paginated_collection(
|
||
|
|
api_client.quality_api.list_reports_endpoint,
|
||
|
|
parent_id=new_report_after_task_changes["id"],
|
||
|
|
target="task",
|
||
|
|
)
|
||
|
|
}
|
||
|
|
assert set(
|
||
|
|
r for r in task_reports_in_new_report_after_task_changes if r[1] != modified_task_id
|
||
|
|
) == set(
|
||
|
|
(r["id"], r["task_id"])
|
||
|
|
for task_id, r in latest_task_reports.items()
|
||
|
|
if task_id != modified_task_id
|
||
|
|
)
|
||
|
|
assert (
|
||
|
|
latest_task_reports[modified_task_id]["id"],
|
||
|
|
modified_task_id,
|
||
|
|
) not in task_reports_in_new_report_after_task_changes
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_sandbox_cases)
|
||
|
|
def test_user_create_project_report_in_sandbox(self, is_staff, allow, find_sandbox_project):
|
||
|
|
project, user = find_sandbox_project(is_staff)
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_create_report_200(user=user["username"], project_id=project["id"])
|
||
|
|
else:
|
||
|
|
self._test_create_report_403(user=user["username"], project_id=project["id"])
|
||
|
|
|
||
|
|
@pytest.mark.parametrize(*_PermissionTestBase._default_org_cases)
|
||
|
|
def test_user_create_project_report_in_org(self, org_role, is_staff, allow, find_org_project):
|
||
|
|
project, user = find_org_project(is_staff, org_role)
|
||
|
|
|
||
|
|
if allow:
|
||
|
|
self._test_create_report_200(user=user["username"], project_id=project["id"])
|
||
|
|
else:
|
||
|
|
self._test_create_report_403(user=user["username"], project_id=project["id"])
|
||
|
|
|
||
|
|
def test_cannot_create_report_with_both_task_and_project_id(self, admin_user, tasks):
|
||
|
|
task = next(t for t in tasks if t.get("project_id") is not None)
|
||
|
|
task_id = task["id"]
|
||
|
|
project_id = task["project_id"]
|
||
|
|
|
||
|
|
with pytest.raises(exceptions.ApiException) as e:
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.quality_api.create_report(
|
||
|
|
quality_report_create_request=models.QualityReportCreateRequest(
|
||
|
|
task_id=task_id, project_id=project_id
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
assert HTTPStatus.BAD_REQUEST == e.value.status
|
||
|
|
assert "Only 1 of the fields" in e.value.body
|
||
|
|
|
||
|
|
|
||
|
|
@pytest.mark.usefixtures("restore_db_per_function")
|
||
|
|
class TestProjectQualitySettingsBehavior(_PermissionTestBase):
|
||
|
|
@pytest.mark.parametrize("inherit", [True, False])
|
||
|
|
def test_can_inherit_project_settings_in_task_report(
|
||
|
|
self, admin_user, tasks, quality_settings, inherit: bool
|
||
|
|
):
|
||
|
|
task = next(
|
||
|
|
t for t in tasks if t.get("project_id") is not None and t.get("validation_mode") is None
|
||
|
|
)
|
||
|
|
task_id = task["id"]
|
||
|
|
project_id = task["project_id"]
|
||
|
|
|
||
|
|
self.create_gt_job(admin_user, task_id)
|
||
|
|
|
||
|
|
project_settings = next(s for s in quality_settings if s["project_id"] == project_id)
|
||
|
|
task_settings = next(s for s in quality_settings if s["task_id"] == task_id)
|
||
|
|
|
||
|
|
with make_api_client(admin_user) as api_client:
|
||
|
|
api_client.quality_api.partial_update_settings(
|
||
|
|
task_settings["id"],
|
||
|
|
patched_quality_settings_request={
|
||
|
|
"inherit": inherit,
|
||
|
|
"empty_is_annotated": inherit,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
api_client.quality_api.partial_update_settings(
|
||
|
|
project_settings["id"],
|
||
|
|
patched_quality_settings_request={
|
||
|
|
"empty_is_annotated": inherit,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
# Create task report
|
||
|
|
task_report = self.create_quality_report(user=admin_user, task_id=task_id)
|
||
|
|
|
||
|
|
# Get report data to verify settings were inherited
|
||
|
|
task_report_data = api_client.quality_api.retrieve_report_data(task_report["id"])[0]
|
||
|
|
assert task_report_data["parameters"]["empty_is_annotated"] == inherit
|
||
|
|
assert task_report_data["parameters"]["inherited"] == inherit
|