32 lines
787 B
Python
32 lines
787 B
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import PIL.Image
|
|
|
|
import cvat_sdk.auto_annotation as cvataa
|
|
import cvat_sdk.models as models
|
|
|
|
from ._torchvision import TorchvisionFunction
|
|
|
|
|
|
class _TorchvisionClassificationFunction(TorchvisionFunction):
|
|
_label_type = "tag"
|
|
|
|
def detect(
|
|
self, context: cvataa.DetectionFunctionContext, image: PIL.Image.Image
|
|
) -> list[models.LabeledImageRequest]:
|
|
conf_threshold = context.conf_threshold or 0
|
|
results = self._model(self._transforms(image).unsqueeze(0))
|
|
|
|
scores = results[0].softmax(0)
|
|
label = scores.argmax().item()
|
|
|
|
if scores[label] >= conf_threshold:
|
|
return [cvataa.tag(label)]
|
|
|
|
return []
|
|
|
|
|
|
create = _TorchvisionClassificationFunction
|