161 lines
4.7 KiB
Python
161 lines
4.7 KiB
Python
|
|
# Copyright (C) CVAT.ai Corporation
|
||
|
|
#
|
||
|
|
# SPDX-License-Identifier: MIT
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import io
|
||
|
|
import mimetypes
|
||
|
|
from collections.abc import Sequence
|
||
|
|
from pathlib import Path
|
||
|
|
from typing import TYPE_CHECKING, Optional
|
||
|
|
|
||
|
|
from PIL import Image
|
||
|
|
|
||
|
|
from cvat_sdk.api_client import apis, models
|
||
|
|
from cvat_sdk.core.helpers import get_paginated_collection
|
||
|
|
from cvat_sdk.core.progress import ProgressReporter
|
||
|
|
from cvat_sdk.core.proxies.annotations import AnnotationCrudMixin
|
||
|
|
from cvat_sdk.core.proxies.issues import Issue
|
||
|
|
from cvat_sdk.core.proxies.model_proxy import (
|
||
|
|
ExportDatasetMixin,
|
||
|
|
ModelListMixin,
|
||
|
|
ModelRetrieveMixin,
|
||
|
|
ModelUpdateMixin,
|
||
|
|
build_model_bases,
|
||
|
|
)
|
||
|
|
from cvat_sdk.core.uploading import AnnotationUploader
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from _typeshed import StrPath
|
||
|
|
|
||
|
|
_JobEntityBase, _JobRepoBase = build_model_bases(
|
||
|
|
models.JobRead, apis.JobsApi, api_member_name="jobs_api"
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class Job(
|
||
|
|
models.IJobRead,
|
||
|
|
_JobEntityBase,
|
||
|
|
ModelUpdateMixin[models.IPatchedJobWriteRequest],
|
||
|
|
AnnotationCrudMixin,
|
||
|
|
ExportDatasetMixin,
|
||
|
|
):
|
||
|
|
_model_partial_update_arg = "patched_job_write_request"
|
||
|
|
|
||
|
|
def import_annotations(
|
||
|
|
self,
|
||
|
|
format_name: str,
|
||
|
|
filename: StrPath,
|
||
|
|
*,
|
||
|
|
conv_mask_to_poly: Optional[bool] = None,
|
||
|
|
status_check_period: Optional[int] = None,
|
||
|
|
pbar: Optional[ProgressReporter] = None,
|
||
|
|
):
|
||
|
|
"""
|
||
|
|
Upload annotations for a job in the specified format (e.g. 'YOLO 1.1').
|
||
|
|
"""
|
||
|
|
|
||
|
|
filename = Path(filename)
|
||
|
|
|
||
|
|
AnnotationUploader(self._client).upload_file_and_wait(
|
||
|
|
self.api.create_annotations_endpoint,
|
||
|
|
filename,
|
||
|
|
format_name,
|
||
|
|
conv_mask_to_poly=conv_mask_to_poly,
|
||
|
|
url_params={"id": self.id},
|
||
|
|
pbar=pbar,
|
||
|
|
status_check_period=status_check_period,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._client.logger.info(f"Annotation file '{filename}' for job #{self.id} uploaded")
|
||
|
|
|
||
|
|
def get_frame(
|
||
|
|
self,
|
||
|
|
frame_id: int,
|
||
|
|
*,
|
||
|
|
quality: Optional[str] = None,
|
||
|
|
) -> io.RawIOBase:
|
||
|
|
(_, response) = self.api.retrieve_data(
|
||
|
|
self.id, number=frame_id, quality=quality, type="frame"
|
||
|
|
)
|
||
|
|
return io.BytesIO(response.data)
|
||
|
|
|
||
|
|
def get_preview(
|
||
|
|
self,
|
||
|
|
) -> io.RawIOBase:
|
||
|
|
(_, response) = self.api.retrieve_preview(self.id)
|
||
|
|
return io.BytesIO(response.data)
|
||
|
|
|
||
|
|
def download_frames(
|
||
|
|
self,
|
||
|
|
frame_ids: Sequence[int],
|
||
|
|
*,
|
||
|
|
image_extension: Optional[str] = None,
|
||
|
|
outdir: StrPath = ".",
|
||
|
|
quality: str = "original",
|
||
|
|
filename_pattern: str = "frame_{frame_id:06d}{frame_ext}",
|
||
|
|
) -> Optional[list[Image.Image]]:
|
||
|
|
"""
|
||
|
|
Download the requested frame numbers for a job and save images as outdir/filename_pattern
|
||
|
|
"""
|
||
|
|
# TODO: add arg descriptions in schema
|
||
|
|
|
||
|
|
outdir = Path(outdir)
|
||
|
|
outdir.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
for frame_id in frame_ids:
|
||
|
|
frame_bytes = self.get_frame(frame_id, quality=quality)
|
||
|
|
|
||
|
|
im = Image.open(frame_bytes)
|
||
|
|
if image_extension is None:
|
||
|
|
mime_type = im.get_format_mimetype() or "image/jpg"
|
||
|
|
im_ext = mimetypes.guess_extension(mime_type)
|
||
|
|
|
||
|
|
# FIXME It is better to use meta information from the server
|
||
|
|
# to determine the extension
|
||
|
|
# replace '.jpe' or '.jpeg' with a more used '.jpg'
|
||
|
|
if im_ext in (".jpe", ".jpeg", None):
|
||
|
|
im_ext = ".jpg"
|
||
|
|
else:
|
||
|
|
im_ext = f".{image_extension.strip('.')}"
|
||
|
|
|
||
|
|
outfile = filename_pattern.format(frame_id=frame_id, frame_ext=im_ext)
|
||
|
|
im.save(outdir / outfile)
|
||
|
|
|
||
|
|
def get_meta(self) -> models.IDataMetaRead:
|
||
|
|
(meta, _) = self.api.retrieve_data_meta(self.id)
|
||
|
|
return meta
|
||
|
|
|
||
|
|
def get_labels(self) -> list[models.ILabel]:
|
||
|
|
return get_paginated_collection(
|
||
|
|
self._client.api_client.labels_api.list_endpoint, job_id=self.id
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_frames_info(self) -> list[models.IFrameMeta]:
|
||
|
|
return self.get_meta().frames
|
||
|
|
|
||
|
|
def remove_frames_by_ids(self, ids: Sequence[int]) -> None:
|
||
|
|
self.api.partial_update_data_meta(
|
||
|
|
self.id,
|
||
|
|
patched_job_data_meta_write_request=models.PatchedJobDataMetaWriteRequest(
|
||
|
|
deleted_frames=ids
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_issues(self) -> list[Issue]:
|
||
|
|
return [
|
||
|
|
Issue(self._client, m)
|
||
|
|
for m in get_paginated_collection(
|
||
|
|
self._client.api_client.issues_api.list_endpoint, job_id=self.id
|
||
|
|
)
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
class JobsRepo(
|
||
|
|
_JobRepoBase,
|
||
|
|
ModelListMixin[Job],
|
||
|
|
ModelRetrieveMixin[Job],
|
||
|
|
):
|
||
|
|
_entity_type = Job
|