cvat/cvat-sdk/cvat_sdk/auto_annotation/functions/torchvision_classification.py

32 lines
787 B
Python
Raw Permalink Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
import PIL.Image
import cvat_sdk.auto_annotation as cvataa
import cvat_sdk.models as models
from ._torchvision import TorchvisionFunction
class _TorchvisionClassificationFunction(TorchvisionFunction):
_label_type = "tag"
def detect(
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
) -> list[models.LabeledImageRequest]:
conf_threshold = context.conf_threshold or 0
results = self._model(self._transforms(image).unsqueeze(0))
scores = results[0].softmax(0)
label = scores.argmax().item()
if scores[label] >= conf_threshold:
return [cvataa.tag(label)]
return []
create = _TorchvisionClassificationFunction