#9 更新 'train_asteroid.py'

Merged
unicorn merged 1 commits from master into test 1 year ago
  1. +4
    -5
      train_asteroid.py

+ 4
- 5
train_asteroid.py View File

@@ -70,7 +70,7 @@ parser.add_argument('--num_spks', default=2, type=int,
help='The number of speakers')

# optimizer
parser.add_argument('--lr', default=1e-3, type=float,
parser.add_argument('--lr', default=0.001, type=float,
help='Init learning rate')
parser.add_argument('--l2', default=1e-5, type=float,
help='weight decay (L2 penalty)')
@@ -163,16 +163,15 @@ def main(args):
sample_rate=args.sample_rate, segment=args.segment)
tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"],
shuffle=False, num_shards=rank_size, shard_id=rank_id)
tr_loader = tr_loader.batch(2)
tr_loader = tr_loader.batch(4)
num_steps = tr_loader.get_dataset_size()
end_time = time.perf_counter()
print("preparing data use: {}min".format((end_time - start_time) / 60))

# param_dict = load_checkpoint("/home/heu_MEDAI/zhangyu/project/checkpoint/DPRNN_ckpt_1-11_7120.ckpt")
# build model
net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels, args.bn_channels,
bidirectional=True, norm=args.norm, num_layers=args.num_layers, dropout=args.dropout, K=args.K)
# load_param_into_net(net, param_dict)
print(net)
net.set_train()
# build optimizer
@@ -196,7 +195,7 @@ def main(args):

#开始训练
print("============== Starting Training ==============")
model.train(epoch=10, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)
model.train(epoch=1, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)

try:
mox.file.copy_parallel(save_folder, obs_train_url)


Loading…
Cancel
Save