Source code for deepvisiontools.inference.patchifier

from typing import Tuple, List
from torch import Tensor
from math import ceil
import torch
from torchvision.transforms.v2 import Pad
from itertools import product
from deepvisiontools.formats import (
    BatchedFormat,
    BaseFormat,
    BboxFormat,
    InstanceMaskFormat,
)
import copy
from torchvision.ops import nms
from abc import ABC, abstractmethod
from deepvisiontools import visualization


[docs] class BasePatchifier(ABC): """Abstract class for patchifier. If you want to implement a custom one you need to implement unpatchify method"""
[docs] def pad_to( self, image: Tensor, new_size: Tuple[int, int] ) -> Tuple[Tensor, Tuple[int, int, int, int]]: """Pad image to given size Args: image (Tensor). new_size (Tuple[int, int]). Returns: Tuple[Tensor, Tuple[int, int, int, int]]: padded image, (t, l, r, b) """ t = int((new_size[0] - image.shape[-2]) / 2) l = int((new_size[1] - image.shape[-1]) / 2) b = int((new_size[0] - image.shape[-2]) - t) r = int((new_size[1] - image.shape[-1]) - l) padder = Pad((l, t, r, b)) pad_coord = (t, l, r, b) return padder(image), pad_coord
[docs] def patchify( self, image: Tensor ) -> Tuple[Tensor, List[Tuple[int, int]], Tensor, Tuple[int, int, int, int]]: """Create patches for image prediction : 1) Pad image to fit all patches, 2) create patches Args: image (``Tensor``) Returns: ``Tuple[ Tensor, List[Tuple[int, int]], Tuple[int, int], Tensor, Tuple[int, int]]``: - patches stacked (N_patch, c, h, w), List of (top, left) pad coordinates, padded image, image pad coordinates """ h, w = image.shape[-2:] assert ( self.patch_size[0] < h and self.patch_size[1] < w ), f"Can't patchify image of shape {(h, w)} with patch size {self.patch_size}." c, h, w = image.shape[-3:] h_patch, w_patch = self.patch_size # compute strides values stride_h = h_patch - round(h_patch * self.overlap) stride_w = w_patch - round(w_patch * self.overlap) # get number of pacthes on axis nb_h_patches = ceil(h / stride_h) nb_w_patches = ceil(w / stride_w) # padded image shape (H,W) h_padded = (nb_h_patches - 1) * stride_h + h_patch w_padded = (nb_w_patches - 1) * stride_w + w_patch # pad image padded_image, pad_coord = self.pad_to(image, (h_padded, w_padded)) # get coordinates top_corners = range(0, nb_h_patches * stride_h, stride_h) left_corners = range(0, nb_w_patches * stride_w, stride_w) origins = list(product(top_corners, left_corners)) # Create patches tensors patches = torch.zeros( (nb_h_patches * nb_w_patches, c, h_patch, w_patch), device=image.device ).to(image.device) # Fill the patches tensors for idx, (y, x) in enumerate(origins): patches[idx] = padded_image[:, y : y + h_patch, x : x + w_patch] return patches, origins, padded_image, pad_coord
@abstractmethod def unpatchify(): pass
[docs] class DetectPatchifier(BasePatchifier): """Handle patchification and unpatchification Args: patch_size (``Tuple[int, int]``): size of patches to create overlap (``float``): overlap between patches border_penalty (``float``, **optional**): penalty to apply on patch border objects before nms. Defaults to 0.5. nms_iou_threshold (``float``, **optional**): nms threshold. Defaults to 0.45. final_score_threshold (``float``, **optional**): final score (after penalty) threshold. Defaults to 0.4. Attributes ---------- Attributes: patch_size (``Tuple[int, int]``) overlap (``float``): overlap between patches border_penalty (``float``) postprocess (```PostProcesser```) **Methods** """ def __init__( self, patch_size: Tuple[int, int], overlap: float, border_penalty: float = 0.5, nms_iou_threshold: float = 0.45, final_score_threshold: float = 0.4, ): self.patch_size = patch_size self.overlap = overlap self.border_effect_penality = border_penalty self.postprocess = PostProcesser(nms_iou_threshold, final_score_threshold) def _merge_prediction_patches( self, pred_patches: BatchedFormat, origins: List[Tuple[int, int]], image_padded_size: Tuple[int, int], ) -> BaseFormat: """combine the patchs predictions together""" # add patch to full sized format full_pred: BaseFormat = type(pred_patches.formats[0]).empty( canvas_size=image_padded_size ) scores = torch.tensor([]).to(pred_patches.device) for pred, ori in zip(pred_patches.formats, origins): r_pad = image_padded_size[1] - ori[1] - pred.canvas_size[1] b_pad = image_padded_size[0] - ori[0] - pred.canvas_size[0] full_pred += pred.pad(ori[0], ori[1], r_pad, b_pad)[0] if pred.scores != None: scores = torch.cat([scores, pred.scores]) assert ( scores.nelement() == full_pred.nb_object ), "Inconsistency between scores and object number." full_pred._scores = scores return full_pred
[docs] def unpatchify( self, pred_patches: BatchedFormat, origins: List[Tuple[int, int]], image_padded_size: Tuple[int, int], padded_image_coords: Tuple[int, int, int, int], original_image_size: Tuple[int, int], ): """merge patchs predictions together while applying penalty, postprocess etc. Note that since InstanceMasks do not handle overlapping objects, you need to treat directly the patches in the postprocess. To do so you derive boxes, then use boxes to filter the masks.""" # apply penalty pred_patches = BatchedFormat( [self._scores_penalty(pred) for pred in pred_patches] ) if pred_patches.formats == []: # handles empty preds return BboxFormat.empty((original_image_size[0], original_image_size[1])) # if instance_masks, need to filter patchs objects before merging (otherwise will override objects ...) if isinstance(pred_patches.formats[0], InstanceMaskFormat): box_patches = [ BboxFormat.from_instance_mask(patch) for patch in pred_patches ] associated_boxes = self._merge_prediction_patches( BatchedFormat(box_patches), origins, image_padded_size ) pred_patches = self.postprocess.handle_patches_mask( pred_patches, associated_boxes ) preds = self._merge_prediction_patches( pred_patches, origins, image_padded_size ) else: # merge patchs preds = self._merge_prediction_patches( pred_patches, origins, image_padded_size ) # Duplicates handling preds = self.postprocess.handle_box_duplicates(preds) # crop to original image t, l, _, _ = padded_image_coords h = original_image_size[0] w = original_image_size[1] preds, _ = preds.crop(t, l, h, w) return preds
def _scores_penalty(self, patch: BaseFormat) -> BaseFormat: """apply score penalty on border. Method is linear""" if isinstance(patch, BboxFormat): return self._box_penalty(patch) elif isinstance(patch, InstanceMaskFormat): # if InstanceMaskFormat create boxes, apply penalty on objects boxes scores, return InstanceMaskFormat with replaced scores boxes = BboxFormat.from_instance_mask(patch) new_scores = self._box_penalty(boxes) new_patch = InstanceMaskFormat( data=patch.data, labels=patch.labels, scores=new_scores.scores ) return new_patch def _box_penalty(self, patch: BaseFormat) -> BaseFormat: """Apply box scores penalty to reduce border effect. Helps patchification multiple objects removal. Args: patch (``BaseFormat``) Returns: ``BaseFormat``: - patch with reduced scores on border. """ if patch.nb_object == 0: return patch assert ( patch.scores != None ), "Can't apply penalty on scores since scores = None." patch: BaseFormat = copy.copy(patch) patch.format = "XYXY" patch_xcenter = patch.canvas_size[0] / 2 patch_ycenter = patch.canvas_size[1] / 2 max_devs = [] for box in patch.data.value: centro_x = int((box[0] + box[2]) / 2) centro_y = int((box[1] + box[3]) / 2) relative_xdeviation = abs(centro_x - patch_xcenter) / patch.canvas_size[1] relative_ydeviation = abs(centro_y - patch_ycenter) / patch.canvas_size[0] # max and *2 because relative_dev in [0, 0.5] range (but we want in [0, 1]) max_devs.append(max(2 * relative_xdeviation, 2 * relative_ydeviation)) # Apply linear penalty relu = torch.nn.ReLU(inplace=True) penalty = relu(1 - self.border_effect_penality * torch.tensor(max_devs)).to( patch.device ) new_scores = patch.scores * penalty patch._scores = new_scores return patch
[docs] class SemanticPatchifier(BasePatchifier): """Semantic segmentation patchifier/unpatchifier Args: patch_size (``Tuple[int, int]``) overlap (``float``) """ def __init__(self, patch_size: Tuple[int, int], overlap: float): self.patch_size = patch_size self.overlap = overlap def unpatchify( self, pred_patches, origins, image_padded_size, padded_image_coords, original_image_size, ): full_pred: BaseFormat = type(pred_patches.formats[0]).empty( canvas_size=image_padded_size ) full_pred = full_pred.generate_scores_from_mask() for i, item in enumerate(zip(pred_patches, origins)): pred, ori = item r_pad = image_padded_size[1] - ori[1] - pred.canvas_size[1] b_pad = image_padded_size[0] - ori[0] - pred.canvas_size[0] pred = pred.pad(ori[0], ori[1], r_pad, b_pad)[0] full_pred += pred return full_pred
[docs] class PostProcesser: """Handles postprocessing Args: nms_iou_th (``float``): nms threshold final_score_threshold (``float``): final score thresholding """ def __init__(self, nms_iou_th: float, final_score_threshold: float): self.nms_iou_th = nms_iou_th self.final_score_threshold = final_score_threshold def handle_patches_mask( self, pred_batch: BatchedFormat, associated_boxes: BboxFormat ): # nms suppression nms_filter = self._box_nms_suppress(associated_boxes) pred_batch, nms_bool_filter = self._apply_index(pred_batch, nms_filter) associated_boxes, _ = associated_boxes[nms_bool_filter] pred_batch = BatchedFormat( [pred[pred.scores > self.final_score_threshold][0] for pred in pred_batch] ) pred_batch.sanitize() return pred_batch def _apply_index(self, pred_batch: BatchedFormat, index: Tensor): index_slicer = [patch.nb_object for patch in pred_batch] # convert index to bool tensor with shape [N_objs] if not index.dtype == torch.bool: idx = range(sum(index_slicer)) index = torch.tensor([i in index for i in idx]) # create index parser for batch index_slicer.insert(0, 0) slicer = [] s = 0 for i in index_slicer: s += i slicer.append(s) # filter patches patch_list = [] for i, patch in enumerate(pred_batch): patch_list.append(patch[index[slicer[i] : slicer[i + 1]]][0]) return BatchedFormat(patch_list), index
[docs] def handle_box_duplicates(self, pred: BaseFormat) -> BaseFormat: """Main function that call further function depending on handling mode Args: pred (``BaseFormat``): data to handle duplicates Returns: ``BaseFormat``: - handled duplicate new data """ # nms suppression nms_filter = self._box_nms_suppress(pred) new_pred, _ = pred[nms_filter] new_pred, _ = new_pred.sanitize() new_pred, _ = new_pred[new_pred.scores > self.final_score_threshold] return new_pred
def _box_nms_suppress(self, to_handle: BboxFormat) -> Tensor: """return a filter tensor of objects to keep after nms""" indexes_filter = nms(to_handle.data.value, to_handle.scores, self.nms_iou_th) return indexes_filter