Source code for deepvisiontools.models.smp.smp

import segmentation_models_pytorch as smp
from deepvisiontools.models.basemodel import BaseModel
from typing import Literal, Union, Tuple
import torch
from torch import Tensor
from deepvisiontools.config import Configuration
import deepvisiontools.models.smp.errors as er
import inspect
from segmentation_models_pytorch.base.model import SegmentationModel
from deepvisiontools.formats import SemanticMaskFormat, SemanticMaskData, BatchedFormat
import copy
from deepvisiontools.formats.utils import logit2pred


class _ConcreteSegmentationModel(BaseModel):
    """Concrete implementation of deepvisiontools BaseModel used for SMP models.
    This is used as one of the two parent class of dynamically created class _SMP"""

    def __init__(self, *args, **kwargs):
        mode = "binary" if Configuration().num_classes == 1 else "multiclass"
        loss = kwargs.pop("loss", smp.losses.FocalLoss(mode=mode))
        super().__init__(*args, **kwargs)
        assert (
            loss.mode == mode
        ), f"Loss has inconsistent mode (multiclass or binary) compared to num_classes. Got {loss.mode} and {Configuration().num_classes}"
        self.loss = loss

    def prepare(self, images, targets=None):
        targets = torch.stack([targ.data.value for targ in targets])
        return images, targets

    def run_forward(self, images, targets):
        images, batched_targets = self.prepare(images, targets)
        logits = self.activation(self(images))
        loss_dict = {"loss": self.loss(logits, batched_targets)}
        if self.training:
            return loss_dict
        else:
            return loss_dict, self.build_results(logits)

    def build_results(self, raw_outputs):
        raw_outputs = (
            raw_outputs if raw_outputs.ndim == 4 else raw_outputs[None, :]
        )  # raw outputs are logits
        logits = copy.deepcopy(raw_outputs)
        # Transform logits to preds (2 cases : binary vs multiclass). needs to handle batch size as well
        preds = torch.stack([logit2pred(l) for l in logits])
        # convert preds to BatchedFormat
        preds = [SemanticMaskData(t) for t in preds]
        preds = [SemanticMaskFormat(f, scores=l) for f, l in zip(preds, logits)]
        preds = BatchedFormat(preds)
        return preds

    def get_predictions(self, images):
        self.eval()
        logits = self.activation(self(images))
        preds = self.build_results(logits)
        return preds

    def activation(self, prediction: Tensor):
        """Return activated predictions by sigmoid (single class) of softmax (multi class).

        Args:
            prediction (Tensor): Raw model ouput/raw probabilities.
        """
        num_classes = Configuration().num_classes
        prediction_dims = prediction.ndim
        assert prediction_dims in [
            3,
            4,
        ], f"Number of dimension in prediction should be 3 or 4 (batch), got {prediction_dims}"
        channel = 0 if prediction_dims == 3 else 1
        if num_classes == 1:
            prediction = prediction.sigmoid()
        elif num_classes > 1:
            prediction = prediction.softmax(dim=channel)  # Channels dimension

        return prediction


[docs] class SMP(torch.nn.Module): """Factory class that wraps segmentation-models-pytorch (smp) models into deepvisiontools. These models are used for semantic segmentation tasks. Using this class you can use all available models, encoder and whatever additional arguments from segmentation model pytorch. Please provide further parameters using non positional arguments (ex : arg=myadditionalarg) Note that you can use any smp loss as well by simply providing and instance of smp losses : loss=smp.loss.WantedLoss() smp : https://github.com/qubvel-org/segmentation_models.pytorch Args: architecture (``SegmentationModel``, **optional**): SMP model architecture : need to provide a smp class (type). Defaults to smp.Unet. Example: ---------- .. highlight:: python .. code-block:: python >>> from deepvisiontools.models import SMP >>> import segmentation_models_pytorch as smp >>> my_model = SMP(smp.Unet, encoder_name="vgg11", loss=smp.losses.FocalLoss(mode="binary")) """ def __new__(cls, architecture: SegmentationModel = smp.Unet, *args, **kwargs): assert ( Configuration().data_type == "semantic_mask" ), f"Can't use SMP models (used for semantic segmentation) if Configuration().data_type is not semantic_mask. Got {Configuration().data_type}." num_cls = Configuration().num_classes # to avoid errors if architecture is smp.Unet or smp.Unet(), ensure that we have a type and not and instance architecture = ( architecture if isinstance(architecture, type) else type(architecture) ) # Get class number and feed to kwargs kwargs["classes"] = num_cls # Dynamically create class class _SMP(_ConcreteSegmentationModel, architecture): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.init_args = args self.architecture = architecture self.init_kwargs = kwargs def __reduce__(self): # used by pickle in torch.save return ( SMP._rebuild_model, ( self.architecture, self.init_args, self.init_kwargs, self.state_dict(), ), ) instance = _SMP(*args, **kwargs) # return newly created instance from dynamically inheritated smp class return instance @staticmethod def _rebuild_model(architecture, args, kwargs, state_dict): kwargs["encoder_weights"] = ( None # Not downloading pretrained weights when using torch.load ) model = SMP(architecture, *args, **kwargs) model.load_state_dict(state_dict) return model