|
- import os
- import torch
- from tqdm import tqdm
- import torch.nn as nn
- from torch import optim
- from torch.utils.data.dataloader import DataLoader
- # from tensorboardX import SummaryWriter
-
- from load_data import MyDataSets
- from utils import classes_two_acc, classes_multiple_acc, weight_init
- from model import generate_resnet
-
- import datetime
-
- # 最好是 2 的 n 次方
- batch_size = 32
- num_classes = 10
- epochs = 100
- lr = 1e-3
- decay = 1e-7
- checkpoint_date = 'checkpoint/checkpoint_' + str(datetime.date.today())
-
-
- def validation(net, loader, criterion, device, valid_num):
- net.eval()
- valid_loss = 0
- valid_record = []
- for valid_img, valid_label in loader:
- valid_img, valid_label = valid_img.to(device), valid_label.to(device)
- with torch.no_grad():
- pred = net(valid_img)
- valid_loss += criterion(pred, valid_label).item()
- valid_record.append([pred, valid_label])
- net.train()
- valid_acc = classes_multiple_acc(valid_record)
-
- return valid_loss / valid_num, valid_acc
-
-
- def train(net, device, train_txt, valid_txt, net_name):
- # 1.datasets
- train_sample = MyDataSets(train_txt)
- valid_sample = MyDataSets(valid_txt)
-
- train_loader = DataLoader(train_sample, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True)
- valid_loader = DataLoader(valid_sample, batch_size=batch_size, num_workers=8, pin_memory=True, shuffle=True)
-
- # 2.models setting
- # criterion = nn.BCELoss()
- criterion = nn.CrossEntropyLoss().to(device)
- # optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=decay)
- optimizer = optim.SGD(net.parameters(), lr=lr, weight_decay=decay, momentum=0.9)
-
- # 如果 20 个 epoch,验证集的准确率没有下降,则证明目前是正确的下降路径,减小 lr
- scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=40, mode='min', factor=0.1, min_lr=1e-7)
-
- # 3.record training
- # log_dir = "runs/" + str(datetime.datetime.now().strftime('%m-%d_%H-%M-%S')) + f"_{net_name}"
- # writer = SummaryWriter(log_dir)
- # graph_input = torch.rand(3, 3, 224, 224, 16).to(device)
- # writer.add_graph(net, input_to_model=graph_input, verbose=False)
-
- # 4.training params
- train_acc_min, valid_acc_min = 0, 0
- train_sample_num, valid_sample_num = len(train_sample), len(valid_sample)
-
- for epoch in range(epochs):
- net.train()
- epoch_loss, epoch_record = 0, []
- pbar = tqdm(enumerate(train_loader), total=len(train_sample) // batch_size,
- desc=f'epoch {epoch + 1}/{epochs}', unit='patient')
-
- for idx, (train_img, train_label) in pbar:
- train_img, train_label = train_img.to(device), train_label.to(device)
-
- optimizer.zero_grad()
-
- prediction = net(train_img)
- loss = criterion(prediction, train_label)
-
- loss.backward()
- # nn.utils.clip_grad_value_(net.parameters(), 0.1)
- optimizer.step()
-
- epoch_loss += loss.item()
-
- # 将预测值的结果转换为 one_hot (每个预测结果只取最大值处为1, 其余为0)
- prediction = (prediction == prediction.max(dim=1, keepdim=True)[0]).to(torch.float32)
- epoch_record.append([prediction, train_label])
- # Test 准确率和召回率
- # classes_multiple_acc(epoch_record)
- pbar.set_postfix(train_loss=loss.item())
-
- epoch_loss, epoch_acc = epoch_loss / train_sample_num, classes_multiple_acc(epoch_record)
- valid_loss, valid_acc = validation(net, valid_loader, criterion, device, valid_sample_num)
- print(epoch_loss, epoch_acc)
- print(valid_loss, valid_acc)
- scheduler.step(0.3 * epoch_acc + 0.7 * valid_acc)
-
- # writer.add_scalar('Loss/train', epoch_loss, epoch)
- # writer.add_scalar('Loss/validation', valid_loss, epoch)
- # writer.add_scalar('Accuracy/train', epoch_acc, epoch)
- # writer.add_scalar('Accuracy/validation', valid_acc, epoch)
-
- if epoch_acc >= train_acc_min and valid_acc >= valid_acc_min:
- train_acc_min = epoch_acc
- valid_acc_min = valid_acc
- checkpoint_dir = checkpoint_date + f'_{net_name}'
- if not os.path.exists(checkpoint_dir):
- os.mkdir(checkpoint_dir)
- save_dir = os.path.join(checkpoint_dir,
- f'train_{format(epoch_acc, ".3f")}-valid_{format(valid_acc, ".3f")}.pth')
- torch.save(net.state_dict(), save_dir)
-
-
- if __name__ == '__main__':
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-
- # server
- train_txt = '/code/datasets/train.txt'
- valid_txt = '/code/datasets/valid.txt'
- test_txt = '/code/datasets/test.txt'
-
- """ 未预训练 / 预训练(正常) / 预训练(特征提取器) """
- net = generate_resnet(50, n_input_channels=3, n_classes=10)
- net_name = "resnet50"
- net.to(device)
- train(net, device, train_txt, test_txt, net_name)
-
-
|