cvat/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_detection.py

31 lines
853 B
Python
Raw Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import PIL.Image
import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
from ._torchvision import TorchvisionFunction
class _TorchvisionDetectionFunction(TorchvisionFunction):
_label_type = "rectangle"
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledShapeRequest]:
conf_threshold = context.conf_threshold or 0
results = self._model([self._transforms(image)])
return [
cvataa.rectangle(label.item(), [x.item() for x in box])
for result in results
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
if score >= conf_threshold
]
create = _TorchvisionDetectionFunction