|
- # Copyright (c) OpenMMLab. All rights reserved.
- from mmengine.utils.dl_utils import TORCH_VERSION
- from mmengine.utils.version_utils import digit_version
- from .averaged_model import (BaseAveragedModel, ExponentialMovingAverage,
- MomentumAnnealingEMA, StochasticWeightAverage)
- from .base_model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
- from .base_module import BaseModule, ModuleDict, ModuleList, Sequential
- from .test_time_aug import BaseTTAModel
- from .utils import (convert_sync_batchnorm, detect_anomalous_params,
- merge_dict, revert_sync_batchnorm, stack_batch)
- from .weight_init import (BaseInit, Caffe2XavierInit, ConstantInit,
- KaimingInit, NormalInit, PretrainedInit,
- TruncNormalInit, UniformInit, XavierInit,
- bias_init_with_prob, caffe2_xavier_init,
- constant_init, initialize, kaiming_init, normal_init,
- trunc_normal_init, uniform_init, update_init_info,
- xavier_init)
- from .wrappers import (MMDistributedDataParallel,
- MMSeparateDistributedDataParallel, is_model_wrapper)
-
- __all__ = [
- 'MMDistributedDataParallel', 'is_model_wrapper', 'BaseAveragedModel',
- 'StochasticWeightAverage', 'ExponentialMovingAverage',
- 'MomentumAnnealingEMA', 'BaseModel', 'BaseDataPreprocessor',
- 'ImgDataPreprocessor', 'MMSeparateDistributedDataParallel', 'BaseModule',
- 'stack_batch', 'merge_dict', 'detect_anomalous_params', 'ModuleList',
- 'ModuleDict', 'Sequential', 'revert_sync_batchnorm', 'update_init_info',
- 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
- 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
- 'bias_init_with_prob', 'BaseInit', 'ConstantInit', 'XavierInit',
- 'NormalInit', 'TruncNormalInit', 'UniformInit', 'KaimingInit',
- 'Caffe2XavierInit', 'PretrainedInit', 'initialize',
- 'convert_sync_batchnorm', 'BaseTTAModel'
- ]
-
- if digit_version(TORCH_VERSION) >= digit_version('2.0.0'):
- from .wrappers import MMFullyShardedDataParallel # noqa:F401
- __all__.append('MMFullyShardedDataParallel')
|