diff --git a/test.py b/test.py index 4f81fd4..525837c 100644 --- a/test.py +++ b/test.py @@ -39,7 +39,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, help='Sample rate') parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 help='Segment length (seconds)') -parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 +parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 help='Batch size') # Network architecture @@ -57,7 +57,7 @@ parser.add_argument('--norm', default='gln', type=str, help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') parser.add_argument('--dropout', default=0.0, type=float, help='dropout') -parser.add_argument('--num_layers', default=6, type=int, +parser.add_argument('--num_layers', default=4, type=int, help='Number of Dual-Path-Block') parser.add_argument('--K', default=250, type=int, help='The length of chunk') @@ -99,8 +99,8 @@ def preprocess(args): print("preprocess done") def main(args): - context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) - # context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) + # context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) + context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) # 在训练环境中定义data_url和train_url,并把数据从obs拷贝到相应的固定路径 obs_data_url = args.data_url diff --git a/wcl.py b/wcl.py index 1a63cbd..4253ffd 100644 --- a/wcl.py +++ b/wcl.py @@ -39,7 +39,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, help='Sample rate') parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 help='Segment length (seconds)') -parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 +parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 help='Batch size') # Network architecture @@ -57,7 +57,7 @@ parser.add_argument('--norm', default='gln', type=str, help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') parser.add_argument('--dropout', default=0.0, type=float, help='dropout') -parser.add_argument('--num_layers', default=6, type=int, +parser.add_argument('--num_layers', default=4, type=int, help='Number of Dual-Path-Block') parser.add_argument('--K', default=250, type=int, help='The length of chunk') @@ -65,7 +65,7 @@ parser.add_argument('--num_spks', default=2, type=int, help='The number of speakers') # optimizer -parser.add_argument('--lr', default=1e-3, type=float, +parser.add_argument('--lr', default=0.001, type=float, help='Init learning rate') parser.add_argument('--l2', default=1e-5, type=float, help='weight decay (L2 penalty)')