102 lines
3.1 KiB
Python
102 lines
3.1 KiB
Python
|
|
# Copyright (C) CVAT.ai Corporation
|
||
|
|
#
|
||
|
|
# SPDX-License-Identifier: MIT
|
||
|
|
import io
|
||
|
|
import zipfile
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Optional, Union
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from cvat_sdk.core.proxies.jobs import Job
|
||
|
|
from cvat_sdk.core.proxies.projects import Project
|
||
|
|
from cvat_sdk.core.proxies.tasks import Task
|
||
|
|
from cvat_sdk.core.proxies.types import Location
|
||
|
|
|
||
|
|
from shared.fixtures.data import CloudStorageAssets
|
||
|
|
from shared.utils.config import IMPORT_EXPORT_BUCKET_ID
|
||
|
|
from shared.utils.s3 import S3Client
|
||
|
|
from shared.utils.s3 import make_client as make_s3_client
|
||
|
|
|
||
|
|
from .util import make_pbar
|
||
|
|
|
||
|
|
ProjectOrTaskOrJob = Union[Project, Task, Job]
|
||
|
|
|
||
|
|
|
||
|
|
class TestDatasetExport:
|
||
|
|
def _test_export_locally(
|
||
|
|
self,
|
||
|
|
resource: ProjectOrTaskOrJob,
|
||
|
|
*,
|
||
|
|
format_name: str,
|
||
|
|
file_path: Path,
|
||
|
|
**export_kwargs,
|
||
|
|
):
|
||
|
|
pbar_out = io.StringIO()
|
||
|
|
pbar = make_pbar(file=pbar_out)
|
||
|
|
resource.export_dataset(format_name, file_path, pbar=pbar, **export_kwargs)
|
||
|
|
assert self.stdout.getvalue() == ""
|
||
|
|
assert "100%" in pbar_out.getvalue().strip("\r").split("\r")[-1]
|
||
|
|
assert file_path.is_file()
|
||
|
|
|
||
|
|
def _test_export_to_cloud_storage(
|
||
|
|
self,
|
||
|
|
resource: ProjectOrTaskOrJob,
|
||
|
|
*,
|
||
|
|
format_name: str,
|
||
|
|
file_path: Path,
|
||
|
|
cs_client: S3Client,
|
||
|
|
**export_kwargs,
|
||
|
|
):
|
||
|
|
resource.export_dataset(format_name, file_path, **export_kwargs)
|
||
|
|
assert self.stdout.getvalue() == ""
|
||
|
|
dataset = cs_client.download_fileobj(str(file_path))
|
||
|
|
assert zipfile.is_zipfile(io.BytesIO(dataset))
|
||
|
|
|
||
|
|
def _test_can_export_dataset(
|
||
|
|
self,
|
||
|
|
resource: ProjectOrTaskOrJob,
|
||
|
|
*,
|
||
|
|
format_name: str,
|
||
|
|
file_path: Path,
|
||
|
|
include_images: bool,
|
||
|
|
location: Optional[Location],
|
||
|
|
request: pytest.FixtureRequest,
|
||
|
|
cloud_storages: CloudStorageAssets,
|
||
|
|
):
|
||
|
|
kwargs = {
|
||
|
|
"include_images": include_images,
|
||
|
|
"location": location,
|
||
|
|
}
|
||
|
|
|
||
|
|
expected_locally = (
|
||
|
|
location == Location.LOCAL
|
||
|
|
or not location
|
||
|
|
and (
|
||
|
|
not resource.target_storage
|
||
|
|
or resource.target_storage.location.value == Location.LOCAL
|
||
|
|
)
|
||
|
|
)
|
||
|
|
|
||
|
|
if expected_locally:
|
||
|
|
self._test_export_locally(
|
||
|
|
resource, format_name=format_name, file_path=file_path, **kwargs
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
bucket = next(cs for cs in cloud_storages if cs["id"] == IMPORT_EXPORT_BUCKET_ID)[
|
||
|
|
"resource"
|
||
|
|
]
|
||
|
|
s3_client = make_s3_client(bucket=bucket)
|
||
|
|
request.addfinalizer(lambda: s3_client.remove_file(filename=str(file_path)))
|
||
|
|
self._test_export_to_cloud_storage(
|
||
|
|
resource,
|
||
|
|
format_name=format_name,
|
||
|
|
file_path=file_path,
|
||
|
|
cs_client=s3_client,
|
||
|
|
**(
|
||
|
|
{"cloud_storage_id": IMPORT_EXPORT_BUCKET_ID}
|
||
|
|
if location == Location.CLOUD_STORAGE
|
||
|
|
else {}
|
||
|
|
),
|
||
|
|
**kwargs,
|
||
|
|
)
|