From f512194e2cc28a08651e5f6f7ddfc89192646eb3 Mon Sep 17 00:00:00 2001 From: ZhangY Date: Sun, 26 Mar 2023 23:21:00 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'test.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- test.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test.py b/test.py index 52e5b92..47c3359 100644 --- a/test.py +++ b/test.py @@ -42,7 +42,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, help='Sample rate') parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 help='Segment length (seconds)') -parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 +parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 help='Batch size') # Network architecture @@ -62,7 +62,7 @@ parser.add_argument('--norm', default='gln', type=str, help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') parser.add_argument('--dropout', default=0.0, type=float, help='dropout') -parser.add_argument('--num_layers', default=4, type=int, +parser.add_argument('--num_layers', default=6, type=int, help='Number of Dual-Path-Block') parser.add_argument('--K', default=250, type=int, help='The length of chunk') @@ -70,7 +70,7 @@ parser.add_argument('--num_spks', default=2, type=int, help='The number of speakers') # optimizer -parser.add_argument('--lr', default=0.001, type=float, +parser.add_argument('--lr', default=1e-3, type=float, help='Init learning rate') parser.add_argument('--l2', default=1e-5, type=float, help='weight decay (L2 penalty)') -- 2.34.1 From 155d1384394d18203db33767352d285764f25f9e Mon Sep 17 00:00:00 2001 From: ZhangY Date: Sun, 26 Mar 2023 23:53:36 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'model=5Ftest.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model_test.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/model_test.py b/model_test.py index 8aa95da..88e590d 100644 --- a/model_test.py +++ b/model_test.py @@ -365,9 +365,9 @@ class Dual_RNN_model(nn.Cell): self.print = ops.Print() self.stack = ops.Stack() - # for p in self.get_parameters(): - # if p.ndim > 1: - # mindspore.common.initializer.HeNormal(p) + for p in self.get_parameters(): + if p.ndim > 1: + mindspore.common.initializer.HeNormal(p) def construct(self, x): """ forward """ ''' @@ -389,9 +389,9 @@ class Dual_RNN_model(nn.Cell): if __name__ == "__main__": context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=0) set_seed(42) - rnn = Dual_RNN_model(64, 64, 128, 128, bidirectional=True, norm='gln', num_layers=4, dropout=0.0) + rnn = Dual_RNN_model(64, 64, 128, 128, bidirectional=True, norm='gln', num_layers=6, dropout=0.0) #encoder = Encoder(16, 512) ones = ops.Ones() - x = ones((1, 32000), mindspore.float32) + x = ones((2, 32000), mindspore.float32) out = rnn(x) print(rnn) \ No newline at end of file -- 2.34.1