cvat/tests/python/rest_api/utils.py

562 lines
18 KiB
Python
Raw Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import json
from abc import ABCMeta, abstractmethod
from collections.abc import Hashable, Iterator, Sequence
from copy import deepcopy
from http import HTTPStatus
from io import BytesIO
from time import sleep
from typing import Any, Callable, Iterable, Optional, TypeVar, Union
import requests
from cvat_sdk.api_client import apis, models
from cvat_sdk.api_client.api.jobs_api import JobsApi
from cvat_sdk.api_client.api.projects_api import ProjectsApi
from cvat_sdk.api_client.api.tasks_api import TasksApi
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
from cvat_sdk.api_client.exceptions import ForbiddenException
from cvat_sdk.core.helpers import get_paginated_collection
from deepdiff import DeepDiff
from urllib3 import HTTPResponse
from shared.utils.config import USER_PASS, make_api_client, post_method
DEFAULT_RETRIES = 50
DEFAULT_INTERVAL = 0.1
def initialize_export(endpoint: Endpoint, *, expect_forbidden: bool = False, **kwargs) -> str:
(_, response) = endpoint.call_with_http_info(
**kwargs, _parse_response=False, _check_status=False
)
if expect_forbidden:
assert (
response.status == HTTPStatus.FORBIDDEN
), f"Request should be forbidden, status: {response.status}"
raise ForbiddenException()
assert response.status == HTTPStatus.ACCEPTED, f"Status: {response.status}"
# define background request ID returned in the server response
rq_id = json.loads(response.data).get("rq_id")
assert rq_id, "The rq_id parameter was not found in the server response"
return rq_id
def wait_background_request(
api_client: ApiClient,
rq_id: str,
*,
max_retries: int = DEFAULT_RETRIES,
interval: float = DEFAULT_INTERVAL,
) -> tuple[models.Request, HTTPResponse]:
for _ in range(max_retries):
(background_request, response) = api_client.requests_api.retrieve(rq_id)
assert response.status == HTTPStatus.OK
if (
background_request.status.value
== models.RequestStatus.allowed_values[("value",)]["FINISHED"]
):
return background_request, response
sleep(interval)
assert False, (
f"Export process was not finished within allowed time ({interval * max_retries}, sec). "
+ f"Last status was: {background_request.status.value}"
)
def wait_and_download_v2(
api_client: ApiClient,
rq_id: str,
*,
max_retries: int = DEFAULT_RETRIES,
interval: float = DEFAULT_INTERVAL,
) -> bytes:
background_request, _ = wait_background_request(
api_client, rq_id, max_retries=max_retries, interval=interval
)
# return downloaded file in case of local downloading
assert background_request.result_url
headers = api_client.get_common_headers()
query_params = []
api_client.update_params_for_auth(headers=headers, queries=query_params)
assert not query_params # query auth is not expected
response = requests.get(background_request.result_url, headers=headers)
assert response.status_code == HTTPStatus.OK, f"Status: {response.status_code}"
return response.content
def export_v2(
endpoint: Endpoint,
*,
max_retries: int = DEFAULT_RETRIES,
interval: float = DEFAULT_INTERVAL,
expect_forbidden: bool = False,
wait_result: bool = True,
download_result: bool = True,
**kwargs,
) -> Union[bytes, str]:
"""Export datasets|annotations|backups using the second version of export API
Args:
endpoint (Endpoint): Export endpoint, will be called only to initialize export process
max_retries (int, optional): Number of retries when checking process status. Defaults to 30.
interval (float, optional): Interval in seconds between retries. Defaults to 0.1.
expect_forbidden (bool, optional): Should export request be forbidden or not. Defaults to False.
download_result (bool, optional): Download exported file. Defaults to True.
Returns:
bytes: The content of the file if downloaded locally.
str: If `wait_result` or `download_result` were False.
"""
# initialize background process and ensure that the first request returns 403 code if request should be forbidden
rq_id = initialize_export(endpoint, expect_forbidden=expect_forbidden, **kwargs)
if not wait_result:
return rq_id
# check status of background process
if download_result:
return wait_and_download_v2(
endpoint.api_client,
rq_id,
max_retries=max_retries,
interval=interval,
)
background_request, _ = wait_background_request(
endpoint.api_client, rq_id, max_retries=max_retries, interval=interval
)
return background_request.id
def export_dataset(
api: Union[ProjectsApi, TasksApi, JobsApi],
*,
save_images: bool,
max_retries: int = 300,
interval: float = DEFAULT_INTERVAL,
format: str = "CVAT for images 1.1", # pylint: disable=redefined-builtin
**kwargs,
) -> Optional[bytes]:
return export_v2(
api.create_dataset_export_endpoint,
max_retries=max_retries,
interval=interval,
save_images=save_images,
format=format,
**kwargs,
)
# FUTURE-TODO: support username: optional, api_client: optional
# tODO: make func signature more userfrendly
def export_project_dataset(username: str, *args, **kwargs) -> Optional[bytes]:
with make_api_client(username) as api_client:
return export_dataset(api_client.projects_api, *args, **kwargs)
def export_task_dataset(username: str, *args, **kwargs) -> Optional[bytes]:
with make_api_client(username) as api_client:
return export_dataset(api_client.tasks_api, *args, **kwargs)
def export_job_dataset(username: str, *args, **kwargs) -> Optional[bytes]:
with make_api_client(username) as api_client:
return export_dataset(api_client.jobs_api, *args, **kwargs)
def export_backup(
api: Union[ProjectsApi, TasksApi],
*,
max_retries: int = DEFAULT_RETRIES,
interval: float = DEFAULT_INTERVAL,
**kwargs,
) -> Optional[bytes]:
endpoint = api.create_backup_export_endpoint
return export_v2(endpoint, max_retries=max_retries, interval=interval, **kwargs)
def export_project_backup(username: str, *args, **kwargs) -> Optional[bytes]:
with make_api_client(username) as api_client:
return export_backup(api_client.projects_api, *args, **kwargs)
def export_task_backup(username: str, *args, **kwargs) -> Optional[bytes]:
with make_api_client(username) as api_client:
return export_backup(api_client.tasks_api, *args, **kwargs)
def import_resource(
endpoint: Endpoint,
*,
max_retries: int = DEFAULT_RETRIES,
interval: float = DEFAULT_INTERVAL,
expect_forbidden: bool = False,
wait_result: bool = True,
**kwargs,
) -> Optional[models.Request]:
# initialize background process and ensure that the first request returns 403 code if request should be forbidden
(_, response) = endpoint.call_with_http_info(
**kwargs,
_parse_response=False,
_check_status=False,
_content_type="multipart/form-data",
)
if expect_forbidden:
assert response.status == HTTPStatus.FORBIDDEN, "Request should be forbidden"
raise ForbiddenException()
assert response.status == HTTPStatus.ACCEPTED
if not wait_result:
return None
# define background request ID returned in the server response
rq_id = json.loads(response.data).get("rq_id")
assert rq_id, "The rq_id parameter was not found in the server response"
# check status of background process
for _ in range(max_retries):
(background_request, response) = endpoint.api_client.requests_api.retrieve(rq_id)
assert response.status == HTTPStatus.OK
if background_request.status.value in (
models.RequestStatus.allowed_values[("value",)]["FINISHED"],
models.RequestStatus.allowed_values[("value",)]["FAILED"],
):
break
sleep(interval)
else:
assert False, (
f"Import process was not finished within allowed time ({interval * max_retries}, sec). "
+ f"Last status was: {background_request.status.value}"
)
return background_request
def import_backup(
api: Union[ProjectsApi, TasksApi],
*,
max_retries: int = DEFAULT_RETRIES,
interval: float = DEFAULT_INTERVAL,
**kwargs,
):
endpoint = api.create_backup_endpoint
return import_resource(endpoint, max_retries=max_retries, interval=interval, **kwargs)
def import_project_backup(username: str, file_content: BytesIO, **kwargs):
with make_api_client(username) as api_client:
return import_backup(
api_client.projects_api, project_file_request={"project_file": file_content}, **kwargs
)
def import_task_backup(username: str, file_content: BytesIO, **kwargs):
with make_api_client(username) as api_client:
return import_backup(
api_client.tasks_api, task_file_request={"task_file": file_content}, **kwargs
)
def import_project_dataset(username: str, file_content: BytesIO, **kwargs):
with make_api_client(username) as api_client:
return import_resource(
api_client.projects_api.create_dataset_endpoint,
dataset_file_request={"dataset_file": file_content},
**kwargs,
)
def import_task_annotations(username: str, file_content: BytesIO, **kwargs):
with make_api_client(username) as api_client:
return import_resource(
api_client.tasks_api.create_annotations_endpoint,
annotation_file_request={"annotation_file": file_content},
**kwargs,
)
def import_job_annotations(username: str, file_content: BytesIO, **kwargs):
with make_api_client(username) as api_client:
return import_resource(
api_client.jobs_api.create_annotations_endpoint,
annotation_file_request={"annotation_file": file_content},
**kwargs,
)
FieldPath = Sequence[Union[str, Callable]]
class CollectionSimpleFilterTestBase(metaclass=ABCMeta):
# These fields need to be defined in the subclass
user: str
samples: list[dict[str, Any]]
field_lookups: dict[str, FieldPath] = None
cmp_ignore_keys: list[str] = ["updated_date"]
@abstractmethod
def _get_endpoint(self, api_client: ApiClient) -> Endpoint: ...
def _retrieve_collection(self, **kwargs) -> list:
kwargs["return_json"] = True
with make_api_client(self.user) as api_client:
return get_paginated_collection(self._get_endpoint(api_client), **kwargs)
@classmethod
def _get_field(cls, d: dict[str, Any], path: Union[str, FieldPath]) -> Optional[Any]:
assert path
for key in path:
if isinstance(d, dict):
assert isinstance(key, str)
d = d.get(key)
else:
if callable(key):
assert isinstance(d, str)
d = key(d)
else:
d = None
return d
def _map_field(self, name: str) -> FieldPath:
return (self.field_lookups or {}).get(name, [name])
@classmethod
def _find_valid_field_value(
cls, samples: Iterator[dict[str, Any]], field_path: FieldPath
) -> Any:
value = None
for sample in samples:
value = cls._get_field(sample, field_path)
if value:
break
assert value, f"Failed to find a sample for the '{'.'.join(field_path)}' field"
return value
def _get_field_samples(self, field: str) -> tuple[Any, list[dict[str, Any]]]:
field_path = self._map_field(field)
field_value = self._find_valid_field_value(self.samples, field_path)
gt_objects = filter(lambda p: field_value == self._get_field(p, field_path), self.samples)
return field_value, gt_objects
def _compare_results(self, gt_objects, received_objects):
if self.cmp_ignore_keys:
ignore_keys = [f"root['{k}']" for k in self.cmp_ignore_keys]
else:
ignore_keys = None
diff = DeepDiff(
list(gt_objects),
received_objects,
exclude_paths=ignore_keys,
ignore_order=True,
)
assert diff == {}, diff
def _test_can_use_simple_filter_for_object_list(
self, field: str, field_values: Optional[list[Any]] = None
):
gt_objects = []
field_path = self._map_field(field)
if not field_values:
value, gt_objects = self._get_field_samples(field)
field_values = [value]
are_gt_objects_initialized = bool(gt_objects)
for value in field_values:
if not are_gt_objects_initialized:
gt_objects = [
sample
for sample in self.samples
if value == self._get_field(sample, field_path)
]
received_items = self._retrieve_collection(**{field: value})
self._compare_results(gt_objects, received_items)
def get_attrs(obj: Any, attributes: Sequence[str]) -> tuple[Any, ...]:
"""Returns 1 or more object attributes as a tuple"""
return (getattr(obj, attr) for attr in attributes)
def build_exclude_paths_expr(ignore_fields: Iterator[str]) -> list[str]:
exclude_expr_parts = []
for key in ignore_fields:
if "." in key:
key_parts = key.split(".")
expr = r"root\['{}'\]".format(key_parts[0])
expr += "".join(r"\[.*\]\['{}'\]".format(part) for part in key_parts[1:])
else:
expr = r"root\['{}'\]".format(key)
exclude_expr_parts.append(expr)
return exclude_expr_parts
def wait_until_task_is_created(api: apis.RequestsApi, rq_id: str) -> models.Request:
for _ in range(100):
(request_details, _) = api.retrieve(rq_id)
if request_details.status.value in ("finished", "failed"):
return request_details
sleep(1)
raise Exception("Cannot create task")
def create_task(username, spec, data, content_type="application/json", **kwargs):
with make_api_client(username) as api_client:
(task, response_) = api_client.tasks_api.create(spec, **kwargs)
assert response_.status == HTTPStatus.CREATED
sent_upload_start = False
data_kwargs = (kwargs or {}).copy()
data_kwargs.pop("org", None)
data_kwargs.pop("org_id", None)
if data.get("client_files") and "json" in content_type:
(_, response) = api_client.tasks_api.create_data(
task.id,
data_request=models.DataRequest(image_quality=data["image_quality"]),
upload_start=True,
_content_type=content_type,
**data_kwargs,
)
assert response.status == HTTPStatus.ACCEPTED
sent_upload_start = True
# Can't encode binary files in json
(_, response) = api_client.tasks_api.create_data(
task.id,
data_request=models.DataRequest(
client_files=data["client_files"],
image_quality=data["image_quality"],
),
upload_multiple=True,
_content_type="multipart/form-data",
**data_kwargs,
)
assert response.status == HTTPStatus.OK
data = data.copy()
del data["client_files"]
last_kwargs = {}
if sent_upload_start:
last_kwargs["upload_finish"] = True
(result, response) = api_client.tasks_api.create_data(
task.id,
data_request=deepcopy(data),
_content_type=content_type,
**data_kwargs,
**last_kwargs,
)
assert response.status == HTTPStatus.ACCEPTED
request_details = wait_until_task_is_created(api_client.requests_api, result.rq_id)
assert request_details.status.value == "finished", request_details.message
return task.id, response_.headers.get("X-Request-Id")
def compare_annotations(a: dict, b: dict) -> dict:
def _exclude_cb(obj, path: str):
# ignoring track elements which do not have shapes
split_path = path.rsplit("['elements']", maxsplit=1)
if len(split_path) == 2:
if split_path[1].count("[") == 1 and not obj["shapes"]:
return True
return path.endswith("['elements']") and not obj
return DeepDiff(
a,
b,
ignore_order=True,
significant_digits=2, # annotations are stored with 2 decimal digit precision
exclude_obj_callback=_exclude_cb,
exclude_regex_paths=[
r"root\['version|updated_date'\]",
r"root(\['\w+'\]\[\d+\])+\['id'\]",
r"root(\['\w+'\]\[\d+\])+\['label_id'\]",
r"root(\['\w+'\]\[\d+\])+\['attributes'\]\[\d+\]\['spec_id'\]",
r"root(\['\w+'\]\[\d+\])+\['source'\]",
],
)
DATUMARO_FORMAT_FOR_DIMENSION = {
"2d": "Datumaro 1.0",
"3d": "Datumaro 3D 1.0",
}
def calc_end_frame(start_frame: int, stop_frame: int, frame_step: int) -> int:
return stop_frame - ((stop_frame - start_frame) % frame_step) + frame_step
_T = TypeVar("_T")
def unique(
it: Union[Iterator[_T], Iterable[_T]], *, key: Callable[[_T], Hashable] = None
) -> Iterable[_T]:
return {key(v): v for v in it}.values()
def register_new_user(username: str) -> dict[str, Any]:
response = post_method(
"admin1",
"auth/register",
data={
"username": username,
"password1": USER_PASS,
"password2": USER_PASS,
"email": f"{username}@email.com",
},
)
assert response.status_code == HTTPStatus.CREATED
return response.json()
def invite_user_to_org(
user_email: str,
org_id: int,
role: str,
):
with make_api_client("admin1") as api_client:
invitation, _ = api_client.invitations_api.create(
models.InvitationWriteRequest(
role=role,
email=user_email,
),
org_id=org_id,
)
return invitation
def get_cloud_storage_content(username: str, cloud_storage_id: int, manifest: Optional[str] = None):
with make_api_client(username) as api_client:
kwargs = {"manifest_path": manifest} if manifest else {}
(data, _) = api_client.cloudstorages_api.retrieve_content_v2(cloud_storage_id, **kwargs)
return [f"{f['name']}{'/' if str(f['type']) == 'DIR' else ''}" for f in data["content"]]