|
-
- import argparse
- import glob
- import os
-
- import cv2
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from torch.utils.data import DataLoader
- import numpy as np
-
- from torchvision import transforms
- from conf import settings
- from utils import *
-
- import matplotlib
- matplotlib.use('Agg')
- import matplotlib.pyplot as plt
-
-
- from torch.optim.lr_scheduler import _LRScheduler
-
-
- class FindLR(_LRScheduler):
- """exponentially increasing learning rate
-
- Args:
- optimizer: optimzier(e.g. SGD)
- num_iter: totoal_iters
- max_lr: maximum learning rate
- """
- def __init__(self, optimizer, max_lr=10, num_iter=100, last_epoch=-1):
-
- self.total_iters = num_iter
- self.max_lr = max_lr
- super().__init__(optimizer, last_epoch)
-
- def get_lr(self):
-
- return [base_lr * (self.max_lr / base_lr) ** (self.last_epoch / (self.total_iters + 1e-32)) for base_lr in self.base_lrs]
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('-net', type=str, required=True, help='net type')
- parser.add_argument('-b', type=int, default=64, help='batch size for dataloader')
- parser.add_argument('-base_lr', type=float, default=1e-7, help='min learning rate')
- parser.add_argument('-max_lr', type=float, default=10, help='max learning rate')
- parser.add_argument('-num_iter', type=int, default=100, help='num of iteration')
- parser.add_argument('-gpu', type=bool, default=True, help='use gpu or not')
- parser.add_argument('-gpus', nargs='+', type=int, default=0, help='gpu device')
- args = parser.parse_args()
-
- cifar100_training_loader = get_training_dataloader(
- settings.CIFAR100_TRAIN_MEAN,
- settings.CIFAR100_TRAIN_STD,
- num_workers=4,
- batch_size=args.b,
- )
-
- net = get_network(args)
-
- loss_function = nn.CrossEntropyLoss()
- optimizer = optim.SGD(net.parameters(), lr=args.base_lr, momentum=0.9, weight_decay=1e-4, nesterov=True)
-
- #set up warmup phase learning rate scheduler
- lr_scheduler = FindLR(optimizer, max_lr=args.max_lr, num_iter=args.num_iter)
- epoches = int(args.num_iter / len(cifar100_training_loader)) + 1
-
- n = 0
-
- learning_rate = []
- losses = []
- for epoch in range(epoches):
-
- #training procedure
- net.train()
-
- for batch_index, (images, labels) in enumerate(cifar100_training_loader):
- if n > args.num_iter:
- break
-
- lr_scheduler.step()
-
- images = images.cuda()
- labels = labels.cuda()
-
- optimizer.zero_grad()
- predicts = net(images)
- loss = loss_function(predicts, labels)
- if torch.isnan(loss).any():
- n += 1e8
- break
- loss.backward()
- optimizer.step()
-
- print('Iterations: {iter_num} [{trained_samples}/{total_samples}]\tLoss: {:0.4f}\tLR: {:0.8f}'.format(
- loss.item(),
- optimizer.param_groups[0]['lr'],
- iter_num=n,
- trained_samples=batch_index * args.b + len(images),
- total_samples=len(cifar100_training_loader.dataset),
- ))
-
- learning_rate.append(optimizer.param_groups[0]['lr'])
- losses.append(loss.item())
- n += 1
-
- learning_rate = learning_rate[10:-5]
- losses = losses[10:-5]
-
- fig, ax = plt.subplots(1,1)
- ax.plot(learning_rate, losses)
- ax.set_xlabel('learning rate')
- ax.set_ylabel('losses')
- ax.set_xscale('log')
- ax.xaxis.set_major_formatter(plt.FormatStrFormatter('%.0e'))
-
- fig.savefig('result.jpg')
|