|
- # from .basics import *
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore.common.initializer import Normal, initializer
-
-
- # import pickle
- # import os
- # import codecs
-
- class Bitparm(nn.Cell):
- '''
- save params
- '''
-
- def __init__(self, channel, final=False):
- super(Bitparm, self).__init__()
- self.final = final
- self.h = ms.Parameter(initializer(Normal(mean=0, sigma=0.01), [1, channel, 1, 1]))
- self.b = ms.Parameter(initializer(Normal(mean=0, sigma=0.01), [1, channel, 1, 1]))
- if not final:
- self.a = ms.Parameter(initializer(Normal(mean=0, sigma=0.01), [1, channel, 1, 1]))
- else:
- self.a = None
-
- def construct(self, x):
- if self.final:
- ms_sigmoid = ms.ops.Sigmoid()
- ms_softplus = ms.ops.Softplus()
- return ms_sigmoid(x * ms_softplus(self.h) + self.b)
- else:
- ms_softplus = ms.ops.Softplus()
- x = x * ms_softplus(self.h) + self.b
- ms_tanh = ms.ops.Tanh()
- return x + ms_tanh(x) * ms_tanh(self.a)
-
-
- class BitEstimator(nn.Cell):
- '''
- Estimate bit
- '''
-
- def __init__(self, channel):
- super(BitEstimator, self).__init__()
- self.f1 = Bitparm(channel)
- self.f2 = Bitparm(channel)
- self.f3 = Bitparm(channel)
- self.f4 = Bitparm(channel, True)
-
- def construct(self, x):
- x = self.f1(x)
- x = self.f2(x)
- x = self.f3(x)
- return self.f4(x)
- # for i in range(3):
- # # print(x.size(), F.softplus(self.h[i]).size())
- # x = x * F.softplus(self.h[i]) + self.b[i]
- # x = x + F.tanh(x) * F.tanh(self.a[i])
- # return F.sigmoid(x * F.softplus(self.h[3]) + self.b[3])
|