#30 添加 'Cite.py'

Merged
Sunhy merged 1 commits from sunhy-patch-24 into master 1 year ago
  1. +163
    -0
      Cite.py

+ 163
- 0
Cite.py View File

@@ -0,0 +1,163 @@
# Copyright 2022 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================

""" TasNet """
import mindspore
import mindspore.nn as nn
import mindspore.ops as ops

EPS = 1e-8

class TasNet(nn.Cell):
""" TasNet """
def __init__(self, L, N, hidden_size, num_layers, bidirectional=False, nspk=2):
super(TasNet, self).__init__()
# hyper-parameter
self.L = L
self.N = N
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.nspk = nspk
# Components
self.encoder = Encoder(L, N)
self.separator = Separator(N, hidden_size, num_layers, bidirectional=bidirectional, nspk=nspk)
self.decoder = Decoder(N, L)
for p in self.get_parameters():
if p.ndim > 1:
mindspore.common.initializer.Uniform(p)

def construct(self, mixture):
"""
Args:
mixture: [B, K, L]
Returns:
est_source: [B, nspk, K, L]
"""
mixture_w, norm_coef = self.encoder(mixture)
est_mask = self.separator(mixture_w)
est_source = self.decoder(mixture_w, est_mask, norm_coef)
return est_source


class Encoder(nn.Cell):
""" Encoder """
def __init__(self, L, N):
super(Encoder, self).__init__()
# hyper-parameter
self.L = L
self.N = N
# Components
self.conv1d_U = nn.Conv1d(L, N, kernel_size=1, stride=1, pad_mode="pad",
has_bias=True, weight_init="XavierUniform")
self.conv1d_V = nn.Conv1d(L, N, kernel_size=1, stride=1, pad_mode="pad",
has_bias=True, weight_init="XavierUniform")
self.relu = ops.ReLU()
self.sigmoid = ops.Sigmoid()
self.expand_dims = ops.ExpandDims()
self.Norm = nn.Norm(axis=2, keep_dims=True)

def construct(self, mixture):
"""
Args:
mixture: [B, K, L]
Returns:
mixture_w: [B, K, N]
norm_coef: [B, K, 1]
"""
B, K, L = mixture.shape
# L2 Norm along L axis
norm_coef = self.Norm(mixture) # B x K x 1
norm_mixture = mixture / (norm_coef + EPS) # B x K x L
# 1-D gated conv
norm_mixture = self.expand_dims(norm_mixture.view(-1, L), 2) # B*K x L x 1
conv = self.relu(self.conv1d_U(norm_mixture)) # B*K x N x 1
gate = self.sigmoid(self.conv1d_V(norm_mixture)) # B*K x N x 1
mixture_w = conv * gate # B*K x N x 1
mixture_w = mixture_w.view(B, K, self.N) # B x K x N
return mixture_w, norm_coef

class Separator(nn.Cell):
""" Estimation of source masks """
def __init__(self, N, hidden_size, num_layers, bidirectional=False, nspk=2):
super(Separator, self).__init__()
# hyper-parameter
self.N = N
self.hidden_size = hidden_size
self.num_layers = num_layers
self.bidirectional = bidirectional
self.nspk = nspk
# Components
self.layer_norm = nn.LayerNorm([N])
self.lstm = nn.LSTM(N, hidden_size, num_layers,
batch_first=True,
bidirectional=bidirectional)
# self.fc = nn.Linear(hidden_size, nspk * N)
self.fc = nn.Dense(hidden_size, nspk * N, weight_init="XavierUniform")
self.new_lstm = nn.Dense(500, 512)
self.softmax = ops.Softmax(axis=2)

def construct(self, mixture_w):
"""
Args:
mixture_w: [B, K, N], padded
Returns:
est_mask: [B, K, nspk, N]
"""
B, K, N = mixture_w.shape
# layer norm
norm_mixture_w = self.layer_norm(mixture_w)
# norm_mixture_w = nn.LayerNorm(mixture_w[-1:])
# LSTM
# output, _ = self.lstm(norm_mixture_w)
output = norm_mixture_w
# fc
score = self.fc(output) # B x K x nspk*N
score = score.view(B, K, self.nspk, N)
# softmax
est_mask = self.softmax(score)
return est_mask

class Decoder(nn.Cell):
""" Decoder """
def __init__(self, N, L):
super(Decoder, self).__init__()
# hyper-parameter
self.N, self.L = N, L
# Components
# self.basis_signals = nn.Linear(N, L, bias=False)
self.basis_signals = nn.Dense(N, L, weight_init="XavierUniform")
# self.basis_signals = nn.Dense(N, L)
self.expand_dims = ops.ExpandDims()
self.transpose = ops.Transpose()

def construct(self, mixture_w, est_mask, norm_coef):
"""
Args:
mixture_w: [B, K, N]
est_mask: [B, K, nspk, N]
norm_coef: [B, K, 1]
Returns:
est_source: [B, nspk, K, L]
"""
# D = W * M
source_w = self.expand_dims(mixture_w, 2) * est_mask # B x K x nspk x N
# S = DB
est_source = self.basis_signals(source_w) # B x K x nspk x L
# reverse L2 norm
norm_coef = self.expand_dims(norm_coef, 2) # B x K x 1 x1
est_source = est_source * norm_coef # B x K x nspk x L
est_source = self.transpose(est_source, (0, 2, 1, 3)) # B x nspk x K x L
return est_source

Loading…
Cancel
Save