Source code for deepvisiontools.data.batch_augmentations

from deepvisiontools.formats import BaseFormat, BatchedFormat, SemanticMaskFormat
import torch
from torch import Tensor
from abc import ABC, abstractmethod
from typing import Tuple, Union, Literal, List
from random import shuffle, uniform
from deepvisiontools import Configuration
import itertools
from torchvision.transforms.v2 import Resize
import copy


[docs] class AbstractBatchAugmenter(ABC): """Abstract class for augmentation within DataLoader (combine elements of batch together such as mosaic type augmentation) Note : these augmentations always come after normal augmentations that are implemented in Dataset instead of dataloader for this one. """ @abstractmethod def get_new_batch( self, images_batch: Tensor, targets_batch: BatchedFormat ) -> Tuple[Tensor, BatchedFormat]: pass
MOSAIC_CONVERTER = { 1: (1, 1), 2: (1, 2), 4: (2, 2), 6: (2, 3), 8: (2, 4), 9: (3, 3), 12: (3, 4), }
[docs] class MosaicBatchAugmenter(AbstractBatchAugmenter): """This Batch augmentation generate a mosaic containing n images from a batch (mix some patch of images / targets into one image). If the number of image is larger than batch size shift to smaller possibility (for e.g. n = 4 batch_size=3 -> n becomes 2). if number of image to be mixed is smaller than batch_size, create new mosaics if possible : for e.g batchsize = 5, n=2 -> generate 2 mosaics from the first 2 images, then an additional 2 images with remaining, and finally the remaining is 1. The remaining images are left untouched Args: mixed_img_numb (``Literal[1, 2, 4, 6, 8, 9, 12]``, **optional**): Number of img per mosaic. Defaults to 2. probability (``float``, **optional**): _description_. Defaults to 0.5. """ def __init__( self, mixed_img_numb: Literal[1, 2, 4, 6, 8, 9, 12] = 2, probability: float = 0.5, ): assert mixed_img_numb in list( MOSAIC_CONVERTER.keys() ), f"Wrong argument in MosaicBatchAugmenter, mixed_img_num must be one of [1, 2, 4, 6, 8, 9, 12], got {mixed_img_numb}" self.mixed_img_number = mixed_img_numb self._used_mixed_img_nb = mixed_img_numb self.probability = probability self._device = Configuration().device self._imgs_shape = None self._mosaic_shape = None self._resizer = None def _adapt_to_batch(self, images_batch: Tensor): assert ( images_batch.dim() == 4 ), f"To use MosaicBatchAugmenter, batches of images must have dim=4 (N, 3, H, W), got {images_batch.shape}" num_img = images_batch.shape[0] h, w = images_batch.shape[-2:] option_list = [1, 2, 4, 6, 8, 9, 12] # adapt batch size to option of nb of mixed img while self._used_mixed_img_nb > num_img: self._used_mixed_img_nb = option_list[ option_list.index(self._used_mixed_img_nb) - 1 ] self._imgs_shape = images_batch.shape[-2:] # adapt mosaic shape to largest between h and w of images m_shape = MOSAIC_CONVERTER[self._used_mixed_img_nb] if h > w: m_shape = (max(m_shape), min(m_shape)) self._mosaic_shape = m_shape self._resizer = Resize((h, w)) def _reset(self): self._used_mixed_img_nb = self.mixed_img_number self._mosaic_shape = None self._resizer = None def _get_mixing_index(self, images_batch: Tensor) -> Tuple[Tensor, Tensor]: num_imgs = images_batch.shape[0] index = list(range(num_imgs)) shuffle(index) kept_index = index mixing_index = [] while len(kept_index) >= self._used_mixed_img_nb: mixing_index.append(kept_index[: self._used_mixed_img_nb]) kept_index = kept_index[self._used_mixed_img_nb :] return mixing_index, kept_index def _generate_batchs_lists( self, images_batch: Tensor, targets_batch: BatchedFormat ) -> Tuple[List[Tensor], List[BatchedFormat], Tensor, BatchedFormat]: mixing_index, kept_indexes = self._get_mixing_index(images_batch) # split batch into the ones that are kept and the ones to be mixed. if kept_indexes != []: images2keep: Tensor = images_batch[kept_indexes] targets2keep: BatchedFormat = targets_batch[kept_indexes] else: images2keep = torch.empty([0]) targets2keep = BatchedFormat([]) images_batches_to_mix: list = [ images_batch[mix_batch] for mix_batch in mixing_index ] # list of Tensor of shape [n_mix, h, w] targets_batches_to_mix: list = [ targets_batch[mix_batch] for mix_batch in mixing_index ] # list of corresponding BatchedFormats return images_batches_to_mix, targets_batches_to_mix, images2keep, targets2keep def _get_crop_coords_lists(self): h, w = self._imgs_shape h_divider, w_divider = self._mosaic_shape crop_coords_list = [] h_step = float(h) / h_divider w_step = float(w) / w_divider for i, j in itertools.product(list(range(h_divider)), list(range(w_divider))): t_c, l_c, h_c, w_c = i * h_step, j * w_step, h_step, w_step crop_coords_list.append((round(t_c), round(l_c), round(h_c), round(w_c))) return crop_coords_list def _mix_batch( self, batch_img: Tensor, batch_targ: BatchedFormat, crops_coords: List[Tuple[int, int, int, int]], ): num_imgs = batch_img.shape[0] cycle = list(range(num_imgs)) # use this for all imgs/targs permutations cycles = [] for i in range(len(cycle)): first_elem = cycle.pop(0) cycle.append(first_elem) cycles.append(copy.deepcopy(cycle)) new_batch_img = [] new_batch_targs = [] for cy in cycles: new_img, new_tar = self._get_img_targ_from_cycle_of_crops( batch_img, batch_targ, cy, crops_coords ) new_batch_img.append(copy.deepcopy(new_img)) new_batch_targs.append(copy.deepcopy(new_tar)) return torch.stack(new_batch_img), BatchedFormat(new_batch_targs) def _get_img_targ_from_cycle_of_crops( self, batch_img: Tensor, batch_targ: BatchedFormat, cycle: list, crops_coords: list, ): list_imgs = [batch_img[i] for i in cycle] list_targs = BatchedFormat([batch_targ.formats[i] for i in cycle]) new_img = torch.zeros((3, *self._imgs_shape)) new_targ: BaseFormat = type(list_targs.formats[0]).empty( canvas_size=list_targs.formats[0].canvas_size ) ori_h, ori_w = new_targ.canvas_size if isinstance(new_targ, SemanticMaskFormat): new_targ = new_targ.generate_scores_from_mask() for im, tar, crop_coord in zip(list_imgs, list_targs, crops_coords): t, l, h, w = crop_coord new_img[:, t : t + h, l : l + w] = im[:, t : t + h, l : l + w] crop_targ = tar.crop(t, l, h, w)[0] t_pad, l_pad, r_pad, b_pad = t, l, ori_w - l - w, ori_h - t - h crop_targ = crop_targ.pad(t_pad, l_pad, r_pad, b_pad)[0] if isinstance(crop_targ, SemanticMaskFormat): crop_targ = crop_targ.generate_scores_from_mask() new_targ += crop_targ return new_img, new_targ def get_new_batch(self, images_batch: Tensor, targets_batch: BatchedFormat): # ======= Prepare self._adapt_to_batch(images_batch) # if proba fails or if image nb to mix is 1, don't do mixing if uniform(0, 1) > self.probability or self._used_mixed_img_nb == 1: return images_batch, targets_batch # get to be mixed imgs and targs as well as to be kept images and targs images_batches_to_mix, targets_batches_to_mix, images2keep, targets2keep = ( self._generate_batchs_lists(images_batch, targets_batch) ) # ======= Handle mixing crops_coords = self._get_crop_coords_lists() final_batchs_img_mixed = [] final_batchs_tars_mixed = [] for batch_img, batch_targ in zip(images_batches_to_mix, targets_batches_to_mix): new_img, new_targ = self._mix_batch(batch_img, batch_targ, crops_coords) final_batchs_img_mixed.append(new_img) final_batchs_tars_mixed.append(new_targ) final_imgs_mixed = torch.cat(final_batchs_img_mixed).to(Configuration().device) final_targs_mixed = BatchedFormat.cat(final_batchs_tars_mixed) # aggragate untouched and mixed if targets2keep.formats != []: new_img_fullbatch = torch.cat([final_imgs_mixed, images2keep]) new_tar_fullbatch = BatchedFormat.cat([final_targs_mixed, targets2keep]) else: new_img_fullbatch = final_imgs_mixed new_tar_fullbatch = final_targs_mixed # ======= reset for futu uses self._reset() return new_img_fullbatch, new_tar_fullbatch