|
- # Copyright (c) OpenMMLab. All rights reserved.
- import logging
- import warnings
- from typing import List, Union
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
-
- from mmengine.logging import print_log
- from mmengine.utils.dl_utils import mmcv_full_available
-
-
- def stack_batch(tensor_list: List[torch.Tensor],
- pad_size_divisor: int = 1,
- pad_value: Union[int, float] = 0) -> torch.Tensor:
- """Stack multiple tensors to form a batch and pad the tensor to the max
- shape use the right bottom padding mode in these images. If
- ``pad_size_divisor > 0``, add padding to ensure the shape of each dim is
- divisible by ``pad_size_divisor``.
-
- Args:
- tensor_list (List[Tensor]): A list of tensors with the same dim.
- pad_size_divisor (int): If ``pad_size_divisor > 0``, add padding
- to ensure the shape of each dim is divisible by
- ``pad_size_divisor``. This depends on the model, and many
- models need to be divisible by 32. Defaults to 1
- pad_value (int, float): The padding value. Defaults to 0.
-
- Returns:
- Tensor: The n dim tensor.
- """
- assert isinstance(
- tensor_list,
- list), (f'Expected input type to be list, but got {type(tensor_list)}')
- assert tensor_list, '`tensor_list` could not be an empty list'
- assert len({
- tensor.ndim
- for tensor in tensor_list
- }) == 1, (f'Expected the dimensions of all tensors must be the same, '
- f'but got {[tensor.ndim for tensor in tensor_list]}')
-
- dim = tensor_list[0].dim()
- num_img = len(tensor_list)
- all_sizes: torch.Tensor = torch.Tensor(
- [tensor.shape for tensor in tensor_list])
- max_sizes = torch.ceil(
- torch.max(all_sizes, dim=0)[0] / pad_size_divisor) * pad_size_divisor
- padded_sizes = max_sizes - all_sizes
- # The first dim normally means channel, which should not be padded.
- padded_sizes[:, 0] = 0
- if padded_sizes.sum() == 0:
- return torch.stack(tensor_list)
- # `pad` is the second arguments of `F.pad`. If pad is (1, 2, 3, 4),
- # it means that padding the last dim with 1(left) 2(right), padding the
- # penultimate dim to 3(top) 4(bottom). The order of `pad` is opposite of
- # the `padded_sizes`. Therefore, the `padded_sizes` needs to be reversed,
- # and only odd index of pad should be assigned to keep padding "right" and
- # "bottom".
- pad = torch.zeros(num_img, 2 * dim, dtype=torch.int)
- pad[:, 1::2] = padded_sizes[:, range(dim - 1, -1, -1)]
- batch_tensor = []
- for idx, tensor in enumerate(tensor_list):
- batch_tensor.append(
- F.pad(tensor, tuple(pad[idx].tolist()), value=pad_value))
- return torch.stack(batch_tensor)
-
-
- def detect_anomalous_params(loss: torch.Tensor, model) -> None:
- parameters_in_graph = set()
- visited = set()
-
- def traverse(grad_fn):
- if grad_fn is None:
- return
- if grad_fn not in visited:
- visited.add(grad_fn)
- if hasattr(grad_fn, 'variable'):
- parameters_in_graph.add(grad_fn.variable)
- parents = grad_fn.next_functions
- if parents is not None:
- for parent in parents:
- grad_fn = parent[0]
- traverse(grad_fn)
-
- traverse(loss.grad_fn)
- for n, p in model.named_parameters():
- if p not in parameters_in_graph and p.requires_grad:
- print_log(
- f'{n} with shape {p.size()} is not '
- f'in the computational graph \n',
- logger='current',
- level=logging.ERROR)
-
-
- def merge_dict(*args):
- """Merge all dictionaries into one dictionary.
-
- If pytorch version >= 1.8, ``merge_dict`` will be wrapped
- by ``torch.fx.wrap``, which will make ``torch.fx.symbolic_trace`` skip
- trace ``merge_dict``.
-
- Note:
- If a function needs to be traced by ``torch.fx.symbolic_trace``,
- but inevitably needs to use ``update`` method of ``dict``(``update``
- is not traceable). It should use ``merge_dict`` to replace
- ``xxx.update``.
-
- Args:
- *args: dictionary needs to be merged.
-
- Returns:
- dict: Merged dict from args
- """
- output = dict()
- for item in args:
- assert isinstance(
- item,
- dict), (f'all arguments of merge_dict should be a dict, but got '
- f'{type(item)}')
- output.update(item)
- return output
-
-
- # torch.fx is only available when pytorch version >= 1.8.
- # If the subclass of `BaseModel` has multiple submodules, and each module
- # will return a loss dict during training process, i.e., `TwoStageDetector`
- # in mmdet. It should use `merge_dict` to get the total loss, rather than
- # `loss.update` to keep model traceable.
- try:
- import torch.fx
-
- # make torch.fx skip trace `merge_dict`.
- merge_dict = torch.fx.wrap(merge_dict)
-
- except ImportError:
- warnings.warn('Cannot import torch.fx, `merge_dict` is a simple function '
- 'to merge multiple dicts')
-
-
- class _BatchNormXd(nn.modules.batchnorm._BatchNorm):
- """A general BatchNorm layer without input dimension check.
-
- Reproduced from @kapily's work:
- (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
- The only difference between BatchNorm1d, BatchNorm2d, BatchNorm3d, etc
- is `_check_input_dim` that is designed for tensor sanity checks.
- The check has been bypassed in this class for the convenience of converting
- SyncBatchNorm.
- """
-
- def _check_input_dim(self, input: torch.Tensor):
- return
-
-
- def revert_sync_batchnorm(module: nn.Module) -> nn.Module:
- """Helper function to convert all `SyncBatchNorm` (SyncBN) and
- `mmcv.ops.sync_bn.SyncBatchNorm`(MMSyncBN) layers in the model to
- `BatchNormXd` layers.
-
- Adapted from @kapily's work:
- (https://github.com/pytorch/pytorch/issues/41081#issuecomment-783961547)
-
- Args:
- module (nn.Module): The module containing `SyncBatchNorm` layers.
-
- Returns:
- module_output: The converted module with `BatchNormXd` layers.
- """
- module_output = module
- module_checklist = [torch.nn.modules.batchnorm.SyncBatchNorm]
-
- if mmcv_full_available():
- from mmcv.ops import SyncBatchNorm
- module_checklist.append(SyncBatchNorm)
-
- if isinstance(module, tuple(module_checklist)):
- module_output = _BatchNormXd(module.num_features, module.eps,
- module.momentum, module.affine,
- module.track_running_stats)
- if module.affine:
- # no_grad() may not be needed here but
- # just to be consistent with `convert_sync_batchnorm()`
- with torch.no_grad():
- module_output.weight = module.weight
- module_output.bias = module.bias
- module_output.running_mean = module.running_mean
- module_output.running_var = module.running_var
- module_output.num_batches_tracked = module.num_batches_tracked
- module_output.training = module.training
- # qconfig exists in quantized models
- if hasattr(module, 'qconfig'):
- module_output.qconfig = module.qconfig
- for name, child in module.named_children():
- # Some custom modules or 3rd party implemented modules may raise an
- # error when calling `add_module`. Therefore, try to catch the error
- # and do not raise it. See https://github.com/open-mmlab/mmengine/issues/638 # noqa: E501
- # for more details.
- try:
- module_output.add_module(name, revert_sync_batchnorm(child))
- except Exception:
- print_log(
- F'Failed to convert {child} from SyncBN to BN!',
- logger='current',
- level=logging.WARNING)
- del module
- return module_output
-
-
- def convert_sync_batchnorm(module: nn.Module,
- implementation='torch') -> nn.Module:
- """Helper function to convert all `BatchNorm` layers in the model to
- `SyncBatchNorm` (SyncBN) or `mmcv.ops.sync_bn.SyncBatchNorm` (MMSyncBN)
- layers. Adapted from `PyTorch convert sync batchnorm`_.
-
- Args:
- module (nn.Module): The module containing `SyncBatchNorm` layers.
- implementation (str): The type of `SyncBatchNorm` to convert to.
-
- - 'torch': convert to `torch.nn.modules.batchnorm.SyncBatchNorm`.
- - 'mmcv': convert to `mmcv.ops.sync_bn.SyncBatchNorm`.
-
- Returns:
- nn.Module: The converted module with `SyncBatchNorm` layers.
-
- .. _PyTorch convert sync batchnorm:
- https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#torch.nn.SyncBatchNorm.convert_sync_batchnorm
- """ # noqa: E501
- module_output = module
-
- if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
- if implementation == 'torch':
- SyncBatchNorm = torch.nn.modules.batchnorm.SyncBatchNorm
- elif implementation == 'mmcv':
- from mmcv.ops import SyncBatchNorm # type: ignore
- else:
- raise ValueError('sync_bn should be "torch" or "mmcv", but got '
- f'{implementation}')
-
- module_output = SyncBatchNorm(module.num_features, module.eps,
- module.momentum, module.affine,
- module.track_running_stats)
-
- if module.affine:
- with torch.no_grad():
- module_output.weight = module.weight
- module_output.bias = module.bias
- module_output.running_mean = module.running_mean
- module_output.running_var = module.running_var
- module_output.num_batches_tracked = module.num_batches_tracked
- if hasattr(module, 'qconfig'):
- module_output.qconfig = module.qconfig
- for name, child in module.named_children():
- module_output.add_module(name,
- convert_sync_batchnorm(child, implementation))
- del module
- return module_output
|