@@ -43,7 +43,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, | |||
help='Sample rate') | |||
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 | |||
help='Segment length (seconds)') | |||
parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 | |||
parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 | |||
help='Batch size') | |||
# Network architecture | |||
@@ -69,7 +69,7 @@ parser.add_argument('--num_spks', default=2, type=int, | |||
help='The number of speakers') | |||
# optimizer | |||
parser.add_argument('--lr', default=0.001, type=float, | |||
parser.add_argument('--lr', default=1e-3, type=float, | |||
help='Init learning rate') | |||
parser.add_argument('--l2', default=1e-5, type=float, | |||
help='weight decay (L2 penalty)') | |||
@@ -109,7 +109,8 @@ def preprocess(args): | |||
print("preprocess done") | |||
def main(args): | |||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
if args.run_distribute: | |||
print("distribute") | |||
@@ -156,8 +157,8 @@ def main(args): | |||
tr_dataset = DatasetGenerator(args.train_dir, args.batch_size, | |||
sample_rate=args.sample_rate, segment=args.segment) | |||
tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], | |||
shuffle=False, num_shards=rank_size, shard_id=rank_id) | |||
tr_loader = tr_loader.batch(1) | |||
shuffle=True, num_shards=rank_size, shard_id=rank_id) | |||
tr_loader = tr_loader.batch(2) | |||
num_steps = tr_loader.get_dataset_size() | |||
# build model | |||
net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels, | |||
@@ -175,7 +176,7 @@ def main(args): | |||
time_cb = TimeMonitor() | |||
loss_cb = LossMonitor(1) | |||
cb = [time_cb, loss_cb] | |||
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=5) | |||
config_ck = CheckpointConfig(save_checkpoint_steps=200, keep_checkpoint_max=5) | |||
ckpt_cb = ModelCheckpoint(prefix='dual', | |||
directory=save_ckpt, | |||
config=config_ck) | |||
@@ -42,7 +42,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, | |||
help='Sample rate') | |||
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 | |||
help='Segment length (seconds)') | |||
parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 | |||
parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 | |||
help='Batch size') | |||
# Network architecture | |||
@@ -62,7 +62,7 @@ parser.add_argument('--norm', default='gln', type=str, | |||
help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') | |||
parser.add_argument('--dropout', default=0.0, type=float, | |||
help='dropout') | |||
parser.add_argument('--num_layers', default=6, type=int, | |||
parser.add_argument('--num_layers', default=4, type=int, | |||
help='Number of Dual-Path-Block') | |||
parser.add_argument('--K', default=250, type=int, | |||
help='The length of chunk') | |||
@@ -70,7 +70,7 @@ parser.add_argument('--num_spks', default=2, type=int, | |||
help='The number of speakers') | |||
# optimizer | |||
parser.add_argument('--lr', default=1e-3, type=float, | |||
parser.add_argument('--lr', default=0.001, type=float, | |||
help='Init learning rate') | |||
parser.add_argument('--l2', default=1e-5, type=float, | |||
help='weight decay (L2 penalty)') | |||
@@ -110,7 +110,8 @@ def preprocess(args): | |||
print("preprocess done") | |||
def main(args): | |||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||
if args.run_distribute: | |||
print("distribute") | |||
@@ -157,8 +158,8 @@ def main(args): | |||
tr_dataset = DatasetGenerator(args.train_dir, args.batch_size, | |||
sample_rate=args.sample_rate, segment=args.segment) | |||
tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], | |||
shuffle=False, num_shards=rank_size, shard_id=rank_id) | |||
tr_loader = tr_loader.batch(2) | |||
shuffle=True, num_shards=rank_size, shard_id=rank_id) | |||
tr_loader = tr_loader.batch(1) | |||
num_steps = tr_loader.get_dataset_size() | |||
end_time = time.perf_counter() | |||
print("preparing data use: {}min".format((end_time - start_time) / 60)) | |||
@@ -178,7 +179,7 @@ def main(args): | |||
loss_cb = LossMonitor(1) | |||
cb = [time_cb, loss_cb] | |||
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=5) | |||
config_ck = CheckpointConfig(save_checkpoint_steps=200, keep_checkpoint_max=5) | |||
ckpt_cb = ModelCheckpoint(prefix='dual', | |||
directory=save_ckpt, | |||
config=config_ck) | |||
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》