Source code for pytorchvideo.models.simclr
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from fvcore.nn.distributed import differentiable_all_gather
from pytorchvideo.layers.utils import set_attributes
[docs]class SimCLR(nn.Module):
"""
A Simple Framework for Contrastive Learning of Visual Representations
Details can be found from:
https://arxiv.org/abs/2002.05709
"""
def __init__(
self,
mlp: nn.Module,
backbone: Optional[nn.Module] = None,
temperature: float = 0.07,
) -> None:
super().__init__()
torch._C._log_api_usage_once("PYTORCHVIDEO.model.SimCLR.__init__")
set_attributes(self, locals())
[docs] def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
"""
Args:
x1 (torch.tensor): a batch of image with augmentation. The input tensor
shape should able to be feed into the backbone.
x2 (torch.tensor): the size batch of image with different augmentation. The
input tensor shape should able to be feed into the backbone.
"""
if self.backbone is not None:
x1 = self.backbone(x1)
x1 = self.mlp(x1)
x1 = F.normalize(x1, p=2, dim=1)
if self.backbone is not None:
x2 = self.backbone(x2)
x2 = self.mlp(x2)
x2 = F.normalize(x2, p=2, dim=1)
x2 = torch.cat(differentiable_all_gather(x2), dim=0)
prod = torch.einsum("nc,kc->nk", [x1, x2])
prod = prod.div(self.temperature)
batch_size = x1.size(0)
if dist.is_available() and dist.is_initialized():
device_ind = dist.get_rank()
else:
device_ind = 0
gt = (
torch.tensor(
list(range(device_ind * batch_size, (device_ind + 1) * batch_size))
)
.long()
.to(x1.device)
)
loss = torch.nn.functional.cross_entropy(prod, gt)
return loss