|
- #!/usr/bin/env python
- # -*- coding: utf-8 -*-
- # @Time : 2023/2/6 下午3:03
- # @File : train.py
- # ----------------------------------------------
- # ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆
- # >>> Author : Kevin Chang
- # >>> QQ : 565479588
- # >>> Mail : lovecode@gmail.com
- # >>> Github : https://github.com/lovecode100
- # >>> Blog : https://www.cnblgs.com/lovecode
- # ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆ ☆
- import os
- import sys
- import json
-
- import torch
- import torch.nn as nn
- from torchvision import transforms, datasets, utils
- import torch.optim as optim
- from tqdm import tqdm
- from model import AlexNet
-
-
- def main():
- # 使用GPU训练
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- print("using {} device.".format(device))
-
- # 对训练集的预处理,多了随机裁剪和水平翻转这两个步骤。可以起到扩充数据集的作用,增强模型泛化能力。
- data_transform = {
- "train": transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪,再缩放成 224×224
- transforms.RandomHorizontalFlip(p=0.5), # 水平方向随机翻转,概率 0.5, 即一半的概率翻转, 一半的概率不翻转
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]),
- "val": transforms.Compose([transforms.Resize((224, 224)), # cannot 224, must (224, 224)
- transforms.ToTensor(),
- transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])}
-
- # 获取图像数据集的路径
- data_root = os.path.abspath(os.path.join(os.getcwd(), "../..")) # get data root path
- #image_path = os.path.join(data_root, "data_set", "flower_data") # flower data set path
- image_path = os.path.join(data_root, "dataset", "flower_data")
- assert os.path.exists(image_path), "{} path does not exist.".format(image_path)
-
- # 导入训练集并进行预处理
- train_dataset = datasets.ImageFolder(root=os.path.join(image_path, "train"),
- transform=data_transform["train"])
- train_num = len(train_dataset)
-
- # 字典,类别:索引 {'daisy':0, 'dandelion':1, 'roses':2, 'sunflower':3, 'tulips':4}
- flower_list = train_dataset.class_to_idx
- # 将flower_list中的 key 和 val 调换位置
- cla_dict = dict((val, key) for key, val in flower_list.items())
- # write dict into json file
- json_str = json.dumps(cla_dict, indent=4)
- with open('/code/02AlexNet/class_indices.json', 'w') as json_file:
- json_file.write(json_str)
-
- batch_size = 32
- nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8]) # number of workers
- print('Using {} dataloader workers every process'.format(nw))
-
- # 按batch_size分批次加载训练集
- train_loader = torch.utils.data.DataLoader(train_dataset, # 导入的训练集
- batch_size=batch_size, # 每批训练的样本数
- shuffle=True, # 是否打乱训练集
- num_workers=nw # 使用线程数,在windows下设置为0
- )
-
- validate_dataset = datasets.ImageFolder(root=os.path.join(image_path, "val"),
- transform=data_transform["val"])
- val_num = len(validate_dataset)
- validate_loader = torch.utils.data.DataLoader(validate_dataset,
- batch_size=4, shuffle=False,
- num_workers=nw)
-
- print("using {} images for training, {} images for validation.".format(train_num,
- val_num))
- # test_data_iter = iter(validate_loader)
- # test_image, test_label = test_data_iter.next()
- #
- # def imshow(img):
- # img = img / 2 + 0.5 # unnormalize
- # npimg = img.numpy()
- # plt.imshow(np.transpose(npimg, (1, 2, 0)))
- # plt.show()
- #
- # print(' '.join('%5s' % cla_dict[test_label[j].item()] for j in range(4)))
- # imshow(utils.make_grid(test_image))
-
- # 实例化网络(输出类型为5,初始化权重)
- net = AlexNet(num_classes=5, init_weights=True)
- # 分配网络到指定的设备(GPU/CPU)训练
- net.to(device)
- loss_function = nn.CrossEntropyLoss()
- # pata = list(net.parameters())
- optimizer = optim.Adam(net.parameters(), lr=0.0002)
-
- epochs = 10
- save_path = '/model/AlexNet.pth'
- best_acc = 0.0
- train_steps = len(train_loader)
- for epoch in range(epochs):
- ########################################## train ###############################################
- net.train() # 训练过程中开启Dropout,即网络失活
- running_loss = 0.0 # 每个epoch都会对running_loss清零
- #time_start = time.perf_counter() # 训练一个 epoch 计时
- train_bar = tqdm(train_loader, file=sys.stdout)
-
- # 遍历训练集,step从0开始计算
- for step, data in enumerate(train_bar):
- images, labels = data # 获取训练集的图像和标签
- optimizer.zero_grad() # 清除历史梯度
- outputs = net(images.to(device)) # 正向传播
- loss = loss_function(outputs, labels.to(device))
- loss.backward() # 反向传播
- optimizer.step() # 优化器更新参数
-
- # print statistics
- running_loss += loss.item()
- # 打印训练进度(使训练过程可视化)
- train_bar.desc = "train epoch[{}/{}] loss:{:.3f}".format(epoch + 1,
- epochs,
- loss)
-
- ########################################### validate ###########################################
- net.eval() # 验证过程中关闭Dropout,即取消网络失活
- acc = 0.0 # accumulate accurate number / epoch
- # 禁止跟踪参数,不更新梯度,用于评估
- with torch.no_grad():
- val_bar = tqdm(validate_loader, file=sys.stdout)
- for val_data in val_bar:
- val_images, val_labels = val_data
- outputs = net(val_images.to(device))
- predict_y = torch.max(outputs, dim=1)[1] # 以output中值最大位置对应的索引(标签)作为预测输出
- acc += torch.eq(predict_y, val_labels.to(device)).sum().item()
-
- val_accurate = acc / val_num
- print('[epoch %d] train_loss: %.3f val_accuracy: %.3f' %
- (epoch + 1, running_loss / train_steps, val_accurate))
-
- # 保存准确率最高的那次网络参数
- if val_accurate > best_acc:
- best_acc = val_accurate
- torch.save(net.state_dict(), save_path)
-
- print('Finished Training')
-
-
- if __name__ == '__main__':
- main()
|