from ultralytics.cfg import get_cfg
from ultralytics.nn.tasks import SegmentationModel, attempt_load_one_weight
from ultralytics.utils import DEFAULT_CFG
from deepvisiontools.models.basemodel import BaseModel
from typing import Literal
from deepvisiontools import Configuration
import deepvisiontools.models.yoloseg.errors as er
import deepvisiontools.models.yoloseg.utils as ut
from deepvisiontools.formats import (
BatchedFormat,
BboxData,
BboxFormat,
InstanceMaskFormat,
InstanceMaskData,
)
import torch
from typing import Dict, Tuple, Union, List
from torch import Tensor
import torchvision.transforms.v2.functional as F
[docs]
class YoloSeg(SegmentationModel, BaseModel):
"""Yolo detection model. data_type must be either bbox or instance_mask to use this model.
Args:
architecture (``Literal["yolon", "yolom", "yolol", "yolox"]``, **optional**): Yolo model size. You can add "-p2" or "-p6" to load the p2 or p6 variants. Defaults to "yolon".
pretrained (``bool``, **optional**): Use pretrained weights. Defaults to True.
reg_max (``int``, **optional**): reg_max argument of yolo models (impacts object size detection). See ultralytics for more information. Defaults to 16.
loss_factor (``float``, **optional**): divide yolo loss value (important for mixed precision to keep it below a certain range). Defaults to 1.
Attributes
----------
Attributes:
- criterion (``v8DetectionLoss``): Yolo loss from ultralytics.
- args (``Any``) : ultralytics Yolo's configuration params.
- pad_requirements (``int``) : pad requirements for yoloseg (basic is image shape must be multiple of 32)
- mask_logit_threshold (``int``) : mask logit threshold to consider if pixel is class or background. Default is 0.5 but can be changed.
Attributes
----------
Properties:
- device (``Literal["cuda", "cpu"]``): model's device
**Methods**
"""
def __init__(
self,
architecture: Literal[
"yolo11n-seg",
"yolo11m-seg",
"yolo11l-seg",
"yolo11x-seg",
"yolov8n-seg",
"yolov8m-seg",
"yolov8l-seg",
"yolov8x-seg",
] = "yolov8n-seg",
pretrained: bool = True,
reg_max=16,
loss_factor: float = 1.0,
*args,
**kwargs,
):
er.check_config()
assert (
"-seg" in architecture
), f"architecture must be one of [yolon-seg, yolom-seg, yolol-seg, yolox-seg] to use YoloSeg. Got {architecture} (probably forgot the -seg)"
config = Configuration()
super().__init__(f"{architecture}.yaml", nc=config.num_classes, *args, **kwargs)
self.args = get_cfg(DEFAULT_CFG)
self.model[-1].reg_max = reg_max
if pretrained:
architecture = attempt_load_one_weight(
f"{architecture}.pt",
)
self.load(architecture[0])
self.criterion = self.init_criterion()
self.device = config.device
self.pad_requirements = 32
self.mask_logit_threshold = 0.5
self.loss_factor = loss_factor
# overwrite
@property
def device(self):
return self._device
# overwrite
@device.setter
def device(self, val):
self.to(val)
self.criterion = self.init_criterion()
[docs]
def prepare_target(self, targets: BatchedFormat) -> Dict[str, Tensor]:
"""Transform SegmentationFormat targets into yolo-seg targets format.
Args:
targets (``BatchedFormats``): Batch targets.
Returns:
``Dict[str, Tensor]``:
- Targets in YOLO format.
"""
# Create batched stacked mask : (N_batch, H, W)
targets = BatchedFormat([t.sanitize()[0] for t in targets])
masks = torch.stack([t.data.value for t in targets])
# masks = ut.mask2yolo(masks)
# Create bboxes
boxes_batched = BatchedFormat(
[BboxFormat.from_instance_mask(t) for t in targets]
)
boxes_batched.set_bboxes_format("CXCYWH")
# normalize boxes
boxes = torch.cat(
[ut.normalize_boxes(b.data.value, masks.shape[-2:]) for b in boxes_batched]
)
# extract labels
labels = torch.cat([t.labels.long() for t in targets])
images_indices = torch.cat(
[torch.full((t.nb_object,), i) for i, t in enumerate(targets)]
)
images_indices = images_indices.to(targets.device)
# check if all values are compatible
N_box = boxes.shape[0]
N_instances = sum([torch.max(m) for m in masks])
N_labels = labels.shape[0]
assert (N_box == N_instances) and (
N_box == N_labels
), "Error in preparing target for YoloSeg : one or multiple of (N_boxes, N_instances, N_labels) is different. You may need to increase mask_min_size threshold in Configuration()"
# put labels and batch_idx in yolo dormat : Tensor (N, 1)
batch_idx = images_indices[:, None]
classes = labels[:, None]
yolotarget = {
"masks": masks,
"bboxes": boxes,
"cls": classes,
"batch_idx": batch_idx,
}
return yolotarget
[docs]
def prepare(
self, images: Tensor, targets: Union[BatchedFormat, None] = None
) -> Union[Tuple[Tensor, Dict], Tensor]:
"""Pad image / targets to fit yolo divisibility by 32 criterium and move targets to yolo format.
If no targets passed simply returns images
Args:
images (``Tensor``): batched images [N, 3, H, W]
targets (``Union[BatchedFormat, None]``)
Returns:
``Union[Tuple[Tensor, Dict], Tensor]``:
- Either : images_padded, yolo_targets OR images_padded
"""
h, w = images.shape[-2], images.shape[-1]
(t, l, r, b) = ut.yolo_pad_requirements(h, w, required=self.pad_requirements)
# Note the inversion for torchvision pad coord ordinates : t <-> l
images = F.pad(images, list((l, t, r, b)))
if targets != None:
targets = BatchedFormat([targ.pad(t, l, b, r)[0] for targ in targets])
targets = self.prepare_target(targets)
return images, targets
else:
return images
[docs]
def prebuild_output(self, raw_outputs: Tuple[Tensor, ...]) -> Tuple[Tensor, ...]:
"""Unpack Yolo-seg (eval mode) raw results.
Args:
raw_output (``Tuple[Tensor, ...]``): Yolo raw eval mode results.
Returns:
``Tuple[Tensor, ...]``:
- boxes (N_batch, N_obj, cxcywh).
- cls_scores (N_batch, N_cls).
- mask_weights (N_batch, N_obj, 32).
- protos (N_batch, protos).
"""
output0, output1 = raw_outputs
output0 = output0.permute(0, 2, 1) # permute in N_batch, N_obj, obj_length
boxes = output0[:, :, 0:4]
cls_indx = 4 + Configuration().num_classes
cls_scores = output0[:, :, 4:cls_indx]
mask_weights = output0[:, :, -32:]
protos = output1[2]
return boxes, cls_scores, mask_weights, protos
[docs]
def build_results(
self, raw_outputs: Tuple[Tensor, ...], get_logit: bool = False
) -> BatchedFormat:
"""Transform model outputs into Batch InstanceMaskFormat for results.
Args:
raw_outputs (``List[Tensor]``): Model outputs.
Returns:
``BatchedFormats``:
- Batched predictions.
"""
# extract info from raw results
boxes, cls_scores, mask_weights, protos = self.prebuild_output(raw_outputs)
spatial_size = self.retrieve_spatial_size(raw_outputs)
results = []
for i, image_boxes in enumerate(boxes):
# get image values
image_boxes = boxes[i]
image_cls_scores = cls_scores[i]
image_mask_weights = mask_weights[i]
image_protos = protos[i]
# get best class and corresponding score
image_cls_scores, best_class = torch.max(image_cls_scores, dim=1)
# filter by confidence thr
conf_filter = ut.confidence_filter(
image_cls_scores, Configuration().model_confidence_threshold
)
image_cls_scores = image_cls_scores[conf_filter]
image_labels = best_class[conf_filter]
image_boxes = image_boxes[conf_filter]
image_mask_weights = image_mask_weights[conf_filter]
# if no objects with good confidence move to next image
if image_labels.nelement() == 0:
results.append(InstanceMaskFormat.empty(spatial_size))
continue
# Apply nms from boxes on other objects
boxes_ = BboxData(image_boxes, "CXCYWH", spatial_size)
nms_indexes = ut.box_nms_filter(boxes_, image_cls_scores)
# apply nms to all values
image_boxes = image_boxes[nms_indexes]
image_cls_scores = image_cls_scores[nms_indexes]
image_mask_weights = image_mask_weights[nms_indexes]
image_labels = image_labels[nms_indexes]
# Keep only model_max_detection elements
model_max_detection = Configuration().model_max_detection
if image_boxes.nelement() > model_max_detection:
indexes = torch.argsort(image_cls_scores)
image_boxes = image_boxes[indexes][-model_max_detection:]
image_cls_scores = image_cls_scores[indexes][-model_max_detection:]
image_mask_weights = image_mask_weights[indexes][-model_max_detection:]
image_labels = image_labels[indexes][-model_max_detection:]
# compute binary masks per remaining obj
_boxes = BboxData(image_boxes, "CXCYWH", spatial_size) # change box format
_boxes.format = "XYXY"
image_masks = ut.proto2mask(
image_protos, image_mask_weights, _boxes.value, spatial_size
)
# apply "logits" thresholding to mask (logit > 0.5 belong to object)
image_masks = image_masks.gt_(self.mask_logit_threshold)
logit_filter = torch.tensor([torch.max(m) != 0 for m in image_masks]).to(
image_masks.device
) # Filter non kept masks
image_cls_scores = image_cls_scores[logit_filter]
image_labels = image_labels[logit_filter]
image_masks = image_masks[logit_filter]
image_instance_mask = InstanceMaskData.from_binary_masks(image_masks)
if image_instance_mask.nb_object != image_labels.shape[0]:
pass
mask_format = InstanceMaskFormat(
image_instance_mask, image_labels, image_cls_scores
)
mask_format, _ = (
mask_format.sanitize()
) # Sanitize will reindex objects thus removing empty masks
results.append(mask_format)
if len(results) == 0:
results = [InstanceMaskFormat.empty(spatial_size)]
return BatchedFormat(results)
[docs]
def compute_loss(self, predictions: Tuple, target: Dict) -> Dict[str, Tensor]:
"""Compute loss with predictions & targets.
Args:
predictions (``Any``): Raw output of model.
target (``Dict[Any, Any]``): Targets in YOLO format.
Returns:
``Dict[str, Tensor]``:
- Loss dict with total loss (key: "loss") & sublosses.
"""
loss, loss_detail = self.criterion(predictions, target)
loss_dict = {
"loss": loss,
"loss_box": loss_detail[0],
"loss_seg": loss_detail[1],
"loss_cls": loss_detail[2],
"loss_dfl": loss_detail[3],
}
# yolo scale loss with batch size -> normalize it here and apply loss factor to keep it in the unit range
# (for mixed precision optim it's important)
batch_factor = target["batch_idx"].unique().shape[0]
loss, loss_detail = self.criterion(predictions, target)
loss /= self.loss_factor * batch_factor
loss_detail /= self.loss_factor * batch_factor
loss_dict = {
"loss": loss,
"loss_box": loss_detail[0],
"loss_seg": loss_detail[1],
"loss_cls": loss_detail[1],
"loss_dfl": loss_detail[2],
}
return loss_dict
[docs]
def retrieve_spatial_size(self, raw_outputs: List[Tensor]) -> Tuple[int, int]:
"""Retrieve image shape from raw_outputs and stride values.
Args:
raw_outputs (``List[Tensor]``): Raw ouptuts from YOLO model.
Returns:
``Tuple[int]``:
- Size of input image (H, W).
"""
if self.training:
h = int(raw_outputs[0][0].shape[-2] * self.stride[0])
w = int(raw_outputs[0][0].shape[-1] * self.stride[0])
else:
h = int(raw_outputs[1][0][0].shape[-2] * self.stride[0])
w = int(raw_outputs[1][0][0].shape[-1] * self.stride[0])
return (h, w)
[docs]
def run_forward(
self,
images: Tensor,
targets: BatchedFormat,
) -> Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], BatchedFormat]]:
"""Compute loss from images and if target passed, compute loss & return both loss dict
and results.
Args:
images (``Tensor``): Batch RGB images.
targets (``BatchedFormat``): Batch targets.
Returns:
``Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], BatchedFormat]]``:
- Loss dict.
- If predict: predictions.
"""
# prepare inputs
prepared_images, prepared_targets = self.prepare(images, targets=targets)
# run forward pass
raw_outputs = self(prepared_images)
# compute loss
loss_dict = self.compute_loss(raw_outputs, prepared_targets)
# return predictions if needed
if not (self.training):
predictions = self.build_results(raw_outputs)
# retrieve the padding from original img
t, l, _, _ = ut.yolo_pad_requirements(
images.shape[-2], images.shape[-1], required=self.pad_requirements
)
h, w = images.shape[-2:]
# crop to original size
predictions = BatchedFormat(
[targ.crop(t, l, h, w)[0] for targ in predictions]
)
return loss_dict, predictions
else:
return loss_dict
[docs]
def get_predictions(self, images: Tensor) -> BatchedFormat:
"""Prepare images, Apply YOLO forward pass and build results.
Args:
images (``Tensor``): RGB images Tensor.
Returns:
``BatchedFormats``:
- Predictions for images as BatchedFormats.
"""
self.eval()
# get original spatial size
ori_h, ori_w = images.shape[-2:]
# pad coord to return back to non yolo required / 32 criterium afterward
top, left, _, _ = ut.yolo_pad_requirements(
ori_h, ori_w, required=self.pad_requirements
)
# pad images
images = self.prepare(images)
# predict
raw_outputs = self(images)
results = self.build_results(raw_outputs)
# crop to back at original spatial size
results = BatchedFormat(
[pred.crop(top, left, ori_h, ori_w)[0] for pred in results]
)
return results