Source code for pytorchvideo.models.byol
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import copy
from typing import Callable, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]class BYOL(nn.Module):
"""
Bootstrap Your Own Latent A New Approach to Self-Supervised Learning
Details can be found in:
https://arxiv.org/pdf/2006.07733.pdf
"""
[docs] def __init__(
self,
backbone: nn.Module,
projector: Optional[nn.Module] = None,
predictor: Optional[nn.Module] = None,
feature_dim: int = 2048,
predictor_inner: int = 4096,
mmt: float = 0.99,
norm: Callable = nn.SyncBatchNorm,
) -> None:
"""
Args:
backbone (nn.Module): backbone for byol, input shape depends on the forward
input size. Standard inputs include `B x C`, `B x C x H x W`, and
`B x C x T x H x W`.
projector (nn.Module): stand projector is a mlp with 2 to 3 hidden layers,
with (synchronized) BatchNorm and ReLU activation.
predictor (nn.Module): predictor MLP of BYOL of similar structure as the
projector MLP.
feature_dim (int): output feature dimension.
predictor_inner (int): inner channel size for predictor.
mmt (float): momentum update ratio for the momentum backbone.
norm (callable): normalization to be used in projector, default is
synchronized batchnorm.
"""
super().__init__()
torch._C._log_api_usage_once("PYTORCHVIDEO.model.BYOL.__init__")
self.mmt = mmt
self.feature_dim = feature_dim
if projector is not None:
backbone = nn.Sequential(
backbone,
projector,
)
self.backbone = backbone
self.backbone_mmt = copy.deepcopy(backbone)
for p in self.backbone_mmt.parameters():
p.requires_grad = False
if predictor is None:
self.predictor = nn.Sequential(
nn.Linear(feature_dim, predictor_inner, bias=False),
norm(predictor_inner),
nn.ReLU(inplace=True),
nn.Linear(predictor_inner, feature_dim, bias=True),
)
else:
self.predictor = predictor
[docs] def sim_loss(self, q, k):
"""
Similarity loss for byol.
Args:
q and k (nn.tensor): inputs to calculate the similarity, expected to have
the same shape of `N x C`.
"""
similarity = torch.einsum("nc,nc->n", [q, k])
loss = -similarity.mean()
return loss
[docs] def update_mmt(self, mmt: float):
"""
Update the momentum. This function can be used to perform momentum annealing.
Args:
mmt (float): update the momentum.
"""
self.mmt = mmt
[docs] def get_mmt(self) -> float:
"""
Get the momentum. This function can be used to perform momentum annealing.
"""
return self.mmt
@torch.no_grad()
def _momentum_update_backbone(self):
"""
Momentum update on the backbone.
"""
for param, param_mmt in zip(
self.backbone.parameters(), self.backbone_mmt.parameters()
):
param_mmt.data = param_mmt.data * self.mmt + param.data * (1.0 - self.mmt)
[docs] @torch.no_grad()
def forward_backbone_mmt(self, x):
"""
Forward momentum backbone.
Args:
x (tensor): input to be forwarded.
"""
with torch.no_grad():
proj = self.backbone_mmt(x)
return F.normalize(proj, dim=1)
[docs] def forward_backbone(self, x):
"""
Forward backbone.
Args:
x (tensor): input to be forwarded.
"""
proj = self.backbone(x)
pred = self.predictor(proj)
return F.normalize(pred, dim=1)
[docs] def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
Args:
x1 (torch.tensor): a batch of image with augmentation. The input tensor
shape should able to be feed into the backbone.
x2 (torch.tensor): the size batch of image with different augmentation. The
input tensor shape should able to be feed into the backbone.
"""
pred_1 = self.forward_backbone(x1)
pred_2 = self.forward_backbone(x2)
with torch.no_grad():
self._momentum_update_backbone()
proj_mmt_1 = self.forward_backbone_mmt(x1)
proj_mmt_2 = self.forward_backbone_mmt(x2)
loss = (
self.sim_loss(pred_1, proj_mmt_2) + self.sim_loss(pred_2, proj_mmt_1)
) / 2
return loss