|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- Create dataset for training and evaluting
- """
-
- import os
- import mindspore.dataset as ds
- import mindspore.dataset.transforms.c_transforms as C
- import mindspore.common.dtype as mstype
- import numpy as np
- import hashlib
-
-
- def get_input_data(input_ids, eod_id):
- """
- Generate position_id and attention_mask according to input_ids considering eod reset
-
- Inputs:
- input_ids: the input token ids
- eod_token: the id for <EOD>
-
- returns:
- input_ids: the input token ids
- position_id: the position ids cosidering eod reset
- attention_mask: the attention mask considering eod reset
- """
-
- seq_length = input_ids.shape[0] - 1
- attention_mask = np.tril(np.ones(shape=(seq_length, seq_length)))
- position_id = np.arange(seq_length)
-
- eod_index = position_id[input_ids[:-1] == eod_id]
- prev_index = 0
- for i in range(eod_index.size):
- index = eod_index[i]
- attention_mask[(index+1):, :(index+1)] = 0
- position_id[(index+1):] -= (index + 1 - prev_index)
- prev_index = index + 1
- return input_ids, position_id, attention_mask
-
- def get_input_data2(input_ids, eod_id, rank, dis):
- """
- Generate position_id and attention_mask according to input_ids considering eod reset
-
- Inputs:
- input_ids: the input token ids
- eod_id: the id for <EOD>
-
- returns:
- input_ids: the input token ids
- position_id: the position ids cosidering eod reset
- attention_mask: the attention mask considering eod reset
- """
- rank = int(rank)
- input_ids = input_ids[rank*dis: (rank+1)*dis]
- seq_length = 1024#input_ids.shape[1] - 1
-
- batch_input_ids = input_ids
- batch_position_ids = np.ones((dis, seq_length))
- batch_attention_mask = np.ones((dis, seq_length, seq_length))
- for bs_i in range(0, len(input_ids)):
- local_ids = input_ids[bs_i]
- batch_attention_mask[bs_i] = np.tril(np.ones(shape=(seq_length, seq_length)))
- batch_position_ids[bs_i] = np.arange(seq_length)
- eod_index = batch_position_ids[bs_i, local_ids[:-1] == eod_id].astype(np.int32)
- prev_index = 0
- for i in range(eod_index.size):
- index = eod_index[i]
- batch_attention_mask[bs_i, (index+1):, :(index+1)] = 0
- batch_position_ids[bs_i, (index+1):] -= (index + 1 - prev_index)
- prev_index = index + 1
- return batch_input_ids, batch_position_ids, batch_attention_mask
-
- def create_dataset(batch_size, data_path, device_num=1, rank=0, drop=True, data_start_index=0, eod_reset=False, eod_id=128297, hash_check=True):# eod_id9
- """
- Create dataset
-
- Inputs:
- batch_size: batch size
- data_path: path of your MindRecord files
- device_num: total device number
- rank: current rank id
- drop: whether drop remainder
- eod_reset: whether enable position reset and attention mask reset
- eod_id: the id for <EOD> eot id
-
- Returns:
- dataset_restore: the dataset for training or evaluating
- """
- ds.config.set_seed(1)
- home_path = os.path.join(os.getcwd(), data_path)
- files = os.listdir(data_path)
- #if len(files) != 3600:
- # raise ValueError("read local dataset error!")
-
- data = [
- os.path.join(home_path, name) for name in files
- if not name.endswith(".db")
- ]
- # if len(data) != 1 and hash_check:
- # data.sort(key=lambda x: (int(x[x.find("mindrecord")+10:]), int(x.split('/')[-1][:1])))
- # tmp_str = " ".join(data)
- # res = hashlib.md5(tmp_str.encode())
- # print("hash value:", res.hexdigest())
- # expect = "a89dcf9dca4fe1d420588866e9c2ca54"
- # cur_md5 = res.hexdigest()
- # assert cur_md5 == expect, "Expect hash is {} but found :{}".format(expect, cur_md5)
-
- dataset = ds.MindDataset(data[data_start_index:], columns_list=["input_ids"], shuffle=True)
- type_cast_op = C.TypeCast(mstype.int32)
- type_cast_op_float = C.TypeCast(mstype.float32)
- if eod_reset:
- map_func = (lambda input_ids: get_input_data(input_ids, eod_id))
- dataset = dataset.map(operations=map_func, input_columns=["input_ids"], output_columns=["input_ids", "position_id", "attention_mask"], column_order=["input_ids", "position_id", "attention_mask"])
- dataset = dataset.map(input_columns="position_id", operations=type_cast_op)
- dataset = dataset.map(input_columns="attention_mask", operations=type_cast_op_float)
- dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
- dataset = dataset.batch(batch_size, drop_remainder=drop)
- dataset = dataset.repeat(1)
- return dataset
-
- def create_dataset2(batch_size, data_path, device_num=1, rank=0, drop=True, data_start_index=0, eod_reset=False, eod_id=128297, hash_check=True):#eod_id3
- """
- Create dataset 分片加载
-
- Inputs:
- batch_size: batch size
- data_path: path of your MindRecord files
- device_num: total device number
- rank: current rank id
- drop: whether drop remainder
- eod_reset: whether enable position reset and attention mask reset
- eod_id: the id for <EOD>
-
- Returns:
- dataset_restore: the dataset for training or evaluating
- """
- ds.config.set_seed(1)
- home_path = os.path.join(os.getcwd(), data_path)
- files = os.listdir(data_path)
- #if len(files) != 3600:
- # raise ValueError("read local dataset error!")
- dis = int(batch_size / device_num)
- print("hhhhh dis:", dis)
- assert dis >=1
-
- data = [os.path.join(home_path, name) for name in files if not name.endswith(".db")]
-
- print("Data path:", data, flush=True)
- if len(data) != 1 and hash_check:
- # data.sort(key=lambda x: (int(x[x.find("mindrecord")+10:]), int(x.split('/')[-1][:1])))
- data.sort(key=lambda x: int(x[x.find("mindrecord")+10:]))
- tmp_str = " ".join(data)
- print("Sorted data:", data, flush=True)
- # res = hashlib.md5(tmp_str.encode())
- # print("hash value:", res.hexdigest())
- # #expect = "a89dcf9dca4fe1d420588866e9c2ca54"
- # expect = '775caaa81cac069ebfcc624085f9d024'
- # cur_md5 = res.hexdigest()
- # assert cur_md5 == expect, "Expect hash is {} but found :{}".format(expect, cur_md5)
-
- if data_start_index >= len(data):
- raise ValueError(f"data start index {data_start_index} is larger than dataset length {len(data)}")
- dataset = ds.MindDataset(data[data_start_index:], columns_list=["input_ids"], shuffle=False)
- type_cast_op = C.TypeCast(mstype.int32)
- type_cast_op_float = C.TypeCast(mstype.float16)
- if eod_reset:
- map_func = (lambda input_ids: get_input_data2(input_ids, eod_id, rank, dis))
- dataset = dataset.batch(batch_size, drop_remainder=drop)
- dataset = dataset.map(operations=map_func, input_columns=["input_ids"], output_columns=["input_ids", "position_id", "attention_mask"], column_order=["input_ids", "position_id", "attention_mask"])
- dataset = dataset.map(input_columns="position_id", operations=type_cast_op)
- dataset = dataset.map(input_columns="attention_mask", operations=type_cast_op_float)
- else:
- raise ValueError("Not supported here")
- dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
- #dataset = dataset.batch(batch_size, drop_remainder=drop)
- dataset = dataset.repeat(1)
- return dataset
-
-
- if __name__ == '__main__':
- import numpy as np
-
- def compute_date(trans_shape):
- data = np.ones(trans_shape)*6
- return data
-
- def generate():
- trans_shape = [1025]
- for i in range(128):
- data = compute_date(trans_shape)
- yield (data, )
- import mindspore.dataset as ds
- import time
- dataset = ds.GeneratorDataset(generate, ["input_ids"])
- eod_reset = True
- eod_id = 6
- rank = 7
- dis = 1
- batch_size = 32
- drop=True
-
-
- type_cast_op = C.TypeCast(mstype.int32)
- type_cast_op_float = C.TypeCast(mstype.float32)
- if eod_reset:
- map_func = (lambda input_ids: get_input_data(input_ids, eod_id))
- dataset = dataset.map(operations=map_func, input_columns=["input_ids"], output_columns=["input_ids", "position_id", "attention_mask"], column_order=["input_ids", "position_id", "attention_mask"])
- dataset = dataset.map(input_columns="position_id", operations=type_cast_op)
- dataset = dataset.map(input_columns="attention_mask", operations=type_cast_op_float)
- dataset = dataset.map(input_columns="input_ids", operations=type_cast_op)
- dataset = dataset.batch(batch_size, drop_remainder=drop)
- dataset = dataset.repeat(1)
-
-
- for k in dataset.create_dict_iterator():
- start = time.time()
- res = k['input_ids']
- end = time.time()
- speed_time = end-start
|