|
- from typing import Optional, Callable
- import math
- import warnings
- from itertools import repeat
- import collections.abc
- import numpy as np
- from scipy import special
-
- import mindspore as ms
- import mindspore.nn as nn
- import mindspore.ops as ops
- from mindspore.ops import functional as F
- from mindspore.common import initializer as init
- from mindspore.common.initializer import initializer, HeNormal
-
- __all__ = ['DropPath', 'trunc_normal_']
-
- # ## Auxiliary Modules
- class DropPath(nn.Cell):
- """
- Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
- """
-
- def __init__(self, keep_prob=None, seed=0):
- super().__init__()
- self.keep_prob = keep_prob
- seed = min(seed, 0)
- self.rand = ops.UniformReal(seed = seed)
- self.shape = ops.Shape()
- self.floor = ops.Floor()
-
- def construct(self, x):
- if not self.training or self.keep_prob == 1:
- return x
-
- x_shape = self.shape(x)
- shape = (x.shape[0],) + (1,) * (len(x_shape) - 1)
- random_tensor = self.rand(shape)
- random_tensor = random_tensor + self.keep_prob
- random_tensor = self.floor(random_tensor)
- x = x / self.keep_prob
- x = x * random_tensor
-
- return x
-
-
- def _trunc_normal_(tensor, mean, std, a, b, seed = 2022):
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
-
- def norm_cdf(x):
- # Computes standard normal cumulative distribution function
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
-
- if (mean < a - 2 * std) or (mean > b + 2 * std):
- warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
- "The distribution of values may be incorrect.",
- stacklevel=2)
-
- # Values are generated by using a truncated uniform distribution and
- # then using the inverse CDF for the normal distribution.
- # Get upper and lower cdf values
- l = norm_cdf((a - mean) / std)
- u = norm_cdf((b - mean) / std)
-
- # Uniformly fill tensor with values from [l, u], then translate to
- # [2l-1, 2u-1].
- tensor = np.random.uniform(2 * l - 1, 2 * u - 1, tensor.shape)
-
- # Use inverse cdf transform for normal distribution to get truncated
- # standard normal
-
- tensor = special.erfinv(tensor)
-
- # Transform to proper mean, std
- tensor = tensor * std * math.sqrt(2.)
- tensor = ms.Tensor(tensor + mean, ms.float32)
-
- # Clamp to ensure it's in the proper range
- a = ms.Tensor(a, ms.float32)
- b = ms.Tensor(b, ms.float32)
-
- tensor = ops.clip_by_value(tensor, a, b)
-
- return tensor
-
-
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
- # type: (Tensor, float, float, float, float) -> Tensor
- r"""Args:
- tensor: an n-dimensional `torch.Tensor`
- mean: the mean of the normal distribution
- std: the standard deviation of the normal distribution
- a: the minimum cutoff value
- b: the maximum cutoff value
- """
- return _trunc_normal_(tensor, mean, std, a, b)
-
- if __name__ == "__main__":
- a = np.random.rand(3,5)
- a = ms.Tensor(a, ms.float32)
- b = trunc_normal_(a)
- print(a,b)
|