397 lines
14 KiB
Python
397 lines
14 KiB
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
from contextlib import AbstractContextManager
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
import requests
|
|
import urllib3
|
|
|
|
from cvat_sdk.api_client.api_client import ApiClient, Endpoint
|
|
from cvat_sdk.api_client.exceptions import ApiException
|
|
from cvat_sdk.api_client.rest import RESTClientObject
|
|
from cvat_sdk.core.helpers import StreamWithProgress, expect_status
|
|
from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter
|
|
|
|
if TYPE_CHECKING:
|
|
from cvat_sdk.core.client import Client
|
|
|
|
import tusclient.uploader as tus_uploader
|
|
from tusclient.client import TusClient as _TusClient
|
|
from tusclient.client import Uploader as _TusUploader
|
|
from tusclient.request import TusRequest as _TusRequest
|
|
from tusclient.request import TusUploadFailed as _TusUploadFailed
|
|
|
|
MAX_REQUEST_SIZE = 100 * 2**20
|
|
|
|
|
|
class _RestClientAdapter:
|
|
# Provides requests.Session-like interface for REST client
|
|
# only patch is called in the tus client
|
|
|
|
def __init__(self, rest_client: RESTClientObject):
|
|
self.rest_client = rest_client
|
|
|
|
def _request(self, method, url, data=None, json=None, **kwargs):
|
|
raw = self.rest_client.request(
|
|
method=method,
|
|
url=url,
|
|
headers=kwargs.get("headers"),
|
|
query_params=kwargs.get("params"),
|
|
post_params=json,
|
|
body=data,
|
|
_parse_response=False,
|
|
_request_timeout=kwargs.get("timeout"),
|
|
_check_status=False,
|
|
)
|
|
|
|
result = requests.Response()
|
|
result._content = raw.data
|
|
result.raw = raw
|
|
result.headers.update(raw.headers)
|
|
result.status_code = raw.status
|
|
result.reason = raw.msg
|
|
return result
|
|
|
|
def patch(self, *args, **kwargs):
|
|
return self._request("PATCH", *args, **kwargs)
|
|
|
|
|
|
class _MyTusUploader(_TusUploader):
|
|
# Adjusts the library code for CVAT server
|
|
# Allows to reuse session
|
|
|
|
def __init__(self, *_args, api_client: ApiClient, **_kwargs):
|
|
self._api_client = api_client
|
|
super().__init__(*_args, **_kwargs)
|
|
|
|
def _do_request(self):
|
|
self.request = _TusRequest(self)
|
|
self.request.handle = _RestClientAdapter(self._api_client.rest_client)
|
|
try:
|
|
self.request.perform()
|
|
self.verify_upload()
|
|
except _TusUploadFailed as error:
|
|
self._retry_or_cry(error)
|
|
|
|
@tus_uploader._catch_requests_error
|
|
def create_url(self):
|
|
"""
|
|
Return upload url.
|
|
|
|
Makes request to tus server to create a new upload url for the required file upload.
|
|
"""
|
|
headers = self.headers
|
|
headers["upload-length"] = str(self.file_size)
|
|
headers["upload-metadata"] = ",".join(self.encode_metadata())
|
|
resp = self._api_client.rest_client.POST(self.client.url, headers=headers)
|
|
self.real_filename = resp.headers.get("Upload-Filename")
|
|
url = resp.headers.get("location")
|
|
if url is None:
|
|
msg = "Attempt to retrieve create file url with status {}".format(resp.status_code)
|
|
raise tus_uploader.TusCommunicationError(msg, resp.status_code, resp.content)
|
|
return tus_uploader.urljoin(self.client.url, url)
|
|
|
|
@tus_uploader._catch_requests_error
|
|
def get_offset(self):
|
|
"""
|
|
Return offset from tus server.
|
|
|
|
This is different from the instance attribute 'offset' because this makes an
|
|
http request to the tus server to retrieve the offset.
|
|
"""
|
|
try:
|
|
resp = self._api_client.rest_client.HEAD(self.url, headers=self.headers)
|
|
except ApiException as ex:
|
|
if ex.status == 405: # Method Not Allowed
|
|
# In CVAT up to version 2.2.0, HEAD requests were internally
|
|
# converted to GET by mod_wsgi, and subsequently rejected by the server.
|
|
# For compatibility with old servers, we'll handle such rejections by
|
|
# restarting the upload from the beginning.
|
|
return 0
|
|
|
|
raise tus_uploader.TusCommunicationError(
|
|
f"Attempt to retrieve offset failed with status {ex.status}",
|
|
ex.status,
|
|
ex.body,
|
|
) from ex
|
|
|
|
offset = resp.headers.get("upload-offset")
|
|
if offset is None:
|
|
raise tus_uploader.TusCommunicationError(
|
|
f"Attempt to retrieve offset failed with status {resp.status}",
|
|
resp.status,
|
|
resp.data,
|
|
)
|
|
|
|
return int(offset)
|
|
|
|
|
|
class Uploader:
|
|
"""
|
|
Implements common uploading protocols
|
|
"""
|
|
|
|
_CHUNK_SIZE = 10 * 2**20
|
|
|
|
def __init__(self, client: Client):
|
|
self._client = client
|
|
|
|
def upload_file(
|
|
self,
|
|
url: str,
|
|
filename: Path,
|
|
*,
|
|
meta: dict[str, Any],
|
|
query_params: dict[str, Any] = None,
|
|
fields: Optional[dict[str, Any]] = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
logger=None,
|
|
) -> urllib3.HTTPResponse:
|
|
"""
|
|
Annotation uploads:
|
|
- have "filename" meta field in chunks
|
|
- have "filename" and "format" query params in the "Upload-Finished" request
|
|
|
|
|
|
Data (image, video, ...) uploads:
|
|
- have "filename" meta field in chunks
|
|
- have a number of fields in the "Upload-Finished" request
|
|
|
|
|
|
Backup uploads:
|
|
- have "filename" meta field in chunks
|
|
- have "filename" query params in the "Upload-Finished" request
|
|
|
|
OR
|
|
- have "task_file" field in the POST request data (a file)
|
|
|
|
meta['filename'] is always required. It must be set to the "visible" file name or path
|
|
|
|
Returns:
|
|
response of the last request (the "Upload-Finished" one)
|
|
"""
|
|
# "CVAT-TUS" protocol has 2 extra messages
|
|
# query params are used only in the extra messages
|
|
assert meta["filename"]
|
|
|
|
if pbar is None:
|
|
pbar = NullProgressReporter()
|
|
|
|
file_size = filename.stat().st_size
|
|
|
|
self._tus_start_upload(url, query_params=query_params)
|
|
with self._uploading_task(pbar, file_size):
|
|
real_filename = self._upload_file_data_with_tus(
|
|
url=url, filename=filename, meta=meta, pbar=pbar, logger=logger
|
|
)
|
|
query_params["filename"] = real_filename
|
|
return self._tus_finish_upload(url, query_params=query_params, fields=fields)
|
|
|
|
@staticmethod
|
|
def _uploading_task(pbar: ProgressReporter, total_size: int) -> AbstractContextManager[None]:
|
|
return pbar.task(
|
|
total=total_size, desc="Uploading data", unit_scale=True, unit="B", unit_divisor=1024
|
|
)
|
|
|
|
@staticmethod
|
|
def _make_tus_uploader(api_client: ApiClient, url: str, **kwargs):
|
|
# Add headers required by CVAT server
|
|
headers = {}
|
|
headers["Origin"] = api_client.configuration.host
|
|
headers.update(api_client.get_common_headers())
|
|
|
|
client = _TusClient(url, headers=headers)
|
|
|
|
return _MyTusUploader(client=client, api_client=api_client, **kwargs)
|
|
|
|
def _upload_file_data_with_tus(self, url, filename, *, meta=None, pbar, logger=None) -> str:
|
|
with open(filename, "rb") as input_file:
|
|
tus_uploader = self._make_tus_uploader(
|
|
self._client.api_client,
|
|
url=url.rstrip("/") + "/",
|
|
metadata=meta,
|
|
file_stream=StreamWithProgress(input_file, pbar),
|
|
chunk_size=Uploader._CHUNK_SIZE,
|
|
log_func=logger,
|
|
)
|
|
tus_uploader.upload()
|
|
return tus_uploader.real_filename
|
|
|
|
def _tus_start_upload(self, url, *, query_params=None):
|
|
response = self._client.api_client.rest_client.POST(
|
|
url,
|
|
query_params=query_params,
|
|
headers={
|
|
"Upload-Start": "",
|
|
**self._client.api_client.get_common_headers(),
|
|
},
|
|
)
|
|
expect_status(202, response)
|
|
return response
|
|
|
|
def _tus_finish_upload(self, url, *, query_params=None, fields=None):
|
|
response = self._client.api_client.rest_client.POST(
|
|
url,
|
|
headers={
|
|
"Upload-Finish": "",
|
|
**self._client.api_client.get_common_headers(),
|
|
},
|
|
query_params=query_params,
|
|
post_params=fields,
|
|
)
|
|
expect_status(202, response)
|
|
return response
|
|
|
|
|
|
class AnnotationUploader(Uploader):
|
|
def upload_file_and_wait(
|
|
self,
|
|
endpoint: Endpoint,
|
|
filename: Path,
|
|
format_name: str,
|
|
*,
|
|
conv_mask_to_poly: Optional[bool] = None,
|
|
url_params: Optional[dict[str, Any]] = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
status_check_period: Optional[int] = None,
|
|
):
|
|
url = self._client.api_map.make_endpoint_url(endpoint.path, kwsub=url_params)
|
|
params = {"format": format_name, "filename": filename.name}
|
|
response = self.upload_file(
|
|
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
|
|
)
|
|
|
|
if conv_mask_to_poly is not None:
|
|
params["conv_mask_to_poly"] = "true" if conv_mask_to_poly else "false"
|
|
|
|
rq_id = json.loads(response.data).get("rq_id")
|
|
assert rq_id, "The rq_id was not found in the response"
|
|
|
|
self._client.wait_for_completion(rq_id, status_check_period=status_check_period)
|
|
|
|
|
|
class DatasetUploader(Uploader):
|
|
def upload_file_and_wait(
|
|
self,
|
|
upload_endpoint: Endpoint,
|
|
filename: Path,
|
|
format_name: str,
|
|
*,
|
|
url_params: Optional[dict[str, Any]] = None,
|
|
conv_mask_to_poly: Optional[bool] = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
status_check_period: Optional[int] = None,
|
|
):
|
|
url = self._client.api_map.make_endpoint_url(upload_endpoint.path, kwsub=url_params)
|
|
params = {"format": format_name, "filename": filename.name}
|
|
|
|
if conv_mask_to_poly is not None:
|
|
params["conv_mask_to_poly"] = "true" if conv_mask_to_poly else "false"
|
|
|
|
response = self.upload_file(
|
|
url, filename, pbar=pbar, query_params=params, meta={"filename": params["filename"]}
|
|
)
|
|
rq_id = json.loads(response.data).get("rq_id")
|
|
assert rq_id, "The rq_id was not found in the response"
|
|
|
|
self._client.wait_for_completion(rq_id, status_check_period=status_check_period)
|
|
|
|
|
|
class DataUploader(Uploader):
|
|
def __init__(self, client: Client, *, max_request_size: int = MAX_REQUEST_SIZE):
|
|
super().__init__(client)
|
|
self.max_request_size = max_request_size
|
|
|
|
def upload_files(
|
|
self,
|
|
url: str,
|
|
resources: list[Path],
|
|
*,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
**kwargs,
|
|
):
|
|
bulk_file_groups, separate_files, total_size = self._split_files_by_requests(resources)
|
|
|
|
if pbar is None:
|
|
pbar = NullProgressReporter()
|
|
|
|
if str(kwargs.get("sorting_method")).lower() == "predefined":
|
|
# Request file ordering, because we reorder files to send more efficiently
|
|
kwargs.setdefault("upload_file_order", [p.name for p in resources])
|
|
|
|
with self._uploading_task(pbar, total_size):
|
|
self._tus_start_upload(url)
|
|
|
|
for group, group_size in bulk_file_groups:
|
|
files = {}
|
|
for i, filename in enumerate(group):
|
|
files[f"client_files[{i}]"] = (
|
|
os.fspath(filename),
|
|
filename.read_bytes(),
|
|
)
|
|
response = self._client.api_client.rest_client.POST(
|
|
url,
|
|
post_params={"image_quality": kwargs["image_quality"], **files},
|
|
headers={
|
|
"Content-Type": "multipart/form-data",
|
|
"Upload-Multiple": "",
|
|
**self._client.api_client.get_common_headers(),
|
|
},
|
|
)
|
|
expect_status(200, response)
|
|
|
|
pbar.advance(group_size)
|
|
|
|
for filename in separate_files:
|
|
self._upload_file_data_with_tus(
|
|
url,
|
|
filename,
|
|
meta={"filename": filename.name},
|
|
pbar=pbar,
|
|
logger=self._client.logger.debug,
|
|
)
|
|
|
|
return self._tus_finish_upload(url, fields=kwargs)
|
|
|
|
def _split_files_by_requests(
|
|
self, filenames: list[Path]
|
|
) -> tuple[list[tuple[list[Path], int]], list[Path], int]:
|
|
bulk_files: dict[str, int] = {}
|
|
separate_files: dict[str, int] = {}
|
|
max_request_size = self.max_request_size
|
|
|
|
# sort by size
|
|
for filename in filenames:
|
|
filename = filename.resolve()
|
|
file_size = filename.stat().st_size
|
|
if max_request_size < file_size:
|
|
separate_files[filename] = file_size
|
|
else:
|
|
bulk_files[filename] = file_size
|
|
|
|
total_size = sum(bulk_files.values()) + sum(separate_files.values())
|
|
|
|
# group small files by requests
|
|
bulk_file_groups: list[tuple[list[str], int]] = []
|
|
current_group_size: int = 0
|
|
current_group: list[str] = []
|
|
for filename, file_size in bulk_files.items():
|
|
if max_request_size < current_group_size + file_size:
|
|
bulk_file_groups.append((current_group, current_group_size))
|
|
current_group_size = 0
|
|
current_group = []
|
|
|
|
current_group.append(filename)
|
|
current_group_size += file_size
|
|
if current_group:
|
|
bulk_file_groups.append((current_group, current_group_size))
|
|
|
|
return bulk_file_groups, separate_files, total_size
|