|
- self.output_size = output_size
- self.hidden_size = hidden_size
- self.batch_size = batch_size
-
- # dual-path transformer
- self.row_transformer = nn.CellList([])
- self.col_transformer = nn.CellList([])
- for i in range(num_layers):
- self.row_transformer.append(SingleTransformer(input_size, hidden_size, batch_size))
- self.col_transformer.append(SingleTransformer(input_size, hidden_size, batch_size))
- self.prelu = nn.PReLU()
- # self.conv2d = nn.Conv2d(input_size, output_size, 1, weight_init="HeUniform")
- self.init_tensor = initializer(HeUniform(), [output_size, input_size, 1, 1], mindspore.float16)
- self.conv2d = nn.Conv2d(input_size, output_size, 1, weight_init=self.init_tensor)
-
- def construct(self, input):
- # input shape: batch, N, dim1, dim2
- # apply transformer on dim1 first and then dim2
- # output shape: B, output_size, dim1, dim2
- batch_size, _, dim1, dim2 = input.shape
- output = input
- for i in range(len(self.row_transformer)):
- row_input = output.transpose((0, 3, 2, 1)).view((batch_size * dim2, dim1, -1)) # B*dim2, dim1, N
- row_output = self.row_transformer[i](row_input) # B*dim2, dim1, H
-
- row_output = row_output.view((batch_size, dim2, dim1, -1)).transpose((0, 3, 2, 1)) # B, N, dim1, dim2
-
- output = row_output + output
-
- col_input = output.transpose((0, 2, 3, 1)).view((batch_size * dim1, dim2, -1)) # B*dim1, dim2, N
- col_output = self.col_transformer[i](col_input) # B*dim1, dim2, H
- col_output = col_output.view((batch_size, dim1, dim2, -1)).transpose((0, 3, 1, 2)) # B, N, dim1, dim2
- output = col_output
-
- output = self.prelu(output)
- output = self.conv2d(output)
-
- return output
-
-
- class BF_module(nn.Cell):
- def __init__(self, input_dim, feature_dim, hidden_dim, num_spk=2, layer=6, segment_size=250, batch_size=2):
- super(BF_module, self).__init__()
-
- # gated output layer
-
- self.input_dim = input_dim
- self.feature_dim = feature_dim
- self.hidden_dim = hidden_dim
-
- self.layer = layer
- self.segment_size = segment_size
- self.num_spk = num_spk
-
- self.batch_size = batch_size
-
- self.eps = 1e-8
-
- self.zero = Zeros()
- self.concat = Concat(2)
- self.concat2 = Concat(3)
-
- # bottleneck
- self.print = ops.Print()
- # self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, weight_init="HeUniform")
-
- self.init_tensor = initializer(HeUniform(), [self.feature_dim, self.input_dim, 1], mindspore.float16)
- self.BN = nn.Conv1d(self.input_dim, self.feature_dim, 1, weight_init=self.init_tensor)
-
- # DPT model
- self.DPT = DPT(self.feature_dim, self.hidden_dim, self.feature_dim * self.num_spk, self.batch_size,
- num_layers=self.layer)
-
-
- # self.output = nn.SequentialCell(nn.Conv1d(self.feature_dim, self.feature_dim, 1, weight_init="HeUniform"),
- # nn.Tanh()
- # )
-
- # self.output_gate = nn.SequentialCell(nn.Conv1d(self.feature_dim, self.feature_dim, 1, weight_init="HeUniform"),
- # nn.Sigmoid()
- # )
- self.init_tensor2 = initializer(HeUniform(), [self.feature_dim, self.feature_dim, 1], mindspore.float16)
- self.output = nn.SequentialCell(nn.Conv1d(self.feature_dim, self.feature_dim, 1, weight_init=self.init_tensor2),
- nn.Tanh()
- )
-
- self.output_gate = nn.SequentialCell(nn.Conv1d(self.feature_dim, self.feature_dim, 1, weight_init=self.init_tensor2),
- nn.Sigmoid()
- )
-
-
- def construct(self, input):
- # input: (B, E, T)
- batch_size, E, seq_length = input.shape
-
- enc_feature = self.BN(input) # (B, E, L)-->(B, N, L) #error
-
- # split the encoder output into overlapped, longer segments
- enc_segments, enc_rest = self.split_feature(enc_feature, self.segment_size) # B, N, L, K: L is the segment_size
-
- output = self.DPT(enc_segments).view((batch_size * self.num_spk, self.feature_dim, self.segment_size, -1)) # B*nspk, N, L, K
-
- # overlap-and-add of the outputs #[4, 64, 31999]
- output = self.merge_feature(output, enc_rest) # B*nspk, N, T
-
- # gated output layer for filter generation
- bf_filter = self.output(output) * self.output_gate(output) # B*nspk, K, T
- bf_filter = bf_filter.transpose((0, 2, 1)).view((batch_size, self.num_spk, -1, self.feature_dim)) # B, nspk, T, N
-
- return bf_filter
-
- def pad_segment(self, input, segment_size):
- # input is the features: (B, N, T)
- batch_size, dim, seq_len = input.shape
- segment_stride = segment_size // 2
-
- rest = segment_size - (segment_stride + seq_len % segment_size) % segment_size
- if rest > 0:
- pad = self.zero((batch_size, dim, rest), mindspore.float16)
- # pad = self.zero((batch_size, dim, rest), mindspore.float32)
- input = self.concat((input, pad)) #e
-
- pad_aux = self.zero((batch_size, dim, segment_stride), mindspore.float16)
- # pad_aux = self.zero((batch_size, dim, segment_stride), mindspore.float32)
- input = self.concat((pad_aux, input, pad_aux))
-
- return input, rest
-
- def split_feature(self, input, segment_size):
- # split the feature into chunks of segment size
- # input is the features: (B, N, T)
-
- input, rest = self.pad_segment(input, segment_size) #e
- batch_size, dim, seq_len = input.shape
- segment_stride = segment_size // 2
-
- segments1 = input[:, :, :-segment_stride].view((batch_size, dim, -1, segment_size))
- segments2 = input[:, :, segment_stride:].view((batch_size, dim, -1, segment_size))
- segments = self.concat2((segments1, segments2)).view((batch_size, dim, -1, segment_size)).transpose((0, 1, 3, 2))
-
- return segments, rest
-
- def merge_feature(self, input, rest):
- # merge the splitted features into full utterance
- # input is the features: (B, N, L, K)
-
- batch_size, dim, segment_size, _ = input.shape
- segment_stride = segment_size // 2
- input = input.transpose((0, 1, 3, 2)).view((batch_size, dim, -1, segment_size * 2)) # B, N, K, L
-
- input1 = input[:, :, :, :segment_size].view((batch_size, dim, -1))[:, :, segment_stride:]
- input2 = input[:, :, :, segment_size:].view((batch_size, dim, -1))[:, :, :-segment_stride]
-
- output = input1 + input2
- if rest > 0:
- output = output[:, :, :-rest]
-
- return output # B, N, T
-
-
- # base module for DPTNet_base
- class DPTNet_base(nn.Cell):
- def __init__(self, enc_dim, feature_dim, hidden_dim, layer, segment_size=250, nspk=2, win_len=2, batch_size=2):
- super(DPTNet_base, self).__init__()
-
- # parameters
- self.window = win_len
- self.stride = self.window // 2
-
- self.enc_dim = enc_dim
- self.feature_dim = feature_dim
- self.hidden_dim = hidden_dim
- self.segment_size = segment_size
|