|
-
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import argparse
- import os
- import pprint
- import shutil
- import sys
-
- import torch
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- import torch.optim
- import torch.utils.data
- import torch.utils.data.distributed
- from tensorboardX import SummaryWriter
-
- from src.models import ResNet, FTTransformer, MLP, UNet
- from src.dataset import WaveDataset
- from src.core.function import train
- from src.core.function import validate
- from src.utils.modelsummary import get_model_summary
-
- from experiments.default import task_config as config, update_config
- from src.utils.utils import get_optimizer
- from src.utils.utils import save_checkpoint
- from src.utils.utils import create_logger
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Train classification network')
-
- parser.add_argument('--cfg',
- help='experiment configure file name',
- default="experiments/1/unet_best.yaml",
- type=str)
-
- parser.add_argument('--modelDir',
- help='model directory',
- type=str,
- default='')
- parser.add_argument('--logDir',
- help='log directory',
- type=str,
- default='')
- parser.add_argument('--dataDir',
- help='data directory',
- type=str,
- default='')
- parser.add_argument('--testModel',
- help='testModel',
- type=str,
- default='')
-
- args = parser.parse_args()
- if args.cfg:
- update_config(args.cfg, config)
-
- return args
-
-
- def main():
- args = parse_args()
-
- logger, final_output_dir, tb_log_dir = create_logger(
- config, args.cfg, 'train')
-
- logger.info(pprint.pformat(args))
- logger.info(pprint.pformat(config))
-
- # cudnn related setting
- cudnn.benchmark = config.CUDNN.BENCHMARK
- torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
- torch.backends.cudnn.enabled = config.CUDNN.ENABLED
-
- if config.MODEL.NAME =='resnet':
- model = ResNet.make_baseline(
- d_in=config.MODEL.EXTRA.d_in,
- n_blocks=config.MODEL.EXTRA.n_blocks,
- d_main=config.MODEL.EXTRA.d_main,
- d_hidden=config.MODEL.EXTRA.d_hidden,
- dropout_first=config.MODEL.EXTRA.dropout_first,
- dropout_second=config.MODEL.EXTRA.dropout_second,
- d_out=config.MODEL.EXTRA.d_out
- )
- elif config.MODEL.NAME =='mlp':
- model = MLP.make_baseline(
- config.MODEL.EXTRA.d_in,
- config.MODEL.EXTRA.d_layers,
- config.MODEL.EXTRA.dropout,
- config.MODEL.EXTRA.d_out
- )
- elif config.MODEL.NAME == 'transformer':
- model = FTTransformer.make_baseline(
- n_num_features=config.MODEL.EXTRA.n_num_features,
- cat_cardinalities=[],
- d_token=config.MODEL.EXTRA.d_token,
- n_blocks=config.MODEL.EXTRA.n_blocks,
- attention_dropout=config.MODEL.EXTRA.attention_dropout,
- ffn_d_hidden=config.MODEL.EXTRA.ffn_d_hidden,
- ffn_dropout=config.MODEL.EXTRA.ffn_dropout,
- residual_dropout=config.MODEL.EXTRA.residual_dropout,
- d_out=config.MODEL.EXTRA.d_out,
- )
- elif config.MODEL.NAME =='unet':
- model = UNet.make_baseline(
- d_in=config.MODEL.EXTRA.d_in,
- d_embed=config.MODEL.EXTRA.d_embed,
- d_out=config.MODEL.EXTRA.d_out,
- n_layer=config.MODEL.EXTRA.n_layer,
- d_encode=config.MODEL.EXTRA.d_encode
- )
- else:
- raise Exception("Model Name Error!! 模型名错误!!")
-
- writer_dict = {
- 'writer': SummaryWriter(log_dir=tb_log_dir),
- 'train_global_steps': 0,
- 'valid_global_steps': 0,
- }
-
- model = model.cuda()
-
- # define loss function (criterion) and optimizer
- criterion = torch.nn.MSELoss().cuda()
-
- optimizer = get_optimizer(config, model)
-
- best_mse = 1e3
- best_model = False
- last_epoch = config.TRAIN.BEGIN_EPOCH
- if config.TRAIN.RESUME:
- model_state_file = os.path.join(final_output_dir,
- 'checkpoint.pth.tar')
- if os.path.isfile(model_state_file):
- checkpoint = torch.load(model_state_file)
- last_epoch = checkpoint['epoch']
- best_mse = checkpoint['perf']
- model.load_state_dict(checkpoint['state_dict'])
- optimizer.load_state_dict(checkpoint['optimizer'])
- logger.info("=> loaded checkpoint (epoch {})"
- .format(checkpoint['epoch']))
- best_model = True
-
- if isinstance(config.TRAIN.LR_STEP, list):
- lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
- optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
- last_epoch - 1
- )
- else:
- lr_scheduler = torch.optim.lr_scheduler.StepLR(
- optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR,
- last_epoch - 1
- )
-
- # Data loading code
- train_dataset = WaveDataset(config.DATASET.ROOT, mode="train")
- valid_dataset = WaveDataset(config.DATASET.ROOT, mode="valid")
- train_loader = torch.utils.data.DataLoader(
- train_dataset,
- batch_size=config.TRAIN.BATCH_SIZE_PER_GPU,
- shuffle=True,
- num_workers=config.WORKERS,
- pin_memory=True
- )
- valid_loader = torch.utils.data.DataLoader(
- valid_dataset,
- batch_size=config.TEST.BATCH_SIZE_PER_GPU,
- shuffle=False,
- num_workers=config.WORKERS,
- pin_memory=True
- )
-
- for epoch in range(last_epoch, config.TRAIN.END_EPOCH):
- lr_scheduler.step()
- # train for one epoch
- train(config, train_loader, model, criterion, optimizer, epoch,
- final_output_dir, tb_log_dir, writer_dict)
- # evaluate on validation set
- valid_mse = validate(config, valid_loader, model, criterion,
- final_output_dir, tb_log_dir, writer_dict)
-
- if valid_mse < best_mse:
- best_mse = valid_mse
- best_model = True
- else:
- best_model = False
-
- logger.info('=> saving checkpoint to {}'.format(final_output_dir))
- save_checkpoint({
- 'epoch': epoch + 1,
- 'model': config.MODEL.NAME,
- 'state_dict': model.state_dict(),
- 'perf': valid_mse,
- 'optimizer': optimizer.state_dict(),
- }, best_model, final_output_dir, filename='checkpoint.pth.tar')
-
- final_model_state_file = os.path.join(final_output_dir,
- 'final_state.pth.tar')
- logger.info('saving final model state to {}'.format(
- final_model_state_file))
- torch.save(model.state_dict(), final_model_state_file)
- writer_dict['writer'].close()
-
-
- if __name__ == '__main__':
- main()
-
-
-
-
|