cvat/tests/python/sdk/common.py

102 lines
3.1 KiB
Python
Raw Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import io
import zipfile
from pathlib import Path
from typing import Optional, Union
import pytest
from cvat_sdk.core.proxies.jobs import Job
from cvat_sdk.core.proxies.projects import Project
from cvat_sdk.core.proxies.tasks import Task
from cvat_sdk.core.proxies.types import Location
from shared.fixtures.data import CloudStorageAssets
from shared.utils.config import IMPORT_EXPORT_BUCKET_ID
from shared.utils.s3 import S3Client
from shared.utils.s3 import make_client as make_s3_client
from .util import make_pbar
ProjectOrTaskOrJob = Union[Project, Task, Job]
class TestDatasetExport:
def _test_export_locally(
self,
resource: ProjectOrTaskOrJob,
*,
format_name: str,
file_path: Path,
**export_kwargs,
):
pbar_out = io.StringIO()
pbar = make_pbar(file=pbar_out)
resource.export_dataset(format_name, file_path, pbar=pbar, **export_kwargs)
assert self.stdout.getvalue() == ""
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
assert file_path.is_file()
def _test_export_to_cloud_storage(
self,
resource: ProjectOrTaskOrJob,
*,
format_name: str,
file_path: Path,
cs_client: S3Client,
**export_kwargs,
):
resource.export_dataset(format_name, file_path, **export_kwargs)
assert self.stdout.getvalue() == ""
dataset = cs_client.download_fileobj(str(file_path))
assert zipfile.is_zipfile(io.BytesIO(dataset))
def _test_can_export_dataset(
self,
resource: ProjectOrTaskOrJob,
*,
format_name: str,
file_path: Path,
include_images: bool,
location: Optional[Location],
request: pytest.FixtureRequest,
cloud_storages: CloudStorageAssets,
):
kwargs = {
"include_images": include_images,
"location": location,
}
expected_locally = (
location == Location.LOCAL
or not location
and (
not resource.target_storage
or resource.target_storage.location.value == Location.LOCAL
)
)
if expected_locally:
self._test_export_locally(
resource, format_name=format_name, file_path=file_path, **kwargs
)
else:
bucket = next(cs for cs in cloud_storages if cs["id"] == IMPORT_EXPORT_BUCKET_ID)[
"resource"
]
s3_client = make_s3_client(bucket=bucket)
request.addfinalizer(lambda: s3_client.remove_file(filename=str(file_path)))
self._test_export_to_cloud_storage(
resource,
format_name=format_name,
file_path=file_path,
cs_client=s3_client,
**(
{"cloud_storage_id": IMPORT_EXPORT_BUCKET_ID}
if location == Location.CLOUD_STORAGE
else {}
),
**kwargs,
)