|
- import argparse
-
- import csv
-
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import DataLoader
- from torchvision.datasets import CIFAR10
- from tqdm import tqdm
-
- import utils
- from model import Model
-
-
- class Net(nn.Module):
- def __init__(self, num_class, pretrained_path):
- super(Net, self).__init__()
-
- # encoder
- self.f = Model().f
- # classifier
- self.fc = nn.Linear(2048, num_class, bias=True)
- self.load_state_dict(torch.load(pretrained_path, map_location='cpu'), strict=False)
-
- def forward(self, x):
- x = self.f(x)
- feature = torch.flatten(x, start_dim=1)
- out = self.fc(feature)
- return out
-
-
- # train or test for one epoch
- def train_val(net, data_loader, train_optimizer):
- is_train = train_optimizer is not None
- net.train() if is_train else net.eval()
-
- total_loss, total_correct_1, total_correct_5, total_num, data_bar = 0.0, 0.0, 0.0, 0, tqdm(data_loader)
- with (torch.enable_grad() if is_train else torch.no_grad()):
- for data, target in data_bar:
- data, target = data.cuda(non_blocking=True), target.cuda(non_blocking=True)
- out = net(data)
- loss = loss_criterion(out, target)
-
- if is_train:
- train_optimizer.zero_grad()
- loss.backward()
- train_optimizer.step()
-
- total_num += data.size(0)
- total_loss += loss.item() * data.size(0)
- prediction = torch.argsort(out, dim=-1, descending=True)
- total_correct_1 += torch.sum((prediction[:, 0:1] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
- total_correct_5 += torch.sum((prediction[:, 0:5] == target.unsqueeze(dim=-1)).any(dim=-1).float()).item()
-
- data_bar.set_description('{} Epoch: [{}/{}] Loss: {:.4f} ACC@1: {:.2f}% ACC@5: {:.2f}%'
- .format('Train' if is_train else 'Test', epoch, epochs, total_loss / total_num,
- total_correct_1 / total_num * 100, total_correct_5 / total_num * 100))
-
- return total_loss / total_num, total_correct_1 / total_num * 100, total_correct_5 / total_num * 100
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser(description='Linear Evaluation')
- parser.add_argument('--model_path', type=str, default='/model/128_4096_0.5_0.999_200_256_500_model.pth',
- help='The pretrained model path')
- parser.add_argument('--batch_size', type=int, default=256, help='Number of images in each mini-batch')
- parser.add_argument('--epochs', type=int, default=100, help='Number of sweeps over the dataset to train')
-
- args = parser.parse_args()
- model_path, batch_size, epochs = args.model_path, args.batch_size, args.epochs
- train_data = CIFAR10(root='/dataset', train=True, transform=utils.train_transform, download=True)
- test_data = CIFAR10(root='/dataset', train=False, transform=utils.test_transform, download=True)
- train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=16, pin_memory=True)
- test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=16, pin_memory=True)
-
- model = Net(num_class=len(train_data.classes), pretrained_path=model_path).cuda()
- for param in model.f.parameters():
- param.requires_grad = False
-
- optimizer = optim.Adam(model.fc.parameters(), lr=1e-3, weight_decay=1e-6)
- loss_criterion = nn.CrossEntropyLoss()
- results = {'train_loss': [], 'train_acc@1': [], 'train_acc@5': [],
- 'test_loss': [], 'test_acc@1': [], 'test_acc@5': []}
-
- best_acc = 0.0
-
- with open('/model/linear_statistics.csv', 'w', encoding='utf8', newline='') as f:
- writer = csv.DictWriter(f, fieldnames=['train_loss', 'train_acc@1', 'train_acc@5', 'test_loss', 'test_acc@1', 'test_acc@5'])
- writer.writeheader() # 将字段写入csv格式文件首行
-
- for epoch in range(1, epochs + 1):
- train_loss, train_acc_1, train_acc_5 = train_val(model, train_loader, optimizer)
- results['train_loss'].append(train_loss)
- results['train_acc@1'].append(train_acc_1)
- results['train_acc@5'].append(train_acc_5)
- test_loss, test_acc_1, test_acc_5 = train_val(model, test_loader, None)
- results['test_loss'].append(test_loss)
- results['test_acc@1'].append(test_acc_1)
- results['test_acc@5'].append(test_acc_5)
- # save statistics
- with open('/model/linear_statistics.csv', 'w', encoding='utf8', newline='') as f:
- writer = csv.DictWriter(f, fieldnames=['train_loss', 'train_acc@1', 'train_acc@5', 'test_loss', 'test_acc@1', 'test_acc@5'])
- for line in results:
- writer.writerow(line)
-
- if test_acc_1 > best_acc:
- best_acc = test_acc_1
- torch.save(model.state_dict(), '/model/linear_model.pth')
|