cvat/tests/python/cli/util.py

135 lines
4.2 KiB
Python
Raw Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import contextlib
import http.server
import io
import ssl
import threading
import unittest
from collections.abc import Generator
from pathlib import Path
from typing import Any, Optional, Union
import pytest
import requests
from cvat_sdk import make_client
from shared.utils.config import BASE_URL, USER_PASS
from shared.utils.helpers import generate_image_file
def run_cli(test: Union[unittest.TestCase, Any], *args: str, expected_code: int = 0) -> None:
from cvat_cli.__main__ import main
if isinstance(test, unittest.TestCase):
# Unittest
test.assertEqual(expected_code, main(args), str(args))
else:
# Pytest case
assert expected_code == main(args)
def generate_images(dst_dir: Path, count: int) -> list[Path]:
filenames = []
dst_dir.mkdir(parents=True, exist_ok=True)
for i in range(count):
filename = dst_dir / f"img_{i}.jpg"
filename.write_bytes(generate_image_file(filename.name).getvalue())
filenames.append(filename)
return filenames
@contextlib.contextmanager
def https_reverse_proxy() -> Generator[str, None, None]:
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
cert_dir = Path(__file__).parent
ssl_context.load_cert_chain(cert_dir / "self-signed.crt", cert_dir / "self-signed.key")
with http.server.HTTPServer(("localhost", 0), _ProxyHttpRequestHandler) as proxy_server:
proxy_server.socket = ssl_context.wrap_socket(
proxy_server.socket,
server_side=True,
)
server_thread = threading.Thread(target=proxy_server.serve_forever)
server_thread.start()
try:
yield f"https://localhost:{proxy_server.server_port}"
finally:
proxy_server.shutdown()
server_thread.join()
class _ProxyHttpRequestHandler(http.server.BaseHTTPRequestHandler):
def do_GET(self):
response = requests.get(**self._shared_request_args())
self._translate_response(response)
def do_POST(self):
body_length = int(self.headers["Content-Length"])
response = requests.post(data=self.rfile.read(body_length), **self._shared_request_args())
self._translate_response(response)
def _shared_request_args(self) -> dict[str, Any]:
headers = {k.lower(): v for k, v in self.headers.items()}
del headers["host"]
return {"url": BASE_URL + self.path, "headers": headers, "timeout": 60, "stream": True}
def _translate_response(self, response: requests.Response) -> None:
self.send_response(response.status_code)
for key, value in response.headers.items():
self.send_header(key, value)
self.end_headers()
# Need to use raw here to prevent requests from handling Content-Encoding.
self.wfile.write(response.raw.read())
class TestCliBase:
@pytest.fixture(autouse=True)
def setup(
self,
restore_db_per_function, # force fixture call order to allow DB setup
restore_redis_inmem_per_function,
restore_redis_ondisk_per_function,
fxt_stdout: io.StringIO,
tmp_path: Path,
admin_user: str,
):
self.tmp_path = tmp_path
self.stdout = fxt_stdout
self.host, self.port = BASE_URL.rsplit(":", maxsplit=1)
self.user = admin_user
self.password = USER_PASS
self.client = make_client(
host=self.host, port=self.port, credentials=(self.user, self.password)
)
self.client.config.status_check_period = 0.01
yield
def run_cli(
self, cmd: str, *args: str, expected_code: int = 0, organization: Optional[str] = None
) -> str:
common_args = [
f"--auth={self.user}:{self.password}",
f"--server-host={self.host}",
f"--server-port={self.port}",
]
if organization is not None:
common_args.append(f"--organization={organization}")
run_cli(
self,
*common_args,
cmd,
*args,
expected_code=expected_code,
)
return self.stdout.getvalue()