|
- # -*- coding: utf-8 -*-
- # File : batchnorm.py
- # Author : Jiayuan Mao
- # Email : maojiayuan@gmail.com
- # Date : 27/01/2018
- #
- # This file is part of Synchronized-BatchNorm-PyTorch.
- # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
- # Distributed under MIT License.
-
- import collections
-
- import torch
- import torch.nn.functional as F
-
- from torch.nn.modules.batchnorm import _BatchNorm
- from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
-
- from comm import SyncMaster
-
- __all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
-
-
- def _sum_ft(tensor):
- """sum over the first and last dimention"""
- return tensor.sum(dim=0).sum(dim=-1)
-
-
- def _unsqueeze_ft(tensor):
- """add new dementions at the front and the tail"""
- return tensor.unsqueeze(0).unsqueeze(-1)
-
-
- _ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
- _MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
-
-
- class _SynchronizedBatchNorm(_BatchNorm):
- def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
- super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
-
- self._sync_master = SyncMaster(self._data_parallel_master)
-
- self._is_parallel = False
- self._parallel_id = None
- self._slave_pipe = None
-
- def forward(self, input):
- # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
- if not (self._is_parallel and self.training):
- return F.batch_norm(
- input, self.running_mean, self.running_var, self.weight, self.bias,
- self.training, self.momentum, self.eps)
-
- # Resize the input to (B, C, -1).
- input_shape = input.size()
- input = input.view(input.size(0), self.num_features, -1)
-
- # Compute the sum and square-sum.
- sum_size = input.size(0) * input.size(2)
- input_sum = _sum_ft(input)
- input_ssum = _sum_ft(input ** 2)
-
- # Reduce-and-broadcast the statistics.
- if self._parallel_id == 0:
- mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
- else:
- mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
-
- # Compute the output.
- if self.affine:
- # MJY:: Fuse the multiplication for speed.
- output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
- else:
- output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
-
- # Reshape it.
- return output.view(input_shape)
-
- def __data_parallel_replicate__(self, ctx, copy_id):
- self._is_parallel = True
- self._parallel_id = copy_id
-
- # parallel_id == 0 means master device.
- if self._parallel_id == 0:
- ctx.sync_master = self._sync_master
- else:
- self._slave_pipe = ctx.sync_master.register_slave(copy_id)
-
- def _data_parallel_master(self, intermediates):
- """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
-
- # Always using same "device order" makes the ReduceAdd operation faster.
- # Thanks to:: Tete Xiao (http://tetexiao.com/)
- intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
-
- to_reduce = [i[1][:2] for i in intermediates]
- to_reduce = [j for i in to_reduce for j in i] # flatten
- target_gpus = [i[1].sum.get_device() for i in intermediates]
-
- sum_size = sum([i[1].sum_size for i in intermediates])
- sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
- mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
-
- broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
-
- outputs = []
- for i, rec in enumerate(intermediates):
- outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
-
- return outputs
-
- def _compute_mean_std(self, sum_, ssum, size):
- """Compute the mean and standard-deviation with sum and square-sum. This method
- also maintains the moving average on the master device."""
- assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
- mean = sum_ / size
- sumvar = ssum - sum_ * mean
- unbias_var = sumvar / (size - 1)
- bias_var = sumvar / size
-
- self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
- self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
-
- return mean, bias_var.clamp(self.eps) ** -0.5
-
-
- class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
- r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
- mini-batch.
-
- .. math::
-
- y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
-
- This module differs from the built-in PyTorch BatchNorm1d as the mean and
- standard-deviation are reduced across all devices during training.
-
- For example, when one uses `nn.DataParallel` to wrap the network during
- training, PyTorch's implementation normalize the tensor on each device using
- the statistics only on that device, which accelerated the computation and
- is also easy to implement, but the statistics might be inaccurate.
- Instead, in this synchronized version, the statistics will be computed
- over all training samples distributed on multiple devices.
-
- Note that, for one-GPU or CPU-only case, this module behaves exactly same
- as the built-in PyTorch implementation.
-
- The mean and standard-deviation are calculated per-dimension over
- the mini-batches and gamma and beta are learnable parameter vectors
- of size C (where C is the input size).
-
- During training, this layer keeps a running estimate of its computed mean
- and variance. The running sum is kept with a default momentum of 0.1.
-
- During evaluation, this running mean/variance is used for normalization.
-
- Because the BatchNorm is done over the `C` dimension, computing statistics
- on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
-
- Args:
- num_features: num_features from an expected input of size
- `batch_size x num_features [x width]`
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Default: 0.1
- affine: a boolean value that when set to ``True``, gives the layer learnable
- affine parameters. Default: ``True``
-
- Shape:
- - Input: :math:`(N, C)` or :math:`(N, C, L)`
- - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
-
- Examples:
- >>> # With Learnable Parameters
- >>> m = SynchronizedBatchNorm1d(100)
- >>> # Without Learnable Parameters
- >>> m = SynchronizedBatchNorm1d(100, affine=False)
- >>> input = torch.autograd.Variable(torch.randn(20, 100))
- >>> output = m(input)
- """
-
- def _check_input_dim(self, input):
- if input.dim() != 2 and input.dim() != 3:
- raise ValueError('expected 2D or 3D input (got {}D input)'
- .format(input.dim()))
- super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
-
-
- class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
- r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
- of 3d inputs
-
- .. math::
-
- y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
-
- This module differs from the built-in PyTorch BatchNorm2d as the mean and
- standard-deviation are reduced across all devices during training.
-
- For example, when one uses `nn.DataParallel` to wrap the network during
- training, PyTorch's implementation normalize the tensor on each device using
- the statistics only on that device, which accelerated the computation and
- is also easy to implement, but the statistics might be inaccurate.
- Instead, in this synchronized version, the statistics will be computed
- over all training samples distributed on multiple devices.
-
- Note that, for one-GPU or CPU-only case, this module behaves exactly same
- as the built-in PyTorch implementation.
-
- The mean and standard-deviation are calculated per-dimension over
- the mini-batches and gamma and beta are learnable parameter vectors
- of size C (where C is the input size).
-
- During training, this layer keeps a running estimate of its computed mean
- and variance. The running sum is kept with a default momentum of 0.1.
-
- During evaluation, this running mean/variance is used for normalization.
-
- Because the BatchNorm is done over the `C` dimension, computing statistics
- on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
-
- Args:
- num_features: num_features from an expected input of
- size batch_size x num_features x height x width
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Default: 0.1
- affine: a boolean value that when set to ``True``, gives the layer learnable
- affine parameters. Default: ``True``
-
- Shape:
- - Input: :math:`(N, C, H, W)`
- - Output: :math:`(N, C, H, W)` (same shape as input)
-
- Examples:
- >>> # With Learnable Parameters
- >>> m = SynchronizedBatchNorm2d(100)
- >>> # Without Learnable Parameters
- >>> m = SynchronizedBatchNorm2d(100, affine=False)
- >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
- >>> output = m(input)
- """
-
- def _check_input_dim(self, input):
- if input.dim() != 4:
- raise ValueError('expected 4D input (got {}D input)'
- .format(input.dim()))
- super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
-
-
- class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
- r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
- of 4d inputs
-
- .. math::
-
- y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
-
- This module differs from the built-in PyTorch BatchNorm3d as the mean and
- standard-deviation are reduced across all devices during training.
-
- For example, when one uses `nn.DataParallel` to wrap the network during
- training, PyTorch's implementation normalize the tensor on each device using
- the statistics only on that device, which accelerated the computation and
- is also easy to implement, but the statistics might be inaccurate.
- Instead, in this synchronized version, the statistics will be computed
- over all training samples distributed on multiple devices.
-
- Note that, for one-GPU or CPU-only case, this module behaves exactly same
- as the built-in PyTorch implementation.
-
- The mean and standard-deviation are calculated per-dimension over
- the mini-batches and gamma and beta are learnable parameter vectors
- of size C (where C is the input size).
-
- During training, this layer keeps a running estimate of its computed mean
- and variance. The running sum is kept with a default momentum of 0.1.
-
- During evaluation, this running mean/variance is used for normalization.
-
- Because the BatchNorm is done over the `C` dimension, computing statistics
- on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
- or Spatio-temporal BatchNorm
-
- Args:
- num_features: num_features from an expected input of
- size batch_size x num_features x depth x height x width
- eps: a value added to the denominator for numerical stability.
- Default: 1e-5
- momentum: the value used for the running_mean and running_var
- computation. Default: 0.1
- affine: a boolean value that when set to ``True``, gives the layer learnable
- affine parameters. Default: ``True``
-
- Shape:
- - Input: :math:`(N, C, D, H, W)`
- - Output: :math:`(N, C, D, H, W)` (same shape as input)
-
- Examples:
- >>> # With Learnable Parameters
- >>> m = SynchronizedBatchNorm3d(100)
- >>> # Without Learnable Parameters
- >>> m = SynchronizedBatchNorm3d(100, affine=False)
- >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
- >>> output = m(input)
- """
-
- def _check_input_dim(self, input):
- if input.dim() != 5:
- raise ValueError('expected 5D input (got {}D input)'
- .format(input.dim()))
- super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|