199 lines
6.1 KiB
Python
199 lines
6.1 KiB
Python
# Copyright (C) 2020-2022 Intel Corporation
|
|
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import os
|
|
import re
|
|
from contextlib import closing
|
|
from pathlib import Path
|
|
from typing import TYPE_CHECKING, Any, Optional
|
|
|
|
import urllib3
|
|
|
|
from cvat_sdk.api_client.api_client import Endpoint
|
|
from cvat_sdk.core.exceptions import CvatSdkException
|
|
from cvat_sdk.core.helpers import expect_status
|
|
from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter
|
|
from cvat_sdk.core.utils import atomic_writer
|
|
|
|
if TYPE_CHECKING:
|
|
from cvat_sdk.core.client import Client
|
|
|
|
|
|
class Downloader:
|
|
"""
|
|
Implements common downloading protocols
|
|
"""
|
|
|
|
def __init__(self, client: Client):
|
|
self._client = client
|
|
|
|
@classmethod
|
|
def _validate_filename(cls, filename: str) -> str | None:
|
|
# Allow only meaningful and valid filenames for the user OS.
|
|
|
|
if len(filename) > 254:
|
|
return None
|
|
|
|
stem, ext = os.path.splitext(filename)
|
|
if not stem or len(ext) < 2:
|
|
return None
|
|
|
|
if filename.startswith(".") or re.search(r"[^A-Za-z0-9_\-\. ]", filename):
|
|
return None
|
|
|
|
return filename
|
|
|
|
@classmethod
|
|
def _get_server_filename(cls, response: urllib3.HTTPResponse) -> str:
|
|
# Header format specification:
|
|
# https://datatracker.ietf.org/doc/html/rfc2616#section-19.5.1
|
|
content_disposition = next(
|
|
(
|
|
parameter
|
|
for part in response.headers.get("Content-Disposition", "").split(";")
|
|
if (parameter := part.strip()) and parameter.lower().startswith("filename=")
|
|
),
|
|
None,
|
|
)
|
|
|
|
filename = None
|
|
if content_disposition:
|
|
filename = content_disposition.split("=", maxsplit=1)[1].strip('"')
|
|
filename = cls._validate_filename(filename)
|
|
|
|
if not filename:
|
|
raise CvatSdkException(
|
|
"Can't find the output filename in the server response, "
|
|
"please try to specify the output filename explicitly"
|
|
)
|
|
|
|
return filename
|
|
|
|
def download_file(
|
|
self,
|
|
url: str,
|
|
output_path: Path,
|
|
*,
|
|
timeout: int = 60,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
) -> Path:
|
|
"""
|
|
Downloads the file from url into a temporary file, then renames it to the requested name.
|
|
If output_path is a directory, saves the file into the directory with
|
|
the server-defined name.
|
|
|
|
Returns: path to the downloaded file
|
|
"""
|
|
|
|
CHUNK_SIZE = 10 * 2**20
|
|
|
|
if output_path.is_file():
|
|
raise FileExistsError(output_path)
|
|
|
|
if pbar is None:
|
|
pbar = NullProgressReporter()
|
|
|
|
response = self._client.api_client.rest_client.GET(
|
|
url,
|
|
_request_timeout=timeout,
|
|
headers=self._client.api_client.get_common_headers(),
|
|
_parse_response=False,
|
|
)
|
|
with closing(response):
|
|
try:
|
|
file_size = int(response.headers.get("Content-Length", 0))
|
|
except ValueError:
|
|
file_size = None
|
|
|
|
if output_path.is_dir():
|
|
output_path /= self._get_server_filename(response)
|
|
if output_path.exists():
|
|
raise FileExistsError(output_path)
|
|
|
|
with (
|
|
atomic_writer(output_path, "wb") as fd,
|
|
pbar.task(
|
|
total=file_size,
|
|
desc="Downloading",
|
|
unit_scale=True,
|
|
unit="B",
|
|
unit_divisor=1024,
|
|
),
|
|
):
|
|
while True:
|
|
chunk = response.read(amt=CHUNK_SIZE, decode_content=False)
|
|
if not chunk:
|
|
break
|
|
|
|
pbar.advance(len(chunk))
|
|
fd.write(chunk)
|
|
|
|
return output_path
|
|
|
|
def prepare_file(
|
|
self,
|
|
endpoint: Endpoint,
|
|
*,
|
|
url_params: Optional[dict[str, Any]] = None,
|
|
query_params: Optional[dict[str, Any]] = None,
|
|
status_check_period: Optional[int] = None,
|
|
):
|
|
client = self._client
|
|
if status_check_period is None:
|
|
status_check_period = client.config.status_check_period
|
|
|
|
client.logger.info("Waiting for the server to prepare the file...")
|
|
|
|
url = client.api_map.make_endpoint_url(
|
|
endpoint.path, kwsub=url_params, query_params=query_params
|
|
)
|
|
|
|
# initialize background process
|
|
response = client.api_client.rest_client.request(
|
|
method=endpoint.settings["http_method"],
|
|
url=url,
|
|
headers=client.api_client.get_common_headers(),
|
|
)
|
|
|
|
client.logger.debug("STATUS %s", response.status)
|
|
expect_status(202, response)
|
|
rq_id = json.loads(response.data).get("rq_id")
|
|
assert rq_id, "Request identifier was not found in server response"
|
|
|
|
# wait until background process will be finished or failed
|
|
request, response = client.wait_for_completion(
|
|
rq_id, status_check_period=status_check_period
|
|
)
|
|
|
|
return request
|
|
|
|
def prepare_and_download_file_from_endpoint(
|
|
self,
|
|
endpoint: Endpoint,
|
|
filename: Path,
|
|
*,
|
|
url_params: Optional[dict[str, Any]] = None,
|
|
query_params: Optional[dict[str, Any]] = None,
|
|
pbar: Optional[ProgressReporter] = None,
|
|
status_check_period: Optional[int] = None,
|
|
) -> Path:
|
|
client = self._client
|
|
|
|
if status_check_period is None:
|
|
status_check_period = client.config.status_check_period
|
|
|
|
export_request = self.prepare_file(
|
|
endpoint,
|
|
url_params=url_params,
|
|
query_params=query_params,
|
|
status_check_period=status_check_period,
|
|
)
|
|
|
|
assert export_request.result_url, "Result url was not found in server response"
|
|
return self.download_file(export_request.result_url, output_path=filename, pbar=pbar)
|