|
- import os
- import argparse
- import torch
- import torchio
-
- from tqdm import tqdm
- from torch.utils.data import DataLoader
- from torchio.transforms import ZNormalization
- from torch.utils.tensorboard import SummaryWriter
- from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR, CosineAnnealingLR
-
- from utils.metric import metric
- from loss_function import BinaryDiceLoss
- from hparam import hparams as hp
-
- devicess = [0]
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
- torch.cuda.set_device(0)
-
- source_train_dir = hp.source_train_dir
- label_train_dir = hp.label_train_dir
-
- source_test_dir = hp.source_test_dir
- label_test_dir = hp.label_test_dir
-
- output_dir_test = hp.output_dir_test
-
-
- def parse_training_args(parser):
- """
- Parse commandline arguments.
- """
-
- parser.add_argument('-o', '--output_dir', type=str, default=hp.output_dir, required=False, help='Directory to save checkpoints')
- parser.add_argument('--latest-checkpoint-file', type=str, default=hp.latest_checkpoint_file, help='Store the latest checkpoint in each epoch')
- parser.add_argument('--model', type=str, default=hp.model, help='Select the model')
- parser.add_argument('--scheduler', type=str, default=hp.scheduler, help='Select the scheduler')
-
- # training
- training = parser.add_argument_group('training setup')
- training.add_argument('--epochs', type=int, default=hp.total_epochs, help='Number of total epochs to run')
- training.add_argument('--epochs-per-checkpoint', type=int, default=hp.epochs_per_checkpoint, help='Number of epochs per checkpoint')
- training.add_argument('--bs', type=int, default=hp.batch_size, help='batch-size')
- parser.add_argument(
- '-k',
- "--ckpt",
- type=str,
- default=hp.ckpt,
- help="path to the checkpoints to resume training",
- )
- parser.add_argument("--init-lr", type=float, default=hp.init_lr, help="learning rate")
- # TODO
- parser.add_argument(
- "--local_rank", type=int, default=0, help="local rank for distributed training"
- )
-
- training.add_argument('--amp-run', action='store_true', help='Enable AMP')
- training.add_argument('--cudnn-enabled', default=True, help='Enable cudnn')
- training.add_argument('--cudnn-benchmark', default=True, help='Run cudnn benchmark')
- training.add_argument('--disable-uniform-initialize-bn-weight', action='store_true', help='disable uniform initialization of batchnorm layer weight')
-
- return parser
-
-
-
- def train():
-
- parser = argparse.ArgumentParser(description='PyTorch Medical Segmentation Training')
- parser = parse_training_args(parser)
- args, _ = parser.parse_known_args()
-
- args = parser.parse_args()
-
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.enabled = args.cudnn_enabled
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
-
- from data_function import MedData_train
- os.makedirs(args.output_dir, exist_ok=True)
-
- if hp.mode == '2d':
- if hp.model == 'Unet':
- from models.two_d.unet import Unet
- model = Unet(in_channels=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'MiniSeg':
- from models.two_d.miniseg import MiniSeg
- model = MiniSeg(in_input=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'fcn':
- from models.two_d.fcn import FCN32s as fcn
- model = fcn(in_class =hp.in_class,n_class=hp.out_class+1)
-
- elif hp.model == 'SegNet':
- from models.two_d.segnet import SegNet
- model = SegNet(input_nbr=hp.in_class,label_nbr=hp.out_class+1)
-
- elif hp.model == 'DeepLabV3':
- from models.two_d.deeplab import DeepLabV3
- model = DeepLabV3(in_class=hp.in_class,class_num=hp.out_class+1)
-
- elif hp.model == 'ResNet34UnetPlus':
- from models.two_d.unetpp import ResNet34UnetPlus
- model = ResNet34UnetPlus(num_channels=hp.in_class,num_class=hp.out_class+1)
-
- elif hp.model == 'PSPNet':
- from models.two_d.pspnet import PSPNet
- model = PSPNet(in_class=hp.in_class,n_classes=hp.out_class+1)
-
- else:
- print('Can not find the model.')
- return
-
- elif hp.mode == '3d':
- if hp.model == 'UNet3D':
- from models.three_d.unet3d import UNet3D
- model = UNet3D(in_channels=hp.in_class, out_channels=hp.out_class+1, init_features=32)
-
- elif hp.model == 'UNet':
- from models.three_d.residual_unet3d import UNet
- model = UNet(in_channels=hp.in_class, n_classes=hp.out_class+1, base_n_filter=2)
-
- elif hp.model == 'FCN_Net':
- from models.three_d.fcn3d import FCN_Net
- model = FCN_Net(in_channels =hp.in_class,n_class =hp.out_class+1)
-
- elif hp.model == 'HighRes3DNet':
- from models.three_d.highresnet import HighRes3DNet
- model = HighRes3DNet(in_channels=hp.in_class,out_channels=hp.out_class+1)
-
- elif hp.model == 'SkipDenseNet3D':
- from models.three_d.densenet3d import SkipDenseNet3D
- model = SkipDenseNet3D(in_channels=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'DenseVoxelNet':
- from models.three_d.densevoxelnet3d import DenseVoxelNet
- model = DenseVoxelNet(in_channels=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'VNet':
- from models.three_d.vnet3d import VNet
- model = VNet(in_channels=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'UNETR':
- from models.three_d.unetr import UNETR
- model = UNETR(img_shape=(hp.crop_or_pad_size), input_dim=hp.in_class, output_dim=hp.out_class+1)
-
- else:
- print('Can not find the model.')
- return
-
- model = torch.nn.DataParallel(model, device_ids=devicess)
- optimizer = torch.optim.Adam(model.parameters(), lr=args.init_lr)
-
- if hp.scheduler == 'ReduceLROnPlateau':
- scheduler = ReduceLROnPlateau(optimizer, 'min',factor=0.5, patience=20, verbose=True)
-
- elif hp.scheduler == 'StepLR':
- scheduler = StepLR(optimizer, step_size=hp.scheduer_step_size, gamma=hp.scheduer_gamma)
-
- elif hp.scheduler == 'CosineAnnealingLR':
- scheduler = CosineAnnealingLR(optimizer, T_max=50, eta_min=5e-6)
-
- else:
- print('Can not find the scheduler.')
- return
-
- if args.ckpt is not None:
- print("load model:", args.ckpt)
- print(os.path.join(args.output_dir, args.latest_checkpoint_file))
- ckpt = torch.load(os.path.join(args.output_dir, args.latest_checkpoint_file), map_location=lambda storage, loc: storage)
-
- model.load_state_dict(ckpt["model"])
- optimizer.load_state_dict(ckpt["optim"])
-
- for state in optimizer.state.values():
- for k, v in state.items():
- if torch.is_tensor(v):
- state[k] = v.cuda()
-
- # scheduler.load_state_dict(ckpt["scheduler"])
- elapsed_epochs = ckpt["epoch"]
- else:
- elapsed_epochs = 0
-
- model.cuda()
-
- # from loss_function import Binary_Loss, DiceLoss
- from torch.nn.modules.loss import CrossEntropyLoss
- criterion_ce = CrossEntropyLoss().cuda()
- dice_loss = BinaryDiceLoss().cuda()
-
- writer = SummaryWriter(args.output_dir)
-
- train_dataset = MedData_train(source_train_dir, label_train_dir)
- train_loader = DataLoader(train_dataset.queue_dataset,
- batch_size=args.batch,
- shuffle=True,
- pin_memory=True,
- drop_last=True,
- num_workers=16)
- print('batchs:', len(train_loader))
- model.train()
-
- epochs = args.epochs - elapsed_epochs
-
- for epoch in range(1, epochs + 1):
-
- loss_per_epoch = 0
- fp = 0
- fn = 0
-
- num_iters = 0
- epoch_dice = 0
-
- # print("epoch:" + str(epoch))
- epoch += elapsed_epochs
-
- for i, batch in enumerate(train_loader):
- # print(f"Batch: {i + 1}/{len(train_loader)} epoch {epoch}")
-
- optimizer.zero_grad()
-
- x = batch['source']['data']
- y = batch['label']['data']
-
- # y_back = torch.zeros_like(y)
- # y_back[(y==0)]=1
-
- x = x.type(torch.FloatTensor).cuda()
- # y = torch.cat((y_back, y),1)
- y = y.type(torch.FloatTensor).cuda()
-
- if hp.mode == '2d':
- x = x.squeeze(4)
- y = y.squeeze(4)
- y = y.squeeze(1)
-
- # y[y!=0] = 1
-
- outputs = model(x)
-
- # for metrics
- labels = outputs.argmax(dim=1)
-
- loss_ce = criterion_ce(outputs, y.long())
- loss_dice = dice_loss(labels, y)
- loss = loss_ce + loss_dice
-
- num_iters += 1
- loss.backward()
-
- optimizer.step()
-
- false_positive_rate, false_negtive_rate, dice = metric(y.cpu(), labels.cpu())
-
- epoch_dice += dice
- fp += false_positive_rate
- fn += false_negtive_rate
-
- loss_per_epoch += loss.item()
-
- scheduler.step()
-
- writer.add_scalar('Training/Loss', loss_per_epoch/len(train_loader), epoch)
- writer.add_scalar('Training/false_positive_rate', fp/len(train_loader), epoch)
- writer.add_scalar('Training/false_negtive_rate', fn/len(train_loader), epoch)
- writer.add_scalar('Training/dice', epoch_dice/len(train_loader), epoch)
-
- print("epoch:"+ str(epoch),"loss:" + str(loss_per_epoch/len(train_loader)), \
- 'lr:' + str(scheduler._last_lr[0]), \
- "dice:" + str(epoch_dice/len(train_loader)),\
- "fp:"+str(fp/len(train_loader)),\
- "fn:"+str(fn/len(train_loader)) )
-
- # Store latest checkpoint in each epoch
- torch.save(
- {
- "model": model.state_dict(),
- "optim": optimizer.state_dict(),
- "scheduler":scheduler.state_dict(),
- "epoch": epoch
- },
- os.path.join(args.output_dir, args.latest_checkpoint_file),
- )
-
- # Save checkpoint
- if epoch % args.epochs_per_checkpoint == 0:
- torch.save(
- {
- "model": model.state_dict(),
- "optim": optimizer.state_dict(),
- "epoch": epoch
- },
- os.path.join(args.output_dir, f"checkpoint_{epoch:04d}.pt"),
- )
-
- writer.close()
-
-
- def test():
-
- parser = argparse.ArgumentParser(description='PyTorch Medical Segmentation Testing')
- parser = parse_training_args(parser)
- args, _ = parser.parse_known_args()
-
- args = parser.parse_args()
-
-
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.enabled = args.cudnn_enabled
- torch.backends.cudnn.benchmark = args.cudnn_benchmark
-
- from data_function import MedData_test
-
- os.makedirs(output_dir_test, exist_ok=True)
-
- if hp.mode == '2d':
- if hp.model == 'Unet':
- from models.two_d.unet import Unet
- model = Unet(in_channels=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'MiniSeg':
- from models.two_d.miniseg import MiniSeg
- model = MiniSeg(in_input=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'fcn':
- from models.two_d.fcn import FCN32s as fcn
- model = fcn(in_class =hp.in_class,n_class=hp.out_class+1)
-
- elif hp.model == 'SegNet':
- from models.two_d.segnet import SegNet
- model = SegNet(input_nbr=hp.in_class,label_nbr=hp.out_class+1)
-
- elif hp.model == 'DeepLabV3':
- from models.two_d.deeplab import DeepLabV3
- model = DeepLabV3(in_class=hp.in_class,class_num=hp.out_class+1)
-
- elif hp.model == 'ResNet34UnetPlus':
- from models.two_d.unetpp import ResNet34UnetPlus
- model = ResNet34UnetPlus(num_channels=hp.in_class,num_class=hp.out_class+1)
-
- elif hp.model == 'PSPNet':
- from models.two_d.pspnet import PSPNet
- model = PSPNet(in_class=hp.in_class,n_classes=hp.out_class+1)
-
- else:
- print('Can not find the model.')
- return
-
- elif hp.mode == '3d':
- if hp.model == 'UNet3D':
- from models.three_d.unet3d import UNet3D
- model = UNet3D(in_channels=hp.in_class, out_channels=hp.out_class+1, init_features=32)
-
- elif hp.model == 'UNet':
- from models.three_d.residual_unet3d import UNet
- model = UNet(in_channels=hp.in_class, n_classes=hp.out_class+1, base_n_filter=2)
-
- elif hp.model == 'FCN_Net':
- from models.three_d.fcn3d import FCN_Net
- model = FCN_Net(in_channels =hp.in_class,n_class =hp.out_class+1)
-
- elif hp.model == 'HighRes3DNet':
- from models.three_d.highresnet import HighRes3DNet
- model = HighRes3DNet(in_channels=hp.in_class,out_channels=hp.out_class+1)
-
- elif hp.model == 'SkipDenseNet3D':
- from models.three_d.densenet3d import SkipDenseNet3D
- model = SkipDenseNet3D(in_channels=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'DenseVoxelNet':
- from models.three_d.densevoxelnet3d import DenseVoxelNet
- model = DenseVoxelNet(in_channels=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'VNet':
- from models.three_d.vnet3d import VNet
- model = VNet(in_channels=hp.in_class, classes=hp.out_class+1)
-
- elif hp.model == 'UNETR':
- from models.three_d.unetr import UNETR
- model = UNETR(img_shape=(hp.crop_or_pad_size), input_dim=hp.in_class, output_dim=hp.out_class+1)
-
- else:
- print('Can not find the model.')
- return
-
- model = torch.nn.DataParallel(model, device_ids=devicess)
-
- print("load model:", args.ckpt)
- print(os.path.join(args.output_dir, args.latest_checkpoint_file))
- ckpt = torch.load(os.path.join(args.output_dir, args.latest_checkpoint_file), map_location=lambda storage, loc: storage)
- model.load_state_dict(ckpt["model"])
- model.cuda()
-
- test_dataset = MedData_test(source_test_dir, label_test_dir)
- znorm = ZNormalization()
-
- if hp.mode == '3d':
- patch_overlap = hp.patch_overlap
- patch_size = hp.patch_size
- elif hp.mode == '2d':
- patch_overlap = hp.patch_overlap
- patch_size = hp.patch_size
-
- for i,subj in enumerate(test_dataset.subjects):
- subj = znorm(subj)
- grid_sampler = torchio.inference.GridSampler(
- subj,
- patch_size,
- patch_overlap,
- )
-
- patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=args.batch)
- # aggregator = torchio.inference.GridAggregator(grid_sampler)
- aggregator_1 = torchio.inference.GridAggregator(grid_sampler)
- model.eval()
- with torch.no_grad():
- for patches_batch in tqdm(patch_loader):
- input_tensor = patches_batch['source'][torchio.DATA].to(device)
- locations = patches_batch[torchio.LOCATION]
-
- if hp.mode == '2d':
- input_tensor = input_tensor.squeeze(4)
-
- outputs = model(input_tensor)
-
- if hp.mode == '2d':
- outputs = outputs.unsqueeze(4)
-
- labels = outputs.argmax(dim=1)
- # model_output_one_hot = torch.nn.functional.one_hot(labels, num_classes=hp.out_class+1).permute(0,4,1,2,3)
- # logits = torch.sigmoid(outputs)
-
- # labels = logits.clone()
- # labels[labels>0.5] = 1
- # labels[labels<=0.5] = 0
-
- # aggregator.add_batch(model_output_one_hot, locations)
- aggregator_1.add_batch(labels.unsqueeze(1), locations)
- # output_tensor = aggregator.get_output_tensor()
- output_tensor_1 = aggregator_1.get_output_tensor()
-
- affine = subj['source']['affine']
-
- # label_image = torchio.ScalarImage(tensor=output_tensor.numpy(), affine=affine)
- # label_image.save(os.path.join(output_dir_test,f"{i:04d}-result_float"+hp.save_arch))
-
- # f"{str(i):04d}-result_float.mhd"
-
- output_image = torchio.ScalarImage(tensor=output_tensor_1.numpy(), affine=affine)
- output_image.save(os.path.join(output_dir_test,f"{i:04d}-result_int"+hp.save_arch))
-
-
- if __name__ == '__main__':
- if hp.train_or_test == 'train':
- train()
- elif hp.train_or_test == 'test':
- test()
|