|
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import argparse
- import os
- import sys
- import shutil
- import pprint
-
- import torch
- import torch.nn.parallel
- import torch.backends.cudnn as cudnn
- import torch.optim
- import torch.utils.data
- import torch.utils.data.distributed
-
- from src.models import ResNet, FTTransformer, MLP
- from src.dataset import WaveDataset
- from src.core.function import test
- from src.utils.modelsummary import get_model_summary
-
- from experiments.default import task_config as config, update_config
-
- from src.utils.utils import create_logger
-
-
- def parse_args():
- parser = argparse.ArgumentParser(description='Train keypoints network')
-
- parser.add_argument('--cfg',
- help='experiment configure file name',
- default="experiments/transformer_default.yaml",
- type=str)
-
- parser.add_argument('--modelDir',
- help='model directory',
- type=str,
- default='')
- parser.add_argument('--logDir',
- help='log directory',
- type=str,
- default='')
- parser.add_argument('--dataDir',
- help='data directory',
- type=str,
- default='')
- parser.add_argument('--testModel',
- help='testModel',
- type=str,
- default='')
-
- args = parser.parse_args()
- if args.cfg:
- update_config(args.cfg, config)
-
- return args
-
-
- def main():
- args = parse_args()
-
- logger, final_output_dir, tb_log_dir = create_logger(
- config, args.cfg, 'test')
-
- logger.info(pprint.pformat(args))
- logger.info(pprint.pformat(config))
-
- # cudnn related setting
- cudnn.benchmark = config.CUDNN.BENCHMARK
- torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
- torch.backends.cudnn.enabled = config.CUDNN.ENABLED
-
- if config.MODEL.NAME == 'resnet':
- model = ResNet.make_baseline(
- d_in=config.MODEL.EXTRA.d_in,
- n_blocks=config.MODEL.EXTRA.n_blocks,
- d_main=config.MODEL.EXTRA.d_main,
- d_hidden=config.MODEL.EXTRA.d_hidden,
- dropout_first=config.MODEL.EXTRA.dropout_first,
- dropout_second=config.MODEL.EXTRA.dropout_second,
- d_out=config.MODEL.EXTRA.d_out
- )
- elif config.MODEL.NAME == 'mlp':
- model = MLP.make_baseline(
- config.MODEL.EXTRA.d_in,
- config.MODEL.EXTRA.d_layers,
- config.MODEL.EXTRA.dropout,
- config.MODEL.EXTRA.d_out
- )
- elif config.MODEL.NAME == 'transformer':
- model = FTTransformer.make_baseline(
- n_num_features=config.MODEL.EXTRA.n_num_features,
- cat_cardinalities=[],
- d_token=config.MODEL.EXTRA.d_token,
- n_blocks=config.MODEL.EXTRA.n_blocks,
- attention_dropout=config.MODEL.EXTRA.attention_dropout,
- ffn_d_hidden=config.MODEL.EXTRA.ffn_d_hidden,
- ffn_dropout=config.MODEL.EXTRA.ffn_dropout,
- residual_dropout=config.MODEL.EXTRA.residual_dropout,
- d_out=config.MODEL.EXTRA.d_out,
- )
- else:
- raise Exception("Model Name Error!! 模型名错误!!")
-
- # dump_input = torch.rand((1, config.MODEL.EXTRA.d_in))
- # logger.info(get_model_summary(model, dump_input))
-
- if config.TEST.MODEL_FILE:
- logger.info('=> loading model from {}'.format(config.TEST.MODEL_FILE))
- model.load_state_dict(torch.load(config.TEST.MODEL_FILE))
- else:
- model_state_file = os.path.join(final_output_dir,
- 'final_state.pth.tar')
- logger.info('=> loading model from {}'.format(model_state_file))
- model.load_state_dict(torch.load(model_state_file))
-
- model = model.cuda()
-
- # define loss function (criterion) and optimizer
- criterion = torch.nn.MSELoss().cuda()
-
- # Data loading code
-
- test_dataset = WaveDataset(config.DATASET.ROOT, mode="test")
- test_loader = torch.utils.data.DataLoader(
- test_dataset,
- batch_size=config.TEST.BATCH_SIZE_PER_GPU,
- shuffle=False,
- num_workers=config.WORKERS,
- pin_memory=True
- )
-
- # evaluate on validation set
- test(config, test_loader, model, criterion, final_output_dir,
- tb_log_dir, None, True)
-
-
-
-
- if __name__ == '__main__':
- main()
|