From e4655111dd590e1486aa8bc18b146f578cc0ab34 Mon Sep 17 00:00:00 2001 From: foreverYoung Date: Mon, 5 Dec 2022 22:04:24 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'train=5Fln=5Fadam.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- train_ln_adam.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/train_ln_adam.py b/train_ln_adam.py index 9056e06..5639df0 100644 --- a/train_ln_adam.py +++ b/train_ln_adam.py @@ -42,7 +42,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, help='Sample rate') parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 help='Segment length (seconds)') -parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 +parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 help='Batch size') # Network architecture @@ -62,7 +62,7 @@ parser.add_argument('--norm', default='gln', type=str, help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') parser.add_argument('--dropout', default=0.0, type=float, help='dropout') -parser.add_argument('--num_layers', default=6, type=int, +parser.add_argument('--num_layers', default=4, type=int, help='Number of Dual-Path-Block') parser.add_argument('--K', default=250, type=int, help='The length of chunk') @@ -70,7 +70,7 @@ parser.add_argument('--num_spks', default=2, type=int, help='The number of speakers') # optimizer -parser.add_argument('--lr', default=1e-3, type=float, +parser.add_argument('--lr', default=0.001, type=float, help='Init learning rate') parser.add_argument('--l2', default=1e-5, type=float, help='weight decay (L2 penalty)') @@ -110,7 +110,8 @@ def preprocess(args): print("preprocess done") def main(args): - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) + # context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) if args.run_distribute: print("distribute") @@ -157,8 +158,8 @@ def main(args): tr_dataset = DatasetGenerator(args.train_dir, args.batch_size, sample_rate=args.sample_rate, segment=args.segment) tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], - shuffle=False, num_shards=rank_size, shard_id=rank_id) - tr_loader = tr_loader.batch(2) + shuffle=True, num_shards=rank_size, shard_id=rank_id) + tr_loader = tr_loader.batch(1) num_steps = tr_loader.get_dataset_size() end_time = time.perf_counter() print("preparing data use: {}min".format((end_time - start_time) / 60)) @@ -178,7 +179,7 @@ def main(args): loss_cb = LossMonitor(1) cb = [time_cb, loss_cb] - config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=5) + config_ck = CheckpointConfig(save_checkpoint_steps=200, keep_checkpoint_max=5) ckpt_cb = ModelCheckpoint(prefix='dual', directory=save_ckpt, config=config_ck) -- 2.34.1 From 1131db51df63586dfcd69f131c719e221e82abcd Mon Sep 17 00:00:00 2001 From: foreverYoung Date: Mon, 5 Dec 2022 23:03:43 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=9B=B4=E6=96=B0=20'ckpt=5Ftest.py'?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- ckpt_test.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ckpt_test.py b/ckpt_test.py index 09b3b8e..5af313e 100644 --- a/ckpt_test.py +++ b/ckpt_test.py @@ -43,7 +43,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, help='Sample rate') parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 help='Segment length (seconds)') -parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 +parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 help='Batch size') # Network architecture @@ -69,7 +69,7 @@ parser.add_argument('--num_spks', default=2, type=int, help='The number of speakers') # optimizer -parser.add_argument('--lr', default=0.001, type=float, +parser.add_argument('--lr', default=1e-3, type=float, help='Init learning rate') parser.add_argument('--l2', default=1e-5, type=float, help='weight decay (L2 penalty)') @@ -109,7 +109,8 @@ def preprocess(args): print("preprocess done") def main(args): - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) + # context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) if args.run_distribute: print("distribute") @@ -156,8 +157,8 @@ def main(args): tr_dataset = DatasetGenerator(args.train_dir, args.batch_size, sample_rate=args.sample_rate, segment=args.segment) tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], - shuffle=False, num_shards=rank_size, shard_id=rank_id) - tr_loader = tr_loader.batch(1) + shuffle=True, num_shards=rank_size, shard_id=rank_id) + tr_loader = tr_loader.batch(2) num_steps = tr_loader.get_dataset_size() # build model net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels, @@ -175,7 +176,7 @@ def main(args): time_cb = TimeMonitor() loss_cb = LossMonitor(1) cb = [time_cb, loss_cb] - config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=5) + config_ck = CheckpointConfig(save_checkpoint_steps=200, keep_checkpoint_max=5) ckpt_cb = ModelCheckpoint(prefix='dual', directory=save_ckpt, config=config_ck) -- 2.34.1