cvat/cvat-sdk/cvat_sdk/core/proxies/jobs.py

161 lines
4.7 KiB
Python
Raw Permalink Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from __future__ import annotations
import io
import mimetypes
from collections.abc import Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from PIL import Image
from cvat_sdk.api_client import apis, models
from cvat_sdk.core.helpers import get_paginated_collection
from cvat_sdk.core.progress import ProgressReporter
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
from cvat_sdk.core.proxies.issues import Issue
from cvat_sdk.core.proxies.model_proxy import (
ExportDatasetMixin,
ModelListMixin,
ModelRetrieveMixin,
ModelUpdateMixin,
build_model_bases,
)
from cvat_sdk.core.uploading import AnnotationUploader
if TYPE_CHECKING:
from _typeshed import StrPath
_JobEntityBase, _JobRepoBase = build_model_bases(
models.JobRead, apis.JobsApi, api_member_name="jobs_api"
)
class Job(
models.IJobRead,
_JobEntityBase,
ModelUpdateMixin[models.IPatchedJobWriteRequest],
AnnotationCrudMixin,
ExportDatasetMixin,
):
_model_partial_update_arg = "patched_job_write_request"
def import_annotations(
self,
format_name: str,
filename: StrPath,
*,
conv_mask_to_poly: Optional[bool] = None,
status_check_period: Optional[int] = None,
pbar: Optional[ProgressReporter] = None,
):
"""
Upload annotations for a job in the specified format (e.g. 'YOLO 1.1').
"""
filename = Path(filename)
AnnotationUploader(self._client).upload_file_and_wait(
self.api.create_annotations_endpoint,
filename,
format_name,
conv_mask_to_poly=conv_mask_to_poly,
url_params={"id": self.id},
pbar=pbar,
status_check_period=status_check_period,
)
self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded")
def get_frame(
self,
frame_id: int,
*,
quality: Optional[str] = None,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_data(
self.id, number=frame_id, quality=quality, type="frame"
)
return io.BytesIO(response.data)
def get_preview(
self,
) -> io.RawIOBase:
(_, response) = self.api.retrieve_preview(self.id)
return io.BytesIO(response.data)
def download_frames(
self,
frame_ids: Sequence[int],
*,
image_extension: Optional[str] = None,
outdir: StrPath = ".",
quality: str = "original",
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
) -> Optional[list[Image.Image]]:
"""
Download the requested frame numbers for a job and save images as outdir/filename_pattern
"""
# TODO: add arg descriptions in schema
outdir = Path(outdir)
outdir.mkdir(parents=True, exist_ok=True)
for frame_id in frame_ids:
frame_bytes = self.get_frame(frame_id, quality=quality)
im = Image.open(frame_bytes)
if image_extension is None:
mime_type = im.get_format_mimetype() or "image/jpg"
im_ext = mimetypes.guess_extension(mime_type)
# FIXME It is better to use meta information from the server
# to determine the extension
# replace '.jpe' or '.jpeg' with a more used '.jpg'
if im_ext in (".jpe", ".jpeg", None):
im_ext = ".jpg"
else:
im_ext = f".{image_extension.strip('.')}"
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
im.save(outdir / outfile)
def get_meta(self) -> models.IDataMetaRead:
(meta, _) = self.api.retrieve_data_meta(self.id)
return meta
def get_labels(self) -> list[models.ILabel]:
return get_paginated_collection(
self._client.api_client.labels_api.list_endpoint, job_id=self.id
)
def get_frames_info(self) -> list[models.IFrameMeta]:
return self.get_meta().frames
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
self.api.partial_update_data_meta(
self.id,
patched_job_data_meta_write_request=models.PatchedJobDataMetaWriteRequest(
deleted_frames=ids
),
)
def get_issues(self) -> list[Issue]:
return [
Issue(self._client, m)
for m in get_paginated_collection(
self._client.api_client.issues_api.list_endpoint, job_id=self.id
)
]
class JobsRepo(
_JobRepoBase,
ModelListMixin[Job],
ModelRetrieveMixin[Job],
):
_entity_type = Job