|
- import argparse
- import logging
- import os
- import random
- import sys
- import time
- import numpy as np
- import torch
- import torch.nn as nn
- import torch.optim as optim
- from tensorboardX import SummaryWriter
- from torch.nn.modules.loss import CrossEntropyLoss
- from torch.utils.data import DataLoader
- from tqdm import tqdm
- from utils import DiceLoss
- from torchvision import transforms
- from utils import test_single_volume
-
- def trainer_synapse(args, model, snapshot_path):
- from datasets.dataset_synapse import Synapse_dataset, RandomGenerator
- logging.basicConfig(filename=snapshot_path + "/log.txt", level=logging.INFO,
- format='[%(asctime)s.%(msecs)03d] %(message)s', datefmt='%H:%M:%S')
- logging.getLogger().addHandler(logging.StreamHandler(sys.stdout))
- logging.info(str(args))
- base_lr = args.base_lr
- num_classes = args.num_classes
- batch_size = args.batch_size * args.n_gpu
- # max_iterations = args.max_iterations
- db_train = Synapse_dataset(base_dir=args.root_path, list_dir=args.list_dir, split="train",
- transform=transforms.Compose(
- [RandomGenerator(output_size=[args.img_size, args.img_size])]))
- print("The length of train set is: {}".format(len(db_train)))
-
- def worker_init_fn(worker_id):
- random.seed(args.seed + worker_id)
-
- trainloader = DataLoader(db_train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True,
- worker_init_fn=worker_init_fn)
- if args.n_gpu > 1:
- model = nn.DataParallel(model)
- model.train()
- ce_loss = CrossEntropyLoss()
- dice_loss = DiceLoss(num_classes)
- optimizer = optim.SGD(model.parameters(), lr=base_lr, momentum=0.9, weight_decay=0.0001)
- writer = SummaryWriter(snapshot_path + '/log')
- iter_num = 0
- max_epoch = args.max_epochs
- max_iterations = args.max_epochs * len(trainloader) # max_epoch = max_iterations // len(trainloader) + 1
- logging.info("{} iterations per epoch. {} max iterations ".format(len(trainloader), max_iterations))
- best_performance = 0.0
- iterator = tqdm(range(max_epoch), ncols=70)
- for epoch_num in iterator:
- for i_batch, sampled_batch in enumerate(trainloader):
- image_batch, label_batch = sampled_batch['image'], sampled_batch['label']
- image_batch, label_batch = image_batch.cuda(), label_batch.cuda()
- outputs = model(image_batch)
- loss_ce = ce_loss(outputs, label_batch[:].long())
- loss_dice = dice_loss(outputs, label_batch, softmax=True)
- loss = 0.4 * loss_ce + 0.6 * loss_dice
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- lr_ = base_lr * (1.0 - iter_num / max_iterations) ** 0.9
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr_
-
- iter_num = iter_num + 1
- writer.add_scalar('info/lr', lr_, iter_num)
- writer.add_scalar('info/total_loss', loss, iter_num)
- writer.add_scalar('info/loss_ce', loss_ce, iter_num)
-
- logging.info('iteration %d : loss : %f, loss_ce: %f' % (iter_num, loss.item(), loss_ce.item()))
-
- if iter_num % 20 == 0:
- image = image_batch[1, 0:1, :, :]
- image = (image - image.min()) / (image.max() - image.min())
- writer.add_image('train/Image', image, iter_num)
- outputs = torch.argmax(torch.softmax(outputs, dim=1), dim=1, keepdim=True)
- writer.add_image('train/Prediction', outputs[1, ...] * 50, iter_num)
- labs = label_batch[1, ...].unsqueeze(0) * 50
- writer.add_image('train/GroundTruth', labs, iter_num)
-
- save_interval = 50 # int(max_epoch/6)
- if epoch_num > int(max_epoch / 2) and (epoch_num + 1) % save_interval == 0:
- save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
- torch.save(model.state_dict(), save_mode_path)
- logging.info("save model to {}".format(save_mode_path))
-
- if epoch_num >= max_epoch - 1:
- save_mode_path = os.path.join(snapshot_path, 'epoch_' + str(epoch_num) + '.pth')
- torch.save(model.state_dict(), save_mode_path)
- logging.info("save model to {}".format(save_mode_path))
- iterator.close()
- break
-
- writer.close()
- return "Training Finished!"
|