|
- import os
- import csv
- from tqdm import tqdm
-
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import DataLoader
- from torchvision.datasets import CIFAR10
- import torchvision.transforms as T
-
-
- from utils import AverageMeter
- from model import Model
-
-
- DISABLE_TQDM = True
-
- @torch.no_grad()
- def validate(model, dataloader, criterion, args):
- Loss = AverageMeter()
- total_top1, total_top5, total_num = 0, 0, 0
- model.eval()
- for inputs, targets in dataloader:
- batch_size = inputs.size(0)
- inputs, targets = inputs.to(args.device, non_blocking=True), targets.to(args.device, non_blocking=True)
- logits = model(inputs)
- loss = criterion(logits, targets)
-
- Loss.update(loss.item(), batch_size)
- prediction = torch.argsort(logits, dim=-1, descending=True)
- total_top1 += torch.sum((prediction[:, 0:1] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item()
- total_top5 += torch.sum((prediction[:, 0:5] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item()
- total_num += batch_size
-
- print(f'Val ep{args.epoch+1}: Loss={Loss.item():6.4f}, Acc@top1={total_top1/total_num:.2%}, Acc@top5={total_top5/total_num:.2%}')
- return Loss.item(), total_top1/total_num, total_top5/total_num
-
- def train(model, dataloader, criterion, optimizer, args):
- total_top1, total_top5, total_num = 0, 0, 0
- Loss = AverageMeter()
- model.train()
- # lr = scheduler.get_last_lr()[0]
-
- with tqdm(dataloader, total=len(dataloader), desc=f'Epoch {args.epoch+1:2d}', ncols=100, disable=DISABLE_TQDM) as t:
- for inputs, targets in t:
- batch_size = inputs.size(0)
- inputs, targets = inputs.to(args.device, non_blocking=True), targets.to(args.device, non_blocking=True)
-
- logits = model(inputs)
- loss = criterion(logits, targets)
-
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- Loss.update(loss.item(), batch_size)
- prediction = torch.argsort(logits, dim=-1, descending=True)
- total_top1 += torch.sum((prediction[:, 0:1] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item()
- total_top5 += torch.sum((prediction[:, 0:5] == targets.unsqueeze(dim=-1)).any(dim=-1).float()).item()
- total_num += batch_size
-
- t.set_postfix_str(f'Loss={Loss.item():6.4f}, Acc@top1={total_top1/total_num:.2%}, Acc@top5={total_top5/total_num:.2%}')
-
- # scheduler.step()
- if DISABLE_TQDM:
- print(f'Train ep{args.epoch+1}: Loss={Loss.item():6.4f}, Acc@top1={total_top1/total_num:.2%}, Acc@top5={total_top5/total_num:.2%}')
-
- return Loss.item(), total_top1/total_num, total_top5/total_num
-
-
- def main(args):
- #setup random seed
- if args.seed and isinstance(args.seed, int):
- setup_seed(args.seed)
-
- # set device
- device = args.device
-
- # preprare dataset & dataloader
- train_transform = T.Compose([
- T.RandomCrop(32, 4),
- T.RandomHorizontalFlip(p=0.5),
- T.ToTensor(),
- T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
- ])
-
- val_transform = T.Compose([
- T.ToTensor(),
- T.Normalize([0.4914, 0.4822, 0.4465], [0.2023, 0.1994, 0.2010])
- ])
-
- train_set = CIFAR10(args.datafolder, train=True, transform=train_transform, download=True)
- val_set = CIFAR10(args.datafolder, train=False, transform=val_transform, download=False)
-
- print(f"Dataset: {len(train_set)} samples for train, {len(val_set)} sampels for valid!")
-
- dataloader_config = {
- "batch_size": args.batch_size,
- "num_workers": args.num_workers,
- "pin_memory": True
- }
- train_loader = DataLoader(train_set, shuffle=True, drop_last=True, **dataloader_config)
- val_loader = DataLoader(val_set, shuffle=False, drop_last=False, **dataloader_config)
- print(f"Dataloader: {len(train_loader)} batches, batch size {args.batch_size} per device, num workers is {args.num_workers}.")
-
- # prepare model & weights
- model = Model(args.arch, "clf", args.num_classes).to(device)
- model.backbone.load_state_dict(torch.load(args.weights))
- model.backbone.requires_grad_(False)
- try:
- from torchinfo import summary
- summary(model, input_data=torch.randn(4, 3, 32, 32).to(device), col_names=("output_size", "num_params", "params_percent", "mult_adds", "trainable"), depth=3)
- except Exception as e:
- print(f"Model: ResNet. Total params: {sum(p.numel() for p in model.parameters())/1e6:.2f} M")
- # prepare criterion, optimizer and shceduler.
- criterion = nn.CrossEntropyLoss(reduction='mean')
- optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad], lr=args.learning_rate, weight_decay=args.weight_decay)
- scheduler = None
-
- with open(os.path.join(args.out, 'linear_statistics.csv'), 'w', encoding='utf8', newline='') as f:
- writer = csv.DictWriter(f, fieldnames=['train_loss', 'train acc@1', 'train acc@5', 'val loss', 'val acc@1', 'val acc@5'])
- writer.writeheader()
-
- best_val_acc = 0
- for epoch in range(args.epochs):
- args.epoch = epoch
- train_loss, train_acc_1, train_acc_5 = train(model, train_loader, criterion, optimizer, args)
- val_loss, val_acc_1, val_acc_5 = validate(model, val_loader, criterion, args)
- # save statistics
- with open(os.path.join(args.out, 'linear_statistics.csv'), 'a', encoding='utf8', newline='') as f:
- writer = csv.DictWriter(f, fieldnames=['train loss', 'train acc@1', 'train acc@5', 'val loss', 'val acc@1', 'val acc@5'])
- writer.writerow({
- 'train loss': train_loss,
- 'train acc@1': train_acc_1,
- 'train acc@5': train_acc_5,
- 'val loss': val_loss,
- 'val acc@1': val_acc_1,
- 'val acc@5': val_acc_5
- })
-
- torch.save(model.state_dict(), os.path.join(args.out, 'linear_model.pth'))
-
-
- if __name__ == '__main__':
- import argparse
-
- parser = argparse.ArgumentParser(description='Linear Probing')
- parser.add_argument('--weights', type=str, default='/model/128_4096_0.5_0.999_200_256_500_model.pth', required=True, help='The pretrained model weight path')
- parser.add_argument('--out', type=str, default='/model', help='output folder')
- parser.add_argument('--seed', default=None, type=int, help="random seed")
- parser.add_argument('--arch', type=str, default='resnet50', help='backbone model arch')
- parser.add_argument('--datafolder', type=str, default='../', help='data folder')
- parser.add_argument('--num_workers', type=int, default=None)
- parser.add_argument('--num_classes', type=int, default=10, help='num classes')
-
- parser.add_argument('--learning_rate', type=float, default=1e-3, help='optimizer base learning rate.')
- parser.add_argument('--weight_decay', type=float, default=1e-6, help='parameters weight decay.')
- parser.add_argument('--batch_size', type=int, default=256, help='Number of images in each mini-batch')
- parser.add_argument('--epochs', type=int, default=100, help='Number of sweeps over the dataset to train')
- parser.add_argument('--cuda', action='store_true', default=False, help='use cuda')
-
- args = parser.parse_args()
- print(args)
-
- os.makedirs(args.out, exist_ok=True)
- args.cuda = torch.cuda.is_available()
- args.device = torch.device('cuda') if torch.cuda.is_available() and args.cuda else torch.device('cpu')
- if args.num_workers is None:
- args.num_workers = min(os.cpu_count(), 8)
-
- main(args)
|