|
- import argparse
-
- import moxing as mox
-
- from model_final import DPTNet_base
- from mindspore import Model
- from data_loader import DatasetGenerator
-
- import mindspore.dataset as ds
- from mindspore import nn, context
- from mindspore.train.callback import LossMonitor, TimeMonitor, ModelCheckpoint, CheckpointConfig
- from network_define import WithLossCell
- from Loss_final import Loss
- from lr_sch import dynamic_lr
- from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
- from train_wrapper import TrainingWrapper
- from mindspore.communication.management import init, get_rank, get_group_size
- from mindspore.context import ParallelMode
-
- import json
- import os
-
- import librosa
-
-
- parser = argparse.ArgumentParser(
- "Dual-path transformer"
- "with Permutation Invariant Training")
- # General config
- # Task related
-
- # parser.add_argument('--train_dir', type=str, default='/home/fanruibo/Project/DPTNet_Last_Ver/out_dir/tr',
- # help='directory including mix.json, s1.json and s2.json')
- # parser.add_argument('--valid_dir', type=str, default='/home/fanruibo/Project/DPTNet_Last_Ver/out_dir/cv',
- # help='directory including mix.json, s1.json and s2.json')
- parser.add_argument('--train_dir', type=str, default="/home/work/user-job-dir/inputs/data_json/tr",
- help='directory including mix.json, s1.json and s2.json')
- parser.add_argument('--valid_dir', type=str, default='/mass_data/dataset/LS-2mix/Libri2Mix/cv',
- help='directory including mix.json, s1.json and s2.json')
- parser.add_argument('--sample_rate', default=8000, type=int,
- help='Sample rate')
- parser.add_argument('--segment', default=4, type=float,
- help='Segment length (seconds)')
- parser.add_argument('--cv_maxlen', default=8, type=float,
- help='max audio length (seconds) in cv, to avoid OOM issue.')
- # Network architecture
- parser.add_argument('--N', default=64, type=int,
- help='Number of filters in autoencoder')
- parser.add_argument('--C', default=2, type=int,
- help='Maximum number of speakers')
- parser.add_argument('--L', default=4, type=int,
- help='Length of window in autoencoder') # L=2 in paper
- parser.add_argument('--H', default=4, type=int,
- help='Number of head in Multi-head attention')
- parser.add_argument('--K', default=250, type=int,
- help='segment size')
- parser.add_argument('--B', default=6, type=int,
- help='Number of repeats')
-
- parser.add_argument('--enc_dim', default=256, type=int,
- help='...')
- parser.add_argument('--feature_dim', default=64, type=int,
- help='Number of filters in autoencoder')
- parser.add_argument('--hidden_dim', default=128, type=int,
- help='...')
- parser.add_argument('--layer', default=6, type=int,
- help='Number of repeats')
- parser.add_argument('--segment_size', default=250, type=int,
- help='segment size')
- parser.add_argument('--nspk', default=2, type=int,
- help='Maximum number of speakers')
- parser.add_argument('--win_len', default=1, type=int,
- help='...')
-
- # Training config
- parser.add_argument('--use_cuda', type=int, default=1,
- help='Whether use GPU')
- parser.add_argument('--epochs', default=100, type=int,
- help='Number of maximum epochs')
- parser.add_argument('--half_lr', dest='half_lr', default=0, type=int,
- help='Halving learning rate when get small improvement')
- parser.add_argument('--early_stop', dest='early_stop', default=0, type=int,
- help='Early stop training when no improvement for 10 epochs')
- parser.add_argument('--max_norm', default=5, type=float,
- help='Gradient norm threshold to clip')
- # minibatch
- parser.add_argument('--shuffle', default=0, type=int,
- help='reshuffle the data at every epoch')
- parser.add_argument('--batch_size', default=3, type=int, #default =3
- help='Batch size')
- parser.add_argument('--num_workers', default=4, type=int, #default = 8
- help='Number of workers to generate minibatch')
- # optimizer
- parser.add_argument('--optimizer', default='adam', type=str,
- choices=['sgd', 'adam'],
- help='Optimizer (support sgd and adam now)')
- parser.add_argument('--lr', default=4e-4, type=float,
- help='Init learning rate')
- parser.add_argument('--momentum', default=0.0, type=float,
- help='Momentum for optimizer')
- parser.add_argument('--l2', default=0.0, type=float,
- help='weight decay (L2 penalty)')
- # save and load model
- parser.add_argument('--save_folder', default='exp/temp',
- help='Location to save epoch models')
- parser.add_argument('--checkpoint', dest='checkpoint', default=0, type=int,
- help='Enables checkpoint saving of model')
- parser.add_argument('--continue_from', default='',
- help='Continue from checkpoint model')
- parser.add_argument('--model_path', default='final.pth.tar',
- help='Location to save best validation model')
- # logging
- parser.add_argument('--print_freq', default=1000, type=int,
- help='Frequency of printing training infomation')
- parser.add_argument('--visdom', dest='visdom', type=int, default=0,
- help='Turn on visdom graphing')
- parser.add_argument('--visdom_epoch', dest='visdom_epoch', type=int, default=0,
- help='Turn on visdom graphing each epoch')
- parser.add_argument('--visdom_id', default='TasNet training',
- help='Identifier for visdom run')
- # define 2 parameters for running on modelArts
- # data_url,train_url是固定用于在modelarts上训练的参数,表示数据集的路径和输出模型的路径
- parser.add_argument('--data_url',
- help='path to training/inference dataset folder',
- default='./data')
-
- parser.add_argument('--train_url',
- help='model folder to save/load',
- default='./model')
- parser.add_argument('--in_dir', type=str, default=r"/home/work/user-job-dir/inputs/data/",
- help='Directory path of wsj0 including tr, cv and tt')
- parser.add_argument('--out_dir', type=str, default=r"/home/work/user-job-dir/inputs/data_json",
- help='Directory path to put output files')
-
- parser.add_argument('--step_per_epoch', default=820, type=int,
- help='...')
- parser.add_argument('--epoch', default=100, type=int,
- help='total epoch')
-
- parser.add_argument(
- '--device_target',
- type=str,
- default="Ascend",
- choices=['Ascend', 'GPU', 'CPU'],
- help='device where the code will be implemented (default: Ascend)')
-
- def preprocess_one_dir(in_dir, out_dir, out_filename, sample_rate=8000):
- file_infos = []
- in_dir = os.path.abspath(in_dir)
- wav_list = os.listdir(in_dir)
- for wav_file in wav_list:
- if not wav_file.endswith('.wav'):
- continue
- wav_path = os.path.join(in_dir, wav_file)
- samples, _ = librosa.load(wav_path, sr=sample_rate)
- file_infos.append((wav_path, len(samples)))
- if not os.path.exists(out_dir):
- os.makedirs(out_dir)
- with open(os.path.join(out_dir, out_filename + '.json'), 'w') as f:
- json.dump(file_infos, f, indent=4)
-
-
- def preprocess(args):
- for data_type in ['tr']:
- for speaker in ['mix', 's1', 's2']:
- preprocess_one_dir(os.path.join(args.in_dir, data_type, speaker),
- os.path.join(args.out_dir, data_type),
- speaker,
- sample_rate=args.sample_rate)
- print("preprocess done")
-
- def main(args):
- context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
- # Construct Solver
- # data
- obs_data_url = args.data_url
- args.data_url = '/home/work/user-job-dir/inputs/data/'
- obs_train_url = args.train_url
- args.train_url = '/home/work/user-job-dir/outputs/model/'
- try:
- mox.file.copy_parallel(obs_data_url, args.data_url)
- print("Successfully Download {} to {}".format(obs_data_url,
- args.data_url))
- except Exception as e:
- print('moxing download {} to {} failed: '.format(
- obs_data_url, args.data_url) + str(e))
-
- device_num = int(os.environ.get("RANK_SIZE", 1))
- if device_num == 1:
- is_distributed = 'False'
- elif device_num > 1:
- is_distributed = 'True'
-
- if is_distributed == 'True':
- print("parallel init", flush=True)
- init()
- rank_id = get_rank()
- context.reset_auto_parallel_context()
- parallel_mode = ParallelMode.DATA_PARALLEL
- rank_size = get_group_size()
- context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=rank_size)
- context.set_auto_parallel_context(parameter_broadcast=True)
- args.save_folder = os.path.join(args.save_folder, 'ckpt_' + str(get_rank()) + '/')
- print("Starting traning on multiple devices...")
-
- # if is_distributed == 'True':
- # print("distribute")
- # device_id = int(os.getenv("DEVICE_ID"))
- # context.set_context(device_id=device_id)
- # init()
- # context.reset_auto_parallel_context()
- # context.set_auto_parallel_context(parallel_mode=ParallelMode.DATA_PARALLEL, gradients_mean=True,
- # device_num=2)
- # rank = get_rank()
- # print("Starting traning on multiple devices...")
-
- print("start preprocess ....")
- preprocess(args)
-
- args.save_checkpoint_path = args.train_url
-
-
-
-
- tr_dataset = DatasetGenerator(args.train_dir, args.batch_size,
- sample_rate=args.sample_rate, segment=args.segment)
- tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=False, num_shards=rank_size, shard_id=rank_id)
- tr_loader = tr_loader.batch(2)
-
- # print("data loading done")
- # model
- net = DPTNet_base(enc_dim=256, feature_dim=64, hidden_dim=128, layer=4, segment_size=250, nspk=2, win_len=2)
-
- print(net)
-
- # lr = dynamic_lr(args.step_per_epoch, args.epoch)
- # # optimizier = nn.SGD(net.trainable_params(), learning_rate=lr, weight_decay=args.l2)
- # optimizier = nn.Adam(net.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=args.l2)
- # # optimizier = nn.Adam(net.trainable_params(), learning_rate=0.00000001, weight_decay=args.l2)
- # my_loss = Loss()
- # net_with_loss = WithLossCell(net, my_loss)
-
- # net_with_loss_ = TrainingWrapper(net_with_loss, optimizier)
-
- # # scale_update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**12,
- # # scale_factor=2,
- # # scale_window=1000)
- # # net_with_loss_ = TrainWrap(net_with_loss, optimizier, scale_update_cell)
- # net_with_loss_.set_train()
- # # net_with_loss.set_train()
-
- # # model = Model(net_with_loss, optimizer=optimizier)
- # model = Model(net_with_loss_)
-
- #不是动态学习率
- net.set_train()
-
- lr = dynamic_lr(args.step_per_epoch, args.epoch)
- # optimizier = nn.SGD(net.trainable_params(), learning_rate=lr, weight_decay=args.l2)
- optimizier = nn.Adam(net.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=args.l2)
-
- # optimizier = nn.Adam(net.get_parameters(), learning_rate=args.lr, weight_decay=args.l2)
- my_loss = Loss()
- net_with_loss = WithLossCell(net, my_loss)
- model = Model(net_with_loss, optimizer=optimizier)
-
- time_cb = TimeMonitor()
- loss_cb = LossMonitor(1)
- cb = [time_cb, loss_cb]
- config_ck = CheckpointConfig(save_checkpoint_steps=5,
- keep_checkpoint_max=5)
- ckpt_cb = ModelCheckpoint(prefix="DPTNet", directory=args.save_folder, config=config_ck)
-
- cb += [ckpt_cb]
-
- for i in range(100):
- model.train(epoch=1, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)
-
- try:
- mox.file.copy_parallel(args.train_url, obs_train_url)
- print("Successfully Upload {} to {}".format(args.train_url,
- obs_train_url))
- except Exception as e:
- print('moxing upload {} to {} failed: '.format(args.train_url,
- obs_train_url) + str(e))
-
-
- # # tr_dataset = DatasetGenerator(args.train_dir, args.batch_size,
- # # sample_rate=args.sample_rate, segment=args.segment)
- # tr_dataset = DatasetGenerator(args.train_dir, args.batch_size,
- # sample_rate=args.sample_rate, segment=args.segment)
- # tr_loader = ds.GeneratorDataset(tr_dataset, ["mixture", "lens", "sources"], shuffle=False)
- # tr_loader = tr_loader.batch(2)
-
- # print("data loading done")
- # # model
- # net = DPTNet_base(enc_dim=256, feature_dim=64, hidden_dim=128, layer=4, segment_size=250, nspk=2, win_len=2)
-
- # print(net)
-
- # lr = dynamic_lr(args.step_per_epoch, args.epoch)
- # # optimizier = nn.SGD(net.trainable_params(), learning_rate=lr, weight_decay=args.l2)
- # optimizier = nn.Adam(net.trainable_params(), learning_rate=lr, beta1=0.9, beta2=0.98, eps=1e-9, weight_decay=args.l2)
- # # optimizier = nn.Adam(net.trainable_params(), learning_rate=0.00000001, weight_decay=args.l2)
- # my_loss = Loss()
- # net_with_loss = WithLossCell(net, my_loss)
-
- # net_with_loss_ = TrainingWrapper(net_with_loss, optimizier)
-
- # # scale_update_cell = DynamicLossScaleUpdateCell(loss_scale_value=2**12,
- # # scale_factor=2,
- # # scale_window=1000)
- # # net_with_loss_ = TrainWrap(net_with_loss, optimizier, scale_update_cell)
- # net_with_loss_.set_train()
- # # net_with_loss.set_train()
-
- # # model = Model(net_with_loss, optimizer=optimizier)
- # model = Model(net_with_loss_)
-
- # time_cb = TimeMonitor()
- # loss_cb = LossMonitor(1)
- # cb = [time_cb, loss_cb]
-
- # config_ck = CheckpointConfig(save_checkpoint_steps=5,
- # keep_checkpoint_max=5)
- # # ckpt_cb = ModelCheckpoint(prefix="DPTNet", directory=args.save_folder, config=config_ck)
- # ckpt_cb = ModelCheckpoint(prefix="DPTNet", directory=args.save_folder, config=config_ck)
-
- # cb += [ckpt_cb]
-
- # print("start training ....")
-
- # model.train(epoch=100, train_dataset=tr_loader, callbacks=cb, dataset_sink_mode=False)
-
- ######################## 将输出的模型拷贝到obs(固定写法) ########################
- # 把训练后的模型数据从本地的运行环境拷贝回obs,在启智平台相对应的训练任务中会提供下载
- # try:
- # mox.file.copy_parallel(args.train_url, obs_train_url)
- # print("Successfully Upload {} to {}".format(args.train_url,
- # obs_train_url))
- # except Exception as e:
- # print('moxing upload {} to {} failed: '.format(args.train_url,
- # obs_train_url) + str(e))
- ######################## 将输出的模型拷贝到obs ########################
-
- if __name__ == '__main__':
- # context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
- # context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend", device_id=7)
- # context.set_context(mode=context.GRAPH_MODE, device_target="GPU", device_id=1)
- # context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU", device_id=1)
- args = parser.parse_args()
- print(args)
- main(args)
|