|
- import torchio as tio
- from torchio.transforms import (
- RandomFlip,
- RandomAffine,
- RandomElasticDeformation,
- RandomNoise,
- RandomMotion,
- RandomBiasField,
- RescaleIntensity,
- Resample,
- ToCanonical,
- ZNormalization,
- CropOrPad,
- HistogramStandardization,
- OneOf,
- Compose,
- )
- from pathlib import Path
- import torch
- import matplotlib.pyplot as plt
-
-
- source_train_dir = 'D:\\PI-CAI\\test_data\\test_images'
- label_train_dir = 'D:\\PI-CAI\\test_data\\test_labels'
- fold_arch = '*.nii.gz'
-
- max_queue_length = 16
- patches_per_volume = 10
-
- subjects = []
-
- images_dir = Path(source_train_dir)
- image_paths = sorted(images_dir.glob(fold_arch))
- labels_dir = Path(label_train_dir)
- label_paths = sorted(labels_dir.glob(fold_arch))
-
- for (image_path, label_path) in zip(image_paths, label_paths):
- subject = tio.Subject(
- source=tio.ScalarImage(image_path),
- label=tio.LabelMap(label_path),
- )
- subjects.append(subject)
-
-
- training_transform = Compose([
- ToCanonical(), # 对原图重排列使之转变为RAS方向轴排布(左->右,后->前,下->上)
- Resample('source'), # 重采样,改变图像像素的物理尺度
- CropOrPad((512,512,1), padding_mode='reflect'), # 给定目标维度,如果原图大则裁剪,若小则填充
- # # RandomMotion(), # 添加动态模糊,mri图像的噪声来源一部分是由于被测试者在采集时的动作导致的,通过当前增强来模拟该场景
- # RandomBiasField(), # 添加随机偏置场伪影,通常由于mri成像设备的磁场不均匀导致的低频强度变化(可以理解为一侧亮度偏低)
- # ZNormalization(), # 基于单张图像的归一化操作,即计算标准差和均值,将像素值强度的分布转化成高斯分布
- # RandomNoise(), # 添加随机高斯噪声
- # RandomFlip(axes=(0,)), # 给定中心轴(可以多个)翻转图像,可以使用数字(0,1,2)指定反转轴,也可以用字母(Left,Right,Height,Width)。允许设置翻转的概率,需要对每个反转轴单独指定。
- # OneOf({
- # RandomAffine(): 0.8, # 随机仿射变换,包括尺度(scale,需要指定缩放的比例,可以设置缩放时的差值策略),旋转(degrees,需要指定每个轴旋转的角度范围,可以设置旋转时pad的数值),平移(translation),还支持各向同性和设置以中心为基准进行变换。
- # RandomElasticDeformation(): 0.2, # 随机弹性形变,通过三次B样条插值实现位移,需要指定控制点的数量和最大位移距离,允许控制边缘是否形变
- # }) # OneOf:从给定增强变换序列中随机选择一项执行,允许设置每一项执行的概率
- ])
-
- dataset = tio.SubjectsDataset(subjects, transform=training_transform)
- patch_size = (128, 128, 1) # 2D slices
-
- def plot_batch(sampler):
- queue = tio.Queue(dataset, max_queue_length, patches_per_volume, sampler)
- loader = torch.utils.data.DataLoader(queue, batch_size=16)
- batch = tio.utils.get_first_item(loader)
-
- fig, axes = plt.subplots(4, 4, figsize=(12, 10))
- for ax, im in zip(axes.flatten(), batch['source']['data']):
- print(im.shape)
- ax.imshow(im.squeeze(), cmap='gray')
- plt.suptitle(sampler.__class__.__name__)
- plt.tight_layout()
-
- probabilities = {0: 0.5, 1: 0.5}
- sampler = tio.data.LabelSampler(
- patch_size=patch_size,
- label_name='label',
- label_probabilities=probabilities,
- )
- plot_batch(sampler)
- print('over')
|