Source code for deepvisiontools.metrics.base_metric

from typing import Any, Callable, Dict, List, Tuple, Union
from deepvisiontools.formats import (
    BatchedFormat,
    BaseFormat,
    InstanceMaskFormat,
    SemanticMaskFormat,
)
import torch
from torch import Tensor
from torchmetrics import Metric
from deepvisiontools.metrics.matcher import Matcher
from deepvisiontools import Configuration
from torchmetrics.classification import StatScores


[docs] class DetectMetric(Metric): """Base class for custom detection metric with torchmetrics engine Args: func (Callable): function to apply to tp, fp, tn, fn name (str, optional): metric's name (useful for tensorboard monitoring). Defaults to "DetectionMetric". """ is_differentiable = None higher_is_better = True full_state_update: bool = False def __init__( self, func: Callable, name: str = "DetectionMetric", **kwargs, ): assert ( Configuration().data_type != "semantic_mask" ), "Can't use detection metrics with Configuration().data_type = semantic_mask. Must be instance_mask or bbox" super().__init__(**kwargs) self.func = func self.name = name self.add_state("stats", default=[], dist_reduce_fx="cat") self.nc = 1 # always one class
[docs] def update( self, prediction: Union[BaseFormat, BatchedFormat], target: Union[BaseFormat, BatchedFormat], ): """Update metric's internal state with prediction target comparison (tp, fp, tn, fn) Args: prediction (Union[BaseFormat, BatchedFormat]) target (Union[BaseFormat, BatchedFormat]) """ # if no objects in prediction or target: stats are neutral (0) # else compute match box and stats # Cast pred and target to batched format if not already batched format prediction: BatchedFormat = ( prediction if isinstance(prediction, BatchedFormat) else BatchedFormat([prediction]) ) target: BatchedFormat = ( target if isinstance(target, BatchedFormat) else BatchedFormat([target]) ) # update metric matcher = Matcher() for pred, targ in zip(prediction, target): if pred.nb_object != 0 and targ.nb_object != 0: tp, fp, fn, _ = matcher.match_pred_target(pred, targ) elif pred.nb_object != 0 and targ.nb_object == 0: tp, fp, fn, _ = 0, pred.nb_object, 0, None # elif prediction.size == 0 and target.size != 0: elif pred.nb_object == 0 and targ.nb_object != 0: tp, fp, fn, _ = 0, 0, targ.nb_object, None else: tp, fp, fn, _ = torch.nan, torch.nan, torch.nan, None # NaN because of no tn in detection self.stats.append( torch.tensor([[[tp, fp, torch.nan, fn, torch.nan]]]) ) # fix tn and support to nan
[docs] def global_micro_compute(self) -> Tensor: """Compute metric with global/micro averagging.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5) # sum stats accross samples samples_stack = samples_stack.nansum(dim=0) # (NC, 5) # sum stats accross classes micro_stack = torch.nansum(samples_stack, dim=0) # (5,) # compute metric tp, fp, tn, fn, _ = micro_stack.unbind(0) return self.func(tp, fp, tn, fn)
[docs] def samplewise_micro(self) -> Tensor: """Compute metric with samplewise/micro averagging.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5)) # sum stat accross classes samples_stack = samples_stack.nansum(dim=1) # (NC, 5) # compute metric/sample tp, fp, tn, fn, _ = samples_stack.unbind(dim=1) samples_metrics = self.func(tp, fp, tn, fn) # mean accross samples return torch.nanmean(samples_metrics)
[docs] def compute_last_sample(self): """Return metrics values of the last sample in self.stats. Used in combination with self.update Returns: Dict[str, Float]: dictionnary with metric value for all classes combined and for each class """ if not self.stats: return {"all_cls": torch.tensor(torch.nan)} metric_sample_dict = {} matrix_stats_img = self.stats[-1].squeeze( 0 ) # matrix with stats (tp, fp, tn, fn) of one sample (one row for each class) nb_class = matrix_stats_img.size(dim=0) tp, fp, tn, fn, _ = matrix_stats_img.nansum(dim=0).unbind( dim=0 ) # sum over all classes all_class_metric_value = round( self.func(tp, fp, tn, fn).item(), 3 ) # metric compute for all classes metric_sample_dict["all_cls"] = all_class_metric_value if nb_class > 1: for index_class in range(nb_class): # metric compute for each class matrix_stats_img_class = matrix_stats_img[index_class, :] tp, fp, tn, fn, _ = matrix_stats_img_class.unbind( dim=0 ) # stats (tp, fp, tn, fn) of the given class class_metric_value = round( self.func(tp, fp, tn, fn).item(), 3 ) # metric compute for the given class metric_sample_dict["cls_{0}".format(index_class)] = class_metric_value return metric_sample_dict
[docs] def compute(self) -> Dict[str, Tensor]: """Return metric computed with internal state. Returns: Dict[str, Tensor]: dictionnary with aggregation_method: value """ if not self.stats: return {self.name: torch.tensor(torch.nan)} # global micro global_value = self.global_micro_compute() # samplewise micro samplewise_value = self.samplewise_micro() metric_dict = {"global": global_value, "samplewise": samplewise_value} return metric_dict
[docs] class ClassWiseDetectMetric(Metric): """Base class that agregates n_classes DetectMetric(s) to obtain class dependant performances. Note that samplewise scores are not performed here. Args: func (Callable): function to apply to tp, fp, tn, fn name (str, optional): metric's name (useful for tensorboard monitoring). Defaults to "ClassWiseDetectionMetric". Attributes ---------- Attributes: - classmetrics (``List[DetectMetric]``): list of detectmetrics specialized in each classes. """ def __init__( self, func: Callable, name: str = "ClassWiseDetectionMetric", **kwargs, ): assert ( Configuration().data_type != "semantic_mask" ), "Can't use detection metrics with Configuration().data_type = semantic_mask. Must be instance_mask or bbox" num_classes = Configuration().num_classes self.nc = num_classes assert ( num_classes >= 2 ), f"{num_classes} is an invalid number of classes for ClassWiseMetric (must be >= 2). If you have one class please use DetectMetric instead of ClasswiseMetric" super().__init__(**kwargs) self.name = name # create instances of DetectMetrics : 1 for global and 1 for each class self.classmetrics = [DetectMetric(func, name=f"{name}/global").to(self.device)] self.classmetrics += [ DetectMetric(func, name=f"{name}/cls_{i}").to(self.device) for i in range(num_classes) ] self.add_state("stats", default=[], dist_reduce_fx="cat") self.func = func
[docs] def update( self, prediction: Union[BaseFormat, BatchedFormat], target: Union[BaseFormat, BatchedFormat], ): """Update all DetectMetrics in self.classmetrics according to prediction / target. Args: prediction (Union[BaseFormat, BatchedFormat]) target (Union[BaseFormat, BatchedFormat]) """ prediction = ( prediction if isinstance(prediction, BatchedFormat) else BatchedFormat([prediction]) ) target = ( target if isinstance(target, BatchedFormat) else BatchedFormat([target]) ) list_met_class = [] for i, met in enumerate(self.classmetrics): met_samples_result = [] if i == 0: met.update(prediction, target) else: filtered_pred = self._filter_classes( prediction, i - 1 ) # -1 because classes start at 0 filtered_targ = self._filter_classes(target, i - 1) met.update(filtered_pred, filtered_targ) # retrieve and save in a list the stats for each sample in the batch for the given class batch_size = prediction.size for num_elem in range(batch_size, 0, -1): met_samples_result.append(met.stats[-num_elem]) list_met_class.append( met_samples_result ) # save the list in the list containing the list of each class # save the batch sample stats for each class in a matrix (N, NC, 5) and append in metric 'stats' attribute for num_sample in range(batch_size): sample_stats = [] for num_class in range(1, len(list_met_class)): sample_stats.append(list_met_class[num_class][num_sample]) class_stack = torch.cat(sample_stats).view( 1, Configuration().num_classes, 5 ) # (N, NC, 5) self.stats.append(class_stack)
# for sample_list_met_class in list_met_class: # sample_list_met_class = sample_list_met_class[1:] # class_stack = torch.cat(sample_list_met_class).view(1, Configuration().num_classes, 5) # (N, NC, 5) # self.stats.append(class_stack) def _filter_classes(self, batchedformat: BatchedFormat, cls: int): new_formats = [] for form in batchedformat: new_form, _ = form[form.labels == cls] new_formats.append(new_form) return BatchedFormat(new_formats)
[docs] def global_micro_compute(self) -> Tensor: """Compute metric with global/micro averagging.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5) # sum stats accross samples samples_stack = samples_stack.nansum(dim=0) # (NC, 5) # sum stats accross classes micro_stack = torch.nansum(samples_stack, dim=0) # (5,) # compute metric tp, fp, tn, fn, _ = micro_stack.unbind(0) return self.func(tp, fp, tn, fn)
[docs] def global_macro_compute(self) -> Tuple[Tensor, Tensor]: """Compute metric with global/macro averraging. Return also metric/class tensor.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5) # sum stats accross samples samples_stack = samples_stack.nansum(dim=0) # (NC, 5) # compute metric/class tp, fp, tn, fn, _ = samples_stack.unbind(1) class_metrics = self.func(tp, fp, tn, fn) # (NC,) return torch.nanmean(class_metrics), class_metrics
[docs] def samplewise_micro(self) -> Tensor: """Compute metric with samplewise/micro averagging.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5)) # sum stat accross classes samples_stack = samples_stack.nansum(dim=1) # (N, 5) # compute metric/sample tp, fp, tn, fn, _ = samples_stack.unbind(dim=1) samples_metrics = self.func(tp, fp, tn, fn) # mean accross samples return torch.nanmean(samples_metrics)
[docs] def samplewise_macro(self) -> Tensor: """Compute metric with samplewise/macro averagging.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5) # compute metric/class/sample tp, fp, tn, fn, _ = samples_stack.unbind(2) class_metrics = self.func(tp, fp, tn, fn) # (N,NC) # mean accross classes macro = torch.nanmean(class_metrics, dim=1) # (N,) # mean accross samples macro_samplewise = torch.nanmean(macro, dim=0) return macro_samplewise
[docs] def compute_last_sample(self): """Return metrics values of the last sample in self.stats. Used in combination with self.update Returns: Dict[str, Float]: dictionnary with metric value for all classes combined and for each class """ if not self.stats: return {"all_cls": torch.tensor(torch.nan)} metric_sample_dict = {} matrix_stats_img = self.stats[-1].squeeze( 0 ) # matrix with stats (tp, fp, tn, fn) of one sample (one row for each class) nb_class = matrix_stats_img.size(dim=0) tp, fp, tn, fn, _ = matrix_stats_img.nansum(dim=0).unbind( dim=0 ) # sum over all classes all_class_metric_value = round( self.func(tp, fp, tn, fn).item(), 3 ) # metric compute for all classes metric_sample_dict["all_cls"] = all_class_metric_value if nb_class > 1: for index_class in range(nb_class): # metric compute for each class matrix_stats_img_class = matrix_stats_img[index_class, :] tp, fp, tn, fn, _ = matrix_stats_img_class.unbind( dim=0 ) # stats (tp, fp, tn, fn) of the given class class_metric_value = round( self.func(tp, fp, tn, fn).item(), 3 ) # metric compute for the given class metric_sample_dict["cls_{0}".format(index_class)] = class_metric_value return metric_sample_dict
[docs] def compute(self): """Return metrics values. Returns: Dict[str, Tensor]: dictionnary with all "global" DetectMetric in self.classmetrics """ # classwisedict = {met.name: met.compute()["global"] for met in self.classmetrics} # return classwisedict if not self.stats: return {self.name: torch.tensor(torch.nan)} metric_dict = {} # global micro global_micro = self.global_micro_compute() metric_dict.update({"_global_micro": global_micro}) # global macro global_macro, class_metrics = self.global_macro_compute() classes_dict = { f"/cls_{i}": class_metrics[i] for i in range(class_metrics.nelement()) } metric_dict.update({"_global_macro": global_macro}) metric_dict.update(classes_dict) # samplewise micro samplewise_micro = self.samplewise_micro() metric_dict.update({"_samplewise_micro": samplewise_micro}) # samplewise macro samplewise_macro = self.samplewise_macro() metric_dict.update({"_samplewise_macro": samplewise_macro}) return metric_dict
# override
[docs] def reset(self): """Reset all metrics in self.classmetrics. Override from torchmetrics Metric""" [met.reset() for met in self.classmetrics]
# override
[docs] def to(self, device: Any): """Move all metrics in self.classmetrics to device. Override from torchmetrics Metric Args: device (Any) """ self.classmetrics = [metric.to(device) for metric in self.classmetrics] return self
[docs] class ClassifMetric(Metric): """Child class of torchmetrics metrics for classification. Allow to take Format as inputs and return dict of metric.""" def __init__( self, func: Callable, # metric functionnal name: str = "ClassifMetric", **kwargs: Any, ): super().__init__(**kwargs) num_classes = Configuration().num_classes self.func = func self.task = "binary" if num_classes == 1 else "multiclass" self.nc = num_classes # use tm engine to get statistics (tp,tn,fp,fn,sup) self.stat_score = StatScores( task=self.task, multidim_average="samplewise", average="none", num_classes=num_classes, ) self.add_state("stats", default=[], dist_reduce_fx="cat") self.name = name def update( self, prediction: Union[BaseFormat, BatchedFormat], target: Union[BaseFormat, BatchedFormat], ): assert ( Configuration().data_type != "semantic_mask" ), "Can't use ClassifMetric with Configuration().data_type = semantic_mask. Must be instance_mask or bbox" """Update internal states.""" # if no predictions or target, no classification evaluation prediction = ( BatchedFormat([prediction]) if not isinstance(prediction, BatchedFormat) else prediction ) target = ( BatchedFormat([target]) if not isinstance(target, BatchedFormat) else target ) for pred, targ in zip(prediction, target): if pred.nb_object == 0 or targ.nb_object == 0: return target_labels = targ.labels # match objects matcher = Matcher() _, _, _, (pred_idxs, target_idxs) = matcher.match_pred_target(pred, targ) pred_idxs = pred_idxs.to(pred.device) target_idxs = target_idxs.to(targ.device) pred, _ = pred[pred_idxs] targ, _ = targ[target_idxs] # if no box match, add all targets in fn if pred.nb_object == 0: class_stats = torch.zeros((self.nc, 5)).to(pred.device) values = torch.tensor( [torch.sum(target_labels == i) for i in range(self.nc)] ) class_stats[:, 4] = values self.stats.append(class_stats[None, ...]) return # get labels pred_labels = pred.labels target_labels = targ.labels # if binary pass label 0 to 1 if self.task == "binary": pred_labels += 1 target_labels += 1 # compute stats stats = self.stat_score( pred_labels[None, ...], target_labels[None, ...] ).view(1, self.nc, 5) self.stats.append(stats)
[docs] def global_micro_compute(self) -> Tensor: """Compute metric with global/micro averagging.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5) # sum stats accross samples samples_stack = samples_stack.nansum(dim=0) # (NC, 5) # sum stats accross classes micro_stack = torch.nansum(samples_stack, dim=0) # (5,) # compute metric tp, fp, tn, fn, _ = micro_stack.unbind(0) return self.func(tp, fp, tn, fn)
[docs] def global_macro_compute(self) -> Tuple[Tensor, Tensor]: """Compute metric with global/macro averraging. Return also metric/class tensor.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5) # sum stats accross samples samples_stack = samples_stack.nansum(dim=0) # (NC, 5) # compute metric/class tp, fp, tn, fn, _ = samples_stack.unbind(1) class_metrics = self.func(tp, fp, tn, fn) # (NC,) return torch.nanmean(class_metrics), class_metrics
[docs] def samplewise_micro(self) -> Tensor: """Compute metric with samplewise/micro averagging.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5)) # sum stat accross classes samples_stack = samples_stack.nansum(dim=1) # (NC, 5) # compute metric/sample tp, fp, tn, fn, _ = samples_stack.unbind(dim=1) samples_metrics = self.func(tp, fp, tn, fn) # mean accross samples return torch.nanmean(samples_metrics)
[docs] def samplewise_macro(self) -> Tensor: """Compute metric with samplewise/macro averagging.""" n_samples = len(self.stats) samples_stack = torch.cat(self.stats).view(n_samples, self.nc, 5) # (N, NC, 5) # compute metric/class/sample tp, fp, tn, fn, _ = samples_stack.unbind(2) class_metrics = self.func(tp, fp, tn, fn) # (N,NC) # mean accross classes macro = torch.nanmean(class_metrics, dim=1) # (N,) # mean accross samples macro_samplewise = torch.nanmean(macro, dim=0) return macro_samplewise
[docs] def compute_last_sample(self): """Return metrics values of the last sample in self.stats. Used in combination with self.update Returns: Dict[str, Float]: dictionnary with metric value for all classes combined and for each class """ if not self.stats: return {"all_cls": torch.tensor(torch.nan)} metric_sample_dict = {} matrix_stats_img = self.stats[-1].squeeze( 0 ) # matrix with stats (tp, fp, tn, fn) of one sample (one row for each class) nb_class = matrix_stats_img.size(dim=0) tp, fp, tn, fn, _ = matrix_stats_img.nansum(dim=0).unbind( dim=0 ) # sum over all classes all_class_metric_value = round( self.func(tp, fp, tn, fn).item(), 3 ) # metric compute for all classes metric_sample_dict["all_cls"] = all_class_metric_value if nb_class > 1: for index_class in range(nb_class): # metric compute for each class matrix_stats_img_class = matrix_stats_img[index_class, :] tp, fp, tn, fn, _ = matrix_stats_img_class.unbind( dim=0 ) # stats (tp, fp, tn, fn) of the given class class_metric_value = round( self.func(tp, fp, tn, fn).item(), 3 ) # metric compute for the given class metric_sample_dict["cls_{0}".format(index_class)] = class_metric_value return metric_sample_dict
[docs] def compute(self): """Comput metric with all averag strategy and return a dict with all values.""" if not self.stats: return {self.name: torch.tensor(torch.nan)} # if binary no need for macro aggregation metric_dict = {} if self.task == "multiclass": # global micro global_micro = self.global_micro_compute() metric_dict.update({"_global_micro": global_micro}) # global macro global_macro, class_metrics = self.global_macro_compute() if Configuration().data_type == "semantic_mask": classes_dict = { f"/cls_{i+1}": class_metrics[i] for i in range(class_metrics.nelement()) } else: classes_dict = { f"/cls_{i}": class_metrics[i] for i in range(class_metrics.nelement()) } metric_dict.update({"_global_macro": global_macro}) metric_dict.update(classes_dict) # samplewise micro samplewise_micro = self.samplewise_micro() metric_dict.update({"_samplewise_micro": samplewise_micro}) # samplewise macro samplewise_macro = self.samplewise_macro() metric_dict.update({"_samplewise_macro": samplewise_macro}) else: # global micro global_micro = self.global_micro_compute() metric_dict.update({"_global": global_micro}) # samplewise micro samplewise_micro = self.samplewise_micro() metric_dict.update({"_samplewise": samplewise_micro}) return metric_dict return metric_dict
[docs] class SemanticSegmentationMetric(ClassifMetric): """Child class of ClassifMetric. Move from instance to semantic segmentation paradigm to provide stats based on classes masks (instead of objects).""" def __init__( self, func: Callable, # metric functionnal name: str = "SegmentationMetric", **kwargs: Any, ): assert Configuration().data_type in [ "instance_mask", "semantic_mask", ], f"Configuration().data_type must be instance_mask or semantic_mask to use SemanticSegmentationMetrics. Got {Configuration().data_type}" # init from ClassifMetric and # redefine stat score num_classes for multiclass detection to include background as a class (+1 class) # Note that background is later removed in update (background is meaningless regarding instance segmentation) but important for stats calculations num_classes = Configuration().num_classes super().__init__(func, name, **kwargs) # Include background if data_type is instance_mask and num_classes != 1 if num_classes > 1 and Configuration().data_type != "semantic_mask": self.stat_score.num_classes = num_classes + 1 # redefine internal number of classes to remove background from semantic mask (background removal is important to avoid super high scores) if ( Configuration().data_type == "semantic_mask" and Configuration().num_classes > 1 ): self.nc -= 1 # override
[docs] def update( self, prediction: Union[BaseFormat, BatchedFormat], target: Union[BaseFormat, BatchedFormat], ): """Convert target & prediction to semantic mask to compute stats in semantic segmentation paradigm. Update internal state.""" prediction = ( BatchedFormat([prediction]) if not isinstance(prediction, BatchedFormat) else prediction ) target = ( BatchedFormat([target]) if not isinstance(target, BatchedFormat) else target ) assert type(prediction.formats[0]) in [ InstanceMaskFormat, SemanticMaskFormat, ] and type(target.formats[0]) in [ InstanceMaskFormat, SemanticMaskFormat, ], "formats of BatchedFormat must be InstanceMaskFormat or SemanticMaskFormat to use SemanticSegmentationMetric" # Convert instance mask to semantic mask format if isinstance(target.formats[0], InstanceMaskFormat): target = BatchedFormat( [SemanticMaskFormat.from_instance_mask(t) for t in target] ) if isinstance(prediction.formats[0], InstanceMaskFormat): prediction = BatchedFormat( [SemanticMaskFormat.from_instance_mask(p) for p in prediction] ) for pred, targ in zip(prediction, target): # handling empty target pred = pred.data.value targ = targ.data.value # flatten and add dummy dim for stat score flatpred = pred.flatten()[None, :] flattarget = targ.flatten()[None, :] if self.task == "binary": self.stats.append(self.stat_score(flatpred, flattarget)) else: # for multiclass will remove the background scores to avoid crazy high scores self.stats.append(self.stat_score(flatpred, flattarget)[:, 1:, :])