Shortcuts

Source code for pytorchvideo.layers.squeeze_excitation

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import Callable, Optional

import torch
import torch.nn as nn
from pytorchvideo.models.resnet import ResBlock


[docs]class SqueezeAndExcitationLayer2D(nn.Module): """2D Squeeze and excitation layer, as per https://arxiv.org/pdf/1709.01507.pdf"""
[docs] def __init__( self, in_planes: int, reduction_ratio: Optional[int] = 16, reduced_planes: Optional[int] = None, ): """ Args: in_planes (int): input channel dimension. reduction_ratio (int): factor by which in_planes should be reduced to get the output channel dimension. reduced_planes (int): Output channel dimension. Only one of reduction_ratio or reduced_planes should be defined. """ super().__init__() self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) # Either reduction_ratio is defined, or out_planes is defined assert bool(reduction_ratio) != bool( reduced_planes ), "Only of reduction_ratio or reduced_planes should be defined for SE layer" reduced_planes = ( in_planes // reduction_ratio if reduced_planes is None else reduced_planes ) self.excitation = nn.Sequential( nn.Conv2d(in_planes, reduced_planes, kernel_size=1, stride=1, bias=True), nn.ReLU(), nn.Conv2d(reduced_planes, in_planes, kernel_size=1, stride=1, bias=True), nn.Sigmoid(), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x (tensor): 2D image of format C * H * W """ x_squeezed = self.avgpool(x) x_excited = self.excitation(x_squeezed) x_scaled = x * x_excited return x_scaled
[docs]def create_audio_2d_squeeze_excitation_block( dim_in: int, dim_out: int, use_se=False, se_reduction_ratio=16, branch_fusion: Callable = lambda x, y: x + y, # Conv configs. conv_a_kernel_size: int = 3, conv_a_stride: int = 1, conv_a_padding: int = 1, conv_b_kernel_size: int = 3, conv_b_stride: int = 1, conv_b_padding: int = 1, # Norm configs. norm: Callable = nn.BatchNorm2d, norm_eps: float = 1e-5, norm_momentum: float = 0.1, # Activation configs. activation: Callable = nn.ReLU, ) -> nn.Module: """ 2-D Residual block with squeeze excitation (SE2D) for 2d. Performs a summation between an identity shortcut in branch1 and a main block in branch2. When the input and output dimensions are different, a convolution followed by a normalization will be performed. :: Input |-------+ ↓ | conv2d | ↓ | Norm | ↓ | activation | ↓ | conv2d | ↓ | Norm | ↓ | SE2D | ↓ } Summation ←-+ Activation Normalization examples include: BatchNorm3d and None (no normalization). Activation examples include: ReLU, Softmax, Sigmoid, and None (no activation). Transform examples include: BottleneckBlock. Args: dim_in (int): input channel size to the bottleneck block. dim_out (int): output channel size of the bottleneck. use_se (bool): if true, use squeeze excitation layer in the bottleneck. se_reduction_ratio (int): factor by which input channels should be reduced to get the output channel dimension in SE layer. branch_fusion (callable): a callable that constructs summation layer. Examples include: lambda x, y: x + y, OctaveSum. conv_a_kernel_size (tuple): convolutional kernel size(s) for conv_a. conv_a_stride (tuple): convolutional stride size(s) for conv_a. conv_a_padding (tuple): convolutional padding(s) for conv_a. conv_b_kernel_size (tuple): convolutional kernel size(s) for conv_b. conv_b_stride (tuple): convolutional stride size(s) for conv_b. conv_b_padding (tuple): convolutional padding(s) for conv_b. norm (callable): a callable that constructs normalization layer. Examples include nn.BatchNorm3d, None (not performing normalization). norm_eps (float): normalization epsilon. norm_momentum (float): normalization momentum. activation (callable): a callable that constructs activation layer in bottleneck and block. Examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing activation). Returns: (nn.Module): resnet basic block layer. """ branch2 = [ nn.Conv2d( dim_in, dim_out, kernel_size=conv_a_kernel_size, stride=conv_a_stride, padding=conv_a_padding, bias=False, ), norm(dim_out, norm_eps, norm_momentum), activation() if activation else nn.Identity(), nn.Conv2d( dim_out, dim_out, kernel_size=conv_b_kernel_size, stride=conv_b_stride, padding=conv_b_padding, bias=False, ), norm(dim_out, norm_eps, norm_momentum), ] if use_se: branch2.append( SqueezeAndExcitationLayer2D(dim_out, reduction_ratio=se_reduction_ratio) ) branch2 = nn.Sequential(*branch2) branch1_conv, branch1_norm = None, None if conv_a_stride * conv_b_stride != 1 or dim_in != dim_out: branch1_conv = nn.Conv2d( dim_in, dim_out, kernel_size=1, stride=conv_a_stride * conv_b_stride, bias=False, ) branch1_norm = norm(dim_out, norm_eps, norm_momentum) return ResBlock( branch1_conv=branch1_conv, branch1_norm=branch1_norm, branch2=branch2, activation=activation() if activation else None, branch_fusion=branch_fusion, )
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.