632 lines
22 KiB
Python
632 lines
22 KiB
Python
|
|
# Copyright (C) CVAT.ai Corporation
|
||
|
|
#
|
||
|
|
# SPDX-License-Identifier: MIT
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from collections.abc import Mapping, Sequence
|
||
|
|
from typing import Callable, Optional, Union, cast
|
||
|
|
|
||
|
|
import attrs
|
||
|
|
from typing_extensions import TypeAlias
|
||
|
|
|
||
|
|
import cvat_sdk.models as models
|
||
|
|
from cvat_sdk.core import Client
|
||
|
|
from cvat_sdk.core.progress import NullProgressReporter, ProgressReporter
|
||
|
|
from cvat_sdk.datasets.task_dataset import TaskDataset
|
||
|
|
|
||
|
|
from ..attributes import attribute_value_validator
|
||
|
|
from .exceptions import BadFunctionError
|
||
|
|
from .interface import (
|
||
|
|
DetectionAnnotation,
|
||
|
|
DetectionFunction,
|
||
|
|
DetectionFunctionContext,
|
||
|
|
DetectionFunctionSpec,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@attrs.frozen
|
||
|
|
class _AttributeNameMapping:
|
||
|
|
name: str
|
||
|
|
|
||
|
|
|
||
|
|
@attrs.frozen
|
||
|
|
class _SublabelNameMapping:
|
||
|
|
name: str
|
||
|
|
attributes: Optional[Mapping[str, _AttributeNameMapping]] = attrs.field(
|
||
|
|
kw_only=True, default=None
|
||
|
|
)
|
||
|
|
|
||
|
|
def map_attribute(self, name: str) -> Optional[_AttributeNameMapping]:
|
||
|
|
if self.attributes is None:
|
||
|
|
return _AttributeNameMapping(name)
|
||
|
|
|
||
|
|
return self.attributes.get(name)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_api(cls, raw: models.SublabelMappingEntryRequest, /) -> _SublabelNameMapping:
|
||
|
|
return _SublabelNameMapping(
|
||
|
|
name=raw.name,
|
||
|
|
attributes=(
|
||
|
|
{k: _AttributeNameMapping(v) for k, v in raw.attributes.items()}
|
||
|
|
if hasattr(raw, "attributes")
|
||
|
|
else None
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@attrs.frozen
|
||
|
|
class _LabelNameMapping(_SublabelNameMapping):
|
||
|
|
sublabels: Optional[Mapping[str, _SublabelNameMapping]] = attrs.field(
|
||
|
|
kw_only=True, default=None
|
||
|
|
)
|
||
|
|
|
||
|
|
def map_sublabel(self, name: str) -> Optional[_SublabelNameMapping]:
|
||
|
|
if self.sublabels is None:
|
||
|
|
return _SublabelNameMapping(name)
|
||
|
|
|
||
|
|
return self.sublabels.get(name)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_api(cls, raw: models.LabelMappingEntryRequest, /) -> _LabelNameMapping:
|
||
|
|
return _LabelNameMapping(
|
||
|
|
**attrs.asdict(_SublabelNameMapping.from_api(raw), recurse=False),
|
||
|
|
sublabels=(
|
||
|
|
{k: _SublabelNameMapping.from_api(v) for k, v in raw.sublabels.items()}
|
||
|
|
if hasattr(raw, "sublabels")
|
||
|
|
else None
|
||
|
|
),
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@attrs.frozen
|
||
|
|
class _SpecNameMapping:
|
||
|
|
labels: Optional[Mapping[str, _LabelNameMapping]] = attrs.field(kw_only=True, default=None)
|
||
|
|
|
||
|
|
def map_label(self, name: str) -> Optional[_LabelNameMapping]:
|
||
|
|
if self.labels is None:
|
||
|
|
return _LabelNameMapping(name)
|
||
|
|
|
||
|
|
return self.labels.get(name)
|
||
|
|
|
||
|
|
@classmethod
|
||
|
|
def from_api(cls, raw: dict[str, models.LabelMappingEntryRequest], /) -> _SpecNameMapping:
|
||
|
|
return cls(labels={k: _LabelNameMapping.from_api(v) for k, v in raw.items()})
|
||
|
|
|
||
|
|
|
||
|
|
class _AnnotationMapper:
|
||
|
|
@attrs.frozen
|
||
|
|
class _AttributeIdMapping:
|
||
|
|
id: int
|
||
|
|
value_validator: Callable[[str], bool]
|
||
|
|
|
||
|
|
@attrs.frozen
|
||
|
|
class _SublabelIdMapping:
|
||
|
|
id: int
|
||
|
|
attributes: Mapping[int, Optional[_AnnotationMapper._AttributeIdMapping]]
|
||
|
|
|
||
|
|
@attrs.frozen
|
||
|
|
class _LabelIdMapping(_SublabelIdMapping):
|
||
|
|
sublabels: Mapping[int, Optional[_AnnotationMapper._SublabelIdMapping]]
|
||
|
|
expected_num_elements: int
|
||
|
|
expected_type: str
|
||
|
|
|
||
|
|
_SpecIdMapping: TypeAlias = Mapping[int, Optional[_LabelIdMapping]]
|
||
|
|
|
||
|
|
_spec_id_mapping: _SpecIdMapping
|
||
|
|
|
||
|
|
def _get_expected_function_output_type(self, fun_label, ds_label):
|
||
|
|
fun_output_type = getattr(fun_label, "type", "any")
|
||
|
|
if fun_output_type == "any":
|
||
|
|
return ds_label.type
|
||
|
|
|
||
|
|
if self._conv_mask_to_poly and fun_output_type == "mask":
|
||
|
|
fun_output_type = "polygon"
|
||
|
|
|
||
|
|
if not self._are_label_types_compatible(fun_output_type, ds_label.type):
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"label {fun_label.name!r} has type {fun_output_type!r} in the function,"
|
||
|
|
f" but {ds_label.type!r} in the dataset"
|
||
|
|
)
|
||
|
|
return fun_output_type
|
||
|
|
|
||
|
|
def _build_attribute_id_mapping(
|
||
|
|
self, fun_attr: models.IAttribute, ds_attr: models.IAttribute, attr_desc: str
|
||
|
|
) -> _AttributeIdMapping:
|
||
|
|
# We could potentially be more lax with these checks. For example, we could permit
|
||
|
|
# fun_attr.values to be a subset of ds_attr.values. For simplicity though,
|
||
|
|
# we'll just use exact comparisons for now.
|
||
|
|
if ds_attr.input_type != fun_attr.input_type:
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"{attr_desc} has input type {fun_attr.input_type!r} in the function,"
|
||
|
|
f" but {ds_attr.input_type!r} in the dataset"
|
||
|
|
)
|
||
|
|
|
||
|
|
if ds_attr.input_type.value in {"text", "checkbox"}:
|
||
|
|
values_match = True
|
||
|
|
elif ds_attr.input_type.value in {"select", "radio"}:
|
||
|
|
values_match = sorted(ds_attr.values) == sorted(fun_attr.values)
|
||
|
|
else:
|
||
|
|
values_match = ds_attr.values == fun_attr.values
|
||
|
|
|
||
|
|
if not values_match:
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"{attr_desc} has values {fun_attr.values!r} in the function,"
|
||
|
|
f" but {ds_attr.values!r} in the dataset"
|
||
|
|
)
|
||
|
|
|
||
|
|
return self._AttributeIdMapping(
|
||
|
|
id=ds_attr.id,
|
||
|
|
value_validator=attribute_value_validator(fun_attr),
|
||
|
|
)
|
||
|
|
|
||
|
|
def _build_sublabel_id_mapping(
|
||
|
|
self,
|
||
|
|
fun_sl: models.ISublabel,
|
||
|
|
ds_sl: models.ISublabel,
|
||
|
|
sl_desc: str,
|
||
|
|
*,
|
||
|
|
sl_nm: _SublabelNameMapping,
|
||
|
|
allow_unmatched_labels: bool,
|
||
|
|
) -> _SublabelIdMapping:
|
||
|
|
ds_attrs_by_name = {ds_attr.name: ds_attr for ds_attr in ds_sl.attributes}
|
||
|
|
|
||
|
|
def attribute_mapping(
|
||
|
|
fun_attr: models.IAttribute,
|
||
|
|
) -> Optional[_AnnotationMapper._AttributeIdMapping]:
|
||
|
|
attr_desc = f"attribute {fun_attr.name!r} of {sl_desc}"
|
||
|
|
|
||
|
|
attr_nm = sl_nm.map_attribute(fun_attr.name)
|
||
|
|
if attr_nm is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
ds_attr = ds_attrs_by_name.get(attr_nm.name)
|
||
|
|
if not ds_attr:
|
||
|
|
if not allow_unmatched_labels:
|
||
|
|
raise BadFunctionError(f"{attr_desc} is not in dataset")
|
||
|
|
|
||
|
|
self._logger.info(
|
||
|
|
"%s is not in dataset; any annotations using it will be ignored", attr_desc
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
return self._build_attribute_id_mapping(fun_attr, ds_attr, attr_desc)
|
||
|
|
|
||
|
|
return self._SublabelIdMapping(
|
||
|
|
ds_sl.id,
|
||
|
|
attributes={
|
||
|
|
attr.id: attribute_mapping(attr) for attr in getattr(fun_sl, "attributes", [])
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
def _build_label_id_mapping(
|
||
|
|
self,
|
||
|
|
fun_label: models.ILabel,
|
||
|
|
ds_label: models.ILabel,
|
||
|
|
label_desc: str,
|
||
|
|
*,
|
||
|
|
label_nm: _LabelNameMapping,
|
||
|
|
allow_unmatched_labels: bool,
|
||
|
|
) -> _LabelIdMapping:
|
||
|
|
base_mapping = self._build_sublabel_id_mapping(
|
||
|
|
fun_label,
|
||
|
|
ds_label,
|
||
|
|
label_desc,
|
||
|
|
sl_nm=label_nm,
|
||
|
|
allow_unmatched_labels=allow_unmatched_labels,
|
||
|
|
)
|
||
|
|
|
||
|
|
ds_sublabels_by_name = {ds_sl.name: ds_sl for ds_sl in ds_label.sublabels}
|
||
|
|
|
||
|
|
def sublabel_mapping(
|
||
|
|
fun_sl: models.ISublabel,
|
||
|
|
) -> Optional[_AnnotationMapper._SublabelIdMapping]:
|
||
|
|
sl_desc = f"sublabel {fun_sl.name!r} of {label_desc}"
|
||
|
|
|
||
|
|
sublabel_nm = label_nm.map_sublabel(fun_sl.name)
|
||
|
|
if sublabel_nm is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
ds_sl = ds_sublabels_by_name.get(sublabel_nm.name)
|
||
|
|
if not ds_sl:
|
||
|
|
if not allow_unmatched_labels:
|
||
|
|
raise BadFunctionError(f"{sl_desc} is not in dataset")
|
||
|
|
|
||
|
|
self._logger.info(
|
||
|
|
"%s is not in dataset; any annotations using it will be ignored", sl_desc
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
return self._build_sublabel_id_mapping(
|
||
|
|
fun_sl,
|
||
|
|
ds_sl,
|
||
|
|
sl_desc,
|
||
|
|
sl_nm=sublabel_nm,
|
||
|
|
allow_unmatched_labels=allow_unmatched_labels,
|
||
|
|
)
|
||
|
|
|
||
|
|
return self._LabelIdMapping(
|
||
|
|
**attrs.asdict(base_mapping, recurse=False),
|
||
|
|
sublabels={
|
||
|
|
fun_sl.id: sublabel_mapping(fun_sl)
|
||
|
|
for fun_sl in getattr(fun_label, "sublabels", [])
|
||
|
|
},
|
||
|
|
expected_num_elements=len(ds_label.sublabels),
|
||
|
|
expected_type=self._get_expected_function_output_type(fun_label, ds_label),
|
||
|
|
)
|
||
|
|
|
||
|
|
def _build_spec_id_mapping(
|
||
|
|
self,
|
||
|
|
fun_labels: Sequence[models.ILabel],
|
||
|
|
ds_labels: Sequence[models.ILabel],
|
||
|
|
*,
|
||
|
|
spec_nm: _SpecNameMapping,
|
||
|
|
allow_unmatched_labels: bool,
|
||
|
|
) -> _SpecIdMapping:
|
||
|
|
ds_labels_by_name = {ds_label.name: ds_label for ds_label in ds_labels}
|
||
|
|
|
||
|
|
def label_id_mapping(
|
||
|
|
fun_label: models.ILabel,
|
||
|
|
) -> Optional[_AnnotationMapper._LabelIdMapping]:
|
||
|
|
label_desc = f"label {fun_label.name!r}"
|
||
|
|
|
||
|
|
label_nm = spec_nm.map_label(fun_label.name)
|
||
|
|
if label_nm is None:
|
||
|
|
return None
|
||
|
|
|
||
|
|
ds_label = ds_labels_by_name.get(label_nm.name)
|
||
|
|
if ds_label is None:
|
||
|
|
if not allow_unmatched_labels:
|
||
|
|
raise BadFunctionError(f"{label_desc} is not in dataset")
|
||
|
|
|
||
|
|
self._logger.info(
|
||
|
|
"%s is not in dataset; any annotations using it will be ignored", label_desc
|
||
|
|
)
|
||
|
|
return None
|
||
|
|
|
||
|
|
return self._build_label_id_mapping(
|
||
|
|
fun_label,
|
||
|
|
ds_label,
|
||
|
|
label_desc,
|
||
|
|
label_nm=label_nm,
|
||
|
|
allow_unmatched_labels=allow_unmatched_labels,
|
||
|
|
)
|
||
|
|
|
||
|
|
return {fun_label.id: label_id_mapping(fun_label) for fun_label in fun_labels}
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
logger: logging.Logger,
|
||
|
|
fun_labels: Sequence[models.ILabel],
|
||
|
|
ds_labels: Sequence[models.ILabel],
|
||
|
|
*,
|
||
|
|
allow_unmatched_labels: bool,
|
||
|
|
conv_mask_to_poly: bool,
|
||
|
|
spec_nm: _SpecNameMapping = _SpecNameMapping(),
|
||
|
|
) -> None:
|
||
|
|
self._logger = logger
|
||
|
|
self._conv_mask_to_poly = conv_mask_to_poly
|
||
|
|
|
||
|
|
self._spec_id_mapping = self._build_spec_id_mapping(
|
||
|
|
fun_labels, ds_labels, spec_nm=spec_nm, allow_unmatched_labels=allow_unmatched_labels
|
||
|
|
)
|
||
|
|
|
||
|
|
def _remap_attribute(
|
||
|
|
self,
|
||
|
|
attribute: models.AttributeValRequest,
|
||
|
|
label_id_mapping: _SublabelIdMapping,
|
||
|
|
seen_attr_ids: set[int],
|
||
|
|
) -> bool:
|
||
|
|
try:
|
||
|
|
attr_id_mapping = label_id_mapping.attributes[attribute.spec_id]
|
||
|
|
except KeyError:
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output attribute with unknown ID ({attribute.spec_id})"
|
||
|
|
)
|
||
|
|
|
||
|
|
if not attr_id_mapping:
|
||
|
|
return False
|
||
|
|
|
||
|
|
if attr_id_mapping.id in seen_attr_ids:
|
||
|
|
raise BadFunctionError("function output shape with multiple attributes with same ID")
|
||
|
|
|
||
|
|
if not attr_id_mapping.value_validator(attribute.value):
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output attribute value ({attribute.value!r})"
|
||
|
|
f" that is unsuitable for its attribute ({attribute.spec_id})"
|
||
|
|
)
|
||
|
|
|
||
|
|
attribute.spec_id = attr_id_mapping.id
|
||
|
|
|
||
|
|
seen_attr_ids.add(attr_id_mapping.id)
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def _remap_attributes(
|
||
|
|
self,
|
||
|
|
annotation: Union[DetectionAnnotation, models.SubLabeledShapeRequest],
|
||
|
|
label_id_mapping: _SublabelIdMapping,
|
||
|
|
) -> None:
|
||
|
|
seen_attr_ids = set()
|
||
|
|
|
||
|
|
if hasattr(annotation, "attributes"):
|
||
|
|
annotation.attributes[:] = [
|
||
|
|
attribute
|
||
|
|
for attribute in annotation.attributes
|
||
|
|
if self._remap_attribute(attribute, label_id_mapping, seen_attr_ids)
|
||
|
|
]
|
||
|
|
|
||
|
|
def _remap_element(
|
||
|
|
self,
|
||
|
|
element: models.SubLabeledShapeRequest,
|
||
|
|
ds_frame: int,
|
||
|
|
label_id_mapping: _LabelIdMapping,
|
||
|
|
seen_sl_ids: set[int],
|
||
|
|
) -> bool:
|
||
|
|
if hasattr(element, "id"):
|
||
|
|
raise BadFunctionError("function output shape element with preset id")
|
||
|
|
|
||
|
|
if hasattr(element, "source"):
|
||
|
|
raise BadFunctionError("function output shape element with preset source")
|
||
|
|
element.source = "auto"
|
||
|
|
|
||
|
|
if element.frame != 0:
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output shape element with unexpected frame number ({element.frame})"
|
||
|
|
)
|
||
|
|
|
||
|
|
element.frame = ds_frame
|
||
|
|
|
||
|
|
if element.type.value != "points":
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output skeleton with element type other than 'points' ({element.type.value})"
|
||
|
|
)
|
||
|
|
|
||
|
|
try:
|
||
|
|
sl_id_mapping = label_id_mapping.sublabels[element.label_id]
|
||
|
|
except KeyError:
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output shape with unknown sublabel ID ({element.label_id})"
|
||
|
|
)
|
||
|
|
|
||
|
|
if not sl_id_mapping:
|
||
|
|
return False
|
||
|
|
|
||
|
|
if sl_id_mapping.id in seen_sl_ids:
|
||
|
|
raise BadFunctionError(
|
||
|
|
"function output skeleton with multiple elements with same sublabel"
|
||
|
|
)
|
||
|
|
|
||
|
|
element.label_id = sl_id_mapping.id
|
||
|
|
|
||
|
|
seen_sl_ids.add(sl_id_mapping.id)
|
||
|
|
|
||
|
|
self._remap_attributes(element, sl_id_mapping)
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def _remap_elements(
|
||
|
|
self, shape: models.LabeledShapeRequest, ds_frame: int, label_id_mapping: _LabelIdMapping
|
||
|
|
) -> None:
|
||
|
|
if shape.type.value == "skeleton":
|
||
|
|
seen_sl_ids = set()
|
||
|
|
|
||
|
|
shape.elements[:] = [
|
||
|
|
element
|
||
|
|
for element in shape.elements
|
||
|
|
if self._remap_element(element, ds_frame, label_id_mapping, seen_sl_ids)
|
||
|
|
]
|
||
|
|
|
||
|
|
if len(shape.elements) != label_id_mapping.expected_num_elements:
|
||
|
|
# There could only be fewer elements than expected,
|
||
|
|
# because the reverse would imply that there are more distinct sublabel IDs
|
||
|
|
# than are actually defined in the dataset.
|
||
|
|
assert len(shape.elements) < label_id_mapping.expected_num_elements
|
||
|
|
|
||
|
|
raise BadFunctionError(
|
||
|
|
"function output skeleton with fewer elements than expected"
|
||
|
|
f" ({len(shape.elements)} vs {label_id_mapping.expected_num_elements})"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
if getattr(shape, "elements", None):
|
||
|
|
raise BadFunctionError("function output non-skeleton shape with elements")
|
||
|
|
|
||
|
|
def _remap_annotation(
|
||
|
|
self, annotation: DetectionAnnotation, ds_frame: int, object_type: str
|
||
|
|
) -> bool:
|
||
|
|
if hasattr(annotation, "id"):
|
||
|
|
raise BadFunctionError(f"function output {object_type} with preset id")
|
||
|
|
|
||
|
|
if hasattr(annotation, "source"):
|
||
|
|
raise BadFunctionError(f"function output {object_type} with preset source")
|
||
|
|
annotation.source = "auto"
|
||
|
|
|
||
|
|
if annotation.frame != 0:
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output {object_type} with unexpected frame number ({annotation.frame})"
|
||
|
|
)
|
||
|
|
|
||
|
|
annotation.frame = ds_frame
|
||
|
|
|
||
|
|
try:
|
||
|
|
label_id_mapping = self._spec_id_mapping[annotation.label_id]
|
||
|
|
except KeyError:
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output {object_type} with unknown label ID ({annotation.label_id})"
|
||
|
|
)
|
||
|
|
|
||
|
|
if not label_id_mapping:
|
||
|
|
return False
|
||
|
|
|
||
|
|
annotation.label_id = label_id_mapping.id
|
||
|
|
|
||
|
|
self._remap_attributes(annotation, label_id_mapping)
|
||
|
|
|
||
|
|
if object_type == "shape":
|
||
|
|
shape = cast(models.LabeledShapeRequest, annotation)
|
||
|
|
|
||
|
|
if not self._are_label_types_compatible(
|
||
|
|
shape.type.value, label_id_mapping.expected_type
|
||
|
|
):
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output shape of type {shape.type.value!r}"
|
||
|
|
f" (expected {label_id_mapping.expected_type!r})"
|
||
|
|
)
|
||
|
|
|
||
|
|
if annotation.type.value == "mask" and self._conv_mask_to_poly:
|
||
|
|
raise BadFunctionError("function output mask shape despite conv_mask_to_poly=True")
|
||
|
|
|
||
|
|
self._remap_elements(shape, ds_frame, label_id_mapping)
|
||
|
|
else:
|
||
|
|
if not self._are_label_types_compatible("tag", label_id_mapping.expected_type):
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output tag"
|
||
|
|
f" (expected shape of type {label_id_mapping.expected_type!r})"
|
||
|
|
)
|
||
|
|
|
||
|
|
return True
|
||
|
|
|
||
|
|
def validate_and_remap(
|
||
|
|
self,
|
||
|
|
annotations: Sequence[DetectionAnnotation],
|
||
|
|
ds_frame: int,
|
||
|
|
) -> tuple[list[models.LabeledImageRequest], list[models.LabeledShapeRequest]]:
|
||
|
|
tags = []
|
||
|
|
shapes = []
|
||
|
|
|
||
|
|
for annotation in annotations:
|
||
|
|
if isinstance(annotation, models.LabeledImageRequest):
|
||
|
|
if self._remap_annotation(annotation, ds_frame, "tag"):
|
||
|
|
tags.append(annotation)
|
||
|
|
elif isinstance(annotation, models.LabeledShapeRequest):
|
||
|
|
if self._remap_annotation(annotation, ds_frame, "shape"):
|
||
|
|
shapes.append(annotation)
|
||
|
|
else:
|
||
|
|
raise BadFunctionError(
|
||
|
|
f"function output an object of type {type(annotation).__name__!r} "
|
||
|
|
f"(expected {models.LabeledImageRequest.__name__!r} "
|
||
|
|
f"or {models.LabeledShapeRequest.__name__!r})"
|
||
|
|
)
|
||
|
|
|
||
|
|
return tags, shapes
|
||
|
|
|
||
|
|
@staticmethod
|
||
|
|
def _are_label_types_compatible(source_type: str, destination_type: str) -> bool:
|
||
|
|
assert source_type != "any"
|
||
|
|
return destination_type == "any" or destination_type == source_type
|
||
|
|
|
||
|
|
|
||
|
|
@attrs.frozen(kw_only=True)
|
||
|
|
class _DetectionFunctionContextImpl(DetectionFunctionContext):
|
||
|
|
frame_name: str
|
||
|
|
conf_threshold: Optional[float] = None
|
||
|
|
conv_mask_to_poly: bool = False
|
||
|
|
|
||
|
|
|
||
|
|
def annotate_task(
|
||
|
|
client: Client,
|
||
|
|
task_id: int,
|
||
|
|
function: DetectionFunction,
|
||
|
|
*,
|
||
|
|
pbar: Optional[ProgressReporter] = None,
|
||
|
|
clear_existing: bool = False,
|
||
|
|
allow_unmatched_labels: bool = False,
|
||
|
|
conf_threshold: Optional[float] = None,
|
||
|
|
conv_mask_to_poly: bool = False,
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
Downloads data for the task with the given ID, applies the given function to it
|
||
|
|
and uploads the resulting annotations back to the task.
|
||
|
|
|
||
|
|
Only tasks with 2D image (not video) data are supported at the moment.
|
||
|
|
|
||
|
|
client is used to make all requests to the CVAT server.
|
||
|
|
|
||
|
|
Currently, the only type of auto-annotation function supported is the detection function.
|
||
|
|
A function of this type is applied independently to each image in the task.
|
||
|
|
The resulting annotations are then combined and modified as follows:
|
||
|
|
|
||
|
|
* The label IDs are replaced with the IDs of the corresponding labels in the task.
|
||
|
|
* The frame numbers are replaced with the frame number of the image.
|
||
|
|
* The sources are set to "auto".
|
||
|
|
|
||
|
|
See the documentation for DetectionFunction for more details.
|
||
|
|
|
||
|
|
If the function is found to violate any constraints set in its interface, BadFunctionError
|
||
|
|
is raised.
|
||
|
|
|
||
|
|
pbar, if supplied, is used to report progress information.
|
||
|
|
|
||
|
|
If clear_existing is true, any annotations already existing in the task are removed.
|
||
|
|
Otherwise, they are kept, and the new annotations are added to them.
|
||
|
|
|
||
|
|
The allow_unmatched_labels parameter controls the behavior in the case when a detection
|
||
|
|
function declares a label/sublabel/attribute in its spec
|
||
|
|
that has no corresponding label/sublabel/attribute in the task.
|
||
|
|
If it's set to True, any annotations/keypoints/attribute values
|
||
|
|
returned by the function that refer to such labels/sublabels/attributes are dropped.
|
||
|
|
If it's set to False, BadFunctionError is raised.
|
||
|
|
|
||
|
|
The conf_threshold parameter must be None or a number between 0 and 1. It will be passed
|
||
|
|
to the AA function as the conf_threshold attribute of the context object.
|
||
|
|
|
||
|
|
The conv_mask_to_poly parameter will be passed to the AA function as the conv_mask_to_poly
|
||
|
|
attribute of the context object. If it's true, and the AA function returns any mask shapes,
|
||
|
|
BadFunctionError will be raised.
|
||
|
|
"""
|
||
|
|
|
||
|
|
if pbar is None:
|
||
|
|
pbar = NullProgressReporter()
|
||
|
|
|
||
|
|
if conf_threshold is not None and not 0 <= conf_threshold <= 1:
|
||
|
|
raise ValueError("conf_threshold must be None or a number between 0 and 1")
|
||
|
|
|
||
|
|
dataset = TaskDataset(client, task_id, load_annotations=False)
|
||
|
|
|
||
|
|
assert isinstance(function.spec, DetectionFunctionSpec)
|
||
|
|
|
||
|
|
mapper = _AnnotationMapper(
|
||
|
|
client.logger,
|
||
|
|
function.spec.labels,
|
||
|
|
dataset.labels,
|
||
|
|
allow_unmatched_labels=allow_unmatched_labels,
|
||
|
|
conv_mask_to_poly=conv_mask_to_poly,
|
||
|
|
)
|
||
|
|
|
||
|
|
tags = []
|
||
|
|
shapes = []
|
||
|
|
|
||
|
|
with pbar.task(total=len(dataset.samples), unit="samples"):
|
||
|
|
for sample in pbar.iter(dataset.samples):
|
||
|
|
frame_annotations = function.detect(
|
||
|
|
# https://github.com/pylint-dev/pylint/issues/9013
|
||
|
|
# pylint: disable-next=abstract-class-instantiated
|
||
|
|
_DetectionFunctionContextImpl(
|
||
|
|
frame_name=sample.frame_name,
|
||
|
|
conf_threshold=conf_threshold,
|
||
|
|
conv_mask_to_poly=conv_mask_to_poly,
|
||
|
|
),
|
||
|
|
sample.media.load_image(),
|
||
|
|
)
|
||
|
|
frame_tags, frame_shapes = mapper.validate_and_remap(
|
||
|
|
frame_annotations, sample.frame_index
|
||
|
|
)
|
||
|
|
tags.extend(frame_tags)
|
||
|
|
shapes.extend(frame_shapes)
|
||
|
|
|
||
|
|
client.logger.info("Uploading annotations to task %d...", task_id)
|
||
|
|
|
||
|
|
if clear_existing:
|
||
|
|
client.tasks.api.update_annotations(
|
||
|
|
task_id, labeled_data_request=models.LabeledDataRequest(tags=tags, shapes=shapes)
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
client.tasks.api.partial_update_annotations(
|
||
|
|
"create",
|
||
|
|
task_id,
|
||
|
|
patched_labeled_data_request=models.PatchedLabeledDataRequest(tags=tags, shapes=shapes),
|
||
|
|
)
|
||
|
|
|
||
|
|
client.logger.info("Upload complete")
|