|
- import os
- import sys
- import json
-
- import torch
- import torch.nn as nn
- from torchvision import transforms, datasets, utils
- from torch.utils.data import DataLoader
- import numpy as np
- import torch.optim as optim
- from tqdm import tqdm
-
- from model import AlexNet
-
-
- def main():
- device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
-
- data_transform = {
- "train": transforms.Compose([
- transforms.RandomResizedCrop(224),
- transforms.RandomHorizontalFlip(),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- ]),
- "val": transforms.Compose([
- transforms.Resize((224, 224)),
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
- ])
- }
-
- image_path = '/tmp/dataset'
- # os.path.join会默认添加 \ 符号 、
- train_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'train'), transform=data_transform['train'])
- train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True)
-
- validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, 'val'), transform=data_transform['val'])
- validate_loader = DataLoader(validate_dataset, batch_size=16, shuffle=False)
-
- alex = AlexNet(num_classes=5, init_weights=False)
-
- alex.to(device)
-
- loss_function = nn.CrossEntropyLoss()
-
- optimizer = optim.Adam(alex.parameters(), lr=0.0001)
-
- epochs = 30
- save_path = '/tmp/output/AlexNet.pth'
- best_acc = 0.0
- val_num = len(validate_dataset)
- train_steps = len(train_loader) # 训练集中总共包含多少个batch_size
-
- print("---------------------------start--------------------------")
- for epoch in range(epochs):
- alex.train()
- running_loss = 0.0
-
- print('========================train=======================')
- # train_bar = tqdm(train_loader, file=sys.stdout)
- for (index, data) in enumerate(train_loader):
- image, label = data
- # image.shape = (4,3,224,224)
- # label.shape = (4,) 一维
-
- optimizer.zero_grad()
- outputs = alex(image.to(device))
- # outputs.shape = (4,5)
- loss = loss_function(outputs, label.to(device))
- loss.backward()
- optimizer.step()
-
- running_loss += loss.item()
- if index % 100 == 0:
- print('[epoch %d] batch index %d is done' %
- (epoch + 1, index + 1))
-
- print('=======================val========================')
- alex.eval()
- acc = 0.0
- with torch.no_grad():
- for (index, data) in enumerate(validate_loader):
- image, label = data
- outputs = alex(image.to(device))
- predict_y = torch.max(outputs, dim=1)[1] # 返回每个向量最大值的下标
- acc += torch.eq(predict_y, label.to(device)).sum().item()
-
- accurate = acc / val_num
- print('[epoch %d] train_loss:%.3f accuracy:%.3f' %
- (epoch + 1, running_loss / train_steps, accurate))
-
- if accurate > best_acc:
- best_acc = accurate
- torch.save(alex.state_dict(), save_path)
-
-
- if __name__ == '__main__':
- main()
|