@@ -16,6 +16,7 @@ from Loss_final1 import loss | |||||
from model_rnn import Dual_RNN_model | from model_rnn import Dual_RNN_model | ||||
from mindspore.nn.dynamic_lr import piecewise_constant_lr | from mindspore.nn.dynamic_lr import piecewise_constant_lr | ||||
from train_wrapper import TrainingWrapper | from train_wrapper import TrainingWrapper | ||||
import time | |||||
parser = argparse.ArgumentParser( | parser = argparse.ArgumentParser( | ||||
description='Parameters for training Dual-Path-RNN') | description='Parameters for training Dual-Path-RNN') | ||||
@@ -44,9 +45,9 @@ parser.add_argument('--sample_rate', default=8000, type=int, | |||||
help='Sample rate') | help='Sample rate') | ||||
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 | parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同 | ||||
help='Segment length (seconds)') | help='Segment length (seconds)') | ||||
parser.add_argument('--data_batch_size', default=3, type=int, # 需要抛弃的音频长度 | |||||
parser.add_argument('--data_batch_size', default=2, type=int, # 需要抛弃的音频长度 | |||||
help='Batch size') | help='Batch size') | ||||
parser.add_argument('--batch_size', type=int, default=1, | |||||
parser.add_argument('--batch_size', type=int, default=2, | |||||
help='Sample rate of audio file') | help='Sample rate of audio file') | ||||
# Network architecture | # Network architecture | ||||
@@ -64,7 +65,7 @@ parser.add_argument('--norm', default='gln', type=str, | |||||
help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') | help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"') | ||||
parser.add_argument('--dropout', default=0.0, type=float, | parser.add_argument('--dropout', default=0.0, type=float, | ||||
help='dropout') | help='dropout') | ||||
parser.add_argument('--num_layers', default=4, type=int, | |||||
parser.add_argument('--num_layers', default=6, type=int, | |||||
help='Number of Dual-Path-Block') | help='Number of Dual-Path-Block') | ||||
parser.add_argument('--K', default=250, type=int, | parser.add_argument('--K', default=250, type=int, | ||||
help='The length of chunk') | help='The length of chunk') | ||||
@@ -76,7 +77,7 @@ parser.add_argument('--lr', default=1e-3, type=float, | |||||
help='Init learning rate') | help='Init learning rate') | ||||
parser.add_argument('--lr1', default=5e-4, type=float, | parser.add_argument('--lr1', default=5e-4, type=float, | ||||
help='Init learning rate') | help='Init learning rate') | ||||
parser.add_argument('--lr2', default=2.5e-4, type=float, | |||||
parser.add_argument('--lr2', default=2e-4, type=float, | |||||
help='Init learning rate') | help='Init learning rate') | ||||
parser.add_argument('--l2', default=1e-5, type=float, | parser.add_argument('--l2', default=1e-5, type=float, | ||||
help='weight decay (L2 penalty)') | help='weight decay (L2 penalty)') | ||||
@@ -84,7 +85,7 @@ parser.add_argument('--l2', default=1e-5, type=float, | |||||
# save and load model | # save and load model | ||||
parser.add_argument('--save_folder', default=r"/home/work/user-job-dir/model/", | parser.add_argument('--save_folder', default=r"/home/work/user-job-dir/model/", | ||||
help='Location to save epoch models') | help='Location to save epoch models') | ||||
parser.add_argument('--nEpochs', type=int, default=30, help='number of epochs to train for') | |||||
parser.add_argument('--nEpochs', type=int, default=60, help='number of epochs to train for') | |||||
parser.add_argument('--device_num', type=int, default=2, | parser.add_argument('--device_num', type=int, default=2, | ||||
help='Sample rate of audio file') | help='Sample rate of audio file') | ||||
parser.add_argument('--device_id', type=int, default=0, | parser.add_argument('--device_id', type=int, default=0, | ||||
@@ -117,7 +118,8 @@ def preprocess(args): | |||||
print("preprocess done") | print("preprocess done") | ||||
def main(args): | def main(args): | ||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||||
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True) | |||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target) | |||||
if args.run_distribute: | if args.run_distribute: | ||||
print("distribute") | print("distribute") | ||||
@@ -195,7 +197,7 @@ def main(args): | |||||
loss_cb = LossMonitor(1) | loss_cb = LossMonitor(1) | ||||
cb = [time_cb, loss_cb] | cb = [time_cb, loss_cb] | ||||
config_ck = CheckpointConfig(save_checkpoint_steps=num_steps, keep_checkpoint_max=4) | |||||
config_ck = CheckpointConfig(save_checkpoint_steps=num_steps, keep_checkpoint_max=5) | |||||
ckpt_cb = ModelCheckpoint(prefix='commit', | ckpt_cb = ModelCheckpoint(prefix='commit', | ||||
directory=save_ckpt, | directory=save_ckpt, | ||||
config=config_ck) | config=config_ck) | ||||
@@ -203,7 +205,7 @@ def main(args): | |||||
#开始训练 | #开始训练 | ||||
print("============== Starting Training ==============") | print("============== Starting Training ==============") | ||||
model.train(epoch=30, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False) | |||||
model.train(epoch=60, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False) | |||||
try: | try: | ||||
mox.file.copy_parallel(save_folder, obs_train_url) | mox.file.copy_parallel(save_folder, obs_train_url) | ||||
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》