|
- class ConvTasNet(nn.Cell):
- def __init__(self, N, L, B, H, P, X, R, C, norm_type="gLN", causal=False,
- mask_nonlinear='relu'):
- """
- Args:
- N: Number of filters in autoencoder
- L: Length of the filters (in samples)
- B: Number of channels in bottleneck 1 × 1-conv block
- H: Number of channels in convolutional blocks
- P: Kernel size in convolutional blocks
- X: Number of convolutional blocks in each repeat
- R: Number of repeats
- C: Number of speakers
- norm_type: BN, gLN, cLN
- causal: causal or non-causal
- mask_nonlinear: use which non-linear function to generate mask
- """
- super(ConvTasNet, self).__init__()
- # Hyper-parameter
- self.N = N
- self.L = L
- self.B = B
- self.H = H
- self.P = P
- self.X = X
- self.R = R
- self.C = C
- self.norm_type = norm_type
- self.causal = causal
- self.mask_nonlinear = mask_nonlinear
- # Components
- self.encoder = Encoder(L, N)
- self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
- self.decoder = Decoder(N, L)
- self.pad = nn.Pad(paddings=((0, 0), (0, 0), (0, 10)), mode="CONSTANT")
- self.print = ops.Print()
- # init
- for p in self.get_parameters():
- if p.ndim > 1:
- mindspore.common.initializer.HeNormal(p)
-
- def construct(self, mixture):
- """
- Args:
- mixture: [M, T], M is batch size, T is #samples
- Returns:
- est_source: [M, C, T]
- """
- mixture_w = self.encoder(mixture)
- est_mask = self.separator(mixture_w)
- #print(est_mask)
- #print("next")
- est_source = self.decoder(mixture_w, est_mask)
-
- # T changed after conv1d in encoder, fix it here
-
- T_origin = mixture.shape[-1]
- T_conv = est_source.shape[-1]
- #print(T_origin)
- #print(T_conv)
- #print(est_source)
- # pad = pad = nn.Pad(paddings=((0, 0), (0, 0), (0, T_origin - T_conv)), mode="CONSTANT")
- est_source = self.pad(est_source)
- #self.print(est_source)
- #self.print(est_source.shape)
- return est_source
-
-
- class Encoder(nn.Cell):
- """Estimation of the nonnegative mixture weight by a 1-D conv layer.
- """
- def __init__(self, L, N):
- super(Encoder, self).__init__()
- # Hyper-parameter
- self.L = L
- self.N = N
- # Components
- # 50% overlap
- self.conv1d_U = nn.Conv1d(1, N, kernel_size=L, stride=L // 2, has_bias=False, pad_mode="pad", weight_init="HeUniform")
- self.expanddims = ops.ExpandDims()
- self.relu = nn.ReLU()
-
- def construct(self, mixture):
- """
- Args:
- mixture: [M, T], M is batch size, T is #samples, N是通道数
- Returns:
- mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
- """
- mixture = self.expanddims(mixture, 1) # [M, 1, T]
- mixture_w = self.relu(self.conv1d_U(mixture)) # [M, N, K]
-
- return mixture_w
-
- def big_matrix():
- x = np.zeros((6398, 3199), np.float16)
- for i in range(3199):
- x[2*i, i] = 1
- x[2*i+1, i] = 1
- y = Tensor.from_numpy(x)
- return y
-
-
- class Decoder(nn.Cell):
- def __init__(self, N, L):
- super(Decoder, self).__init__()
- # Hyper-parameter
- self.N = N
- self.L = L
- # Components
- self.basis_signals = nn.Dense(N, L, has_bias=False)
- self.expanddims = ops.ExpandDims()
- self.transpose = ops.Transpose()
- self.zero = ops.Zeros()
- self.conc = ops.Concat(2)
- self.big_matrix = big_matrix()
-
- def construct(self, mixture_w, est_mask):
- """
- Args:
- mixture_w: [M, N, K]
- est_mask: [M, C, N, K] K = (T-L)/(L/2)+1 = 2T/L-1
- Returns:
- est_source: [M, C, T] #输出的【batch size,说话人数,T is #samples】
- """
- # D = W * M
- #print("decoder")
- source_w = self.expanddims(mixture_w, 1) * est_mask # [M, C, N, K]
- source_w = self.transpose(source_w, (0, 1, 3, 2))
- # S = DV
- est_source = self.basis_signals(source_w) # [M, C, K, L]
- est_source = self.overlap_and_add(est_source, self.L//2) # M x C x T
- #print(est_source.shape)
- return est_source
-
- def overlap_and_add(self, signal, frame_step):
- """Reconstructs a signal from a framed representation.
-
- Adds potentially overlapping frames of a signal with shape
- `[..., frames, frame_length]`, offsetting subsequent frames by `frame_step`.
- The resulting tensor has shape `[..., output_size]` where
-
- output_size = (frames - 1) * frame_step + frame_length
-
- Args:
- signal: A [..., frames, frame_length] Tensor. All dimensions may be unknown, and rank must be at least 2.
- frame_step: An integer denoting overlap offsets. Must be less than or equal to frame_length.
-
- Returns:
- A Tensor with shape [..., output_size] containing the overlap-added frames of signal's inner-most two dimensions.
- output_size = (frames - 1) * frame_step + frame_length
-
- Based on https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/contrib/signal/python/ops/reconstruction_ops.py
- """
-
- outer_dimensions = signal.shape[:-2]
- frames, frame_length = signal.shape[-2:]
-
- subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
- subframe_step = frame_step // subframe_length
- subframes_per_frame = frame_length // subframe_length
- output_size = frame_step * (frames - 1) + frame_length
- output_subframes = output_size // subframe_length
-
- subframe_signal = signal.view(*outer_dimensions, -1, subframe_length)
-
- # frame = mindspore.numpy.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
- frame = mindspore.numpy.arange(0, output_subframes)
- frame = ops.Concat(-1)(
- (ops.expand_dims(frame[0:-1:subframe_step], 1), ops.expand_dims(frame[1::subframe_step], 1)))
- # frame = signal.new_tensor(frame).long() # signal may in GPU or CPU
- frame = frame.view(-1)
-
- # zeros = ops.Zeros()
- # [2, 2, 32000,1]
- # a, b = outer_dimensions
- # result = self.zero((a, b, output_subframes, subframe_length), mindspore.float32)
- # i = 0
- # j = 0
- # while (j < subframe_signal.shape[2]) and (i < len(frame)):
- # if i == (len(frame) - 1) or frame[i] != frame[i + 1]:
- # result[:, :, frame[i], :] += subframe_signal[:, :, j, :]
- # j = j + 1
- # i = i + 1
- # else:
- # result[:, :, frame[i], :] += subframe_signal[:, :, j, :]
- # result[:, :, frame[i], :] += subframe_signal[:, :, j + 1, :]
- # j = j + 2
- # i = i + 2
- # transpose = ops.Transpose()
-
- # result = transpose(result, (2, 1, 0, 3))
- #
- # frame = torch.unsqueeze(frame, 1).repeat(1, subframe_signal)
- # indice = frame
- # result = ops.ScatterNdAdd(result, indice, subframe_signal)
- # # indices = frame
- # # inplaceAdd = ops.InplaceAdd(indices)
- # # result = inplaceAdd(result, subframe_signal)
- # result = transpose(result, (2, 1, 0, 3))
-
- #pad = self.zero((subframe_signal.shape[0], subframe_signal.shape[1], 1, subframe_signal.shape[3]),
- #mindspore.float32)
- #subframe_signal = self.conc((pad, subframe_signal, pad))
- subframe_signal_ = subframe_signal.transpose((0, 1, 3, 2))
|