Shortcuts

Source code for pytorchvideo.transforms.mix

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Tuple

import torch
from pytorchvideo.transforms.functional import convert_to_one_hot


def _mix_labels(
    labels: torch.Tensor,
    num_classes: int,
    lam: float = 1.0,
    label_smoothing: float = 0.0,
):
    """
    This function converts class indices to one-hot vectors and mix labels, given the
    number of classes.

    Args:
        labels (torch.Tensor): Class labels.
        num_classes (int): Total number of classes.
        lam (float): lamba value for mixing labels.
        label_smoothing (float): Label smoothing value.
    """
    labels1 = convert_to_one_hot(labels, num_classes, label_smoothing)
    labels2 = convert_to_one_hot(labels.flip(0), num_classes, label_smoothing)
    return labels1 * lam + labels2 * (1.0 - lam)


[docs]class MixUp(torch.nn.Module): """ Mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412) """
[docs] def __init__( self, alpha: float = 1.0, label_smoothing: float = 0.0, num_classes: int = 400, ) -> None: """ This implements MixUp for videos. Args: alpha (float): Mixup alpha value. label_smoothing (float): Label smoothing value. num_classes (int): Number of total classes. """ super().__init__() self.mixup_beta_sampler = torch.distributions.beta.Beta(alpha, alpha) self.label_smoothing = label_smoothing self.num_classes = num_classes
[docs] def forward( self, x: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ The input is a batch of samples and their corresponding labels. Args: x (torch.Tensor): Input tensor. The input should be a batch of videos with shape (B, C, T, H, W). labels (torch.Tensor): Labels for input with shape (B). """ assert x.size(0) > 1, "MixUp cannot be applied to a single instance." mixup_lambda = self.mixup_beta_sampler.sample() x_flipped = x.flip(0).mul_(1.0 - mixup_lambda) x.mul_(mixup_lambda).add_(x_flipped) new_labels = _mix_labels( labels, self.num_classes, mixup_lambda, self.label_smoothing, ) return x, new_labels
[docs]class CutMix(torch.nn.Module): """ CutMix: Regularization Strategy to Train Strong Classifiers with Localizable Features (https://arxiv.org/abs/1905.04899) """
[docs] def __init__( self, alpha: float = 1.0, label_smoothing: float = 0.0, num_classes: int = 400, ) -> None: """ This implements CutMix for videos. Args: alpha (float): CutMix alpha value. label_smoothing (float): Label smoothing value. num_classes (int): Number of total classes. """ super().__init__() self.cutmix_beta_sampler = torch.distributions.beta.Beta(alpha, alpha) self.label_smoothing = label_smoothing self.num_classes = num_classes
def _clip(self, value: int, min_value: int, max_value: int) -> int: """ Clip value based on minimum value and maximum value. """ return min(max(value, min_value), max_value) def _get_rand_box(self, input_shape: Tuple[int], cutmix_lamda: float) -> Tuple[int]: """ Get a random square box given a lambda value. """ ratio = (1 - cutmix_lamda) ** 0.5 input_h, input_w = input_shape[-2:] cut_h, cut_w = int(input_h * ratio), int(input_w * ratio) cy = torch.randint(input_h, (1,)).item() cx = torch.randint(input_w, (1,)).item() yl = self._clip(cy - cut_h // 2, 0, input_h) yh = self._clip(cy + cut_h // 2, 0, input_h) xl = self._clip(cx - cut_w // 2, 0, input_w) xh = self._clip(cx + cut_w // 2, 0, input_w) return yl, yh, xl, xh def _cutmix( self, x: torch.Tensor, cutmix_lamda: float ) -> Tuple[torch.Tensor, float]: """ Perform CutMix and return corrected lambda value. """ yl, yh, xl, xh = self._get_rand_box(x.size(), cutmix_lamda) box_area = float((yh - yl) * (xh - xl)) cutmix_lamda_corrected = 1.0 - box_area / (x.size(-2) * x.size(-1)) x[..., yl:yh, xl:xh] = x.flip(0)[..., yl:yh, xl:xh] return x, cutmix_lamda_corrected
[docs] def forward( self, x: torch.Tensor, labels: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """ The input is a batch of samples and their corresponding labels. Args: x (torch.Tensor): Input tensor. The input should be a batch of videos with shape (B, C, T, H, W). labels (torch.Tensor): Labels for input with shape (B). """ assert x.size(0) > 1, "Cutmix cannot be applied to a single instance." assert x.dim() == 4 or x.dim() == 5, "Please correct input shape." cutmix_lamda = self.cutmix_beta_sampler.sample() x, cutmix_lamda_corrected = self._cutmix(x, cutmix_lamda) new_labels = _mix_labels( labels, self.num_classes, cutmix_lamda_corrected, self.label_smoothing, ) return x, new_labels
[docs]class MixVideo(torch.nn.Module): """ Stochastically applies either MixUp or CutMix to the input video. """
[docs] def __init__( self, cutmix_prob: float = 0.5, mixup_alpha: float = 1.0, cutmix_alpha: float = 1.0, label_smoothing: float = 0.0, num_classes: int = 400, ): """ Args: cutmix_prob (float): Probability of using CutMix. MixUp will be used with probability 1 - cutmix_prob. If cutmix_prob is 0, then MixUp is always used. If cutmix_prob is 1, then CutMix is always used. mixup_alpha (float): MixUp alpha value. cutmix_alpha (float): CutMix alpha value. label_smoothing (float): Label smoothing value. num_classes (int): Number of total classes. """ assert 0.0 <= cutmix_prob <= 1.0, "cutmix_prob should be between 0.0 and 1.0" super().__init__() self.cutmix_prob = cutmix_prob self.mixup = MixUp( alpha=mixup_alpha, label_smoothing=label_smoothing, num_classes=num_classes ) self.cutmix = CutMix( alpha=cutmix_alpha, label_smoothing=label_smoothing, num_classes=num_classes )
[docs] def forward(self, x: torch.Tensor, labels: torch.Tensor): """ The input is a batch of samples and their corresponding labels. Args: x (torch.Tensor): Input tensor. The input should be a batch of videos with shape (B, C, T, H, W). labels (torch.Tensor): Labels for input with shape (B). """ if torch.rand(1).item() < self.cutmix_prob: return self.cutmix(x, labels) else: return self.mixup(x, labels)
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.