import torch
import torch.nn as nn
import timm
from ultralytics.nn.modules.head import Detect, Segment
from deepvisiontools.models.basemodel import BaseModel
import deepvisiontools.models.timmyolo.utils as ut
from ultralytics.utils.loss import (
v8DetectionLoss,
TaskAlignedAssigner,
BboxLoss,
)
from deepvisiontools.config import Configuration
from deepvisiontools.formats import (
BatchedFormat,
BboxFormat,
InstanceMaskFormat,
BboxData,
)
from torchvision.transforms.v2 import Pad
from math import ceil
from torch import Tensor
import deepvisiontools.models.timmyolo.errors as er
from dataclasses import dataclass
from typing import Tuple, List, Dict
@dataclass
class LossHyperParams:
"""Used in YoloDetectHead"""
# Yolo default vals
box: float = 7.5
dfl: float = 1.5
cls: float = 0.5
class YoloLoss(v8DetectionLoss):
def __init__(self, model, tal_topk=10):
"""Adapting Yolo Loss to TimmYolo."""
device = Configuration().device # get model device
h = model.args # hyperparameters
m = model.detect_head # Detect() module
self.bce = nn.BCEWithLogitsLoss(reduction="none")
self.hyp = h
self.stride = m.stride # model strides
self.nc = m.nc # number of classes
self.no = m.nc + m.reg_max * 4
self.reg_max = m.reg_max
self.device = device
self.use_dfl = m.reg_max > 1
self.assigner = TaskAlignedAssigner(
topk=tal_topk, num_classes=self.nc, alpha=0.5, beta=6.0
)
self.bbox_loss = BboxLoss(m.reg_max).to(device)
self.proj = torch.arange(m.reg_max, dtype=torch.float, device=device)
[docs]
class TimmYolo(BaseModel):
"""This class combines any timm library encoder compatible with features_only=True with a Yolo detection head.
This leverage complex encodeur, potentially with attention layers, while remaining flexible on the input image size.
The idea is to patchify all images that run through the model, perform feature prediction, combine the feature and run the fully convolutional yolo detection head.
**Note: ** This model does not have a forward method ! use run() or get_predictions instead.
Args:
backbone_name (str, optional): timm backbone. Defaults to "swin_small_patch4_window7_224". Has been tested with "vit_large_patch14_dinov2" and "resnet50.a1_in1k" as well
num_classes (int, optional): Defaults to 1.
pretrained (bool, optional): Defaults to True.
overlap (float | int | Tuple[int, int] | None, optional): If different of None use the pixel given value for overlap (careful it must be compatible with the reduction level).
If none it uses the maximum reduction x 2. Defaults to None.
internal_batch_size (int, optional): Number of patch to run simultaneously. Defaults to 1.
"""
def __init__(
self,
backbone_name="swin_small_patch4_window7_224",
num_classes=1,
pretrained: bool = True,
overlap: float | int | Tuple[int, int] | None = None,
internal_batch_size: int = 1,
loss_factor=100,
):
super().__init__()
# Load Transformer backbone (from TIMM)
self.backbone = timm.create_model(
backbone_name, pretrained=pretrained, features_only=True
)
# Get backbone input size
self.patch_size = self.backbone.pretrained_cfg["input_size"][-2:]
# Get feature map channels from different stages
feature_channels = self.backbone.feature_info.channels()
self.feature_channels = feature_channels
# reduction factor of features
self.reds = [it["reduction"] for it in self.backbone.feature_info.info]
# There is a weird issue with transformers in timm : the reduction info has not the same size as the feature info... Need to readapt
if len(self.reds) > len(self.feature_channels):
self.reds = self.reds[0 : len(self.feature_channels)]
if overlap == None:
overlap = max(self.reds) * 2
# YOLO-like detection head (fully convolutional)
detect_head = Detect(nc=num_classes, ch=feature_channels)
detect_head.dynamic = True
self.detect_head = detect_head
self.detect_head.stride = torch.tensor(self.reds).to(torch.float)
self.overlap = overlap
self.patcher = ut.Patcher(self.patch_size, self.overlap)
self.args = LossHyperParams()
self.loss = YoloLoss(self)
self.internal_batch_size = internal_batch_size
self.loss_factor = loss_factor
[docs]
def prepare_target(
self, targets: BatchedFormat, img_size: Tuple[int, int]
) -> Dict[str, Tensor]:
"""Return target from BatchedFormat to ultralytics yolo format.
Args:
targets (BatchedFormat)
img_size (Tuple[int, int])
Returns:
Dict[str, Tensor]: target as per ultralytics Yolo format.
"""
# Convert to BboxFormat if there are InstanceMasks
if any([isinstance(targ, InstanceMaskFormat) for targ in targets.formats]):
forms = [
(
BboxFormat.from_instance_mask(targ)
if isinstance(targ, InstanceMaskFormat)
else targ
)
for targ in targets
]
targets = BatchedFormat(forms)
targets.set_bboxes_format("CXCYWH")
boxes = torch.cat([targ.data.value for targ in targets])
boxes = ut.normalize_boxes(boxes, img_size)
labels = torch.cat([targ.labels for targ in targets])[..., None]
batch_idx = torch.cat(
[torch.ones(targ.nb_object) * i for i, targ in enumerate(targets)]
)[..., None]
batch_idx = batch_idx.to(Configuration().device)
return {"batch_idx": batch_idx, "cls": labels, "bboxes": boxes}
[docs]
def prepare(
self, images, targets=None
) -> Tuple[Tensor, BatchedFormat | None, torch.Size]:
"""Pad images and target so final patch match exactly image border.
Args:
images (``_type_``): image to be prepared
targets (``_type_``, **optional**): tragets to be prepared. Defaults to None.
Returns:
``Tuple[Tensor, BatchedFormat | None, torch.Size]``:
- prepared images, prepared targets, original image size
"""
l, t, r, b = self.patcher.gen_pad_requirements(images.shape[-2:])
padder = Pad((l, t, r, b))
prepared_images = padder(images)
if targets != None:
prepared_targets = BatchedFormat([ta.pad(l, t, r, b)[0] for ta in targets])
prepared_targets = self.prepare_target(
prepared_targets, prepared_images.shape[-2:]
)
return prepared_images, prepared_targets, images.shape[-2:]
else:
return prepared_images, None, images.shape[-2:]
[docs]
def build_results(
self,
raw_outputs: List[Tensor],
prebuild_outputs: Tensor,
original_img_size: torch.Size,
) -> BatchedFormat:
"""Transform model outputs into Batch BboxFormat for results.
Args:
raw_outputs (``List[Tensor]``): Model outputs.
prebuild_outputs (``Tensor``): Extracted boxes from outputs in eval mode.
Returns:
``BatchedFormats``:
- Batched predictions.
"""
prebuild_outputs = prebuild_outputs.unbind()
h, w = self.retrieve_spatial_size(raw_outputs)
results = []
# for each prediction
for prediction in prebuild_outputs:
# send pred in good pshape
prediction = prediction.permute(1, 0)
# get best class and corresponding score
best_class = torch.argmax(prediction[:, 4:], dim=1)
confidence, _ = torch.max(prediction[:, 4:], dim=1)
# gather box cxcywh coordinates
boxes = BboxData(prediction[:, :4], "CXCYWH", (h, w))
# build result
result = BboxFormat(boxes, best_class, scores=confidence)
# objects selections
result = ut.confidence_filter(result)
result = ut.box_nms_filter(result)
result, _ = result[: Configuration().model_max_detection]
# stack batch results
results.append(result)
if len(results) == 0:
results = []
# crop predictions to return original image size preds
results = BatchedFormat(results)
cv_size = results.formats[0].canvas_size
t = int((float(cv_size[0]) - original_img_size[0]) / 2.0)
l = int((float(cv_size[1]) - original_img_size[1]) / 2.0)
h, w = original_img_size
predictions = BatchedFormat([p.crop(t, l, h, w)[0] for p in results])
return predictions
def get_predictions(self, images: Tensor):
self.eval()
# prepare inputs
prepared_images, _, original_img_size = self.prepare(images, targets=None)
# for img in batch do patchification
patched_imgs: List[Tuple[Tensor, List[Tuple[int, int, int, int]]]] = []
for prep_img in prepared_images:
patches, absolute_positions = self.patcher(prep_img, original_img_size)
patches = torch.stack(patches)
patched_imgs.append((patches, absolute_positions))
# infer weight shape from imgs size
weights_shapes = [
tuple(ceil(float(s) / r) for s in prepared_images.shape[-2:])
for r in self.reds
]
prebuild_output, raw_outputs = self.run(patched_imgs, weights_shapes)
predictions = self.build_results(
raw_outputs, prebuild_output, original_img_size
)
return predictions
def run_forward(self, images, targets):
# prepare inputs
prepared_images, prepared_targets, original_img_size = self.prepare(
images, targets=targets
)
# for img in batch do patchification
patched_imgs: List[Tuple[Tensor, List[Tuple[int, int, int, int]]]] = []
for prep_img in prepared_images:
patches, absolute_positions = self.patcher(prep_img, original_img_size)
patches = torch.stack(patches)
patched_imgs.append((patches, absolute_positions))
# infer weight shape from imgs size
weights_shapes = [
tuple(ceil(float(s) / r) for s in prepared_images.shape[-2:])
for r in self.reds
]
# forward pass
if self.training:
raw_outputs = self.run(patched_imgs, weights_shapes)
else:
prebuild_output, raw_outputs = self.run(patched_imgs, weights_shapes)
loss_dict = self.compute_loss(raw_outputs, prepared_targets)
# return predictions if needed
if not (self.training):
predictions = self.build_results(
raw_outputs, prebuild_output, original_img_size
)
return loss_dict, predictions
else:
return loss_dict
def run(
self,
patched_imgs: List[Tuple[Tensor, List[Tuple[int, int, int, int]]]],
weights_shapes: List[Tuple],
):
full_feats = {i: [] for i in range(len(self.reds))}
for patches, abs_pos in patched_imgs:
weights: List[Tensor] = [torch.zeros(ws) for ws in weights_shapes]
feats = self.backbone(patches)
# permute to get feats : [N_batch, channels, h, w]
order = [
(torch.tensor(f.shape) == ch).nonzero().item()
for f, ch in zip(feats, self.feature_channels)
]
# handles feature channel dimension reordering (N_batch, N_feat, h, w) after backbone reduction
permutes = []
for o in order:
l = [1, 2, 3]
l.remove(o)
permutes.append([o] + l)
feats = [f.permute(0, *o) for f, o in zip(feats, permutes)]
# merging patched features
feats = self.merge_features(feats, abs_pos, weights)
for k in full_feats.keys():
full_feats[k].append(feats[k])
# stack patchs and stack imgs to recover batch
full_feats = [torch.stack(f) for f in full_feats.values()]
output = self.detect_head(full_feats)
return output
def prepare_weights(self, prepared_images: Tensor):
dims = [
tuple((int(i / r) for i in prepared_images.shape[-2:])) for r in self.reds
]
feat_weights = [torch.zeros(shape) for shape in dims]
return feat_weights
def merge_features(
self,
feats: List[Tensor],
absolute_positions: Tuple[int, int, int, int],
weights: List[Tensor],
):
# align weights and features according to shape
sorted(weights, key=lambda x: x.shape[-1])
sorted(feats, key=lambda x: x.shape[-1])
# generate 0 filled features from channels size and weight size (which have the final shape of the feature maps)
merged_feats = [
torch.zeros((feats[i].shape[1], *weights[i].shape[-2:])).to(feats[0].device)
for i in range(len(feats))
]
# Loop over feature maps (typically (Ch1, 64, 64), (Ch2, 32, 32), (Ch3, 16, 16), (Ch4, 8, 8)) but depend on encoder
for i, (mer_feat, feat, weight, r) in enumerate(
zip(merged_feats, feats, weights, self.reds)
):
# loop over patches
for f, abs_pos in zip(feat, absolute_positions):
y1 = int(abs_pos[0] / r)
x1 = int(abs_pos[1] / r)
y2 = int((abs_pos[0] + abs_pos[2]) / r)
x2 = int((abs_pos[1] + abs_pos[3]) / r)
weight[y1:y2, x1:x2] += 1
mer_feat[:, y1:y2, x1:x2] += f
mer_feat /= weight.to(mer_feat.device)
return merged_feats
def forward(self, x):
raise er.NonImplementedForward()
[docs]
def retrieve_spatial_size(self, raw_outputs: List[Tensor]) -> Tuple[int]:
"""Retrieve image shape from raw_outputs and stride values.
Args:
raw_outputs (``List[Tensor]``): Raw ouptuts from YOLO model.
Returns:
``Tuple[int]``:
- Size of input image (H, W).
"""
h = int(raw_outputs[0].shape[-2] * self.detect_head.stride[0])
w = int(raw_outputs[0].shape[-1] * self.detect_head.stride[0])
return (h, w)
[docs]
def compute_loss(
self, raw_outputs: Tensor, targets: Dict[str, Tensor]
) -> Dict[str, Tensor]:
"""Compute loss with predictions & targets.
Args:
raw_outputs (``Any``): Raw output of model.
targets (``DetectionFormat``): Targets in YOLO format.
Returns:
``Dict[str, Tensor]``:
- Loss dict with total loss (key: "loss") & sublosses.
"""
# yolo scale loss with batch size -> normalize it here and apply loss factor to keep it in the unit range
# (for mixed precision optim it's important)
batch_factor = targets["batch_idx"].unique().shape[0]
loss, loss_detail = self.loss(raw_outputs, targets)
loss /= self.loss_factor * batch_factor
loss_detail /= self.loss_factor * batch_factor
loss_dict = {
"loss": loss,
"loss_box": loss_detail[0],
"loss_cls": loss_detail[1],
"loss_dfl": loss_detail[2],
}
return loss_dict