|
- import mindspore
- import mindspore.nn as nn
- import mindspore.ops as ops
- import mindspore.common.initializer
- from mindspore import Tensor
- # from utils import overlap_and_add
- import argparse
- import numpy as np
- from mindspore import context
- import math
- EPS = 1e-8
-
-
- class ConvTasNet(nn.Cell):
- def __init__(self, N, L, B, H, P, X, R, C, norm_type="gLN", causal=False,
- mask_nonlinear='relu'):
- """
- Args:
- N: Number of filters in autoencoder
- L: Length of the filters (in samples)
- B: Number of channels in bottleneck 1 × 1-conv block
- H: Number of channels in convolutional blocks
- P: Kernel size in convolutional blocks
- X: Number of convolutional blocks in each repeat
- R: Number of repeats
- C: Number of speakers
- norm_type: BN, gLN, cLN
- causal: causal or non-causal
- mask_nonlinear: use which non-linear function to generate mask
- """
- super(ConvTasNet, self).__init__()
- # Hyper-parameter
- self.N = N
- self.L = L
- self.B = B
- self.H = H
- self.P = P
- self.X = X
- self.R = R
- self.C = C
- self.norm_type = norm_type
- self.causal = causal
- self.mask_nonlinear = mask_nonlinear
- # Components
- self.encoder = Encoder(L, N)
- self.separator = TemporalConvNet(N, B, H, P, X, R, C, norm_type, causal, mask_nonlinear)
- self.decoder = Decoder(N, L)
- self.pad = nn.Pad(paddings=((0, 0), (0, 0), (0, 10)), mode="CONSTANT")
- self.print = ops.Print()
- # init
- for p in self.get_parameters():
- if p.ndim > 1:
- mindspore.common.initializer.HeNormal(p)
-
- def construct(self, mixture):
- """
- Args:
- mixture: [M, T], M is batch size, T is #samples
- Returns:
- est_source: [M, C, T]
- """
- mixture_w = self.encoder(mixture)
- est_mask = self.separator(mixture_w)
- #print(est_mask)
- #print("next")
- est_source = self.decoder(mixture_w, est_mask)
-
- # T changed after conv1d in encoder, fix it here
-
- T_origin = mixture.shape[-1]
- T_conv = est_source.shape[-1]
- #print(T_origin)
- #print(T_conv)
- #print(est_source)
- # pad = pad = nn.Pad(paddings=((0, 0), (0, 0), (0, T_origin - T_conv)), mode="CONSTANT")
- est_source = self.pad(est_source)
- #self.print(est_source)
- #self.print(est_source.shape)
- return est_source
-
-
- class Encoder(nn.Cell):
- """Estimation of the nonnegative mixture weight by a 1-D conv layer.
- """
- def __init__(self, L, N):
- super(Encoder, self).__init__()
- # Hyper-parameter
- self.L = L
- self.N = N
- # Components
- # 50% overlap
- self.conv1d_U = nn.Conv1d(1, N, kernel_size=L, stride=L // 2, has_bias=False, pad_mode="pad", weight_init="HeUniform")
- self.expanddims = ops.ExpandDims()
- self.relu = nn.ReLU()
-
- def construct(self, mixture):
- """
- Args:
- mixture: [M, T], M is batch size, T is #samples, N是通道数
- Returns:
- mixture_w: [M, N, K], where K = (T-L)/(L/2)+1 = 2T/L-1
- """
- mixture = self.expanddims(mixture, 1) # [M, 1, T]
- mixture_w = self.relu(self.conv1d_U(mixture)) # [M, N, K]
-
- return mixture_w
-
- def big_matrix():
- x = np.zeros((6398, 3199), np.float16)
- for i in range(3199):
- x[2*i, i] = 1
- x[2*i+1, i] = 1
- y = Tensor.from_numpy(x)
- return y
-
-
- class Decoder(nn.Cell):
- def __init__(self, N, L):
- super(Decoder, self).__init__()
- # Hyper-parameter
- self.N = N
- self.L = L
- # Components
- self.basis_signals = nn.Dense(N, L, has_bias=False)
- self.expanddims = ops.ExpandDims()
- self.transpose = ops.Transpose()
- self.zero = ops.Zeros()
- self.conc = ops.Concat(2)
- self.big_matrix = big_matrix()
-
- def construct(self, mixture_w, est_mask):
- """
- Args:
- mixture_w: [M, N, K]
- est_mask: [M, C, N, K] K = (T-L)/(L/2)+1 = 2T/L-1
- Returns:
- est_source: [M, C, T] #输出的【batch size,说话人数,T is #samples】
- """
- # D = W * M
- #print("decoder")
- source_w = self.expanddims(mixture_w, 1) * est_mask # [M, C, N, K]
- source_w = self.transpose(source_w, (0, 1, 3, 2))
- # S = DV
- est_source = self.basis_signals(source_w) # [M, C, K, L]
- est_source = self.overlap_and_add(est_source, self.L//2) # M x C x T
- #print(est_source.shape)
- return est_source
|