|
- import matplotlib.pyplot as plt
- import glob
- import os
- import torch
- import sys
- import numpy as np
- from numpy import *
- import SimpleITK as sitk
- import scipy.ndimage as ndimage
-
- from monai.transforms import (
- AsDiscrete,
- EnsureChannelFirstd,
- Compose,
- LoadImaged,
- Orientationd,
- ScaleIntensityRanged,
- ScaleIntensityRangePercentilesd,
- Resized,
- RandFlipd,
- RandScaleIntensityd,
- RandShiftIntensityd,
- Spacingd,
- ResizeWithPadOrCropd,
- CropForegroundd,
- EnsureTyped,
- EnsureType
- )
- from monai.data import CacheDataset, DataLoader, decollate_batch
- from monai.networks.layers import Norm
- from monai.inferers import sliding_window_inference
- sys.path.insert(0, '/tmp/code/e2emria/networks')
- from vitseg import ViTSeg
-
-
- def load_model(args, device, model_path):
- model = ViTSeg(
- in_channels = 1,
- out_channels = args.num_classes,
- img_size = args.image_size,
- feature_size = 16, ####
- hidden_size = 768,
- mlp_dim = 3072,
- num_heads = 12,
- norm_name = "batch",
- res_block = True,
- dropout_rate = 0.0,
- args = args
- )
- model = torch.nn.DataParallel(model).to(device)
- checkpoint = torch.load(model_path)
- model.load_state_dict(checkpoint['model_state_dict'], strict=False)
- model.eval()
- return model
-
-
- class MMWHS2CHD:
- """
- mm-whs raw: 205., [420., 421.], 500., 550., 600., 820., 850., TO
- correspond: 1, 2, 3, 4, 5, 6, 7, 0
- chd-seg: 5, 3, 1, 4, 2, 6, 7, 0
- """
-
- def operation(self, data):
- origin_labels = {
- 205 : 5,
- 420 : 3, 421 : 3,
- 500 : 1,
- 550 : 4,
- 600 : 2,
- 820 : 6,
- 850 : 7,
- }
- new = np.zeros(data.shape)
- for k,v in origin_labels.items():
- new = np.where(data==k, v, new)
- return new
-
- def __call__(self, data):
- label = data['label']
- label = self.operation(label)
- data['label'] = label
- return data
-
-
- test_transforms = Compose(
- [
- LoadImaged(keys=["image", "label"]),
- EnsureChannelFirstd(keys=["image", "label"]),
- Spacingd(keys=["image", "label"], pixdim=(1, 1, 1), mode=("bilinear", "nearest")),
- Orientationd(keys=["image", "label"], axcodes="RAS"),
- ScaleIntensityRangePercentilesd(
- keys=["image"], lower=0, upper=98,
- b_min=0.0, b_max=1.0, clip=True, relative=False
- ),
- MMWHS2CHD(),
- EnsureTyped(keys=["image", "label"]),
- ])
-
-
- def inference(args, test_dir, model, output_dir):
- test_images = sorted(glob.glob(f'{test_dir}/images/*'))
- test_labels = sorted(glob.glob(f'{test_dir}/labels/*'))
- files = [{"image": image_name, "label": label_name}
- for image_name, label_name in zip(test_images, test_labels)]
- print('Number of cases:', len(files))
-
- # test_files
- test_ds = CacheDataset(data=files, transform=test_transforms, cache_rate=0.0, num_workers=8,)
- test_loader = DataLoader(test_ds, batch_size=1, num_workers=8,)
-
- sw_batch_size = 4
-
- with torch.no_grad():
- n = 0
- dice_list_case = []
- for test_data in test_loader:
- test_inputs, test_labels = test_data['image'].to(device), test_data['label'].to(device)
- test_outputs = sliding_window_inference(test_inputs, args.image_size, sw_batch_size, model, overlap=0.25)
-
- val_outputs = torch.softmax(test_outputs, 1).cpu().numpy()
- val_outputs = np.argmax(val_outputs, axis=1).astype(np.uint8)[0]
- val_labels = test_labels.cpu().numpy()[0, 0, :, :, :]
-
- dice_list_sub = []
- for i in range(1, args.num_classes):
- sub_dice = dice(val_outputs == i, val_labels == i)
- dice_list_sub.append(sub_dice)
- print("Dice of class {}: {}".format(i, sub_dice))
- mean_dice = np.mean(dice_list_sub)
- print("Mean Dice of this case: {}".format(mean_dice))
- dice_list_case.append(mean_dice)
-
- save_result(n, files, test_inputs, val_outputs, output_dir)
- n += 1
- break
- print("Overall Mean Dice: {}".format(np.mean(dice_list_case)))
- return
-
- def dice(x, y):
- intersect = np.sum(np.sum(np.sum(x * y)))
- y_sum = np.sum(np.sum(np.sum(y)))
- if y_sum == 0:
- return 0.0
- x_sum = np.sum(np.sum(np.sum(x)))
- return 2 * intersect / (x_sum + y_sum)
-
- def save_result(n, files, test_inputs, val_outputs, output_dir):
- raw_name = files[n]['image'].split('/')[-1]
- save_name = raw_name.replace('image', 'predlabel')
-
- print(val_outputs.shape)
- img_arr = test_inputs.cpu().numpy()[0, 0, :, :, :]
- print(img_arr.shape)
- out = sitk.GetImageFromArray(val_outputs.astype(np.uint8))
- img = sitk.GetImageFromArray(img_arr)
- img.SetDirection(out.GetDirection())
- img.SetOrigin(out.GetOrigin())
- img.SetSpacing(out.GetSpacing())
- save_pred = f'{output_dir}/{save_name}'
- save_img = f'{output_dir}/{raw_name}'
- sitk.WriteImage(out, save_pred)
- sitk.WriteImage(img, save_img)
- return
-
-
- if __name__ == "__main__":
- import argparse
- parser = argparse.ArgumentParser()
- parser.add_argument("--num_classes", default=8, type=int, help="number of segmentation classes, including background")
- parser.add_argument("--model_path", default='/model/mri_3dseg_model.pth', type=str, help="Optional input file, read from stdin if not given", nargs="?")
- parser.add_argument("--in_file", default='/dataset', help="Optional input file, read from stdin if not given", nargs="?")
- parser.add_argument("--out_file", default='/result', help="Optional output file, write to stdout if not given", nargs="?")
- parser.add_argument("--arch", default="vit_base", type=str)
- parser.add_argument("--finetune", action="store_true", help="finetune a pretrained model, else train from scratch")
- args = parser.parse_args()
-
- args.image_size = (128, 128, 128)
-
- # load the network, assigning it to the selected device
- device = "cuda:0" if torch.cuda.is_available() else "cpu"
-
- model_path = '/pretrainmodel/mri_3dseg_model.pth'
- model = load_model(args, device, model_path)
-
- data_dir = '/tmp/dataset/mr_train'
- out_dir = '/tmp/output'
-
- # inference
- inference(args, data_dir, model, out_dir)
|