31 lines
957 B
Python
31 lines
957 B
Python
# Copyright (C) CVAT.ai Corporation
|
|
#
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
from functools import cached_property
|
|
|
|
import torchvision.models
|
|
|
|
import cvat_sdk.auto_annotation as cvataa
|
|
|
|
|
|
class TorchvisionFunction:
|
|
_label_type = "any"
|
|
|
|
def __init__(self, model_name: str, weights_name: str = "DEFAULT", **kwargs) -> None:
|
|
weights_enum = torchvision.models.get_model_weights(model_name)
|
|
self._weights = weights_enum[weights_name]
|
|
self._transforms = self._weights.transforms()
|
|
self._model = torchvision.models.get_model(model_name, weights=self._weights, **kwargs)
|
|
self._model.eval()
|
|
|
|
@cached_property
|
|
def spec(self) -> cvataa.DetectionFunctionSpec:
|
|
return cvataa.DetectionFunctionSpec(
|
|
labels=[
|
|
cvataa.label_spec(cat, i, type=self._label_type)
|
|
for i, cat in enumerate(self._weights.meta["categories"])
|
|
if cat != "N/A"
|
|
]
|
|
)
|