195 lines
6.3 KiB
Python
195 lines
6.3 KiB
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import abc
|
|
import math
|
|
from collections.abc import Iterable
|
|
from typing import ClassVar, Optional
|
|
|
|
import cvat_sdk.auto_annotation as cvataa
|
|
import cvat_sdk.models as models
|
|
import PIL.Image
|
|
from ultralytics import YOLO
|
|
from ultralytics.engine.results import Results
|
|
|
|
|
|
class YoloFunction(abc.ABC):
|
|
def __init__(self, model: YOLO, device: str) -> None:
|
|
self._model = model
|
|
self._device = device
|
|
|
|
self.spec = cvataa.DetectionFunctionSpec(
|
|
labels=[self._label_spec(name, id) for id, name in self._model.names.items()],
|
|
)
|
|
|
|
@abc.abstractmethod
|
|
def _label_spec(self, name: str, id_: int) -> models.PatchedLabelRequest: ...
|
|
|
|
|
|
class YoloFunctionWithSimpleLabel(YoloFunction):
|
|
LABEL_TYPE: ClassVar[str]
|
|
|
|
def _label_spec(self, name: str, id_: int) -> models.PatchedLabelRequest:
|
|
return cvataa.label_spec(name, id_, type=self.LABEL_TYPE)
|
|
|
|
|
|
class YoloClassificationFunction(YoloFunctionWithSimpleLabel):
|
|
LABEL_TYPE = "tag"
|
|
|
|
def detect(
|
|
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
|
|
) -> list[cvataa.DetectionAnnotation]:
|
|
# Unlike the other models, the `predict` method of the classification models does not
|
|
# take a confidence threshold. Therefore, we apply one manually.
|
|
# We also use 0 as the default threshold on the assumption that by default the user
|
|
# wants to get exactly one tag per image.
|
|
conf_threshold = context.conf_threshold or 0.0
|
|
|
|
return [
|
|
cvataa.tag(results.probs.top1)
|
|
for results in self._model.predict(source=image, device=self._device, verbose=False)
|
|
if results.probs.top1conf >= conf_threshold
|
|
]
|
|
|
|
|
|
class YoloFunctionWithShapes(YoloFunction):
|
|
def detect(
|
|
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
|
|
) -> list[cvataa.DetectionAnnotation]:
|
|
kwargs = {}
|
|
if context.conf_threshold is not None:
|
|
kwargs["conf"] = context.conf_threshold
|
|
|
|
return [
|
|
annotation
|
|
for results in self._model.predict(
|
|
source=image, device=self._device, verbose=False, **kwargs
|
|
)
|
|
if len(results) > 0
|
|
for annotation in self._annotations_from_results(results)
|
|
]
|
|
|
|
@abc.abstractmethod
|
|
def _annotations_from_results(
|
|
self, results: Results
|
|
) -> Iterable[cvataa.DetectionAnnotation]: ...
|
|
|
|
|
|
class YoloDetectionFunction(YoloFunctionWithSimpleLabel, YoloFunctionWithShapes):
|
|
LABEL_TYPE = "rectangle"
|
|
|
|
def _annotations_from_results(self, results: Results) -> Iterable[cvataa.DetectionAnnotation]:
|
|
return (
|
|
cvataa.rectangle(int(label.item()), points.tolist())
|
|
for label, points in zip(results.boxes.cls, results.boxes.xyxy)
|
|
)
|
|
|
|
|
|
class YoloOrientedDetectionFunction(YoloFunctionWithSimpleLabel, YoloFunctionWithShapes):
|
|
LABEL_TYPE = "rectangle"
|
|
|
|
def _annotations_from_results(self, results: Results) -> Iterable[cvataa.DetectionAnnotation]:
|
|
return (
|
|
cvataa.rectangle(
|
|
int(label.item()),
|
|
[x - 0.5 * w, y - 0.5 * h, x + 0.5 * w, y + 0.5 * h],
|
|
rotation=math.degrees(r),
|
|
)
|
|
for label, xywhr in zip(results.obb.cls, results.obb.xywhr)
|
|
for x, y, w, h, r in [xywhr.tolist()]
|
|
)
|
|
|
|
|
|
DEFAULT_KEYPOINT_NAMES = [
|
|
# The keypoint names are not recorded in the model file, so we have to ask the user to
|
|
# supply them separately (see the keypoint_names_path option).
|
|
# But to make using the default models easier, we hardcode the usual COCO keypoint names.
|
|
"nose",
|
|
"left_eye",
|
|
"right_eye",
|
|
"left_ear",
|
|
"right_ear",
|
|
"left_shoulder",
|
|
"right_shoulder",
|
|
"left_elbow",
|
|
"right_elbow",
|
|
"left_wrist",
|
|
"right_wrist",
|
|
"left_hip",
|
|
"right_hip",
|
|
"left_knee",
|
|
"right_knee",
|
|
"left_ankle",
|
|
"right_ankle",
|
|
]
|
|
|
|
|
|
class YoloPoseEstimationFunction(YoloFunctionWithShapes):
|
|
def __init__(
|
|
self, model: YOLO, device: str, *, keypoint_names_path: Optional[str] = None
|
|
) -> None:
|
|
if keypoint_names_path is None:
|
|
self._keypoint_names = DEFAULT_KEYPOINT_NAMES
|
|
else:
|
|
self._keypoint_names = self._load_names(keypoint_names_path)
|
|
|
|
super().__init__(model, device)
|
|
|
|
def _load_names(self, path: str) -> list[str]:
|
|
with open(path, "r") as f:
|
|
return [
|
|
stripped_line
|
|
for line in f.readlines()
|
|
for stripped_line in [line.strip()]
|
|
if stripped_line
|
|
]
|
|
|
|
def _label_spec(self, name: str, id_: int) -> models.PatchedLabelRequest:
|
|
return cvataa.skeleton_label_spec(
|
|
name,
|
|
id_,
|
|
[
|
|
cvataa.keypoint_spec(kp_name, kp_id)
|
|
for kp_id, kp_name in enumerate(self._keypoint_names)
|
|
],
|
|
)
|
|
|
|
def _annotations_from_results(self, results: Results) -> Iterable[cvataa.DetectionAnnotation]:
|
|
return (
|
|
cvataa.skeleton(
|
|
int(label.item()),
|
|
[
|
|
cvataa.keypoint(kp_index, kp.tolist(), outside=kp_conf.item() < 0.5)
|
|
for kp_index, (kp, kp_conf) in enumerate(zip(kps, kp_confs))
|
|
],
|
|
)
|
|
for label, kps, kp_confs in zip(
|
|
results.boxes.cls, results.keypoints.xy, results.keypoints.conf
|
|
)
|
|
)
|
|
|
|
|
|
class YoloSegmentationFunction(YoloFunctionWithSimpleLabel, YoloFunctionWithShapes):
|
|
LABEL_TYPE = "polygon"
|
|
|
|
def _annotations_from_results(self, results: Results) -> Iterable[cvataa.DetectionAnnotation]:
|
|
return (
|
|
cvataa.polygon(int(label.item()), [c for p in poly_points.tolist() for c in p])
|
|
for label, poly_points in zip(results.boxes.cls, results.masks.xy)
|
|
)
|
|
|
|
|
|
FUNCTION_CLASS_BY_TASK = {
|
|
"classify": YoloClassificationFunction,
|
|
"detect": YoloDetectionFunction,
|
|
"pose": YoloPoseEstimationFunction,
|
|
"obb": YoloOrientedDetectionFunction,
|
|
"segment": YoloSegmentationFunction,
|
|
}
|
|
|
|
|
|
def create(model: str, **kwargs) -> cvataa.DetectionFunction:
|
|
model = YOLO(model=model, verbose=False)
|
|
return FUNCTION_CLASS_BY_TASK[model.task](model, **kwargs)
|