#225 master

Merged
unicorn merged 3 commits from master into unicorn-patch-1 1 year ago
  1. +5
    -5
      model_asteroid.py
  2. +2
    -2
      train_asteroid.py
  3. +4
    -4
      train_ln_adam.py

+ 5
- 5
model_asteroid.py View File

@@ -369,9 +369,9 @@ class Dual_RNN_model(nn.Cell):
self.print = ops.Print()
self.stack = ops.Stack()

# for p in self.get_parameters():
# if p.ndim > 1:
# mindspore.common.initializer.HeNormal(p)
for p in self.get_parameters():
if p.ndim > 1:
mindspore.common.initializer.HeNormal(p)
def construct(self, x):
""" forward """
'''
@@ -393,9 +393,9 @@ class Dual_RNN_model(nn.Cell):
if __name__ == "__main__":
context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=0)
set_seed(42)
rnn = Dual_RNN_model(64, 64, 128, 128, bidirectional=True, norm='gln', num_layers=4, dropout=0.0)
rnn = Dual_RNN_model(64, 64, 128, 128, bidirectional=True, norm='gln', num_layers=6, dropout=0.0)
#encoder = Encoder(16, 512)
ones = ops.Ones()
x = ones((2, 32000), mindspore.float32)
x = ones((1, 32000), mindspore.float32)
out = rnn(x)
print(rnn)

+ 2
- 2
train_asteroid.py View File

@@ -42,7 +42,7 @@ parser.add_argument('--sample_rate', default=8000, type=int,
help='Sample rate')
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同
help='Segment length (seconds)')
parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度
parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度
help='Batch size')

# Network architecture
@@ -62,7 +62,7 @@ parser.add_argument('--norm', default='gln', type=str,
help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"')
parser.add_argument('--dropout', default=0.0, type=float,
help='dropout')
parser.add_argument('--num_layers', default=6, type=int,
parser.add_argument('--num_layers', default=5, type=int,
help='Number of Dual-Path-Block')
parser.add_argument('--K', default=250, type=int,
help='The length of chunk')


+ 4
- 4
train_ln_adam.py View File

@@ -41,7 +41,7 @@ parser.add_argument('--sample_rate', default=8000, type=int,
help='Sample rate')
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同
help='Segment length (seconds)')
parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度
parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度
help='Batch size')

# Network architecture
@@ -61,7 +61,7 @@ parser.add_argument('--norm', default='gln', type=str,
help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"')
parser.add_argument('--dropout', default=0.0, type=float,
help='dropout')
parser.add_argument('--num_layers', default=4, type=int,
parser.add_argument('--num_layers', default=6, type=int,
help='Number of Dual-Path-Block')
parser.add_argument('--K', default=250, type=int,
help='The length of chunk')
@@ -109,8 +109,8 @@ def preprocess(args):
print("preprocess done")

def main(args):
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True)
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)

if args.run_distribute:
print("distribute")


Loading…
Cancel
Save