diff --git a/train_asteroid.py b/train_asteroid.py index 963e241..474760b 100644 --- a/train_asteroid.py +++ b/train_asteroid.py @@ -70,7 +70,7 @@ parser.add_argument('--num_spks', default=2, type=int, help='The number of speakers') # optimizer -parser.add_argument('--lr', default=1e-3, type=float, +parser.add_argument('--lr', default=0.001, type=float, help='Init learning rate') parser.add_argument('--l2', default=1e-5, type=float, help='weight decay (L2 penalty)') @@ -163,16 +163,15 @@ def main(args): sample_rate=args.sample_rate, segment=args.segment) tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=False, num_shards=rank_size, shard_id=rank_id) - tr_loader = tr_loader.batch(2) + tr_loader = tr_loader.batch(4) num_steps = tr_loader.get_dataset_size() end_time = time.perf_counter() print("preparing data use: {}min".format((end_time - start_time) / 60)) - # param_dict = load_checkpoint("/home/heu_MEDAI/zhangyu/project/checkpoint/DPRNN_ckpt_1-11_7120.ckpt") # build model net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels, args.bn_channels, bidirectional=True, norm=args.norm, num_layers=args.num_layers, dropout=args.dropout, K=args.K) - # load_param_into_net(net, param_dict) + print(net) net.set_train() # build optimizer @@ -196,7 +195,7 @@ def main(args): #开始训练 print("============== Starting Training ==============") - model.train(epoch=10, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False) + model.train(epoch=1, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False) try: mox.file.copy_parallel(save_folder, obs_train_url)