Shortcuts

Source code for pytorchvideo.layers.batch_norm

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.

import torch
import torch.distributed as dist
from fvcore.nn.distributed import differentiable_all_reduce
from pytorchvideo.layers.distributed import get_world_size
from torch import nn


[docs]class NaiveSyncBatchNorm1d(nn.BatchNorm1d): """ An implementation of 1D naive sync batch normalization. See details in NaiveSyncBatchNorm2d below. """ def forward(self, input): if get_world_size() == 1 or not self.training: return super().forward(input) B, C = input.shape[0], input.shape[1] mean = torch.mean(input, dim=[0, 2]) meansqr = torch.mean(input * input, dim=[0, 2]) assert B > 0, "SyncBatchNorm does not support zero batch size." vec = torch.cat([mean, meansqr], dim=0) vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1) bias = bias.reshape(1, -1, 1) self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) return input * scale + bias
[docs]class NaiveSyncBatchNorm2d(nn.BatchNorm2d): """ An implementation of 2D naive sync batch normalization. In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient when the batch size on each worker is different. (e.g., when scale augmentation is used, or when it is applied to mask head). This is a slower but correct alternative to `nn.SyncBatchNorm`. Note: This module computes overall statistics by using statistics of each worker with equal weight. The result is true statistics of all samples (as if they are all on one worker) only when all workers have the same (N, H, W). This mode does not support inputs with zero batch size. """ def forward(self, input): if get_world_size() == 1 or not self.training: return super().forward(input) B, C = input.shape[0], input.shape[1] mean = torch.mean(input, dim=[0, 2, 3]) meansqr = torch.mean(input * input, dim=[0, 2, 3]) assert B > 0, "SyncBatchNorm does not support zero batch size." vec = torch.cat([mean, meansqr], dim=0) vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) return input * scale + bias
[docs]class NaiveSyncBatchNorm3d(nn.BatchNorm3d): """ An implementation of 3D naive sync batch normalization. See details in NaiveSyncBatchNorm2d above. """ def forward(self, input): if get_world_size() == 1 or not self.training: return super().forward(input) B, C = input.shape[0], input.shape[1] mean = torch.mean(input, dim=[0, 2, 3, 4]) meansqr = torch.mean(input * input, dim=[0, 2, 3, 4]) assert B > 0, "SyncBatchNorm does not support zero batch size." vec = torch.cat([mean, meansqr], dim=0) vec = differentiable_all_reduce(vec) * (1.0 / dist.get_world_size()) mean, meansqr = torch.split(vec, C) var = meansqr - mean * mean invstd = torch.rsqrt(var + self.eps) scale = self.weight * invstd bias = self.bias - mean * scale scale = scale.reshape(1, -1, 1, 1, 1) bias = bias.reshape(1, -1, 1, 1, 1) self.running_mean += self.momentum * (mean.detach() - self.running_mean) self.running_var += self.momentum * (var.detach() - self.running_var) return input * scale + bias
Read the Docs v: latest
Versions
latest
stable
Downloads
pdf
html
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.