Shortcuts

Source code for pytorchvideo.transforms.augmix

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Any, Dict, Optional

import torch
from pytorchvideo.transforms.augmentations import (
    _AUGMENTATION_MAX_LEVEL,
    AugmentTransform,
    _decreasing_int_to_arg,
    _decreasing_to_arg,
    _increasing_magnitude_to_arg,
    _increasing_randomly_negate_to_arg,
)
from pytorchvideo.transforms.transforms import OpSampler


_AUGMIX_LEVEL_TO_ARG = {
    "AutoContrast": None,
    "Equalize": None,
    "Rotate": _increasing_randomly_negate_to_arg,
    "Posterize": _decreasing_int_to_arg,
    "Solarize": _decreasing_to_arg,
    "ShearX": _increasing_randomly_negate_to_arg,
    "ShearY": _increasing_randomly_negate_to_arg,
    "TranslateX": _increasing_randomly_negate_to_arg,
    "TranslateY": _increasing_randomly_negate_to_arg,
    "AdjustSaturation": _increasing_magnitude_to_arg,
    "AdjustContrast": _increasing_magnitude_to_arg,
    "AdjustBrightness": _increasing_magnitude_to_arg,
    "AdjustSharpness": _increasing_magnitude_to_arg,
}

_TRANSFORM_AUGMIX_MAX_PARAMS = {
    "AutoContrast": None,
    "Equalize": None,
    "Rotate": (0, 30),
    "Posterize": (4, 4),
    "Solarize": (1, 1),
    "ShearX": (0, 0.3),
    "ShearY": (0, 0.3),
    "TranslateX": (0, 1.0 / 3.0),
    "TranslateY": (0, 1.0 / 3.0),
    "AdjustSaturation": (0.1, 1.8),
    "AdjustContrast": (0.1, 1.8),
    "AdjustBrightness": (0.1, 1.8),
    "AdjustSharpness": (0.1, 1.8),
}

# Hyperparameters for sampling magnitude.
# sampling_data_type determines whether uniform sampling samples among ints or floats.
# sampling_min determines the minimum possible value obtained from uniform
# sampling among floats.
SAMPLING_AUGMIX_DEFAULT_HPARAS = {"sampling_data_type": "float", "sampling_min": 0.1}


[docs]class AugMix: """ This implements AugMix for video. AugMix generates several chains of augmentations on the original video, which are then mixed together with each other and with the original video to create an augmented video. The input video tensor should have shape (T, C, H, W). AugMix: A Simple Data Processing Method to Improve Robustness and Uncertainty (https://arxiv.org/pdf/1912.02781.pdf) """
[docs] def __init__( self, magnitude: int = 3, alpha: float = 1.0, width: int = 3, depth: int = -1, transform_hparas: Optional[Dict[str, Any]] = None, sampling_hparas: Optional[Dict[str, Any]] = None, ) -> None: """ Args: magnitude (int): Magnitude used for transform function. Default is 3. alpha (float): Parameter for choosing mixing weights from the beta and Dirichlet distributions. Default is 1.0. width (int): The number of transformation chains. Default is 3. depth (int): The number of transformations in each chain. If depth is -1, each chain will have a random length between 1 and 3 inclusive. Default is -1. transform_hparas (Optional[Dict[Any]]): Transform hyper parameters. Needs to have key fill. By default, the fill value is (0.5, 0.5, 0.5). sampling_hparas (Optional[Dict[Any]]): Hyper parameters for sampling. If gaussian sampling is used, it needs to have key sampling_std. By default, it uses SAMPLING_AUGMIX_DEFAULT_HPARAS. """ assert isinstance(magnitude, int), "magnitude must be an int" assert ( magnitude >= 1 and magnitude <= _AUGMENTATION_MAX_LEVEL ), f"magnitude must be between 1 and {_AUGMENTATION_MAX_LEVEL} inclusive" assert alpha > 0.0, "alpha must be greater than 0" assert width > 0, "width must be greater than 0" self._magnitude = magnitude self.dirichlet = torch.distributions.dirichlet.Dirichlet( torch.tensor([alpha] * width) ) self.beta = torch.distributions.beta.Beta(alpha, alpha) transforms_list = [ AugmentTransform( transform_name=transform_name, magnitude=self._magnitude, prob=1.0, level_to_arg=_AUGMIX_LEVEL_TO_ARG, transform_max_paras=_TRANSFORM_AUGMIX_MAX_PARAMS, transform_hparas=transform_hparas, sampling_type="uniform", sampling_hparas=sampling_hparas or SAMPLING_AUGMIX_DEFAULT_HPARAS, ) for transform_name in list(_TRANSFORM_AUGMIX_MAX_PARAMS.keys()) ] if depth > 0: self.augmix_fn = OpSampler( transforms_list, num_sample_op=depth, replacement=True, ) else: self.augmix_fn = OpSampler( transforms_list, num_sample_op=3, randomly_sample_depth=True, replacement=True, )
[docs] def __call__(self, video: torch.Tensor) -> torch.Tensor: """ Perform AugMix to the input video tensor. Args: video (torch.Tensor): Input video tensor with shape (T, C, H, W). """ mixing_weights = self.dirichlet.sample() m = self.beta.sample().item() mixed = torch.zeros(video.shape, dtype=torch.float32) for mw in mixing_weights: mixed += mw * self.augmix_fn(video) if video.dtype == torch.uint8: return (m * video + (1 - m) * mixed).type(torch.uint8) else: return m * video + (1 - m) * mixed
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.