94 lines
2.8 KiB
Python
94 lines
2.8 KiB
Python
|
|
# Copyright (C) CVAT.ai Corporation
|
||
|
|
#
|
||
|
|
# SPDX-License-Identifier: MIT
|
||
|
|
|
||
|
|
from typing import TypedDict
|
||
|
|
|
||
|
|
import attrs
|
||
|
|
import attrs.validators
|
||
|
|
import torch
|
||
|
|
import torch.utils.data
|
||
|
|
|
||
|
|
from cvat_sdk.datasets.common import UnsupportedDatasetError
|
||
|
|
from cvat_sdk.pytorch.common import Target
|
||
|
|
|
||
|
|
|
||
|
|
@attrs.frozen
|
||
|
|
class ExtractSingleLabelIndex:
|
||
|
|
"""
|
||
|
|
A target transform that takes a `Target` object and produces a single label index
|
||
|
|
based on the tag in that object, as a 0-dimensional tensor.
|
||
|
|
|
||
|
|
This makes the dataset samples compatible with the image classification networks
|
||
|
|
in torchvision.
|
||
|
|
|
||
|
|
If the annotations contain no tags, or multiple tags, raises a `ValueError`.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __call__(self, target: Target) -> int:
|
||
|
|
tags = target.annotations.tags
|
||
|
|
if not tags:
|
||
|
|
raise ValueError("sample has no tags")
|
||
|
|
|
||
|
|
if len(tags) > 1:
|
||
|
|
raise ValueError("sample has multiple tags")
|
||
|
|
|
||
|
|
return torch.tensor(target.label_id_to_index[tags[0].label_id], dtype=torch.long)
|
||
|
|
|
||
|
|
|
||
|
|
class LabeledBoxes(TypedDict):
|
||
|
|
boxes: torch.Tensor
|
||
|
|
labels: torch.Tensor
|
||
|
|
|
||
|
|
|
||
|
|
_SUPPORTED_SHAPE_TYPES = frozenset(["rectangle", "polygon", "polyline", "points", "ellipse"])
|
||
|
|
|
||
|
|
|
||
|
|
@attrs.frozen
|
||
|
|
class ExtractBoundingBoxes:
|
||
|
|
"""
|
||
|
|
A target transform that takes a `Target` object and returns a dictionary compatible
|
||
|
|
with the object detection networks in torchvision.
|
||
|
|
|
||
|
|
The dictionary contains the following entries:
|
||
|
|
|
||
|
|
"boxes": a tensor with shape [N, 4], where each row represents a bounding box of a shape
|
||
|
|
in the annotations in the (xmin, ymin, xmax, ymax) format.
|
||
|
|
"labels": a tensor with shape [N] containing corresponding label indices.
|
||
|
|
|
||
|
|
Limitations:
|
||
|
|
|
||
|
|
* Only the following shape types are supported: rectangle, polygon, polyline,
|
||
|
|
points, ellipse.
|
||
|
|
* Rotated shapes are not supported.
|
||
|
|
"""
|
||
|
|
|
||
|
|
include_shape_types: frozenset[str] = attrs.field(
|
||
|
|
converter=frozenset,
|
||
|
|
validator=attrs.validators.deep_iterable(attrs.validators.in_(_SUPPORTED_SHAPE_TYPES)),
|
||
|
|
kw_only=True,
|
||
|
|
)
|
||
|
|
"""Shapes whose type is not in this set will be ignored."""
|
||
|
|
|
||
|
|
def __call__(self, target: Target) -> LabeledBoxes:
|
||
|
|
boxes = []
|
||
|
|
labels = []
|
||
|
|
|
||
|
|
for shape in target.annotations.shapes:
|
||
|
|
if shape.type.value not in self.include_shape_types:
|
||
|
|
continue
|
||
|
|
|
||
|
|
if shape.rotation != 0:
|
||
|
|
raise UnsupportedDatasetError("Rotated shapes are not supported")
|
||
|
|
|
||
|
|
x_coords = shape.points[0::2]
|
||
|
|
y_coords = shape.points[1::2]
|
||
|
|
|
||
|
|
boxes.append((min(x_coords), min(y_coords), max(x_coords), max(y_coords)))
|
||
|
|
labels.append(target.label_id_to_index[shape.label_id])
|
||
|
|
|
||
|
|
return LabeledBoxes(
|
||
|
|
boxes=torch.tensor(boxes, dtype=torch.float),
|
||
|
|
labels=torch.tensor(labels, dtype=torch.long),
|
||
|
|
)
|