Source code for deepvisiontools.data.augmentation_class

from typing import List
from deepvisiontools.formats import (
    BaseFormat,
    BboxFormat,
    BboxData,
    InstanceMaskData,
    InstanceMaskFormat,
)
import torchvision.transforms.v2 as T
from torch import Tensor
from torchvision.tv_tensors import Image
import torch
import deepvisiontools.data.errors as er

# Tested : Mask : Ok : RandomResize, RandomCrop, RandomZoomOut, ScaleJitter, RandomHorizontalFlip / Vertical, RandomRotation, RandomAffine, RandomPerspective
#           Errors : RandomCropIou if you have masks only -> needs boxes
#           boxes : Ok : same as mask but careful with crop / rotation : can lead to slightly off new boxes as per rotation / crops don't preserve structural information of boxes by nature

# TODO check augmentation rotation with boxes : got an error of different input shape / canvas size between image and boxes


[docs] class Augmentation: """Class that handles augmentation in dataset. Call on different Formats (data_type) specific methods Args: augmentations (List[T.Transform]): List of torchvision.transforms.v2 Transform classes (or from deepvisiontools.data.additional_augmentations) """ def __init__(self, augmentations: List[T.Transform]) -> None: self.transform = T.Compose(augmentations) def __call__(self, image: Tensor, target: BaseFormat): """Augment depending on format type""" image = Image(image) transformed_target, _, transformed_img = target.apply_augmentation( image, self.transform ) return transformed_img, transformed_target