|
- import mindspore
- import mindspore.ops as ops
- from mindspore.ops import ReLU, Zeros, Concat
- from mindspore import context, Tensor
- from transformer_improved import TransformerEncoderLayer
- import mindspore.common.initializer
- import math
- import mindspore.nn as nn
- import numpy as np
-
- EPS = 1e-8
-
- class Encoder(nn.Cell):
- """Estimation of the nonnegative mixture weight by a 1-D conv layer.
- """
- def __init__(self, W=2, N=64):
- super(Encoder, self).__init__()
- # Hyper-parameter
- self.conv1d_U = nn.Conv1d(1, N, kernel_size=W, stride=W // 2, has_bias=False, weight_init="HeUniform", pad_mode="pad")
- self.expand_dims = ops.ExpandDims()
- self.relu = ReLU()
-
- def construct(self, mixture):
- """
- Args:
- mixture: [B, T], B is batch size, T is #samples
- Returns:
- mixture_w: [B, N, L], where L = (T-W)/(W/2)+1 = 2T/W-1
- L is the number of time steps
- """
- mixture = self.expand_dims(mixture, 1) # [B, 1, T]
- mixture_w = self.conv1d_U(mixture) # [B, N, L]
- mixture_w = self.relu(mixture_w)
- return mixture_w
-
-
- def big_matrix():
- x = np.zeros((64000, 32000), np.float16)
-
- for i in range(32000):
- x[2 * i, i] = 1
- x[2 * i + 1, i] = 1
- big_num = Tensor.from_numpy(x)
- return big_num
-
-
- class Decoder(nn.Cell):
- def __init__(self, E, W):
- super(Decoder, self).__init__()
- # Hyper-parameter
- self.E, self.W = E, W
- # Components
- self.expand_dims = ops.ExpandDims()
- self.basis_signals = nn.Dense(E, W, has_bias=False, weight_init="HeUniform")
- self.zeros = ops.Zeros()
- self.concat = ops.Concat(2)
- self.big_num = big_matrix()
-
- def construct(self, mixture_w, est_mask):
- """
- Args:
- mixture_w: [B, E, L]
- est_mask: [B, C, E, L]
- Returns:
- est_source: [B, C, T]
- """
- source_w = self.expand_dims(mixture_w, 1) * est_mask # [B, C, E, L]
- source_w = source_w.transpose((0, 1, 3, 2))
- # S = DV
- est_source = self.basis_signals(source_w) # [B, C, L, W]
- est_source = self.overlap_and_add(est_source, self.W//2) # B x C x T
- return est_source
-
- def overlap_and_add(self, signal, frame_step):
- outer_dimensions = signal.shape[:-2]
- frames, frame_length = signal.shape[-2:]
-
- # subframe_length = math.gcd(frame_length, frame_step) # gcd=Greatest Common Divisor
- # subframe_step = frame_step // subframe_length
- # subframes_per_frame = frame_length // subframe_length
- # output_size = frame_step * (frames - 1) + frame_length
- # output_subframes = output_size // subframe_length
-
- a, b = outer_dimensions
-
- # subframe_signal = signal.view((a, b, -1, subframe_length))
- subframe_signal = signal.view((a, b, -1, 1))
-
- # frame = mindspore.numpy.arange(0, output_subframes).unfold(0, subframes_per_frame, subframe_step)
- # frame = mindspore.numpy.arange(0, output_subframes)
- # frame = ops.Concat(-1)((ops.expand_dims(frame[0:-1:subframe_step], 1), ops.expand_dims(frame[1::subframe_step], 1)))
- # # frame = frame.view(-1)
-
-
- pad = self.zeros((subframe_signal.shape[0], subframe_signal.shape[1], 1, subframe_signal.shape[3]),
- mindspore.float32)
-
- subframe_signal = self.concat((pad, subframe_signal, pad))
- subframe_signal_ = subframe_signal.transpose((0, 1, 3, 2))
-
- subframe_signal_first_ = ops.matmul(subframe_signal_, self.big_num)
- result = subframe_signal_first_.transpose((0, 1, 3, 2))
- result = result.view((a, b, -1))
- return result
-
-
- class SingleTransformer(nn.Cell):
- """
- Container module for a single Transformer layer.
- args: input_size: int, dimension of the input feature. The input should have shape (batch, seq_len, input_size).
- """
- def __init__(self, input_size, hidden_size, batch_size=2):
- super(SingleTransformer, self).__init__()
- self.transformer = TransformerEncoderLayer(d_model=input_size, nhead=4, hidden_size=hidden_size,
- batch_size=batch_size)
-
- def construct(self, output):
- transformer_output = self.transformer(output)
- return transformer_output
-
-
- # dual-path transformer
- class DPT(nn.Cell):
- """
- Deep dual-path transformer.
-
- args:
- input_size: int, dimension of the input feature. The input should have shape
- (batch, seq_len, input_size).
- hidden_size: int, dimension of the hidden state.
- output_size: int, dimension of the output size.
- num_layers: int, number of stacked Transformer layers. Default is 1.
- dropout: float, dropout ratio. Default is 0.
- """
-
- def __init__(self, input_size, hidden_size, output_size, batch_size=2, num_layers=1):
- super(DPT, self).__init__()
-
- self.input_size = input_size
- self.output_size = output_size
- self.hidden_size = hidden_size
- self.batch_size = batch_size
-
- # dual-path transformer
- self.row_transformer = nn.CellList([])
- self.col_transformer = nn.CellList([])
- for i in range(num_layers):
- self.row_transformer.append(SingleTransformer(input_size, hidden_size, batch_size))
- self.col_transformer.append(SingleTransformer(input_size, hidden_size, batch_size))
- self.prelu = nn.PReLU()
- self.conv2d = nn.Conv2d(input_size, output_size, 1, weight_init="HeUniform")
-
- def construct(self, output):
- # input shape: batch, N, dim1, dim2
- # apply transformer on dim1 first and then dim2
- # output shape: B, output_size, dim1, dim2
- batch_size, _, dim1, dim2 = output.shape
- for i in range(len(self.row_transformer)):
- row_input = output.transpose((0, 3, 2, 1)).view((batch_size * dim2, dim1, -1)) # B*dim2, dim1, N
- row_output = self.row_transformer[i](row_input) # B*dim2, dim1, H
-
- row_output = row_output.view((batch_size, dim2, dim1, -1)).transpose((0, 3, 2, 1)) # B, N, dim1, dim2
-
- output = row_output
-
- col_input = output.transpose((0, 2, 3, 1)).view((batch_size * dim1, dim2, -1)) # B*dim1, dim2, N
- col_output = self.col_transformer[i](col_input) # B*dim1, dim2, H
- col_output = col_output.view((batch_size, dim1, dim2, -1)).transpose((0, 3, 1, 2)) # B, N, dim1, dim2
- output = col_output
-
- output = self.prelu(output)
- output = self.conv2d(output)
-
- return output
-
-
- class BF_module(nn.Cell):
- def __init__(self, input_dim, feature_dim, hidden_dim, num_spk=2, layer=6, segment_size=250, batch_size=2):
- super(BF_module, self).__init__()
-
- # gated output layer
-
- self.input_dim = input_dim
- self.feature_dim = feature_dim
- self.hidden_dim = hidden_dim
-
- self.layer = layer
- self.segment_size = segment_size
- self.num_spk = num_spk
-
- self.batch_size = batch_size
-
- self.eps = 1e-8
-
- self.zero = Zeros()
- self.concat = Concat(2)
- self.concat2 = Concat(3)
-
- # bottleneck
- self.print = ops.Print()
- self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, weight_init="HeUniform")
-
- # DPT model
- self.DPT = DPT(self.feature_dim, self.hidden_dim, self.feature_dim * self.num_spk, self.batch_size,
- num_layers=self.layer)
-
- self.output = nn.SequentialCell(nn.Conv1d(self.feature_dim, self.feature_dim, 1, weight_init="HeUniform"),
- nn.Tanh()
- )
-
- self.conv1d1 = nn.Conv1d(self.feature_dim, self.feature_dim, 1, weight_init="HeUniform")
- self.tanh = nn.Tanh()
-
- self.output_gate = nn.SequentialCell(nn.Conv1d(self.feature_dim, self.feature_dim, 1, weight_init="HeUniform"),
- nn.Sigmoid()
- )
-
- self.sigmoid = nn.Sigmoid()
-
- def construct(self, input):
- # input: (B, E, T)
- batch_size, E, seq_length = input.shape
-
- enc_feature = self.BN(input) # (B, E, L)-->(B, N, L) #error
-
- # split the encoder output into overlapped, longer segments
- enc_segments, enc_rest = self.split_feature(enc_feature, self.segment_size) # B, N, L, K: L is the segment_size
-
- output = self.DPT(enc_segments).view((batch_size * self.num_spk, self.feature_dim, self.segment_size, -1)) # B*nspk, N, L, K
-
- # overlap-and-add of the outputs #[4, 64, 31999]
- output = self.merge_feature(output, enc_rest) # B*nspk, N, T
-
- # gated output layer for filter generation
- bf_filter = self.output(output) * self.output_gate(output) # B*nspk, K, T
- bf_filter = bf_filter.transpose((0, 2, 1)).view((batch_size, self.num_spk, -1, self.feature_dim)) # B, nspk, T, N
-
- return bf_filter
-
- def pad_segment(self, input, segment_size):
- # input is the features: (B, N, T)
- batch_size, dim, seq_len = input.shape
- segment_stride = segment_size // 2
-
- rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
- if rest > 0:
- pad = self.zero((batch_size, dim, rest), mindspore.float32)
- input = self.concat((input, pad)) #e
-
- pad_aux = self.zero((batch_size, dim, segment_stride), mindspore.float32)
- input = self.concat((pad_aux, input, pad_aux))
-
- return input, rest
-
- def split_feature(self, input, segment_size):
- # split the feature into chunks of segment size
- # input is the features: (B, N, T)
-
- input, rest = self.pad_segment(input, segment_size) #e
- batch_size, dim, seq_len = input.shape
- segment_stride = segment_size // 2
-
- segments1 = input[:, :, :-segment_stride].view((batch_size, dim, -1, segment_size))
- segments2 = input[:, :, segment_stride:].view((batch_size, dim, -1, segment_size))
- segments = self.concat2((segments1, segments2)).view((batch_size, dim, -1, segment_size)).transpose((0, 1, 3, 2))
-
- return segments, rest
-
- def merge_feature(self, input, rest):
- # merge the splitted features into full utterance
- # input is the features: (B, N, L, K)
-
- batch_size, dim, segment_size, _ = input.shape
- segment_stride = segment_size // 2
- input = input.transpose((0, 1, 3, 2)).view((batch_size, dim, -1, segment_size * 2)) # B, N, K, L
-
- input1 = input[:, :, :, :segment_size].view((batch_size, dim, -1))[:, :, segment_stride:]
- input2 = input[:, :, :, segment_size:].view((batch_size, dim, -1))[:, :, :-segment_stride]
-
- output = input1 + input2
- if rest > 0:
- output = output[:, :, :-rest]
-
- return output # B, N, T
-
-
- # base module for DPTNet_base
- class DPTNet_base(nn.Cell):
- def __init__(self, enc_dim, feature_dim, hidden_dim, layer, segment_size=250, nspk=2, win_len=2, batch_size=2):
- super(DPTNet_base, self).__init__()
-
- # parameters
- self.window = win_len
- self.stride = self.window // 2
-
- self.enc_dim = enc_dim
- self.feature_dim = feature_dim
- self.hidden_dim = hidden_dim
- self.segment_size = segment_size
-
- self.layer = layer
- self.num_spk = nspk
- self.eps = 1e-8
- self.batch_size = batch_size
-
- self.relu = ReLU()
- self.zeros = Zeros()
- self.concat = Concat(1)
-
- # waveform encoder
- self.encoder = Encoder(win_len, enc_dim) # [B T]-->[B N L]
- self.enc_LN = nn.GroupNorm(1, self.enc_dim, eps=self.eps, affine=True) # [B N L]-->[B N L]
- self.separator = BF_module(self.enc_dim, self.feature_dim, self.hidden_dim,
- self.num_spk, self.layer, self.segment_size, self.batch_size)
- # [B, N, L] -> [B, E, L]
- self.mask_conv1x1 = nn.Conv1d(self.feature_dim, self.enc_dim, 1, has_bias=False, weight_init="HeUniform")
- self.decoder = Decoder(enc_dim, win_len)
- self.expand_dims = ops.ExpandDims()
-
- for p in self.get_parameters():
- if p.ndim > 1:
- mindspore.common.initializer.XavierUniform(p)
-
- def construct(self, input):
- """
- input: shape (batch, T)
- """
- # pass to a DPT
- B, _ = input.shape
- mixture_w = self.encoder(input) # B, E, L
-
- mixture_w_t = self.expand_dims(mixture_w, 0). transpose((0, 2, 1, 3))
- # mixture_w_t = mixture_w.expand_dims(axis=0).transpose((0, 2, 1, 3))
- # print('mixture_w.shape {}'.format(mixture_w.shape))
- score_ = self.enc_LN(mixture_w_t) # B, E, L
-
- score_ = score_.transpose((0, 2, 1, 3)).squeeze(axis=0) #e
- score_ = self.separator(score_)
- score_ = score_.view((B*self.num_spk, -1, self.feature_dim)).transpose((0, 2, 1)) # B*nspk, N, T
-
- score = self.mask_conv1x1(score_) # [B*nspk, N, L] -> [B*nspk, E, L]
-
- score = score.view((B, self.num_spk, self.enc_dim, -1)) # [B*nspk, E, L] -> [B, nspk, E, L]
-
- est_mask = self.relu(score)
-
- est_source = self.decoder(mixture_w, est_mask) # [B, E, L] + [B, nspk, E, L]--> [B, nspk, T]
-
- return est_source
-
-
- if __name__ == "__main__":
- # context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
- context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
- model = DPTNet_base(enc_dim=256, feature_dim=64, hidden_dim=128, layer=1, segment_size=250, nspk=2, win_len=2, batch_size=2)
- # optimizier = TransformerOptimizer(optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9), k=0.2, d_model=64, warmup_steps=4000)
- print(model)
- x1 = np.ones((2, 32000)).astype(np.float32)
- y = mindspore.Tensor.from_numpy(x1)
-
- output = model(y)
|