cvat/cvat-sdk/cvat_sdk/pytorch/transforms.py

94 lines
2.8 KiB
Python
Raw Normal View History

2025-09-16 01:19:40 +00:00
# Copyright (C) CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT
from typing import TypedDict
import attrs
import attrs.validators
import torch
import torch.utils.data
from cvat_sdk.datasets.common import UnsupportedDatasetError
from cvat_sdk.pytorch.common import Target
@attrs.frozen
class ExtractSingleLabelIndex:
"""
A target transform that takes a `Target` object and produces a single label index
based on the tag in that object, as a 0-dimensional tensor.
This makes the dataset samples compatible with the image classification networks
in torchvision.
If the annotations contain no tags, or multiple tags, raises a `ValueError`.
"""
def __call__(self, target: Target) -> int:
tags = target.annotations.tags
if not tags:
raise ValueError("sample has no tags")
if len(tags) > 1:
raise ValueError("sample has multiple tags")
return torch.tensor(target.label_id_to_index[tags[0].label_id], dtype=torch.long)
class LabeledBoxes(TypedDict):
boxes: torch.Tensor
labels: torch.Tensor
_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"])
@attrs.frozen
class ExtractBoundingBoxes:
"""
A target transform that takes a `Target` object and returns a dictionary compatible
with the object detection networks in torchvision.
The dictionary contains the following entries:
"boxes": a tensor with shape [N, 4], where each row represents a bounding box of a shape
in the annotations in the (xmin, ymin, xmax, ymax) format.
"labels": a tensor with shape [N] containing corresponding label indices.
Limitations:
* Only the following shape types are supported: rectangle, polygon, polyline,
points, ellipse.
* Rotated shapes are not supported.
"""
include_shape_types: frozenset[str] = attrs.field(
converter=frozenset,
validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)),
kw_only=True,
)
"""Shapes whose type is not in this set will be ignored."""
def __call__(self, target: Target) -> LabeledBoxes:
boxes = []
labels = []
for shape in target.annotations.shapes:
if shape.type.value not in self.include_shape_types:
continue
if shape.rotation != 0:
raise UnsupportedDatasetError("Rotated shapes are not supported")
x_coords = shape.points[0::2]
y_coords = shape.points[1::2]
boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords)))
labels.append(target.label_id_to_index[shape.label_id])
return LabeledBoxes(
boxes=torch.tensor(boxes, dtype=torch.float),
labels=torch.tensor(labels, dtype=torch.long),
)