@@ -43,7 +43,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, | |||||
help='Sample rate') | help='Sample rate') | ||||
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 | parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 | ||||
help='Segment length (seconds)') | help='Segment length (seconds)') | ||||
parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 | |||||
parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 | |||||
help='Batch size') | help='Batch size') | ||||
# Network architecture | # Network architecture | ||||
@@ -69,7 +69,7 @@ parser.add_argument('--num_spks', default=2, type=int, | |||||
help='The number of speakers') | help='The number of speakers') | ||||
# optimizer | # optimizer | ||||
parser.add_argument('--lr', default=0.001, type=float, | |||||
parser.add_argument('--lr', default=1e-3, type=float, | |||||
help='Init learning rate') | help='Init learning rate') | ||||
parser.add_argument('--l2', default=1e-5, type=float, | parser.add_argument('--l2', default=1e-5, type=float, | ||||
help='weight decay (L2 penalty)') | help='weight decay (L2 penalty)') | ||||
@@ -109,7 +109,8 @@ def preprocess(args): | |||||
print("preprocess done") | print("preprocess done") | ||||
def main(args): | def main(args): | ||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||||
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
if args.run_distribute: | if args.run_distribute: | ||||
print("distribute") | print("distribute") | ||||
@@ -156,8 +157,8 @@ def main(args): | |||||
tr_dataset = DatasetGenerator(args.train_dir, args.batch_size, | tr_dataset = DatasetGenerator(args.train_dir, args.batch_size, | ||||
sample_rate=args.sample_rate, segment=args.segment) | sample_rate=args.sample_rate, segment=args.segment) | ||||
tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], | tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], | ||||
shuffle=False, num_shards=rank_size, shard_id=rank_id) | |||||
tr_loader = tr_loader.batch(1) | |||||
shuffle=True, num_shards=rank_size, shard_id=rank_id) | |||||
tr_loader = tr_loader.batch(2) | |||||
num_steps = tr_loader.get_dataset_size() | num_steps = tr_loader.get_dataset_size() | ||||
# build model | # build model | ||||
net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels, | net = Dual_RNN_model(args.in_channels, args.out_channels, args.hidden_channels, | ||||
@@ -175,7 +176,7 @@ def main(args): | |||||
time_cb = TimeMonitor() | time_cb = TimeMonitor() | ||||
loss_cb = LossMonitor(1) | loss_cb = LossMonitor(1) | ||||
cb = [time_cb, loss_cb] | cb = [time_cb, loss_cb] | ||||
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=5) | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=200, keep_checkpoint_max=5) | |||||
ckpt_cb = ModelCheckpoint(prefix='dual', | ckpt_cb = ModelCheckpoint(prefix='dual', | ||||
directory=save_ckpt, | directory=save_ckpt, | ||||
config=config_ck) | config=config_ck) | ||||
@@ -42,7 +42,7 @@ parser.add_argument('--sample_rate', default=8000, type=int, | |||||
help='Sample rate') | help='Sample rate') | ||||
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 | parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 | ||||
help='Segment length (seconds)') | help='Segment length (seconds)') | ||||
parser.add_argument('--batch_size', default=2, type=int, # 需要抛弃的音频长度 | |||||
parser.add_argument('--batch_size', default=3, type=int, # 需要抛弃的音频长度 | |||||
help='Batch size') | help='Batch size') | ||||
# Network architecture | # Network architecture | ||||
@@ -62,7 +62,7 @@ parser.add_argument('--norm', default='gln', type=str, | |||||
help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') | help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') | ||||
parser.add_argument('--dropout', default=0.0, type=float, | parser.add_argument('--dropout', default=0.0, type=float, | ||||
help='dropout') | help='dropout') | ||||
parser.add_argument('--num_layers', default=6, type=int, | |||||
parser.add_argument('--num_layers', default=4, type=int, | |||||
help='Number of Dual-Path-Block') | help='Number of Dual-Path-Block') | ||||
parser.add_argument('--K', default=250, type=int, | parser.add_argument('--K', default=250, type=int, | ||||
help='The length of chunk') | help='The length of chunk') | ||||
@@ -70,7 +70,7 @@ parser.add_argument('--num_spks', default=2, type=int, | |||||
help='The number of speakers') | help='The number of speakers') | ||||
# optimizer | # optimizer | ||||
parser.add_argument('--lr', default=1e-3, type=float, | |||||
parser.add_argument('--lr', default=0.001, type=float, | |||||
help='Init learning rate') | help='Init learning rate') | ||||
parser.add_argument('--l2', default=1e-5, type=float, | parser.add_argument('--l2', default=1e-5, type=float, | ||||
help='weight decay (L2 penalty)') | help='weight decay (L2 penalty)') | ||||
@@ -110,7 +110,8 @@ def preprocess(args): | |||||
print("preprocess done") | print("preprocess done") | ||||
def main(args): | def main(args): | ||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||||
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
if args.run_distribute: | if args.run_distribute: | ||||
print("distribute") | print("distribute") | ||||
@@ -157,8 +158,8 @@ def main(args): | |||||
tr_dataset = DatasetGenerator(args.train_dir, args.batch_size, | tr_dataset = DatasetGenerator(args.train_dir, args.batch_size, | ||||
sample_rate=args.sample_rate, segment=args.segment) | sample_rate=args.sample_rate, segment=args.segment) | ||||
tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], | tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], | ||||
shuffle=False, num_shards=rank_size, shard_id=rank_id) | |||||
tr_loader = tr_loader.batch(2) | |||||
shuffle=True, num_shards=rank_size, shard_id=rank_id) | |||||
tr_loader = tr_loader.batch(1) | |||||
num_steps = tr_loader.get_dataset_size() | num_steps = tr_loader.get_dataset_size() | ||||
end_time = time.perf_counter() | end_time = time.perf_counter() | ||||
print("preparing data use: {}min".format((end_time - start_time) / 60)) | print("preparing data use: {}min".format((end_time - start_time) / 60)) | ||||
@@ -178,7 +179,7 @@ def main(args): | |||||
loss_cb = LossMonitor(1) | loss_cb = LossMonitor(1) | ||||
cb = [time_cb, loss_cb] | cb = [time_cb, loss_cb] | ||||
config_ck = CheckpointConfig(save_checkpoint_steps=100, keep_checkpoint_max=5) | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=200, keep_checkpoint_max=5) | |||||
ckpt_cb = ModelCheckpoint(prefix='dual', | ckpt_cb = ModelCheckpoint(prefix='dual', | ||||
directory=save_ckpt, | directory=save_ckpt, | ||||
config=config_ck) | config=config_ck) | ||||
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》