|
- import os
- import random
- import torch
- import logging
- import numpy as np
- import torch.nn as nn
- from models.FullNet import MultiTaskFullNet
-
-
- from data_folder import DataFolder
- from torch.utils.data import DataLoader
- from my_transforms import get_transforms
- from models.utils import mmseg_acc
- from tensorboardX import SummaryWriter
- from sklearn.metrics import accuracy_score
- import shutil
- import argparse
-
- writer = SummaryWriter()
-
-
- def save_checkpoint(state, epoch, is_best, save_dir, cp_flag):
- cp_dir = '{:s}/checkpoints'.format(save_dir)
- if not os.path.exists(cp_dir):
- os.mkdir(cp_dir)
- filename = '{:s}/checkpoint_{}.pth'.format(cp_dir, epoch)
- torch.save(state, filename)
- if is_best:
- shutil.copyfile(filename, '{:s}/checkpoint_best.pth'.format(cp_dir))
-
-
- def get_one_hot(gt, num_classes):
- size = list(gt.size())
- gt = gt.cpu()
- gt = gt.view(-1) # reshape 为向量
- ones = torch.eye(num_classes)
- ones = ones.index_select(0, gt) # 用上面的办法转为换one hot
- size.append(num_classes) # 把类别输目添到size的尾后,准备reshape回原来的尺寸
- res = ones.view(*size).permute(0, 3, 1, 2).cuda()
- return res
-
-
- def main(args, logger):
- model = MultiTaskFullNet(color_channels=3, output_channels=args.num_classes)
- model = nn.DataParallel(model) # add ', device_ids=opt.train['gpu']'
- model = model.cuda()
-
- if args.optimizer == 'adam':
- optimizer = torch.optim.Adam(model.parameters(), args.lr, betas=(0.9, 0.99),
- weight_decay=args.weight_decay)
- if args.optimizer == 'SGD':
- optimizer = torch.optim.SGD(model.parameters(), args.lr, momentum=0.9,
- weight_decay=args.weight_decay)
-
- dsets = {}
-
- data_transforms = {
- 'train': get_transforms({
- 'scale': 240,
- 'horizontal_flip': True,
- 'random_affine': 0.3,
- 'random_elastic': [6, 15],
- 'random_rotation': 90,
- 'random_crop': 240,
- 'to_tensor': 1,
- }),
- 'validation': get_transforms({
- 'scale': 240,
- 'to_tensor': 1,
- })
- }
-
- for x in ['train', 'validation']:
- img_dir = os.path.join(args.train_img_dir, x)
- target_dir = os.path.join(args.train_label_dir, x)
- dir_list = [img_dir, target_dir]
- post_fix = ['.png']
-
- num_channels = [3, 1]
- dsets[x] = DataFolder(dir_list, post_fix, num_channels, data_transforms[x])
- train_loader = DataLoader(dsets['train'], batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)
- val_loader = DataLoader(dsets['validation'], batch_size=args.batch_size, shuffle=True, pin_memory=True,
- num_workers=4)
-
- if args.seg_loss == "CE":
- criterion_seg = torch.nn.CrossEntropyLoss().cuda()
- if args.seg_loss == "MSE":
- criterion_seg = torch.nn.MSELoss().cuda()
-
- criterion_cls = torch.nn.CrossEntropyLoss().cuda()
-
- best_cri = -1
-
- for epoch in range(args.num_epoches):
- # train for one epoch or len(train_loader) iterations
- logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, args.num_epoches))
-
- for index, sample in enumerate(train_loader):
- input, target, category = sample
- model.train()
-
- input, target, category = input.cuda(), target.squeeze().cuda(), category.squeeze().cuda()
- segoutput, clsoutput = model(input)
-
- if args.seg_loss == "CE":
- segloss = criterion_seg(segoutput, target)
- if args.seg_loss == "MSE":
- segloss = criterion_seg(segoutput, get_one_hot(target, num_classes=args.num_classes))
-
- clsloss = criterion_cls(clsoutput, category)
-
- # l = 1 - (epoch / num_epoches) ** 2 # parabolic decay
- # l = 0.5 # fix
- # l = math.cos(epoch / num_epoches * math.pi /2) # cosine decay
- # l = 1 - epoch / num_epoches # linear decay
- # l = np.random.beta(0.2, 0.2) # beta distribution
- # l = 1 if self.epoch <= 120 else 0 # seperated stage
-
- # loss = l * segloss + (2 - l) * clsloss
- loss = segloss + clsloss
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
-
- # calculate the mIoU
- pred = np.argmax(segoutput.detach().cpu().numpy(), axis=1)
- all_acc, acc, iou = mmseg_acc(pred, target.cpu().numpy(), num_classes=args.num_classes, ignore_index=0)
-
- # calculate the classification precision
- pred_cls = np.argmax(clsoutput.detach().cpu().numpy(), axis=1)
- precision = accuracy_score(pred_cls, category)
-
- if index % 30 == 0:
- logger.info(
- "Iteration {}/{},Training Loss : {}, Segmentation Loss : {}, Classification Loss:{}, Pixel accu : {}, mIoU : {}, Classification Precision: {}"
- .format(epoch + 1,
- args.num_epoches,
- loss.item(),
- segloss.item(),
- clsloss.item(),
- np.nanmean(acc),
- np.nanmean(iou),
- precision))
-
- niter = epoch * len(train_loader) + index
- writer.add_scalar('Train Total Loss', loss.item(), niter)
- writer.add_scalar('Train Segmentation Loss', segloss.item(), niter)
- writer.add_scalar('Train Classification Loss', clsloss.item(), niter)
-
- # writer.add_image('image', segoutput[0], global_step=None, walltime=None, dataformats='CHW')
-
- # images=input
- # features = images#.view( 256, , 3)
- # print(images.shape)
- # label_imgs=(images.permute(0,3,1,2)).unsqueeze(1)
- # writer.add_embedding(features,metadata=target,label_img=images)
-
- if (epoch + 1) % args.eval_per_epoch == 0:
- model.eval()
-
- eval_results = np.zeros((2,), np.float)
-
- max_batch = 0
- t = 0
- p = 0
-
- for index, sample in enumerate(val_loader):
- input, target, category = sample
-
- input, target, category = input.cuda(), target.squeeze().cuda(), category.squeeze().cuda()
-
- segoutput, clsoutput = model(input)
- pred = np.argmax(segoutput.detach().cpu().numpy(), axis=1)
- all_acc, acc, iou = mmseg_acc(pred, target.cpu().numpy(), num_classes=args.num_classes, ignore_index=0)
- eval_results += np.array([np.nanmean(acc), np.nanmean(iou)])
- max_batch = max(index, max_batch)
-
- pred_cls = np.argmax(clsoutput.detach().cpu().numpy(), axis=1)
- p += accuracy_score(pred_cls, category, normalize=False)
- t += input.shape[0]
-
- eval_results = [value / max_batch for value in eval_results.tolist()]
- logger.info("Eval Results : mAcc = {}, mIoU = {}, Classification Precision: {}".format(eval_results[0],
- eval_results[1], p / t))
-
- writer.add_scalar('Val Segmentation mAcc', eval_results[0], epoch)
- writer.add_scalar('Val Segmentation mIoU', eval_results[1], epoch)
- writer.add_scalar('Val Classification Precision', p / t, epoch)
-
- # check if it is the best accuracy
- val_iou = eval_results[1]
- cls_precision = p / t
-
- cri = val_iou + cls_precision
-
- val_loss = loss.item()
- is_best = cri > best_cri
- # is_best = val_loss < best_loss
-
- best_cri = max(cri, best_cri)
- # best_loss = min(val_loss, best_loss)
- save_checkpoint({
- 'epoch': epoch + 1,
- 'state_dict': model.state_dict(),
- 'best_cri': best_cri,
- # 'best_loss': best_loss,
- 'optimizer': optimizer.state_dict(),
- }, epoch, is_best, args.model_save_dir, True)
-
-
- if __name__ == "__main__":
- parser = argparse.ArgumentParser(description='Process some integers.')
-
- # dataloader
- parser.add_argument('--batch_size', default=4, type=int)
- parser.add_argument('--num_classes', default=8, type=int)
-
- # loss
- parser.add_argument('--seg_loss', default="CE", type=str, choices=["CE", "MSE"])
-
- # optimizer
- parser.add_argument('--optimizer', default="SGD", type=str, choices=["SGD", "Adam"])
- parser.add_argument('--lr', default=0.001, type=float)
- parser.add_argument('--weight_decay', default=1e-4, type=float)
-
- # training
- parser.add_argument('--num_epoches', default=300, type=int)
- parser.add_argument('--train_img_dir', default="endoscope400/ade20k/images", type=str)
- parser.add_argument('--train_label_dir', default="endoscope400/ade20k/annotations", type=str)
-
- # evaluation
- parser.add_argument('--eval_per_epoch', default=10, type=int)
-
- parser.add_argument('--model_save_dir', default="runs/", type=str)
-
- args = parser.parse_args()
-
- logging.basicConfig(filename='runs/logs_{}'.format(int(random.random()*10000000)),
- filemode='a',
- format='%(asctime)s,%(msecs)d %(name)s %(levelname)s %(message)s',
- datefmt='%H:%M:%S',
- level=logging.INFO)
- sh = logging.StreamHandler()#往屏幕上输出
- fh = logging.FileHandler('runs/logs_{}'.format(int(random.random()*10000000)))
-
- logging.info("Training and Validation Record.")
- logger = logging.getLogger("FullNet for Endoscope")
- logger.addHandler(sh)
- logger.addHandler(fh)
-
- os.environ["CUDA_VISIBLE_DEVICES"] = "2"
-
- main(args, logger)
|