cvat/ai-models/detector/yolo/func.py

195 lines
6.3 KiB
Python
Raw Permalink Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import abc
import math
from collections.abc import Iterable
from typing import ClassVar, Optional
import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
import PIL.Image
from ultralytics import YOLO
from ultralytics.engine.results import Results
class YoloFunction(abc.ABC):
def __init__(self, model: YOLO, device: str) -> None:
self._model = model
self._device = device
self.spec = cvataa.DetectionFunctionSpec(
labels=[self._label_spec(name, id) for id, name in self._model.names.items()],
)
@abc.abstractmethod
def _label_spec(self, name: str, id_: int) -> models.PatchedLabelRequest: ...
class YoloFunctionWithSimpleLabel(YoloFunction):
LABEL_TYPE: ClassVar[str]
def _label_spec(self, name: str, id_: int) -> models.PatchedLabelRequest:
return cvataa.label_spec(name, id_, type=self.LABEL_TYPE)
class YoloClassificationFunction(YoloFunctionWithSimpleLabel):
LABEL_TYPE = "tag"
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[cvataa.DetectionAnnotation]:
# Unlike the other models, the `predict` method of the classification models does not
# take a confidence threshold. Therefore, we apply one manually.
# We also use 0 as the default threshold on the assumption that by default the user
# wants to get exactly one tag per image.
conf_threshold = context.conf_threshold or 0.0
return [
cvataa.tag(results.probs.top1)
for results in self._model.predict(source=image, device=self._device, verbose=False)
if results.probs.top1conf >= conf_threshold
]
class YoloFunctionWithShapes(YoloFunction):
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[cvataa.DetectionAnnotation]:
kwargs = {}
if context.conf_threshold is not None:
kwargs["conf"] = context.conf_threshold
return [
annotation
for results in self._model.predict(
source=image, device=self._device, verbose=False, **kwargs
)
if len(results) > 0
for annotation in self._annotations_from_results(results)
]
@abc.abstractmethod
def _annotations_from_results(
self, results: Results
) -> Iterable[cvataa.DetectionAnnotation]: ...
class YoloDetectionFunction(YoloFunctionWithSimpleLabel, YoloFunctionWithShapes):
LABEL_TYPE = "rectangle"
def _annotations_from_results(self, results: Results) -> Iterable[cvataa.DetectionAnnotation]:
return (
cvataa.rectangle(int(label.item()), points.tolist())
for label, points in zip(results.boxes.cls, results.boxes.xyxy)
)
class YoloOrientedDetectionFunction(YoloFunctionWithSimpleLabel, YoloFunctionWithShapes):
LABEL_TYPE = "rectangle"
def _annotations_from_results(self, results: Results) -> Iterable[cvataa.DetectionAnnotation]:
return (
cvataa.rectangle(
int(label.item()),
[x - 0.5 * w, y - 0.5 * h, x + 0.5 * w, y + 0.5 * h],
rotation=math.degrees(r),
)
for label, xywhr in zip(results.obb.cls, results.obb.xywhr)
for x, y, w, h, r in [xywhr.tolist()]
)
DEFAULT_KEYPOINT_NAMES = [
# The keypoint names are not recorded in the model file, so we have to ask the user to
# supply them separately (see the keypoint_names_path option).
# But to make using the default models easier, we hardcode the usual COCO keypoint names.
"nose",
"left_eye",
"right_eye",
"left_ear",
"right_ear",
"left_shoulder",
"right_shoulder",
"left_elbow",
"right_elbow",
"left_wrist",
"right_wrist",
"left_hip",
"right_hip",
"left_knee",
"right_knee",
"left_ankle",
"right_ankle",
]
class YoloPoseEstimationFunction(YoloFunctionWithShapes):
def __init__(
self, model: YOLO, device: str, *, keypoint_names_path: Optional[str] = None
) -> None:
if keypoint_names_path is None:
self._keypoint_names = DEFAULT_KEYPOINT_NAMES
else:
self._keypoint_names = self._load_names(keypoint_names_path)
super().__init__(model, device)
def _load_names(self, path: str) -> list[str]:
with open(path, "r") as f:
return [
stripped_line
for line in f.readlines()
for stripped_line in [line.strip()]
if stripped_line
]
def _label_spec(self, name: str, id_: int) -> models.PatchedLabelRequest:
return cvataa.skeleton_label_spec(
name,
id_,
[
cvataa.keypoint_spec(kp_name, kp_id)
for kp_id, kp_name in enumerate(self._keypoint_names)
],
)
def _annotations_from_results(self, results: Results) -> Iterable[cvataa.DetectionAnnotation]:
return (
cvataa.skeleton(
int(label.item()),
[
cvataa.keypoint(kp_index, kp.tolist(), outside=kp_conf.item() < 0.5)
for kp_index, (kp, kp_conf) in enumerate(zip(kps, kp_confs))
],
)
for label, kps, kp_confs in zip(
results.boxes.cls, results.keypoints.xy, results.keypoints.conf
)
)
class YoloSegmentationFunction(YoloFunctionWithSimpleLabel, YoloFunctionWithShapes):
LABEL_TYPE = "polygon"
def _annotations_from_results(self, results: Results) -> Iterable[cvataa.DetectionAnnotation]:
return (
cvataa.polygon(int(label.item()), [c for p in poly_points.tolist() for c in p])
for label, poly_points in zip(results.boxes.cls, results.masks.xy)
)
FUNCTION_CLASS_BY_TASK = {
"classify": YoloClassificationFunction,
"detect": YoloDetectionFunction,
"pose": YoloPoseEstimationFunction,
"obb": YoloOrientedDetectionFunction,
"segment": YoloSegmentationFunction,
}
def create(model: str, **kwargs) -> cvataa.DetectionFunction:
model = YOLO(model=model, verbose=False)
return FUNCTION_CLASS_BY_TASK[model.task](model, **kwargs)