|
- import cv2
- import numpy as np
- from matplotlib import pyplot as plt
- import transforms as T
- from PIL import Image
-
-
- class SegmentationPresetTrain:
- def __init__(self, base_size, crop_size, hflip_prob=0.5, vflip_prob=0.5,
- mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
- # T.RandomResize(min_size, max_size)
- base_size = 256
- trans = []
- # if hflip_prob > 0:
- # trans.append(T.RandomHorizontalFlip(hflip_prob))
- # if vflip_prob > 0:
- # trans.append(T.RandomVerticalFlip(vflip_prob))
- # trans.append(T.RandomRotation(0.5))
- trans.append(T.CenterCrop(0.5, base_size))
- # trans.append(T.GaussianBlur(0.5))
- # trans.append(T.ColorJitter(0.5))
- # trans.append(T.RandomAdjustSharpness(0.5))
- # trans.append(T.RandomEqualize(0.5))
- # trans.append(T.AugMix(0.5))
- # trans.extend([
- # # T.RandomCrop(crop_size),
- # T.ToTensor(),
- # T.Normalize(mean=mean, std=std),
- # ])
- self.transforms = T.Compose(trans)
-
- def __call__(self, image1, image2, target):
- return self.transforms(image1, image2, target)
-
-
- def get_transform(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)):
- base_size = 256
- crop_size = 256
- return SegmentationPresetTrain(base_size, crop_size, mean=mean, std=std)
-
-
- def visualize(image, image2, mask, original_image=None, original_image2=None, original_mask=None):
- fontsize = 12
-
- if original_image is None and original_mask is None:
- f, ax = plt.subplots(3, 1, figsize=(8, 8))
-
- ax[0].imshow(image)
- ax[1].imshow(image2)
- ax[2].imshow(mask)
- else:
- f, ax = plt.subplots(2, 3, figsize=(10, 10))
- # image1
- ax[0, 0].imshow(original_image)
- ax[0, 0].set_title('Original image1', fontsize=fontsize)
-
- ax[1, 0].imshow(image)
- ax[1, 0].set_title('Transformed image1', fontsize=fontsize)
-
- # image2
- ax[0, 1].imshow(original_image2)
- ax[0, 1].set_title('Original image2', fontsize=fontsize)
-
- ax[1, 1].imshow(image2)
- ax[1, 1].set_title('Transformed image2', fontsize=fontsize)
-
- # mask
- ax[0, 2].imshow(original_mask)
- ax[0, 2].set_title('Original mask', fontsize=fontsize)
-
- ax[1, 2].imshow(mask)
- ax[1, 2].set_title('Transformed mask', fontsize=fontsize)
-
-
- #
- # image = cv2.imread(r'D:\Datasets\Data_CD\LEVIR-CD\LEVIR-CD_1024\train\A\train_2.png')
- # image2 = cv2.imread(r'D:\Datasets\Data_CD\LEVIR-CD\LEVIR-CD_1024\train\B\train_2.png')
- # mask = cv2.imread(r'D:\Datasets\Data_CD\LEVIR-CD\LEVIR-CD_1024\train\label\train_2.png')
- image = Image.open(r'D:\Datasets\Data_CD\LEVIR-CD\LEVIR-CD_1024\train\A\train_2.png')
- image2 = Image.open(r'D:\Datasets\Data_CD\LEVIR-CD\LEVIR-CD_1024\train\B\train_2.png')
- mask = np.array(Image.open(r'D:\Datasets\Data_CD\LEVIR-CD\LEVIR-CD_1024\train\label\train_2.png')) / 255
- mask = Image.fromarray(mask)
-
- transforms = get_transform()
- trans_image, trans_image2, trans_mask = transforms(image, image2, mask)
-
- trans_image = np.array(trans_image)
- trans_image2 = np.array(trans_image2)
- trans_mask = np.array(trans_mask)
-
- image = np.array(image)
- image2 = np.array(image2)
- mask = np.array(mask)
-
- visualize(trans_image, trans_image2, trans_mask, original_image=image, original_image2=image2, original_mask=mask)
- plt.show()
|