@@ -16,6 +16,7 @@ from Loss_final1 import loss
from model_rnn import Dual_RNN_model
from mindspore.nn.dynamic_lr import piecewise_constant_lr
from train_wrapper import TrainingWrapper
import time
parser = argparse.ArgumentParser(
description='Parameters for training Dual-Path-RNN')
@@ -44,9 +45,9 @@ parser.add_argument('--sample_rate', default=8000, type=int,
help='Sample rate')
parser.add_argument('--segment', default=4, type=float, # 取音频的长度,2s。#数据集语音长度要相同
help='Segment length (seconds)')
parser.add_argument('--data_batch_size', default=3 , type=int, # 需要抛弃的音频长度
parser.add_argument('--data_batch_size', default=2 , type=int, # 需要抛弃的音频长度
help='Batch size')
parser.add_argument('--batch_size', type=int, default=1 ,
parser.add_argument('--batch_size', type=int, default=2 ,
help='Sample rate of audio file')
# Network architecture
@@ -64,7 +65,7 @@ parser.add_argument('--norm', default='gln', type=str,
help='gln = "Global Norm", cln = "Cumulative Norm", ln = "Layer Norm"')
parser.add_argument('--dropout', default=0.0, type=float,
help='dropout')
parser.add_argument('--num_layers', default=4 , type=int,
parser.add_argument('--num_layers', default=6 , type=int,
help='Number of Dual-Path-Block')
parser.add_argument('--K', default=250, type=int,
help='The length of chunk')
@@ -76,7 +77,7 @@ parser.add_argument('--lr', default=1e-3, type=float,
help='Init learning rate')
parser.add_argument('--lr1', default=5e-4, type=float,
help='Init learning rate')
parser.add_argument('--lr2', default=2.5 e-4, type=float,
parser.add_argument('--lr2', default=2e-4, type=float,
help='Init learning rate')
parser.add_argument('--l2', default=1e-5, type=float,
help='weight decay (L2 penalty)')
@@ -84,7 +85,7 @@ parser.add_argument('--l2', default=1e-5, type=float,
# save and load model
parser.add_argument('--save_folder', default=r"/home/work/user-job-dir/model/",
help='Location to save epoch models')
parser.add_argument('--nEpochs', type=int, default=3 0, help='number of epochs to train for')
parser.add_argument('--nEpochs', type=int, default=6 0, help='number of epochs to train for')
parser.add_argument('--device_num', type=int, default=2,
help='Sample rate of audio file')
parser.add_argument('--device_id', type=int, default=0,
@@ -117,7 +118,8 @@ def preprocess(args):
print("preprocess done")
def main(args):
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True)
# context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, save_graphs=True)
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.run_distribute:
print("distribute")
@@ -195,7 +197,7 @@ def main(args):
loss_cb = LossMonitor(1)
cb = [time_cb, loss_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=num_steps, keep_checkpoint_max=4 )
config_ck = CheckpointConfig(save_checkpoint_steps=num_steps, keep_checkpoint_max=5 )
ckpt_cb = ModelCheckpoint(prefix='commit',
directory=save_ckpt,
config=config_ck)
@@ -203,7 +205,7 @@ def main(args):
#开始训练
print("============== Starting Training ==============")
model.train(epoch=3 0, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)
model.train(epoch=6 0, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)
try:
mox.file.copy_parallel(save_folder, obs_train_url)