Source code for deepvisiontools.inference.predictor

from deepvisiontools import Configuration
from deepvisiontools.models.basemodel import BaseModel
from typing import Tuple, Union, Callable, Dict
from deepvisiontools.preprocessing.preprocessing import build_preprocessing
from pathlib import Path
import torch
from torch import Tensor
from deepvisiontools.formats import BatchedFormat, BaseFormat
from typing import List, Union
from deepvisiontools.utils import visualization
from deepvisiontools.preprocessing.image import load_image
from deepvisiontools.inference.patchifier import (
    DetectPatchifier,
    SemanticPatchifier,
    BasePatchifier,
)
from torchvision.transforms.v2 import Pad, CenterCrop
from tqdm import tqdm
from deepvisiontools.formats import SemanticMaskData, SemanticMaskFormat


def build_patchifier(*args, **kwargs):
    if Configuration().data_type == "semantic_mask":
        return SemanticPatchifier(*args, **kwargs)
    else:
        return DetectPatchifier(*args, **kwargs)


[docs] class Predictor: """Predictor class for deepvisiontools. Load a model and apply on image, get prediction. Can handle patchification for large image prediction. Args: model (``Union[BaseModel, str, Path]``): model path / instance of BaseModel to be used. preprocessing (``Callable``, **optional**): used preprocesser. Defaults to build_preprocessing(). patch_size (``Union[Tuple[int, int], None]``, **optional**): size of the patchs to be used for large image inference. If None will run the full image. Defaults to None. overlap (``float``, **optional**): Overlap between patches used in case of patchification. Defaults to 0.4. border_padding (``int``, **optional**): default image padding when using patchification. Defaults to 100. batch_size (``int``, **optional**): batch size for patchification. Defaults to 1. border_penalty (``float``, **optional**): apply a penalty on patch border predictions : makes nms more efficient. Higher is more stringent. Max to 1 and Min to 0. Defaults to 0.5. nms_iou_threshold (``float``, **optional**): nms threshold to be used when upatchifying. Defaults to 0.45. final_score_threshold (``float``, **optional**): Apply a score thresholding after penalty and after nms. Defaults to 0.4. categories (``Dict[int, str]``, **optional**): To rename your categories in the visualization. patchifier (``Union[BasePatchifier, None], **optional**): If None use default SemanticPatchifier or DetectPatchifier according to Configuration().data_type. Default to None. verbose (``bool``, **optional**) : if set to True will display progress state in patchs predictions. Default to True. Example: ---------- .. highlight:: python .. code-block:: python >>> from deepvisiontools import Predictor >>> img = \"path/to/img\" >>> predictor = Predictor(model=\path\to\model.pth) >>> results = predictor.predict(img) Attributes ---------- Attributes: model (``BaseModel``) preprocessing (``Callable``) patch_size (``Union[Tuple[int, int], None]``) padder (``Transform``) batch_size (``int``) cropper (``Transform``) patchifier (``BasePatchifier``) categories (``Dict[int, str]``) verbose (``bool``) **Methods** """ def __init__( self, model: Union[BaseModel, str, Path], preprocessing: Callable = build_preprocessing(), patch_size: Union[Tuple[int, int], None] = None, overlap: float = 0.4, border_padding: int = 100, batch_size: int = 1, categories: Dict[int, str] = None, patchifier: Union[BasePatchifier, None] = None, verbose: bool = True, ): assert any( [ isinstance(model, BaseModel), isinstance(model, Path), isinstance(model, str), ] ), "model must be instance of BaseModel or path to model" if not isinstance(model, BaseModel): model = Path(model) if isinstance(model, str) else model assert model.exists(), f"model path does not exists, got {model.as_posix()}" model = torch.load(model, map_location=Configuration().device) self.model: BaseModel = model.to(Configuration().device).eval() self.patch_size = patch_size self.batch_size = batch_size self.preprocessing = preprocessing self.padder = Pad(border_padding) self.cropper = None # is updated in self.predict if patchifier == None: self.patchifier = build_patchifier(self.patch_size, overlap) else: self.patchifier = patchifier self.categories = categories self.verbose = verbose
[docs] def forward_pass(self, batch_patchs: Tensor) -> BatchedFormat: """Run predictions on image / batch of patches""" with torch.no_grad(): loader = PredictorDataLoader(batch_patchs, self.batch_size) # wrap in tqdm if more than one batch if self.verbose: loader = ( tqdm( loader, desc="Predict on batch of patches : ", total=len(loader) ) if len(loader) > 1 else loader ) predictions = BatchedFormat([]) for patch in loader: with torch.autocast( device_type=Configuration().device, dtype=torch.float16, enabled=Configuration().optimize, ): predictions += self.model.get_predictions(patch) return predictions
[docs] def predict( self, image: Union[str, Path, Tensor], visu_path: Union[str, Path] = "", ) -> BaseFormat: """Main function of ```Predictor``` : call everything needed for prediction. Args: image (``Union[str, Path, Tensor]``): _description_ visu_path (``Union[str, Path]``, **optional**): path to visualization to be saved. Defaults to "". Returns: ``BaseFormat``: - prediction as deepvisiontools format. """ # Load image if needed if isinstance(image, str): image = Path(image) image: Tensor = load_image(image) if isinstance(image, Path): image = load_image(image) # pad image self.cropper = CenterCrop(image.shape[-2:]) h_original, w_original = image.shape[-2:] image = image.to(Configuration().device) # preprocess if self.preprocessing != None: preprocessed_image = self.preprocessing(image[None, :])[0] else: preprocessed_image = image # handles patch if needed if self.patch_size != None and self.patch_size != ( image.shape[-2], image.shape[-1], ): _h_pad, _w_pad = 0, 0 if image.shape[-2] < self.patch_size[0]: _h_pad = self.patch_size[0] - image.shape[-2] _t_pad = int(_h_pad // 2) _t_pad = _t_pad if _t_pad > 0 else 0 _b_pad = _h_pad - _t_pad _b_pad = _b_pad if _b_pad > 0 else 0 if image.shape[-2] < self.patch_size[0]: _w_pad = self.patch_size[1] - image.shape[-1] _l_pad = int(_w_pad // 2) _l_pad = _l_pad if _l_pad > 0 else 0 _r_pad = _w_pad - _l_pad _r_pad = _r_pad if _r_pad > 0 else 0 if _h_pad > 0 or _w_pad > 0: preprocessed_image = Pad((_l_pad, _t_pad, _r_pad, _b_pad))( preprocessed_image ) batch_patch, pad_origins, padded_image, pad_coord = ( self.patchifier.patchify(preprocessed_image) ) h_pad, w_pad = padded_image.shape[-2:] preds = self.forward_pass(batch_patch) # filter empty patches and associated origins preds, pad_origins = self.filter_empty_patches(preds, pad_origins) # unpatchification preds = self.patchifier.unpatchify( preds, pad_origins, (h_pad, w_pad), pad_coord, (h_original, w_original) ) else: preds = self.forward_pass( preprocessed_image[None, :] ) # forward_pass need batch image -> add dummy dim preds = preds.formats[0] # recover pre-border-padding shape preds, _, _ = preds.apply_augmentation(image, self.cropper) # Create visualization if needed if visu_path: visu_path = visu_path if isinstance(visu_path, Path) else Path(visu_path) visualization(image, preds, categories=self.categories, save_path=visu_path) return preds
[docs] def filter_empty_patches( self, preds_batch_patch: BatchedFormat, pad_origins: List[Tuple[int, int]] ): """remove empty patches for unpatchification""" filter_empty = [ False if pred.nb_object == 0 else True for pred in preds_batch_patch ] pad_origins = [el for i, el in enumerate(pad_origins) if filter_empty[i]] preds_batch_patch = preds_batch_patch[torch.tensor(filter_empty)] return preds_batch_patch, pad_origins
[docs] class PredictorDataLoader: """Wrap predictor patchification output as loader with given batch_size for forward""" def __init__(self, patches: Tensor, batch_size: int = 1): N_ = patches.shape[0] batchs = [] for i in range(N_ // batch_size + 1): if i * batch_size >= N_: break if (i + 1) * batch_size < N_: batchs.append(patches[i * batch_size : (i + 1) * batch_size]) else: batchs.append(patches[i * batch_size :]) break self.batchs = batchs def __iter__(self): return iter(self.batchs) def __len__(self): return len(self.batchs)