|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import argparse
- import cv2
- import mxnet as mx
-
-
- parser = argparse.ArgumentParser(description='do dataset merge')
- parser.add_argument('--include', default='../ms1m/', type=str, help='this dir include train.idx ,train.rec')
- parser.add_argument('--output', default='../ms1m_img', type=str, help='MS1M images transfer path')
- args = parser.parse_args()
-
- def main():
- include_datasets = args.include.split(',')
- rec_list = []
- for ds in include_datasets:
- path_imgrec = os.path.join(ds, 'train.rec')
- path_imgidx = os.path.join(ds, 'train.idx')
- imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r') # pylint: disable=redefined-variable-type
- rec_list.append(imgrec)
- if not os.path.exists(args.output):
- os.makedirs(args.output)
- for ds_id in range(len(rec_list)):
- imgrec = rec_list[ds_id]
- s = imgrec.read_idx(0)
- header, _ = mx.recordio.unpack(s)
- assert header.flag > 0
- print('header0 label', header.label)
- seq_identity = range(int(header.label[0]), int(header.label[1]))
- pp = 0
- for identity in seq_identity:
- id_dir = os.path.join(args.output, "%d_%d" % (ds_id, identity))
- os.makedirs(id_dir)
- pp += 1
- if pp % 10 == 0:
- print('processing id', pp)
- s = imgrec.read_idx(identity)
- header, _ = mx.recordio.unpack(s)
- imgid = 0
- for _idx in range(int(header.label[0]), int(header.label[1])):
- s = imgrec.read_idx(_idx)
- _, _img = mx.recordio.unpack(s)
- _img = mx.image.imdecode(_img).asnumpy()[:, :, ::-1] # to bgr
- image_path = os.path.join(id_dir, "%d.jpg" % imgid)
- cv2.imwrite(image_path, _img)
- imgid += 1
-
- if __name__ == '__main__':
- main()
|