cvat/cvat-sdk/cvat_sdk/core/downloading.py

199 lines
6.1 KiB
Python

# Copyright (C) 2020-2022 Intel Corporation
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import json
import os
import re
from contextlib import closing
from pathlib import Path
from typing import TYPE_CHECKING, Any, Optional
import urllib3
from cvat_sdk.api_client.api_client import Endpoint
from cvat_sdk.core.exceptions import CvatSdkException
from cvat_sdk.core.helpers import expect_status
from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter
from cvat_sdk.core.utils import atomic_writer
if TYPE_CHECKING:
from cvat_sdk.core.client import Client
class Downloader:
"""
Implements common downloading protocols
"""
def __init__(self, client: Client):
self._client = client
@classmethod
def _validate_filename(cls, filename: str) -> str | None:
# Allow only meaningful and valid filenames for the user OS.
if len(filename) > 254:
return None
stem, ext = os.path.splitext(filename)
if not stem or len(ext) < 2:
return None
if filename.startswith(".") or re.search(r"[^A-Za-z0-9_\-\. ]", filename):
return None
return filename
@classmethod
def _get_server_filename(cls, response: urllib3.HTTPResponse) -> str:
# Header format specification:
# https://datatracker.ietf.org/doc/html/rfc2616#section-19.5.1
content_disposition = next(
(
parameter
for part in response.headers.get("Content-Disposition", "").split(";")
if (parameter := part.strip()) and parameter.lower().startswith("filename=")
),
None,
)
filename = None
if content_disposition:
filename = content_disposition.split("=", maxsplit=1)[1].strip('"')
filename = cls._validate_filename(filename)
if not filename:
raise CvatSdkException(
"Can't find the output filename in the server response, "
"please try to specify the output filename explicitly"
)
return filename
def download_file(
self,
url: str,
output_path: Path,
*,
timeout: int = 60,
pbar: Optional[ProgressReporter] = None,
) -> Path:
"""
Downloads the file from url into a temporary file, then renames it to the requested name.
If output_path is a directory, saves the file into the directory with
the server-defined name.
Returns: path to the downloaded file
"""
CHUNK_SIZE = 10 * 2**20
if output_path.is_file():
raise FileExistsError(output_path)
if pbar is None:
pbar = NullProgressReporter()
response = self._client.api_client.rest_client.GET(
url,
_request_timeout=timeout,
headers=self._client.api_client.get_common_headers(),
_parse_response=False,
)
with closing(response):
try:
file_size = int(response.headers.get("Content-Length", 0))
except ValueError:
file_size = None
if output_path.is_dir():
output_path /= self._get_server_filename(response)
if output_path.exists():
raise FileExistsError(output_path)
with (
atomic_writer(output_path, "wb") as fd,
pbar.task(
total=file_size,
desc="Downloading",
unit_scale=True,
unit="B",
unit_divisor=1024,
),
):
while True:
chunk = response.read(amt=CHUNK_SIZE, decode_content=False)
if not chunk:
break
pbar.advance(len(chunk))
fd.write(chunk)
return output_path
def prepare_file(
self,
endpoint: Endpoint,
*,
url_params: Optional[dict[str, Any]] = None,
query_params: Optional[dict[str, Any]] = None,
status_check_period: Optional[int] = None,
):
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
client.logger.info("Waiting for the server to prepare the file...")
url = client.api_map.make_endpoint_url(
endpoint.path, kwsub=url_params, query_params=query_params
)
# initialize background process
response = client.api_client.rest_client.request(
method=endpoint.settings["http_method"],
url=url,
headers=client.api_client.get_common_headers(),
)
client.logger.debug("STATUS %s", response.status)
expect_status(202, response)
rq_id = json.loads(response.data).get("rq_id")
assert rq_id, "Request identifier was not found in server response"
# wait until background process will be finished or failed
request, response = client.wait_for_completion(
rq_id, status_check_period=status_check_period
)
return request
def prepare_and_download_file_from_endpoint(
self,
endpoint: Endpoint,
filename: Path,
*,
url_params: Optional[dict[str, Any]] = None,
query_params: Optional[dict[str, Any]] = None,
pbar: Optional[ProgressReporter] = None,
status_check_period: Optional[int] = None,
) -> Path:
client = self._client
if status_check_period is None:
status_check_period = client.config.status_check_period
export_request = self.prepare_file(
endpoint,
url_params=url_params,
query_params=query_params,
status_check_period=status_check_period,
)
assert export_request.result_url, "Result url was not found in server response"
return self.download_file(export_request.result_url, output_path=filename, pbar=pbar)