135 lines
4.2 KiB
Python
135 lines
4.2 KiB
Python
|
|
# Copyright (C) CVAT.ai Corporation
|
||
|
|
#
|
||
|
|
# SPDX-License-Identifier: MIT
|
||
|
|
|
||
|
|
|
||
|
|
import contextlib
|
||
|
|
import http.server
|
||
|
|
import io
|
||
|
|
import ssl
|
||
|
|
import threading
|
||
|
|
import unittest
|
||
|
|
from collections.abc import Generator
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import Any, Optional, Union
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import requests
|
||
|
|
from cvat_sdk import make_client
|
||
|
|
|
||
|
|
from shared.utils.config import BASE_URL, USER_PASS
|
||
|
|
from shared.utils.helpers import generate_image_file
|
||
|
|
|
||
|
|
|
||
|
|
def run_cli(test: Union[unittest.TestCase, Any], *args: str, expected_code: int = 0) -> None:
|
||
|
|
from cvat_cli.__main__ import main
|
||
|
|
|
||
|
|
if isinstance(test, unittest.TestCase):
|
||
|
|
# Unittest
|
||
|
|
test.assertEqual(expected_code, main(args), str(args))
|
||
|
|
else:
|
||
|
|
# Pytest case
|
||
|
|
assert expected_code == main(args)
|
||
|
|
|
||
|
|
|
||
|
|
def generate_images(dst_dir: Path, count: int) -> list[Path]:
|
||
|
|
filenames = []
|
||
|
|
dst_dir.mkdir(parents=True, exist_ok=True)
|
||
|
|
for i in range(count):
|
||
|
|
filename = dst_dir / f"img_{i}.jpg"
|
||
|
|
filename.write_bytes(generate_image_file(filename.name).getvalue())
|
||
|
|
filenames.append(filename)
|
||
|
|
return filenames
|
||
|
|
|
||
|
|
|
||
|
|
@contextlib.contextmanager
|
||
|
|
def https_reverse_proxy() -> Generator[str, None, None]:
|
||
|
|
ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
||
|
|
ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
||
|
|
cert_dir = Path(__file__).parent
|
||
|
|
ssl_context.load_cert_chain(cert_dir / "self-signed.crt", cert_dir / "self-signed.key")
|
||
|
|
|
||
|
|
with http.server.HTTPServer(("localhost", 0), _ProxyHttpRequestHandler) as proxy_server:
|
||
|
|
proxy_server.socket = ssl_context.wrap_socket(
|
||
|
|
proxy_server.socket,
|
||
|
|
server_side=True,
|
||
|
|
)
|
||
|
|
server_thread = threading.Thread(target=proxy_server.serve_forever)
|
||
|
|
server_thread.start()
|
||
|
|
try:
|
||
|
|
yield f"https://localhost:{proxy_server.server_port}"
|
||
|
|
finally:
|
||
|
|
proxy_server.shutdown()
|
||
|
|
server_thread.join()
|
||
|
|
|
||
|
|
|
||
|
|
class _ProxyHttpRequestHandler(http.server.BaseHTTPRequestHandler):
|
||
|
|
def do_GET(self):
|
||
|
|
response = requests.get(**self._shared_request_args())
|
||
|
|
self._translate_response(response)
|
||
|
|
|
||
|
|
def do_POST(self):
|
||
|
|
body_length = int(self.headers["Content-Length"])
|
||
|
|
|
||
|
|
response = requests.post(data=self.rfile.read(body_length), **self._shared_request_args())
|
||
|
|
self._translate_response(response)
|
||
|
|
|
||
|
|
def _shared_request_args(self) -> dict[str, Any]:
|
||
|
|
headers = {k.lower(): v for k, v in self.headers.items()}
|
||
|
|
del headers["host"]
|
||
|
|
|
||
|
|
return {"url": BASE_URL + self.path, "headers": headers, "timeout": 60, "stream": True}
|
||
|
|
|
||
|
|
def _translate_response(self, response: requests.Response) -> None:
|
||
|
|
self.send_response(response.status_code)
|
||
|
|
for key, value in response.headers.items():
|
||
|
|
self.send_header(key, value)
|
||
|
|
self.end_headers()
|
||
|
|
# Need to use raw here to prevent requests from handling Content-Encoding.
|
||
|
|
self.wfile.write(response.raw.read())
|
||
|
|
|
||
|
|
|
||
|
|
class TestCliBase:
|
||
|
|
@pytest.fixture(autouse=True)
|
||
|
|
def setup(
|
||
|
|
self,
|
||
|
|
restore_db_per_function, # force fixture call order to allow DB setup
|
||
|
|
restore_redis_inmem_per_function,
|
||
|
|
restore_redis_ondisk_per_function,
|
||
|
|
fxt_stdout: io.StringIO,
|
||
|
|
tmp_path: Path,
|
||
|
|
admin_user: str,
|
||
|
|
):
|
||
|
|
self.tmp_path = tmp_path
|
||
|
|
self.stdout = fxt_stdout
|
||
|
|
self.host, self.port = BASE_URL.rsplit(":", maxsplit=1)
|
||
|
|
self.user = admin_user
|
||
|
|
self.password = USER_PASS
|
||
|
|
self.client = make_client(
|
||
|
|
host=self.host, port=self.port, credentials=(self.user, self.password)
|
||
|
|
)
|
||
|
|
self.client.config.status_check_period = 0.01
|
||
|
|
|
||
|
|
yield
|
||
|
|
|
||
|
|
def run_cli(
|
||
|
|
self, cmd: str, *args: str, expected_code: int = 0, organization: Optional[str] = None
|
||
|
|
) -> str:
|
||
|
|
common_args = [
|
||
|
|
f"--auth={self.user}:{self.password}",
|
||
|
|
f"--server-host={self.host}",
|
||
|
|
f"--server-port={self.port}",
|
||
|
|
]
|
||
|
|
|
||
|
|
if organization is not None:
|
||
|
|
common_args.append(f"--organization={organization}")
|
||
|
|
|
||
|
|
run_cli(
|
||
|
|
self,
|
||
|
|
*common_args,
|
||
|
|
cmd,
|
||
|
|
*args,
|
||
|
|
expected_code=expected_code,
|
||
|
|
)
|
||
|
|
return self.stdout.getvalue()
|