from deepvisiontools.models.basemodel import BaseModel
import torch.amp
from torch.optim import Optimizer
from deepvisiontools import Configuration
from typing import Literal, Dict, Tuple, List, Union
from torch import Tensor
from deepvisiontools.formats import BatchedFormat
from deepvisiontools.data import DeepVisionLoader
from tqdm import tqdm
import torch
from deepvisiontools.metrics.base_metric import (
DetectMetric,
ClassWiseDetectMetric,
SemanticSegmentationMetric,
ClassifMetric,
)
from torch.utils.tensorboard import SummaryWriter
from pathlib import Path
import shutil
import warnings
[docs]
class Trainer:
"""Class that handles training in deepvisiontools. Handles train / valid epochs, monitoring (via tensorboard) and metrics computation.
Args:
model (``BaseModel``): deepvisiontools model.
optimizer (``Optimizer``): torch optimizer (Ex: Adam())
metrics (``List[Union[DetectMetric, ClassWiseDetectMetric, SemanticSegmentationMetric, ClassifMetric]]``, **optional**): List of deepvisiontools metrics. Check available metrics in deepvisiontools.metrics.available_metrics Defaults to [].
log_dir (``str``, **optional**): tensorboard output directory. If "" no monitoring is provided. Defaults to "".
Example:
----------
.. highlight:: python
.. code-block:: python
>>> from deepvisiontools import DeepVisionDataset, DeepVisionLoader, Trainer
>>> from torch.optim import Adam
>>> model = Yolo()
>>> optim = Adam(model.parameters(), 1e-4)
>>> train_set, valid_set, _ = DeepVisionDataset(dataset_path=data_path).split((0.8, 0.2, 0))
>>> trainer = Trainer(model, optim, metrics=[DetectF1score()], log_dir="test_dir")
>>> train_loader = DeepVisionLoader(train_set, batch_size=6)
>>> valid_loader = DeepVisionLoader(valid_set, batch_size=6)
>>> for e in range(N_epoch):
>>> trainer.train_epoch(train_loader, e)
>>> trainer.valid_epoch(valid_loader, e)
Attributes
----------
Attributes:
- model (``BaseModel``): deepvisiontools model.
- optimizer (``Optimizer``): torch optimizer (Ex: Adam())
- metrics (``List[Union[DetectMetric, ClassWiseDetectMetric, SemanticSegmentationMetric, ClassifMetric]]``, **optional**): List of deepvisiontools metrics. Check available metrics in deepvisiontools.metrics.available_metrics Defaults to [].
- board (``SummaryWriter``): tensorboard output directory.
Attributes
----------
Properties:
- device (``Literal["cpu", "cuda"]``) : the setter move evrything that's needed to desired device.
**Methods**
"""
def __init__(
self,
model: BaseModel,
optimizer: Optimizer,
metrics: List[
Union[
DetectMetric,
ClassWiseDetectMetric,
SemanticSegmentationMetric,
ClassifMetric,
]
] = [],
log_dir="",
):
self.model = model
self.optimizer = optimizer
self.log_dir = log_dir
self.device = Configuration().device
self.metrics: List[
Union[
DetectMetric,
ClassWiseDetectMetric,
SemanticSegmentationMetric,
ClassifMetric,
]
] = [m.to(self.device) for m in metrics]
# create log dir and board for tensorboard
if log_dir:
# if log dir exist remove it
if Path(log_dir).exists():
shutil.rmtree(log_dir)
Path(log_dir).mkdir(parents=True)
self.board = SummaryWriter(log_dir)
else:
self.board = False
@property
def device(self):
return self._device
@device.setter
def device(self, val: Literal["cpu", "cuda"]):
self.model.device = val
self._device = val
[docs]
def train_step(
self, images: Tensor, targets: BatchedFormat, scaler: torch.amp.GradScaler
) -> Dict[str, Tensor]:
"""Run forward pass, loss computation and backward pass.
Args:
images (``Tensor``): Batch images
targets (``BatchedFormat``): Batch targets.
Returns:
``Dict[str, Tensor]``:
- Dict of losses containing (total loss at key 'loss').
"""
assert self.model.training, "model is not in train mode for train_step"
with torch.autocast(
device_type=Configuration().device,
dtype=torch.float16,
enabled=Configuration().optimize,
):
loss_dict = self.model.run_forward(images, targets)
loss = loss_dict["loss"]
if loss > 5 and Configuration().optimize:
warnings.warn(
"Loss value being large while using torch optimization can cause problems in your training. If your loss does not decrease across epoch consider 1) reducing the loss value by a factor (some deepvisiontools models have this option) 2) switch Configuration().optimize to False"
)
scaler.scale(loss).backward()
scaler.step(self.optimizer)
scaler.update()
self.optimizer.zero_grad()
return loss_dict
[docs]
def valid_step(
self, images: Tensor, targets: BatchedFormat, scaler: torch.amp.GradScaler
) -> Tuple[Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]:
"""Run forward, compute metrics, return loss dict and metrics.
Args:
images (``Tensor``): Batch images.
targets (``BatchedFormat``): Targets.
Returns:
``Tuple[Dict[str, Tensor], Dict[str, Dict[str, Tensor]]]``:
- Losses and metrics values.
"""
assert not (self.model.training), "model is not in valid mode for valid_step"
with torch.autocast(
device_type=Configuration().device,
dtype=torch.float16,
enabled=Configuration().optimize,
):
loss_dict, predictions = self.model.run_forward(images, targets)
metrics = self.compute_metrics(predictions, targets)
return loss_dict, metrics
[docs]
def epoch(
self,
loader: DeepVisionLoader,
ep_number: int,
tag: str = "",
) -> Dict[str, Tensor]:
"""Run trainning epoch.
Args:
loader (``DeepVisionLoader``): DeepVisionLoader.
ep_number (``int``): Epoch number.
tag (``str``, **optional**): Tag to link to epoch. Defaults to "".
Returns:
``Dict[str, Tensor]``:
- Epochs values (Losses & metrics).
"""
# create aggregator for loss averged accros samples
loss_aggregator = Aggregator()
iterator = tqdm(loader, total=len(loader), desc=f"Epoch {ep_number}/{tag}")
scaler = torch.amp.GradScaler(
Configuration().device, enabled=Configuration().optimize
)
# iterate over batches
for images, targets, _ in iterator:
batch_size = images.shape[0]
# send to device
images = images.to(self.device)
targets: BatchedFormat
# gather loss & metrics (if valid)
if self.model.training:
loss_dict = self.train_step(images, targets, scaler)
loss_aggregator(loss_dict, batch_size)
epoch_dict = loss_aggregator.compute()
else:
loss_dict, metric_dict = self.valid_step(images, targets, scaler)
loss_aggregator(loss_dict, batch_size)
epoch_dict = loss_aggregator.compute()
epoch_dict.update(metric_dict)
# extract str from log to display in terminal
log_str = self.log_string(epoch_dict)
iterator.set_postfix_str(f"{log_str}")
for metric in self.metrics:
# TODO understand why metric.reset() leads to increasing computation time ... .__init__ seems to solve the issue
metric.__init__()
metric.to(Configuration().device)
if self.log_dir:
for key, value in epoch_dict.items():
if isinstance(value, dict):
self.board.add_scalars(key, value, ep_number)
else:
self.board.add_scalars(key, {tag: value}, ep_number)
return epoch_dict
[docs]
def train_epoch(
self, loader: DeepVisionLoader, ep_number: int, tag: str = "Train"
) -> Dict[str, Tensor]:
"""Run train epoch.
Args:
loader (``DetectionLoader``): DetectionLoader.
ep_number (``int``): Epoch number.
tag (``str``, **optional**): Tag to link to epoch. Defaults to "Train".
Returns:
``Dict[str, Tensor]``:
- Epochs values (Losses).
"""
self.model.train()
epoch_dict = self.epoch(loader, ep_number, tag=tag)
return epoch_dict
[docs]
def valid_epoch(
self, loader: DeepVisionLoader, ep_number: int, tag: str = "Valid"
) -> Dict[str, Tensor]:
"""Run train epoch.
Args:
loader (``DetectionLoader``): DetectionLoader.
ep_number (``int``): Epoch number.
tag (``str``, **optional**): Tag to link to epoch. Defaults to "Valid".
Returns:
``Dict[str, Tensor]``:
- Epochs values (Losses & metrics).
"""
self.model.eval()
with torch.no_grad():
epoch_dict = self.epoch(loader, ep_number, tag=tag)
return epoch_dict
[docs]
def log_string(self, epoch_dict: Dict[str, Tensor]) -> str:
"""Transform epoch dict in string.
Args:
epoch_dict (``Dict[str, Tensor]``): Dict of epoch values to display.
Returns:
``str``:
- String to print with epoch values.
"""
flattened_dict = epoch_dict.copy()
for key, value in flattened_dict.items():
if isinstance(value, dict):
flattened_dict[key] = value[list(value.keys())[0]]
log = ""
for key, value in flattened_dict.items():
log += f"{key} : {str(round(value.item(), 4))} "
return log
def compute_metrics(self, predictions: BatchedFormat, targets: BatchedFormat):
for metric in self.metrics:
metric.update(predictions, targets)
# after all updates recompute to get averaged values of metric
metric_dict = {}
for metric in self.metrics:
results = metric.compute()
metric_dict.update({metric.name: results})
return metric_dict
[docs]
class Aggregator:
"""Aggregator aggregate losses across batchs.
Attributes:
-----------
Attributes:
iterations (``int``): Number of iterations.
losses (``Dict[str, Tensor]``): Dictionnary of epoch losses (over iterations).
Methods:
----------
"""
iterations: int
losses: dict
def __init__(self):
self.iterations = 0
self.losses: dict = None
[docs]
def update(self, batch_losses: Dict[str, Tensor]):
"""Update internal loss dict with new losses.
Args:
batch_losses (``Dict[str, Tensor]``): Dict of losses.
"""
for key, value in self.losses.items():
self.losses[key] = value + batch_losses[key]
def __call__(self, batch_losses: Dict[str, Tensor], batch_size: int):
"""Update internal loss dict.
Args:
batch_losses (``Dict[str, Tensor]``): Dict of losses.
batch_size (``int``): Batch size.
"""
if self.losses:
self.update(batch_losses)
else:
self.losses = batch_losses
self.iterations += 1 * batch_size
[docs]
def compute(self) -> Dict[str, Tensor]:
"""Return loss dict with values divided by iterations (Mean accross samples).
Returns:
``Dict[str, Tensor]``:
- Losses over iterations.
"""
out_dict = {
key: (value / self.iterations) for key, value in self.losses.items()
}
return out_dict