|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # @Time : 2023/2/6 下午4:02
- # @File : predict.py
- # ----------------------------------------------
- # ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆
- # >>> Author : Kevin Chang
- # >>> QQ : 565479588
- # >>> Mail : lovecode@gmail.com
- # >>> Github : https://github.com/lovecode100
- # >>> Blog : https://www.cnblogs.com/lovecode
- # ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆
- # 导入包
- import torch
- import torchvision.transforms as transforms
- from PIL import Image
- from model import LeNet
-
- # 数据预处理
- transform = transforms.Compose(
- [transforms.Resize((32, 32)), # 首先需resize成跟训练集图像一样的大小
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
-
- # 导入要测试的图像(自己找的,不在数据集中),放在源文件目录下
- im = Image.open('/code/01LeNet/horse.jpg')
- im = transform(im) # [C, H, W]
- im = torch.unsqueeze(im, dim=0) # 对数据增加一个新维度,因为tensor的参数是[batch, channel, height, width]
-
- # 实例化网络,加载训练好的模型参数
- net = LeNet()
- net.load_state_dict(torch.load('/model/Lenet.pth'))
-
- # 预测
- classes = ('plane', 'car', 'bird', 'cat',
- 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
- with torch.no_grad():
- outputs = net(im)
- predict = torch.max(outputs, dim=1)[1].data.numpy() # 输出预测的标签
- #predict = torch.softmax(outputs, dim=1) #也可用softmax输出10个概率
- print(classes[int(predict)])
|