367 lines
12 KiB
Python
367 lines
12 KiB
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import urllib.parse
|
|
from collections.abc import Generator, Sequence
|
|
from contextlib import contextmanager, suppress
|
|
from pathlib import Path
|
|
from time import sleep
|
|
from typing import Any, Optional, TypeVar
|
|
|
|
import attrs
|
|
import packaging.specifiers as specifiers
|
|
import packaging.version as pv
|
|
import platformdirs
|
|
import urllib3
|
|
|
|
from cvat_sdk.api_client import ApiClient, Configuration, exceptions, models
|
|
from cvat_sdk.core.exceptions import (
|
|
BackgroundRequestException,
|
|
IncompatibleVersionException,
|
|
InvalidHostException,
|
|
)
|
|
from cvat_sdk.core.proxies.issues import CommentsRepo, IssuesRepo
|
|
from cvat_sdk.core.proxies.jobs import JobsRepo
|
|
from cvat_sdk.core.proxies.model_proxy import Repo
|
|
from cvat_sdk.core.proxies.organizations import OrganizationsRepo
|
|
from cvat_sdk.core.proxies.projects import ProjectsRepo
|
|
from cvat_sdk.core.proxies.tasks import TasksRepo
|
|
from cvat_sdk.core.proxies.users import UsersRepo
|
|
from cvat_sdk.version import VERSION
|
|
|
|
_DEFAULT_CACHE_DIR = platformdirs.user_cache_path("cvat-sdk", "CVAT.ai")
|
|
|
|
_RepoType = TypeVar("_RepoType", bound=Repo)
|
|
|
|
|
|
@attrs.define
|
|
class Config:
|
|
"""
|
|
Allows to tweak behavior of Client instances.
|
|
"""
|
|
|
|
status_check_period: float = 5
|
|
"""Operation status check period, in seconds"""
|
|
|
|
allow_unsupported_server: bool = True
|
|
"""Allow to use SDK with an unsupported server version. If disabled, raise an exception"""
|
|
|
|
verify_ssl: Optional[bool] = None
|
|
"""Whether to verify host SSL certificate or not"""
|
|
|
|
cache_dir: Path = attrs.field(converter=Path, default=_DEFAULT_CACHE_DIR)
|
|
"""Directory in which to store cached server data"""
|
|
|
|
|
|
_VERSION_OBJ = pv.Version(VERSION)
|
|
|
|
|
|
class Client:
|
|
"""
|
|
Provides session management, implements authentication operations
|
|
and simplifies access to server APIs.
|
|
"""
|
|
|
|
SUPPORTED_SERVER_VERSIONS = (
|
|
pv.Version(f"{_VERSION_OBJ.epoch}!{_VERSION_OBJ.major}.{_VERSION_OBJ.minor}"),
|
|
pv.Version(f"{_VERSION_OBJ.epoch}!{_VERSION_OBJ.major}.{_VERSION_OBJ.minor+1}"),
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
url: str,
|
|
*,
|
|
logger: Optional[logging.Logger] = None,
|
|
config: Optional[Config] = None,
|
|
check_server_version: bool = True,
|
|
) -> None:
|
|
self.logger = logger or logging.getLogger(__name__)
|
|
"""The root logger"""
|
|
|
|
url = self._validate_and_prepare_url(url)
|
|
|
|
self.config = config or Config()
|
|
"""Configuration for this object"""
|
|
|
|
self.api_map = CVAT_API_V2(url)
|
|
"""Handles server API URL interaction logic"""
|
|
|
|
self.api_client = ApiClient(
|
|
Configuration(host=self.api_map.host, verify_ssl=self.config.verify_ssl)
|
|
)
|
|
"""Provides low-level access to the CVAT server"""
|
|
|
|
if check_server_version:
|
|
self.check_server_version()
|
|
|
|
self._repos: dict[str, Repo] = {}
|
|
"""A cache for created Repository instances"""
|
|
|
|
_ORG_SLUG_HEADER = "X-Organization"
|
|
|
|
@property
|
|
def organization_slug(self) -> Optional[str]:
|
|
"""
|
|
If this is set to a slug for an organization,
|
|
all requests will be made in the context of that organization.
|
|
|
|
If it's set to an empty string, requests will be made in the context
|
|
of the user's personal workspace.
|
|
|
|
If set to None (the default), no organization context will be used.
|
|
"""
|
|
return self.api_client.default_headers.get(self._ORG_SLUG_HEADER)
|
|
|
|
@organization_slug.setter
|
|
def organization_slug(self, org_slug: Optional[str]):
|
|
if org_slug is None:
|
|
self.api_client.default_headers.pop(self._ORG_SLUG_HEADER, None)
|
|
else:
|
|
self.api_client.default_headers[self._ORG_SLUG_HEADER] = org_slug
|
|
|
|
@contextmanager
|
|
def organization_context(self, slug: str) -> Generator[None, None, None]:
|
|
prev_slug = self.organization_slug
|
|
self.organization_slug = slug
|
|
try:
|
|
yield
|
|
finally:
|
|
self.organization_slug = prev_slug
|
|
|
|
ALLOWED_SCHEMAS = ("https", "http")
|
|
|
|
def _validate_and_prepare_url(self, url: str) -> str:
|
|
url_parts = url.split("://", maxsplit=1)
|
|
if len(url_parts) == 2:
|
|
schema, base_url = url_parts
|
|
else:
|
|
schema = ""
|
|
base_url = url
|
|
|
|
base_url = base_url.rstrip("/")
|
|
|
|
if schema and schema not in self.ALLOWED_SCHEMAS:
|
|
raise InvalidHostException(
|
|
f"Invalid url schema '{schema}', expected "
|
|
f"one of <none>, {', '.join(self.ALLOWED_SCHEMAS)}"
|
|
)
|
|
|
|
if not schema:
|
|
schema = self._detect_schema(base_url)
|
|
url = f"{schema}://{base_url}"
|
|
|
|
return url
|
|
|
|
def _detect_schema(self, base_url: str) -> str:
|
|
def attempt(schema: str) -> bool:
|
|
with ApiClient(Configuration(host=f"{schema}://{base_url}")) as api_client:
|
|
with suppress(urllib3.exceptions.RequestError):
|
|
(_, response) = api_client.server_api.retrieve_about(
|
|
_request_timeout=5, _parse_response=False, _check_status=False
|
|
)
|
|
|
|
if response.status in [200, 401]:
|
|
# Server versions prior to 2.3.0 respond with unauthorized
|
|
# 2.3.0 allows unauthorized access
|
|
return True
|
|
return False
|
|
|
|
if attempt("https"):
|
|
return "https"
|
|
|
|
self.logger.warning(
|
|
"Failed to connect to the server using HTTPS; will attempt HTTP instead"
|
|
)
|
|
self.logger.warning(
|
|
"This fallback will be removed in a future version of the SDK;"
|
|
" to avoid breakage, explicitly add 'https://' or 'http://' to the URL"
|
|
)
|
|
|
|
if attempt("http"):
|
|
return "http"
|
|
|
|
raise InvalidHostException(
|
|
"Failed to detect host schema automatically, please check "
|
|
"the server url and try to specify 'https://' or 'http://' explicitly"
|
|
)
|
|
|
|
def __enter__(self):
|
|
self.api_client.__enter__()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback) -> None:
|
|
return self.api_client.__exit__(exc_type, exc_value, traceback)
|
|
|
|
def close(self) -> None:
|
|
return self.__exit__(None, None, None)
|
|
|
|
def login(self, credentials: tuple[str, str]) -> None:
|
|
self.api_client.auth_api.create_login(
|
|
models.LoginSerializerExRequest(username=credentials[0], password=credentials[1])
|
|
)
|
|
assert "sessionid" in self.api_client.cookies
|
|
assert "csrftoken" in self.api_client.cookies
|
|
self.api_client.set_default_header("Origin", self.api_client.build_origin_header())
|
|
self.api_client.set_default_header(
|
|
"X-CSRFToken", self.api_client.cookies["csrftoken"].value
|
|
)
|
|
|
|
def has_credentials(self) -> bool:
|
|
return ("sessionid" in self.api_client.cookies) or ("csrftoken" in self.api_client.cookies)
|
|
|
|
def logout(self) -> None:
|
|
if self.has_credentials():
|
|
self.api_client.auth_api.create_logout()
|
|
self.api_client.cookies.pop("sessionid", None)
|
|
self.api_client.cookies.pop("csrftoken", None)
|
|
self.api_client.default_headers.pop("Origin", None)
|
|
self.api_client.default_headers.pop("X-CSRFToken", None)
|
|
|
|
def wait_for_completion(
|
|
self: Client,
|
|
rq_id: str,
|
|
*,
|
|
status_check_period: Optional[int] = None,
|
|
log_prefix: Optional[str] = None,
|
|
) -> tuple[models.Request, urllib3.HTTPResponse]:
|
|
if status_check_period is None:
|
|
status_check_period = self.config.status_check_period
|
|
|
|
while True:
|
|
request, response = self.api_client.requests_api.retrieve(rq_id)
|
|
status, message = request.status, request.message
|
|
|
|
log_prefix = log_prefix or f"{request.operation.type} operation"
|
|
self.logger.info(
|
|
"%s status: %s (message=%s)",
|
|
log_prefix,
|
|
status,
|
|
message,
|
|
)
|
|
if status.value == models.RequestStatus.allowed_values[("value",)]["FINISHED"]:
|
|
break
|
|
elif status.value == models.RequestStatus.allowed_values[("value",)]["FAILED"]:
|
|
raise BackgroundRequestException(message)
|
|
|
|
sleep(status_check_period)
|
|
|
|
return request, response
|
|
|
|
def check_server_version(self, fail_if_unsupported: Optional[bool] = None) -> None:
|
|
if fail_if_unsupported is None:
|
|
fail_if_unsupported = not self.config.allow_unsupported_server
|
|
|
|
try:
|
|
server_version = self.get_server_version()
|
|
except exceptions.ApiException as e:
|
|
msg = (
|
|
"Failed to retrieve server API version: %s. "
|
|
"Some SDK functions may not work properly with this server."
|
|
) % (e,)
|
|
self.logger.warning(msg)
|
|
if fail_if_unsupported:
|
|
raise IncompatibleVersionException(msg)
|
|
return
|
|
|
|
if not any(
|
|
self._is_version_compatible(server_version, supported_version)
|
|
for supported_version in self.SUPPORTED_SERVER_VERSIONS
|
|
):
|
|
msg = (
|
|
"Server version '%s' is not compatible with SDK version '%s'. "
|
|
"Some SDK functions may not work properly with this server. "
|
|
"You can continue using this SDK, or you can "
|
|
"try to update with 'pip install cvat-sdk'."
|
|
) % (server_version, pv.Version(VERSION))
|
|
self.logger.warning(msg)
|
|
if fail_if_unsupported:
|
|
raise IncompatibleVersionException(msg)
|
|
|
|
def _is_version_compatible(self, current: pv.Version, target: pv.Version) -> bool:
|
|
# Check for (major, minor) compatibility.
|
|
# Micro releases and fixes do not affect API compatibility in general.
|
|
epoch = f"{target.epoch}!" if target.epoch else "" # 1.0 ~= 0!1.0 is false
|
|
return current in specifiers.Specifier(
|
|
f"~= {epoch}{target.major}.{target.minor}.{target.micro}"
|
|
)
|
|
|
|
def get_server_version(self) -> pv.Version:
|
|
(about, _) = self.api_client.server_api.retrieve_about()
|
|
return pv.Version(about.version)
|
|
|
|
def _get_repo(self, repo_type: _RepoType) -> _RepoType:
|
|
repo = self._repos.get(repo_type, None)
|
|
if repo is None:
|
|
repo = repo_type(self)
|
|
self._repos[repo_type] = repo
|
|
return repo
|
|
|
|
@property
|
|
def tasks(self) -> TasksRepo:
|
|
return self._get_repo(TasksRepo)
|
|
|
|
@property
|
|
def projects(self) -> ProjectsRepo:
|
|
return self._get_repo(ProjectsRepo)
|
|
|
|
@property
|
|
def jobs(self) -> JobsRepo:
|
|
return self._get_repo(JobsRepo)
|
|
|
|
@property
|
|
def users(self) -> UsersRepo:
|
|
return self._get_repo(UsersRepo)
|
|
|
|
@property
|
|
def organizations(self) -> OrganizationsRepo:
|
|
return self._get_repo(OrganizationsRepo)
|
|
|
|
@property
|
|
def issues(self) -> IssuesRepo:
|
|
return self._get_repo(IssuesRepo)
|
|
|
|
@property
|
|
def comments(self) -> CommentsRepo:
|
|
return self._get_repo(CommentsRepo)
|
|
|
|
|
|
class CVAT_API_V2:
|
|
"""Build parameterized API URLs"""
|
|
|
|
def __init__(self, host: str):
|
|
self.host = host.rstrip("/")
|
|
self.base = self.host + "/api/"
|
|
|
|
def make_endpoint_url(
|
|
self,
|
|
path: str,
|
|
*,
|
|
psub: Optional[Sequence[Any]] = None,
|
|
kwsub: Optional[dict[str, Any]] = None,
|
|
query_params: Optional[dict[str, Any]] = None,
|
|
) -> str:
|
|
url = self.host + path
|
|
if psub or kwsub:
|
|
url = url.format(*(psub or []), **(kwsub or {}))
|
|
if query_params:
|
|
url += "?" + urllib.parse.urlencode(query_params)
|
|
return url
|
|
|
|
|
|
def make_client(
|
|
host: str, *, port: Optional[int] = None, credentials: Optional[tuple[str, str]] = None
|
|
) -> Client:
|
|
url = host.rstrip("/")
|
|
if port:
|
|
url = f"{url}:{port}"
|
|
|
|
client = Client(url=url)
|
|
if credentials is not None:
|
|
client.login(credentials)
|
|
return client
|