|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Base segmentation dataset"""
- import random
- import numpy as np
-
- from PIL import Image, ImageOps, ImageFilter
-
- __all__ = ['SegmentationDataset']
-
-
- class SegmentationDataset():
- """Segmentation Base Dataset"""
-
- def __init__(self, root, split, mode, base_size=520, crop_size=480):
- super(SegmentationDataset, self).__init__()
- self.root = root
- self.split = split
- self.mode = mode if mode is not None else split
- self.base_size = base_size
- self.crop_size = self.to_tuple(crop_size)
-
- def to_tuple(self, size):
- if isinstance(size, (list, tuple)):
- return tuple(size)
- if isinstance(size, (int, float)):
- return tuple((size, size))
- raise ValueError('Unsupport datatype: {}'.format(type(size)))
-
- def _val_sync_transform(self, img, mask):
- '''_val_sync_transform'''
- outsize = self.crop_size
- short_size = min(outsize)
- w, h = img.size
- if w > h:
- oh = short_size
- ow = int(1.0 * w * oh / h)
- else:
- ow = short_size
- oh = int(1.0 * h * ow / w)
- img = img.resize((ow, oh), Image.BILINEAR)
- mask = mask.resize((ow, oh), Image.NEAREST)
- # center crop
- w, h = img.size
- x1 = int(round((w - outsize[1]) / 2.))
- y1 = int(round((h - outsize[0]) / 2.))
- img = img.crop((x1, y1, x1 + outsize[1], y1 + outsize[0]))
- mask = mask.crop((x1, y1, x1 + outsize[1], y1 + outsize[0]))
-
- # final transform
- img, mask = self._img_transform(img), self._mask_transform(mask)
- return img, mask
-
- def _sync_transform(self, img, mask):
- '''_sync_transform'''
- # random mirror
- if random.random() < 0.5:
- img = img.transpose(Image.FLIP_LEFT_RIGHT)
- mask = mask.transpose(Image.FLIP_LEFT_RIGHT)
- crop_size = self.crop_size
- # random scale (short edge)
- short_size = random.randint(int(self.base_size * 0.5), int(self.base_size * 2.0))
- w, h = img.size
- if h > w:
- ow = short_size
- oh = int(1.0 * h * ow / w)
- else:
- oh = short_size
- ow = int(1.0 * w * oh / h)
- img = img.resize((ow, oh), Image.BILINEAR)
- mask = mask.resize((ow, oh), Image.NEAREST)
- # pad crop
- if short_size < min(crop_size):
- padh = crop_size[0] - oh if oh < crop_size[0] else 0
- padw = crop_size[1] - ow if ow < crop_size[1] else 0
- img = ImageOps.expand(img, border=(0, 0, padw, padh), fill=0)
- mask = ImageOps.expand(mask, border=(0, 0, padw, padh), fill=-1)
- # random crop crop_size
- w, h = img.size
- x1 = random.randint(0, w - crop_size[1])
- y1 = random.randint(0, h - crop_size[0])
- img = img.crop((x1, y1, x1 + crop_size[1], y1 + crop_size[0]))
- mask = mask.crop((x1, y1, x1 + crop_size[1], y1 + crop_size[0]))
- # gaussian blur as in PSP
- if random.random() < 0.5:
- img = img.filter(ImageFilter.GaussianBlur(
- radius=random.random()))
-
- # final transform
- img, mask = self._img_transform(img), self._mask_transform(mask)
- return img, mask
-
- def _img_transform(self, img):
- return np.array(img)
-
- def _mask_transform(self, mask):
- return np.array(mask).astype('int32')
-
- @property
- def num_class(self):
- """Number of categories."""
- return self.NUM_CLASS
-
- @property
- def pred_offset(self):
- return 0
|