|
- #!/usr/bin/python
- #coding=utf-8
- '''
- If there are Chinese comments in the code,please add at the beginning:
- #!/usr/bin/python
- #coding=utf-8
-
- 示例选用的数据集是MnistDataset_torch.zip
- 数据集结构是:
- MnistDataset_torch.zip
- ├── test
- └── train
-
- 预训练模型文件夹结构是:
- Torch_MNIST_Example_Model
- ├── mnist_epoch1_0.76.pkl
-
- '''
- from model import Model
- import numpy as np
- import torch
- from torchvision.datasets import mnist
- from torch.nn import CrossEntropyLoss
- from torch.optim import SGD
- from torch.utils.data import DataLoader
- from torchvision.transforms import ToTensor
- import argparse
- import os
- #导入c2net包
- from c2net.context import prepare, upload_output
-
- # Training settings
- parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
- parser.add_argument('--epoch_size', type=int, default=10, help='how much epoch to train')
- parser.add_argument('--batch_size', type=int, default=256, help='how much batch_size in epoch')
-
- # 参数声明
- WORKERS = 0
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- model = Model().to(device)
- optimizer = SGD(model.parameters(), lr=1e-1)
- cost = CrossEntropyLoss()
-
- # 模型测试
- def test(model, test_loader, data_length, output_path):
- model.eval()
- test_loss = 0
- correct = 0
- with torch.no_grad():
- for i, data in enumerate(test_loader, 0):
- x, y = data
- x = x.to(device)
- y = y.to(device)
- y_hat = model(x)
- test_loss += cost(y_hat, y).item()
- pred = y_hat.max(1, keepdim=True)[1]
- correct += pred.eq(y.view_as(pred)).sum().item()
- test_loss /= (i+1)
-
- # 结果写入输出文件夹
- filename = 'result.txt'
- file_path = os.path.join(output_path, filename)
- with open(file_path, 'w') as file:
- file.write('Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
- test_loss, correct, data_length, 100. * correct / data_length))
-
-
- if __name__ == '__main__':
- args, unknown = parser.parse_known_args()
- #初始化导入数据集和预训练模型到容器内
- c2net_context = prepare()
- #获取数据集路径
- MnistDataset_torch_path = c2net_context.dataset_path+"/"+"MnistDataset_torch"
- #获取预训练模型路径
- Torch_MNIST_Example_Model_path = c2net_context.pretrain_model_path+"/"+"Torch_MNIST_Example_Model"
- #获取输出路径
- output_path = c2net_context.output_path
-
- print('cuda is available:{}'.format(torch.cuda.is_available()))
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- batch_size = args.batch_size
- epochs = args.epoch_size
- test_dataset = mnist.MNIST(root=MnistDataset_torch_path + "/test", train=False, transform=ToTensor(),download=False)
- test_loader = DataLoader(test_dataset, batch_size=batch_size)
- model = Model().to(device)
- checkpoint = torch.load(Torch_MNIST_Example_Model_path + "/mnist_epoch1_0.73.pkl")
- model.load_state_dict(checkpoint['model'])
- test(model,test_loader,len(test_dataset),output_path)
- upload_output()
|