pytorchvideo.models.masked_multistream¶
-
class
pytorchvideo.models.masked_multistream.
MaskedTemporalPooling
(method)[source]¶ Applies temporal pooling operations on masked inputs. For each pooling operation all masked values are ignored.
-
__init__
(method)[source]¶ - method (str): the method of pooling to use. Options:
‘max’: reduces temporal dimension to each valid max value. ‘avg’: averages valid values in the temporal dimension. ‘sum’: sums valid values in the temporal dimension. Note if all batch row elements are invalid, the temporal dimension is pooled to 0 values.
- Parameters
method (str) –
-
forward
(x, mask=None)[source]¶ - Parameters
x (torch.Tensor) – tensor with shape (batch_size, seq_len, feature_dim)
mask (torch.Tensor) – bool tensor with shape (batch_size, seq_len). Sequence elements that are False are invalid.
- Returns
Tensor with shape (batch_size, feature_dim)
- Return type
-
-
class
pytorchvideo.models.masked_multistream.
TransposeMultiheadAttention
(feature_dim, num_heads=1)[source]¶ Wrapper for nn.MultiheadAttention which first transposes the input tensor from (batch_size, seq_len, feature_dim) to (seq_length, batch_size, feature_dim), then applies the attention and transposes the attention outputs back to the input shape.
-
property
attention_weights
¶ Contains attention weights from last forward call.
-
forward
(x, mask=None)[source]¶ - Parameters
x (torch.Tensor) – tensor of shape (batch_size, seq_len, feature_dim)
mask (torch.Tensor) – bool tensor with shape (batch_size, seq_len). Sequence elements that are False are invalid.
- Returns
Tensor with shape (batch_size, seq_len, feature_dim)
- Return type
-
property
-
class
pytorchvideo.models.masked_multistream.
LearnMaskedDefault
(feature_dim, init_method='gaussian', freeze=False)[source]¶ Learns default values to fill invalid entries within input tensors. The invalid entries are represented by a mask which is passed into forward alongside the input tensor. Note the default value is only used if all entries in the batch row are invalid rather than just a portion of invalid entries within each batch row.
-
forward
(x, mask)[source]¶ - Parameters
x (torch.Tensor) – tensor of shape (batch_size, feature_dim).
mask (torch.Tensor) – bool tensor of shape (batch_size, seq_len) If all elements in the batch dimension are False the learned default parameter is used for that batch element.
- Returns
Tensor with shape (batch_size, feature_dim)
- Return type
-
-
class
pytorchvideo.models.masked_multistream.
LSTM
(dim_in, hidden_dim, dropout=0.0, bidirectional=False)[source]¶ Wrapper for torch.nn.LSTM that handles masked inputs.
-
forward
(data, mask=None)[source]¶ - Parameters
data (torch.Tensor) – tensor with shape (batch_size, seq_len, feature_dim)
mask (torch.Tensor) – bool tensor with shape (batch_size, seq_len). Sequence elements that are False are invalid.
- Returns
- Tensor with shape (batch_size, output_dim) - outoput_dim is determined by
hidden_dim and whether bidirectional or not
- Return type
-
-
class
pytorchvideo.models.masked_multistream.
TransposeTransformerEncoder
(dim_in, num_heads=1, num_layers=1)[source]¶ Wrapper for torch.nn.TransformerEncoder that handles masked inputs.
-
forward
(data, mask=None)[source]¶ - Parameters
data (torch.Tensor) – tensor with shape (batch_size, seq_len, feature_dim)
mask (torch.Tensor) – bool tensor with shape (batch_size, seq_len). Sequence elements that are False are invalid.
- Returns
Tensor with shape (batch_size, feature_dim)
- Return type
-
-
class
pytorchvideo.models.masked_multistream.
MaskedSequential
(*args)[source]¶ A sequential container that overrides forward to take a mask as well as the usual input tensor. This mask is only applied to modules in _MASK_MODULES (which take the mask argument).
-
class
pytorchvideo.models.masked_multistream.
MaskedMultiPathWay
(*, multipathway_blocks, multipathway_fusion)[source]¶ Masked multi-pathway is composed of a list of stream nn.Modules followed by a fusion nn.Module that reduces these streams. Each stream module takes a mask and input tensor.
Pathway 1 ... Pathway N ↓ ↓ Block 1 Block N ↓⭠ --Fusion----↓