|
- import argparse
-
- import torch
- from loguru import logger
- from torch.optim import lr_scheduler
-
- from diffusion.data_loaders import get_data_loaders
- from diffusion.logger import utils
- from diffusion.solver import train
- from diffusion.unit2mel import Unit2Mel
- from diffusion.vocoder import Vocoder
-
-
- def parse_args(args=None, namespace=None):
- """Parse command-line arguments."""
- parser = argparse.ArgumentParser()
- parser.add_argument(
- "-c",
- "--config",
- type=str,
- required=True,
- help="path to the config file")
- return parser.parse_args(args=args, namespace=namespace)
-
-
- if __name__ == '__main__':
- # parse commands
- cmd = parse_args()
-
- # load config
- args = utils.load_config(cmd.config)
- logger.info(' > config:'+ cmd.config)
- logger.info(' > exp:'+ args.env.expdir)
-
- # load vocoder
- vocoder = Vocoder(args.vocoder.type, args.vocoder.ckpt, device=args.device)
-
- # load model
- model = Unit2Mel(
- args.data.encoder_out_channels,
- args.model.n_spk,
- args.model.use_pitch_aug,
- vocoder.dimension,
- args.model.n_layers,
- args.model.n_chans,
- args.model.n_hidden,
- args.model.timesteps,
- args.model.k_step_max
- )
-
- logger.info(f' > Now model timesteps is {model.timesteps}, and k_step_max is {model.k_step_max}')
-
- # load parameters
- optimizer = torch.optim.AdamW(model.parameters())
- initial_global_step, model, optimizer = utils.load_model(args.env.expdir, model, optimizer, device=args.device)
- for param_group in optimizer.param_groups:
- param_group['initial_lr'] = args.train.lr
- param_group['lr'] = args.train.lr * (args.train.gamma ** max(((initial_global_step-2)//args.train.decay_step),0) )
- param_group['weight_decay'] = args.train.weight_decay
- scheduler = lr_scheduler.StepLR(optimizer, step_size=args.train.decay_step, gamma=args.train.gamma,last_epoch=initial_global_step-2)
-
- # device
- if args.device == 'cuda':
- torch.cuda.set_device(args.env.gpu_id)
- model.to(args.device)
-
- for state in optimizer.state.values():
- for k, v in state.items():
- if torch.is_tensor(v):
- state[k] = v.to(args.device)
-
- # datas
- loader_train, loader_valid = get_data_loaders(args, whole_audio=False)
-
- # run
- train(args, initial_global_step, model, optimizer, scheduler, vocoder, loader_train, loader_valid)
-
|