pytorchvideo.models.x3d¶
-
pytorchvideo.models.x3d.
create_x3d_stem
(*, in_channels, out_channels, conv_kernel_size=(5, 3, 3), conv_stride=(1, 2, 2), conv_padding=(2, 1, 1), norm=<class 'torch.nn.modules.batchnorm.BatchNorm3d'>, norm_eps=1e-05, norm_momentum=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>)[source]¶ Creates the stem layer for X3D. It performs spatial Conv, temporal Conv, BN, and Relu.
Conv_xy ↓ Conv_t ↓ Normalization ↓ Activation
- Parameters
in_channels (int) – input channel size of the convolution.
out_channels (int) – output channel size of the convolution.
conv_kernel_size (tuple) – convolutional kernel size(s).
conv_stride (tuple) – convolutional stride size(s).
conv_padding (tuple) – convolutional padding size(s).
norm (callable) – a callable that constructs normalization layer, options include nn.BatchNorm3d, None (not performing normalization).
norm_eps (float) – normalization epsilon.
norm_momentum (float) – normalization momentum.
activation (callable) – a callable that constructs activation layer, options include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing activation).
- Returns
(nn.Module) – X3D stem layer.
- Return type
torch.nn.modules.module.Module
-
pytorchvideo.models.x3d.
create_x3d_bottleneck_block
(*, dim_in, dim_inner, dim_out, conv_kernel_size=(3, 3, 3), conv_stride=(1, 2, 2), norm=<class 'torch.nn.modules.batchnorm.BatchNorm3d'>, norm_eps=1e-05, norm_momentum=0.1, se_ratio=0.0625, activation=<class 'torch.nn.modules.activation.ReLU'>, inner_act=<class 'pytorchvideo.layers.swish.Swish'>)[source]¶ Bottleneck block for X3D: a sequence of Conv, Normalization with optional SE block, and Activations repeated in the following order:
Conv3d (conv_a) ↓ Normalization (norm_a) ↓ Activation (act_a) ↓ Conv3d (conv_b) ↓ Normalization (norm_b) ↓ Squeeze-and-Excitation ↓ Activation (act_b) ↓ Conv3d (conv_c) ↓ Normalization (norm_c)
- Parameters
dim_in (int) – input channel size to the bottleneck block.
dim_inner (int) – intermediate channel size of the bottleneck.
dim_out (int) – output channel size of the bottleneck.
conv_kernel_size (tuple) – convolutional kernel size(s) for conv_b.
conv_stride (tuple) – convolutional stride size(s) for conv_b.
norm (callable) – a callable that constructs normalization layer, examples include nn.BatchNorm3d, None (not performing normalization).
norm_eps (float) – normalization epsilon.
norm_momentum (float) – normalization momentum.
se_ratio (float) – if > 0, apply SE to the 3x3x3 conv, with the SE channel dimensionality being se_ratio times the 3x3x3 conv dim.
activation (callable) – a callable that constructs activation layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing activation).
inner_act (callable) – whether use Swish activation for act_b or not.
- Returns
(nn.Module) – X3D bottleneck block.
- Return type
torch.nn.modules.module.Module
-
pytorchvideo.models.x3d.
create_x3d_res_block
(*, dim_in, dim_inner, dim_out, bottleneck=<function create_x3d_bottleneck_block>, use_shortcut=True, conv_kernel_size=(3, 3, 3), conv_stride=(1, 2, 2), norm=<class 'torch.nn.modules.batchnorm.BatchNorm3d'>, norm_eps=1e-05, norm_momentum=0.1, se_ratio=0.0625, activation=<class 'torch.nn.modules.activation.ReLU'>, inner_act=<class 'pytorchvideo.layers.swish.Swish'>)[source]¶ Residual block for X3D. Performs a summation between an identity shortcut in branch1 and a main block in branch2. When the input and output dimensions are different, a convolution followed by a normalization will be performed.
Input |-------+ ↓ | Block | ↓ | Summation ←-+ ↓ Activation
- Parameters
dim_in (int) – input channel size to the bottleneck block.
dim_inner (int) – intermediate channel size of the bottleneck.
dim_out (int) – output channel size of the bottleneck.
bottleneck (callable) – a callable for create_x3d_bottleneck_block.
conv_kernel_size (tuple) – convolutional kernel size(s) for conv_b.
conv_stride (tuple) – convolutional stride size(s) for conv_b.
norm (callable) – a callable that constructs normalization layer, examples include nn.BatchNorm3d, None (not performing normalization).
norm_eps (float) – normalization epsilon.
norm_momentum (float) – normalization momentum.
se_ratio (float) – if > 0, apply SE to the 3x3x3 conv, with the SE channel dimensionality being se_ratio times the 3x3x3 conv dim.
activation (callable) – a callable that constructs activation layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing activation).
inner_act (callable) – whether use Swish activation for act_b or not.
use_shortcut (bool) –
- Returns
(nn.Module) – X3D block layer.
- Return type
torch.nn.modules.module.Module
-
pytorchvideo.models.x3d.
create_x3d_res_stage
(*, depth, dim_in, dim_inner, dim_out, bottleneck=<function create_x3d_bottleneck_block>, conv_kernel_size=(3, 3, 3), conv_stride=(1, 2, 2), norm=<class 'torch.nn.modules.batchnorm.BatchNorm3d'>, norm_eps=1e-05, norm_momentum=0.1, se_ratio=0.0625, activation=<class 'torch.nn.modules.activation.ReLU'>, inner_act=<class 'pytorchvideo.layers.swish.Swish'>)[source]¶ Create Residual Stage, which composes sequential blocks that make up X3D.
Input ↓ ResBlock ↓ . . . ↓ ResBlock
- Parameters
depth (init) – number of blocks to create.
dim_in (int) – input channel size to the bottleneck block.
dim_inner (int) – intermediate channel size of the bottleneck.
dim_out (int) – output channel size of the bottleneck.
bottleneck (callable) – a callable for create_x3d_bottleneck_block.
conv_kernel_size (tuple) – convolutional kernel size(s) for conv_b.
conv_stride (tuple) – convolutional stride size(s) for conv_b.
norm (callable) – a callable that constructs normalization layer, examples include nn.BatchNorm3d, None (not performing normalization).
norm_eps (float) – normalization epsilon.
norm_momentum (float) – normalization momentum.
se_ratio (float) – if > 0, apply SE to the 3x3x3 conv, with the SE channel dimensionality being se_ratio times the 3x3x3 conv dim.
activation (callable) – a callable that constructs activation layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not performing activation).
inner_act (callable) – whether use Swish activation for act_b or not.
- Returns
(nn.Module) – X3D stage layer.
- Return type
torch.nn.modules.module.Module
-
pytorchvideo.models.x3d.
create_x3d_head
(*, dim_in, dim_inner, dim_out, num_classes, pool_act=<class 'torch.nn.modules.activation.ReLU'>, pool_kernel_size=(13, 5, 5), norm=<class 'torch.nn.modules.batchnorm.BatchNorm3d'>, norm_eps=1e-05, norm_momentum=0.1, bn_lin5_on=False, dropout_rate=0.5, activation=<class 'torch.nn.modules.activation.Softmax'>, output_with_global_average=True)[source]¶ Creates X3D head. This layer performs an projected pooling operation followed by an dropout, a fully-connected projection, an activation layer and a global spatiotemporal averaging.
ProjectedPool ↓ Dropout ↓ Projection ↓ Activation ↓ Averaging
- Parameters
dim_in (int) – input channel size of the X3D head.
dim_inner (int) – intermediate channel size of the X3D head.
dim_out (int) – output channel size of the X3D head.
num_classes (int) – the number of classes for the video dataset.
pool_act (callable) – a callable that constructs resnet pool activation layer such as nn.ReLU.
pool_kernel_size (tuple) – pooling kernel size(s) when not using adaptive pooling.
norm (callable) – a callable that constructs normalization layer, examples include nn.BatchNorm3d, None (not performing normalization).
norm_eps (float) – normalization epsilon.
norm_momentum (float) – normalization momentum.
bn_lin5_on (bool) – if True, perform normalization on the features before the classifier.
dropout_rate (float) – dropout rate.
activation (callable) – a callable that constructs resnet head activation layer, examples include: nn.ReLU, nn.Softmax, nn.Sigmoid, and None (not applying activation).
output_with_global_average (bool) – if True, perform global averaging on temporal and spatial dimensions and reshape output to batch_size x out_features.
- Returns
(nn.Module) – X3D head layer.
- Return type
torch.nn.modules.module.Module
-
pytorchvideo.models.x3d.
create_x3d
(*, input_channel=3, input_clip_length=13, input_crop_size=160, model_num_class=400, dropout_rate=0.5, width_factor=2.0, depth_factor=2.2, norm=<class 'torch.nn.modules.batchnorm.BatchNorm3d'>, norm_eps=1e-05, norm_momentum=0.1, activation=<class 'torch.nn.modules.activation.ReLU'>, stem_dim_in=12, stem_conv_kernel_size=(5, 3, 3), stem_conv_stride=(1, 2, 2), stage_conv_kernel_size=((3, 3, 3), (3, 3, 3), (3, 3, 3), (3, 3, 3)), stage_spatial_stride=(2, 2, 2, 2), stage_temporal_stride=(1, 1, 1, 1), bottleneck=<function create_x3d_bottleneck_block>, bottleneck_factor=2.25, se_ratio=0.0625, inner_act=<class 'pytorchvideo.layers.swish.Swish'>, head_dim_out=2048, head_pool_act=<class 'torch.nn.modules.activation.ReLU'>, head_bn_lin5_on=False, head_activation=<class 'torch.nn.modules.activation.Softmax'>, head_output_with_global_average=True)[source]¶ X3D model builder. It builds a X3D network backbone, which is a ResNet.
Christoph Feichtenhofer. “X3D: Expanding Architectures for Efficient Video Recognition.” https://arxiv.org/abs/2004.04730
Input ↓ Stem ↓ Stage 1 ↓ . . . ↓ Stage N ↓ Head
- Parameters
input_channel (int) – number of channels for the input video clip.
input_clip_length (int) – length of the input video clip. Value for different models: X3D-XS: 4; X3D-S: 13; X3D-M: 16; X3D-L: 16.
input_crop_size (int) – spatial resolution of the input video clip. Value for different models: X3D-XS: 160; X3D-S: 160; X3D-M: 224; X3D-L: 312.
model_num_class (int) – the number of classes for the video dataset.
dropout_rate (float) – dropout rate.
width_factor (float) – width expansion factor.
depth_factor (float) – depth expansion factor. Value for different models: X3D-XS: 2.2; X3D-S: 2.2; X3D-M: 2.2; X3D-L: 5.0.
norm (callable) – a callable that constructs normalization layer.
norm_eps (float) – normalization epsilon.
norm_momentum (float) – normalization momentum.
activation (callable) – a callable that constructs activation layer.
stem_dim_in (int) – input channel size for stem before expansion.
stem_conv_kernel_size (tuple) – convolutional kernel size(s) of stem.
stem_conv_stride (tuple) – convolutional stride size(s) of stem.
stage_conv_kernel_size (tuple) – convolutional kernel size(s) for conv_b.
stage_spatial_stride (tuple) – the spatial stride for each stage.
stage_temporal_stride (tuple) – the temporal stride for each stage.
bottleneck_factor (float) – bottleneck expansion factor for the 3x3x3 conv.
se_ratio (float) – if > 0, apply SE to the 3x3x3 conv, with the SE channel dimensionality being se_ratio times the 3x3x3 conv dim.
inner_act (callable) – whether use Swish activation for act_b or not.
head_dim_out (int) – output channel size of the X3D head.
head_pool_act (callable) – a callable that constructs resnet pool activation layer such as nn.ReLU.
head_bn_lin5_on (bool) – if True, perform normalization on the features before the classifier.
head_activation (callable) – a callable that constructs activation layer.
head_output_with_global_average (bool) – if True, perform global averaging on the head output.
bottleneck (Callable) –
- Returns
(nn.Module) – the X3D network.
- Return type
torch.nn.modules.module.Module
-
class
pytorchvideo.models.x3d.
ProjectedPool
(*, pre_conv=None, pre_norm=None, pre_act=None, pool=None, post_conv=None, post_norm=None, post_act=None)[source]¶ A pooling module augmented with Conv, Normalization and Activation both before and after pooling for the head layer of X3D.
Conv3d (pre_conv) ↓ Normalization (pre_norm) ↓ Activation (pre_act) ↓ Pool3d ↓ Conv3d (post_conv) ↓ Normalization (post_norm) ↓ Activation (post_act)
-
__init__
(*, pre_conv=None, pre_norm=None, pre_act=None, pool=None, post_conv=None, post_norm=None, post_act=None)[source]¶ - Parameters
pre_conv (torch.nn.modules) – convolutional module.
pre_norm (torch.nn.modules) – normalization module.
pre_act (torch.nn.modules) – activation module.
pool (torch.nn.modules) – pooling module.
post_conv (torch.nn.modules) – convolutional module.
post_norm (torch.nn.modules) – normalization module.
post_act (torch.nn.modules) – activation module.
- Return type
-