Shortcuts

Source code for pytorchvideo.models.net

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

from typing import List, Optional

import torch
import torch.nn as nn
from pytorchvideo.layers.utils import set_attributes
from pytorchvideo.models.weight_init import init_net_weights


[docs]class Net(nn.Module): """ Build a general Net models with a list of blocks for video recognition. :: Input Block 1 . . . Block N The ResNet builder can be found in `create_resnet`. """
[docs] def __init__(self, *, blocks: nn.ModuleList) -> None: """ Args: blocks (torch.nn.module_list): the list of block modules. """ super().__init__() assert blocks is not None self.blocks = blocks init_net_weights(self)
def forward(self, x: torch.Tensor) -> torch.Tensor: for idx in range(len(self.blocks)): x = self.blocks[idx](x) return x
[docs]class DetectionBBoxNetwork(nn.Module): """ A general purpose model that handles bounding boxes as part of input. """
[docs] def __init__(self, model: nn.Module, detection_head: nn.Module): """ Args: model (nn.Module): a model that preceeds the head. Ex: stem + stages. detection_head (nn.Module): a network head. that can take in input bounding boxes and the outputs from the model. """ super().__init__() self.model = model self.detection_head = detection_head
[docs] def forward(self, x: torch.Tensor, bboxes: torch.Tensor): """ Args: x (torch.tensor): input tensor bboxes (torch.tensor): accociated bounding boxes. The format is N*5 (Index, X_1,Y_1,X_2,Y_2) if using RoIAlign and N*6 (Index, x_ctr, y_ctr, width, height, angle_degrees) if using RoIAlignRotated. """ features = self.model(x) out = self.detection_head(features, bboxes) return out.view(out.shape[0], -1)
[docs]class MultiPathWayWithFuse(nn.Module): """ Build multi-pathway block with fusion for video recognition, each of the pathway contains its own Blocks and Fusion layers across different pathways. :: Pathway 1 ... Pathway N ↓ ↓ Block 1 Block N ↓⭠ --Fusion----↓ """
[docs] def __init__( self, *, multipathway_blocks: nn.ModuleList, multipathway_fusion: Optional[nn.Module], inplace: Optional[bool] = True, ) -> None: """ Args: multipathway_blocks (nn.module_list): list of models from all pathways. multipathway_fusion (nn.module): fusion model. inplace (bool): If inplace, directly update the input list without making a copy. """ super().__init__() set_attributes(self, locals())
def forward(self, x: List[torch.Tensor]) -> torch.Tensor: assert isinstance( x, list ), "input for MultiPathWayWithFuse needs to be a list of tensors" if self.inplace: x_out = x else: x_out = [None] * len(x) for pathway_idx in range(len(self.multipathway_blocks)): if self.multipathway_blocks[pathway_idx] is not None: x_out[pathway_idx] = self.multipathway_blocks[pathway_idx]( x[pathway_idx] ) if self.multipathway_fusion is not None: x_out = self.multipathway_fusion(x_out) return x_out
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.