From e0d448ec57c9bb7ad2eaf8516007725a315f77a6 Mon Sep 17 00:00:00 2001 From: unicorn <15684175528@163.com> Date: Thu, 17 Nov 2022 15:59:22 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'train=5Fasteroid.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_asteroid.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/train_asteroid.py b/train_asteroid.py index 963e241..474760b 100644 --- a/train_asteroid.py +++ b/train_asteroid.py @@ -70,7 +70,7 @@ parser.add_argument('--num_spks', default=2, type=int, help='The number of speakers') # optimizer -parser.add_argument('--lr', default=1e-3, type=float, +parser.add_argument('--lr', default=0.001, type=float, help='Init learning rate') parser.add_argument('--l2', default=1e-5, type=float, help='weight decay (L2 penalty)') @@ -163,16 +163,15 @@ def main(args): sample_rate=args.sample_rate, segment=args.segment) tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=False, num_shards=rank_size, shard_id=rank_id) - tr_loader = tr_loader.batch(2) + tr_loader = tr_loader.batch(4) num_steps = tr_loader.get_dataset_size() end_time = time.perf_counter() print("preparing data use: {}min".format((end_time - start_time) / 60)) - # param_dict = load_checkpoint("/home/heu_MEDAI/zhangyu/project/checkpoint/DPRNN_ckpt_1-11_7120.ckpt") # build model net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels, args.bn_channels, bidirectional=True, norm=args.norm, num_layers=args.num_layers, dropout=args.dropout, K=args.K) - # load_param_into_net(net, param_dict) + print(net) net.set_train() # build optimizer @@ -196,7 +195,7 @@ def main(args): #开始训练 print("============== Starting Training ==============") - model.train(epoch=10, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False) + model.train(epoch=1, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False) try: mox.file.copy_parallel(save_folder, obs_train_url) -- 2.34.1