31 lines
853 B
Python
31 lines
853 B
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import PIL.Image
|
|
|
|
import cvat_sdk.auto_annotation as cvataa
|
|
import cvat_sdk.models as models
|
|
|
|
from ._torchvision import TorchvisionFunction
|
|
|
|
|
|
class _TorchvisionDetectionFunction(TorchvisionFunction):
|
|
_label_type = "rectangle"
|
|
|
|
def detect(
|
|
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
|
|
) -> list[models.LabeledShapeRequest]:
|
|
conf_threshold = context.conf_threshold or 0
|
|
results = self._model([self._transforms(image)])
|
|
|
|
return [
|
|
cvataa.rectangle(label.item(), [x.item() for x in box])
|
|
for result in results
|
|
for box, label, score in zip(result["boxes"], result["labels"], result["scores"])
|
|
if score >= conf_threshold
|
|
]
|
|
|
|
|
|
create = _TorchvisionDetectionFunction
|