from torchvision.ops import box_iou
from deepvisiontools.formats import BaseFormat, BboxFormat, InstanceMaskFormat
import torch
from deepvisiontools import Configuration
from typing import Literal, Tuple
from torch import Tensor
import deepvisiontools.metrics.utils as ut
from torch.nn.functional import one_hot
[docs]
class Matcher:
"""Class that handles the matching of prediction and targets to get tp, fp, fn"""
def __init__(self):
self.iou_th = Configuration().metrics_match_iou_threshold
self._mode: Literal["bbox", "instance_mask"] = (
Configuration().metrics_matcher_type
)
[docs]
def match_pred_target(
self, pred: BaseFormat, targ: BaseFormat
) -> Tuple[int, int, int, Tuple[Tensor, Tensor]]:
"""Matches predictions and targets
Args:
pred (Format)
targ (Format)
Returns:
Tuple[int, int, int, Tuple[Tensor, Tensor]]: tp, fp, fn, (matched_predictions indices, matched_targets_indices)
"""
# compute cross ious
if self._mode == "bbox":
cross_ious = self.match_boxes(pred, targ)
elif self._mode == "instance_mask":
cross_ious = self.match_instance_masks(pred, targ)
#### derive tp, fp, fn and matching indexes
matched_candidates = (
torch.max(cross_ious, dim=1)[0][..., None]
== torch.max(cross_ious, dim=0)[0][None, ...]
).view(cross_ious.shape)
# true positive if iou of max_matchs > iou threshold
tp = torch.sum((matched_candidates > 0) & (cross_ious > self.iou_th))
# false positive: all boxes with no match with targets
fp = torch.sum(pred.nb_object - torch.sum(tp))
# false negative if target has no pred box with iou > threshold
fn = torch.sum(torch.max(cross_ious, dim=0)[0] < 0.5)
pred_idxs, target_idxs = torch.where(
torch.logical_and((matched_candidates > 0), (cross_ious > self.iou_th))
)
# extract indexes
pred_idxs = pred_idxs.tolist() if pred_idxs.nelement() > 0 else []
target_idxs = target_idxs.tolist() if target_idxs.nelement() > 0 else []
match_idxs = (torch.tensor(pred_idxs).long(), torch.tensor(target_idxs).long())
# send back box format to original format
return tp, fp, fn, match_idxs
[docs]
def match_boxes(
self, pred: BaseFormat, targ: BaseFormat
) -> Tuple[int, int, int, Tuple[Tensor, Tensor]]:
"""compute box cross ious for matching"""
assert isinstance(pred, BboxFormat) or isinstance(
pred, InstanceMaskFormat
), "Prediction must be BboxFormat or InstanceMaskFormat to use match_boxes"
assert isinstance(targ, BboxFormat) or isinstance(
targ, InstanceMaskFormat
), "target must be BboxFormat or InstanceMaskFormat to use match_boxes"
# Convert to BboxFormat if needed
if isinstance(pred, InstanceMaskFormat):
pred: BboxFormat = BboxFormat.from_instance_mask(pred)
if isinstance(targ, InstanceMaskFormat):
targ: BboxFormat = BboxFormat.from_instance_mask(targ)
pred.data.format = "XYXY"
targ.data.format = "XYXY"
cross_ious = box_iou(pred.data.value, targ.data.value)
return cross_ious
[docs]
def match_instance_masks(self, pred: InstanceMaskFormat, targ: InstanceMaskFormat):
"""compute instance_mask cross ious for matching"""
assert isinstance(
pred, InstanceMaskFormat
), "Prediction must be InstanceMaskFormat to use match_instance_masks"
assert isinstance(
targ, InstanceMaskFormat
), "target must be InstanceMaskFormat to use match_instance_masks"
# convert pred / target to one hots to use mask_iou func. Remove class 0 encoded and move to cpu for memory usage.
original_pred_device = pred.device
original_targ_device = targ.device
pred.device = "cpu"
targ.device = "cpu"
one_hot_preds = one_hot(pred.data.value).permute(2, 0, 1)[1:]
one_hot_targs = one_hot(targ.data.value).permute(2, 0, 1)[1:]
# apply cross iou and return object to devices
cross_ious = ut.mask_iou(one_hot_preds, one_hot_targs)
pred.device = original_pred_device
targ.device = original_targ_device
cross_ious = cross_ious.to(original_pred_device)
return cross_ious