|
- """get the dataset"""
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- from __future__ import print_function, absolute_import
-
- import glob
- import re
- import os.path as osp
-
- class ReadDtaset():
- """
- market1501,duckmtmcreid,and self make dataset
- """
- def __init__(self, root='./data/', dataset_name='market1501', **kwargs):
- self.dataset_dir = osp.join(root, dataset_name)
- self.train_dir = osp.join(self.dataset_dir, 'bounding_box_train')
- self.query_dir = osp.join(self.dataset_dir, 'query')
- self.gallery_dir = osp.join(self.dataset_dir, 'bounding_box_test')
-
- self._check_before_run()
- train, num_train_pids, num_train_imgs = self._process_dir(self.train_dir, relabel=True)
- query, num_query_pids, num_query_imgs = self._process_dir(self.query_dir, relabel=False)
- gallery, num_gallery_pids, num_gallery_imgs = self._process_dir(self.gallery_dir, relabel=False)
-
- print("=> dataset loaded")
- print("Dataset statistics:")
- print(" ------------------------------")
- print(" subset | # ids | # images")
- print(" ------------------------------")
- print(" train | {:5d} | {:8d}".format(num_train_pids, num_train_imgs))
- print(" query | {:5d} | {:8d}".format(num_query_pids, num_query_imgs))
- print(" gallery | {:5d} | {:8d}".format(num_gallery_pids, num_gallery_imgs))
- print(" ------------------------------")
-
- self.train = train
- self.query = query
- self.gallery = gallery
-
- self.num_train_pids = num_train_pids
- self.num_query_pids = num_query_pids
- self.num_gallery_pids = num_gallery_pids
-
- def _check_before_run(self):
- """Check if all files are available before going deeper"""
- if not osp.exists(self.dataset_dir):
- raise RuntimeError("'{}' is not available".format(self.dataset_dir))
- if not osp.exists(self.train_dir):
- raise RuntimeError("'{}' is not available".format(self.train_dir))
- if not osp.exists(self.query_dir):
- raise RuntimeError("'{}' is not available".format(self.query_dir))
- if not osp.exists(self.gallery_dir):
- raise RuntimeError("'{}' is not available".format(self.gallery_dir))
-
- def _process_dir(self, dir_path, relabel=False):
- """process the path and get the data"""
- img_paths = glob.glob(osp.join(dir_path, '*.jpg'))
- pattern = re.compile(r'([-\d]+)_c(\d)')
-
- pid_container = set()
- for img_path in img_paths:
- pid, _ = map(int, pattern.search(img_path).groups())
- if pid == -1: continue # junk images are just ignored
- pid_container.add(pid)
- pid2label = {pid: label for label, pid in enumerate(pid_container)}
-
- dataset = []
- for img_path in img_paths:
- pid, camid = map(int, pattern.search(img_path).groups())
- if pid == -1: continue # junk images are just ignored
- camid -= 1 # index starts from 0
- if relabel: pid = pid2label[pid]
- dataset.append((img_path, pid, camid))
- num_pids = len(pid_container)
- num_imgs = len(dataset)
- return dataset, num_pids, num_imgs
-
- def init_img_dataset(root, name):
- return ReadDtaset(root=root, dataset_name=name)
|