Source code for deepvisiontools.models.mask2former.mask2former
from __future__ import annotations
from typing import List, Tuple, Literal, Dict, Union, Any
from deepvisiontools.models.basemodel import BaseModel
from deepvisiontools import Configuration
from deepvisiontools.formats import InstanceMaskFormat, BatchedFormat, InstanceMaskData
from deepvisiontools import Configuration
from transformers import (
Mask2FormerConfig,
Mask2FormerForUniversalSegmentation,
Mask2FormerImageProcessor,
)
from operator import itemgetter
from transformers.models.mask2former.modeling_mask2former import (
Mask2FormerForUniversalSegmentationOutput,
)
import torch
from torch import Tensor
# TODO solve résolution issues (create modulable feature map, check if possible to adapt to large images in inference mode. Have a look at prediction head ;) ). Use property
[docs]
class Mask2Former(Mask2FormerForUniversalSegmentation, BaseModel):
"""Mask2Former class, child class of Mask2FormerForUniversalSegmentation from hugging face. To use, data_type must be set to instance_mask.
Args:
pretrain (Literal["large", "medium", "small", "tiny", ""], optional): Pretrained architecture. Defaults to "tiny".
overlap_mask_thr (float, optional): Defaults to 0.8.
Attributes
----------
Attributes:
processor (``Mask2FormerImageProcessor``)
overlap_mask_thr (``float``)
Attributes
----------
Properties:
queries (``torch.nn.Embedding``) : number of queries / dim for embedding. To use setter please provide int or Tuple[int, int]. In case only a int is provided dimensional embedding is 256, otherwise Tuple is query number, dim.
Notes: When used for large image inference, Mask2Former is less performant if trained on smaller patches. One way out is to increase the query number. Please check property description. Future amelioration on this matter is under developpment.
**Methods**
"""
size_configs = {
"large": "facebook/mask2former-swin-large-coco-instance",
"medium": "facebook/mask2former-swin-base-coco-instance",
"small": "facebook/mask2former-swin-small-coco-instance",
"tiny": "facebook/mask2former-swin-tiny-coco-instance",
}
def __init__(
self,
pretrain: Literal["large", "medium", "small", "tiny", ""] = "tiny",
overlap_mask_thr: float = 0.8,
):
# assert Task mode is "instance_mask"
assert (
Configuration().data_type == "instance_mask"
), f"Configuration().data_type should be 'instance_mask' to construct Mask2Former object, got {Configuration().data_type}"
if pretrain:
pretrain_config = Mask2Former.size_configs[pretrain]
pretrain_model = Mask2FormerForUniversalSegmentation.from_pretrained(
pretrain_config,
num_labels=Configuration().num_classes
+ 1, # mask2former redefine labels including background.
ignore_mismatched_sizes=True,
)
self.__dict__ = pretrain_model.__dict__
else:
super().__init__(Mask2FormerConfig(Configuration().num_classes))
# define mask2former input processor
self.processor = Mask2FormerImageProcessor(
do_resize=False, do_normalize=False, do_rescale=False, ignore_index=255
)
self.overlap_mask_thr = overlap_mask_thr
# original queries in model (can be changed with queries property)
self._queries = self.model.transformer_module.queries_embedder
@property
def queries(self):
return self._queries
@queries.setter
def queries(self, embedding: Union[int, Tuple]):
if isinstance(embedding, int):
nb, dim = embedding, 256
else:
nb, dim = embedding
self.model.transformer_module.queries_embedder = torch.nn.Embedding(nb, dim)
self.model.transformer_module.queries_features = torch.nn.Embedding(nb, dim)
self._queries = torch.nn.Embedding(nb, dim)
[docs]
def prepare_target(
self, target: InstanceMaskFormat
) -> Tuple[Tensor, Dict[int, int]]:
"""Prepare target in Mask2Former format"""
labels = target.labels
labels = torch.cat([torch.tensor([0]).to(self.device), labels + 1])
instance_labels_dict = dict(
zip(range(0, target.nb_object + 1), labels.tolist())
)
masks = target.data.value
return masks, instance_labels_dict
[docs]
def prepare(
self, images: Tensor, targets: Union[BatchedFormat, None] = None
) -> Dict[str, Union[Tensor, Dict[Any, Any]]]:
"""Transform images and targets into Mask2Former specific format for prediction & loss computation.
Args:
images (``Tensor``): Batch images.
targets (``BatchedFormats``, **optional**): Batched targets from DetectionDataset.
Returns:
``Union[Any, Tuple[Any]]``:
- Images data prepared for Mask2Former.
- If targets: images + targets prepared for Mask2Former.
"""
if targets != None:
instance_labels = []
segmentation_maps = []
for target in targets:
assert isinstance(
target, InstanceMaskFormat
), "Target should be instance mask format for Mask2former"
target_masks, target_dict = self.prepare_target(target)
instance_labels.append(target_dict)
segmentation_maps.append(target_masks)
model_input = self.processor(
images=list(images.unbind()),
segmentation_maps=segmentation_maps,
instance_id_to_semantic_id=instance_labels,
return_tensors="pt",
)
else:
model_input = self.processor(
images=list(images.unbind()),
return_tensors="pt",
)
return model_input
# override
[docs]
def build_results(
self,
raw_outputs: Mask2FormerForUniversalSegmentationOutput,
spatial_size: Tuple[int, int],
) -> BatchedFormat:
"""Transform model outputs into BatchedFormat for results.
Args:
raw_outputs (``Mask2FormerForUniversalSegmentationOutput``): Mask2Former output.
spatial_size (``Tuple[int, int]``): Size of original image (H, W).
Returns:
``BatchedFormats``:
- Model output as BatchedFormat.
"""
# Process raw output wtih Mask2Former processor.
batch_size = raw_outputs.masks_queries_logits.shape[0]
predictions = self.processor.post_process_instance_segmentation(
raw_outputs,
overlap_mask_area_threshold=self.overlap_mask_thr,
threshold=Configuration().model_confidence_threshold,
target_sizes=[spatial_size] * batch_size,
)
results = []
# iter on predictions
for prediction in predictions:
spatial_size = prediction["segmentation"].shape[-2:]
# remove empty segmentation objects (objects with no mask pixels)
mask: Tensor = prediction[
"segmentation"
].long() # Here -1 = non segmented, then 0 - N objects includes background
segments = prediction["segments_info"]
# retrieve labels, scores
labels = torch.tensor([i["label_id"] for i in segments])
scores = torch.tensor([i["score"] for i in segments])
# remove non existing objects
empty_objs_filt = torch.tensor(
[torch.count_nonzero(mask == l) != 0 for l in range(labels.shape[0])]
)
labels = labels[empty_objs_filt]
scores = scores[empty_objs_filt]
# move mask non predicted (label = -1) and background (label=0) together
if mask.unique()[0] == -1:
if torch.any(labels == 0).item():
mask[mask == mask.unique()[1:][labels == 0].item()] = -1
else:
mask[mask == mask.unique()[labels == 0].item()] = -1
# Reindex labels and scores to remove background ( -1 in labels and in mask)
labels -= 1
filtering = labels != -1
labels = labels[filtering]
scores = scores[filtering]
labels = labels.to(self.device)
scores = scores.to(self.device)
# create InstanceMaskData instance and handle empty detection
mask += 1
mask, _ = InstanceMaskData(mask)._reindex()
if mask.nb_object != 0:
result = InstanceMaskFormat(mask, labels=labels, scores=scores)
result, _ = result.sanitize()
results.append(result)
else:
results.append(InstanceMaskFormat.empty(spatial_size))
if len(results) == 0:
results.append(InstanceMaskFormat.empty(spatial_size))
results = BatchedFormat(results)
return results
[docs]
def inputs_to_device(self, input: Any, device: Literal["cpu", "cuda"]):
"""Send Mask2Former inputs to device."""
for k, v in input.items():
if isinstance(v, list):
input[k] = [t.to(device) for t in v]
elif isinstance(v, Tensor):
input[k] = v.to(device)
return input
[docs]
def run_forward(
self, images: Tensor, targets: BatchedFormat
) -> Dict[str, Tensor] | Tuple[Dict[str, Tensor] | BatchedFormat]:
"""Compute loss from images and if target passed, compute loss & return both loss dict
and results.
Args:
images (``Tensor``): Batch RGB images.
targets (``BatchedFormat``): Batch targets.
Returns:
``Union[Dict[str, Tensor], Tuple[Dict[str, Tensor], BatchedFormat]]``:
- Loss dict and prediction if model in eval mode.
"""
# prepare inputs
spatial_size = images.shape[-2:]
model_input = self.prepare(images, targets=targets)
model_input = self.inputs_to_device(model_input, self.device)
# run forward pass
output: Mask2FormerForUniversalSegmentationOutput = self(
pixel_values=model_input["pixel_values"],
mask_labels=model_input["mask_labels"],
class_labels=model_input["class_labels"],
)
# compute loss
loss_dict = {"loss": output.loss}
# return predictions if needed
if not self.training:
predictions = self.build_results(output, spatial_size)
return loss_dict, predictions
else:
return loss_dict
[docs]
def get_predictions(self, images: Tensor) -> BatchedFormat:
"""Prepare images, Apply model forward pass and build results.
Args:
images (``Tensor``): RGB images Tensor.
Returns:
``BatchedFormat``:
- Predictions for images as BatchedFormat.
"""
self.eval()
spatial_size = images.shape[-2:]
model_input = self.prepare(images)
model_input = self.inputs_to_device(model_input, self.device)
# predict
output: Mask2FormerForUniversalSegmentationOutput = self(
pixel_values=model_input["pixel_values"]
)
results = self.build_results(output, spatial_size=spatial_size)
return results