
Source code for pytorchvideo.models.byol

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import copy
from typing import Callable, Optional

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs]class BYOL(nn.Module): """ Bootstrap Your Own Latent A New Approach to Self-Supervised Learning Details can be found in: """
[docs] def __init__( self, backbone: nn.Module, projector: Optional[nn.Module] = None, predictor: Optional[nn.Module] = None, feature_dim: int = 2048, predictor_inner: int = 4096, mmt: float = 0.99, norm: Callable = nn.SyncBatchNorm, ) -> None: """ Args: backbone (nn.Module): backbone for byol, input shape depends on the forward input size. Standard inputs include `B x C`, `B x C x H x W`, and `B x C x T x H x W`. projector (nn.Module): stand projector is a mlp with 2 to 3 hidden layers, with (synchronized) BatchNorm and ReLU activation. predictor (nn.Module): predictor MLP of BYOL of similar structure as the projector MLP. feature_dim (int): output feature dimension. predictor_inner (int): inner channel size for predictor. mmt (float): momentum update ratio for the momentum backbone. norm (callable): normalization to be used in projector, default is synchronized batchnorm. """ super().__init__() torch._C._log_api_usage_once("PYTORCHVIDEO.model.BYOL.__init__") self.mmt = mmt self.feature_dim = feature_dim if projector is not None: backbone = nn.Sequential( backbone, projector, ) self.backbone = backbone self.backbone_mmt = copy.deepcopy(backbone) for p in self.backbone_mmt.parameters(): p.requires_grad = False if predictor is None: self.predictor = nn.Sequential( nn.Linear(feature_dim, predictor_inner, bias=False), norm(predictor_inner), nn.ReLU(inplace=True), nn.Linear(predictor_inner, feature_dim, bias=True), ) else: self.predictor = predictor
[docs] def sim_loss(self, q, k): """ Similarity loss for byol. Args: q and k (nn.tensor): inputs to calculate the similarity, expected to have the same shape of `N x C`. """ similarity = torch.einsum("nc,nc->n", [q, k]) loss = -similarity.mean() return loss
[docs] def update_mmt(self, mmt: float): """ Update the momentum. This function can be used to perform momentum annealing. Args: mmt (float): update the momentum. """ self.mmt = mmt
[docs] def get_mmt(self) -> float: """ Get the momentum. This function can be used to perform momentum annealing. """ return self.mmt
@torch.no_grad() def _momentum_update_backbone(self): """ Momentum update on the backbone. """ for param, param_mmt in zip( self.backbone.parameters(), self.backbone_mmt.parameters() ): = * self.mmt + * (1.0 - self.mmt)
[docs] @torch.no_grad() def forward_backbone_mmt(self, x): """ Forward momentum backbone. Args: x (tensor): input to be forwarded. """ with torch.no_grad(): proj = self.backbone_mmt(x) return F.normalize(proj, dim=1)
[docs] def forward_backbone(self, x): """ Forward backbone. Args: x (tensor): input to be forwarded. """ proj = self.backbone(x) pred = self.predictor(proj) return F.normalize(pred, dim=1)
[docs] def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: """ Args: x1 (torch.tensor): a batch of image with augmentation. The input tensor shape should able to be feed into the backbone. x2 (torch.tensor): the size batch of image with different augmentation. The input tensor shape should able to be feed into the backbone. """ pred_1 = self.forward_backbone(x1) pred_2 = self.forward_backbone(x2) with torch.no_grad(): self._momentum_update_backbone() proj_mmt_1 = self.forward_backbone_mmt(x1) proj_mmt_2 = self.forward_backbone_mmt(x2) loss = ( self.sim_loss(pred_1, proj_mmt_2) + self.sim_loss(pred_2, proj_mmt_1) ) / 2 return loss
Read the Docs v: stable
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.