cvat/cvat-sdk/cvat_sdk/auto_annotation/functions/_torchvision.py

31 lines
957 B
Python

# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from functools import cached_property
import torchvision.models
import cvat_sdk.auto_annotation as cvataa
class TorchvisionFunction:
_label_type = "any"
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
weights_enum = torchvision.models.get_model_weights(model_name)
self._weights = weights_enum[weights_name]
self._transforms = self._weights.transforms()
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
self._model.eval()
@cached_property
def spec(self) -> cvataa.DetectionFunctionSpec:
return cvataa.DetectionFunctionSpec(
labels=[
cvataa.label_spec(cat, i, type=self._label_type)
for i, cat in enumerate(self._weights.meta["categories"])
if cat != "N/A"
]
)