|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- GPT train script
- """
-
- import os
- import argparse
- import json
- import math
- import time
- import logging
- #import numpy as np
- import moxing as mox
- from pathlib2 import Path
- from mindspore import context
- from mindspore.train.model import Model
- import mindspore.communication.management as D
- from mindspore.context import ParallelMode
- import mindspore.nn as nn
- from mindspore.train.callback import TimeMonitor, ModelCheckpoint, CheckpointConfig, Callback
- from mindspore.train.serialization import load_checkpoint, load_param_into_net, load_distributed_checkpoint
- from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
- import mindspore.common.dtype as mstype
- from mindspore.parallel._cost_model_context import _set_multi_subgraphs
- from mindspore.parallel import set_algo_parameters
- import mindspore
- from mindspore.common import set_seed
- import mindspore.dataset as de
- from dataset_restore_data0 import create_dataset2 as create_dataset
- from gpt_dropout_recompute_eos_tiny_c79 import GPT, GPTWithLoss, CrossEntropyLoss, GPTWithLoss_gd, EvalNet_p
- # from gpt_ms1_3_tiny import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss, GPTWithLoss_gd
- from gpt_wrapcell_gradient_scale_eos_tiny import GPTTrainOneStepWithLossScaleCell, VirtualDatasetOneInputCell, AdamWeightDecay
- from utils_fix import GPTConfig, LearningRate,StrategySaveCallback
- from obs import ObsUploader, ObsRestorer, SOMARestorer
- import os
- from mindspore.parallel._auto_parallel_context import auto_parallel_context
- import time
- from mindspore import Tensor
- import numpy as np
- from utils_fix import LossSummaryCallback, Ckpt2ObsSummaryCallback
- os.environ['HCCL_CONNECT_TIMEOUT'] = '1800'
-
- from download_dataset import DatasetDownloader#, BUCKET_DIR, LOCAL_PATH
-
- BUCKET_DIR = 'obs://datasets/V1-sample300-bpe-1024/'
- LOCAL_PATH = "/cache/V1-sample300-bpe-1024/"
-
-
- def ckpt_copy_tar(obs_path, target_path="/cache/ckpt"):
- """
- requires the obs_path to be a complete name
- Copy tar file from the obs to the /cache/
- """
- sub_name_list = ['_0.tar', '_1.tar', '_2.tar', '_3.tar']
- for item in sub_name_list:
- sub_name = obs_path + item
- tmp_name = 'model.tar'
- mox.file.copy(sub_name, os.path.join(target_path, tmp_name))
- os.system('cd {}; tar -xvf {}'.format(target_path, tmp_name))
-
- def get_ckpt_file_list(ckpt_path):
- returned_list = []
- for i in range(0, 16):#512):
- returned_list.append('filerted_{}.ckpt'.format(i))
- returned_list = [os.path.join(ckpt_path, item) for item in returned_list if 'embedding' not in item]
- print("Sorted list", returned_list)
- for item in returned_list:
- fsize = os.path.getsize(item)
- f_gb = fsize / float(1024) / 1024 / 1024
- print(item, " :{:.2f}".format(f_gb))
- return returned_list
-
- class SOMACallback(Callback):
- def __init__(self, soma_dir, obs_dir="s3://mindspore-file/soma/gpt_1024", upload=True, retry=3,
- retry_time=30, interval_num=512, interval_time=90):
- super(SOMACallback, self).__init__()
- self.soma_dir = soma_dir
- self.obs_dir = obs_dir
- self.upload = upload
- self.retry = retry
- self.retry_time = retry_time
- self.interval_num = interval_num
- self.interval_time = interval_time
-
- def step_end(self, run_context):
- if not self.upload:
- return
- cb_params = run_context.original_args()
- if (cb_params.cur_step_num) % 100 == 2:
- rank_id = os.getenv("RANK_ID")
- sleep_time = int(rank_id) // self.interval_num * self.interval_time
- if sleep_time > 0:
- logging.info(f"rank_{rank_id} waits {sleep_time}s before uploading soma.")
- time.sleep(sleep_time)
- except_log_dir = os.path.join(self.obs_dir, "upload_log")
- if not mox.file.exists(except_log_dir):
- mox.file.mk_dir(except_log_dir)
- except_file_path = os.path.join(except_log_dir,
- f"except_upload_rank_{rank_id}_soma_{cb_params.cur_epoch_num+1}.log")
- except_info = ""
- soma_obs_dir = os.path.join(self.obs_dir, f"rank_{rank_id}")
- if not mox.file.exists(soma_obs_dir):
- mox.file.mk_dir(soma_obs_dir)
- success = False
- for i in range(self.retry + 1):
- try:
- start = time.time()
- mox.file.copy_parallel(self.soma_dir, soma_obs_dir)
- end = time.time()
- success = True
- logging.info(
- f"rank_{rank_id}: uploading {self.soma_dir} to {soma_obs_dir} cost {end - start}s.")
- break
- except Exception as e:
- if i < self.retry:
- logging.info(e.__str__() + f" rank_{rank_id}: uploading {self.soma_dir} to {soma_obs_dir}"
- f" failed: retry {i + 1}/{self.retry}.")
- time.sleep(self.retry_time)
- else:
- except_info = e.__str__()
-
- if not success:
- mox.file.append(except_file_path, f"{except_info}. rank_{rank_id}: uploading {self.soma_dir} to "
- f"{soma_obs_dir} failed.\n")
- else:
- self.upload = False
-
- class LossCallBack(Callback):
- """
- Monitor the loss in training.
- If the loss in NAN or INF terminating training.
- Note:
- if per_print_times is 0 do not print loss.
- Args:
- per_print_times (int): Print loss every times. Default: 1.
- """
- def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0):
- super(LossCallBack, self).__init__()
- self._dataset_size = dataset_size
- self.local_rank = local_rank
- self.has_trained_epoch = has_trained_epoch
- self.has_trained_step = has_trained_step
- print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True)
- def step_end(self, run_context):
- """
- Print loss after each step
- """
- cb_params = run_context.original_args()
- # de.config.set_sending_batches(cb_params.cur_step_num + 4) # send 4
- if self._dataset_size > 0 and self.local_rank % 8 == 0:
- percent, epoch_num = math.modf(cb_params.cur_step_num /
- self._dataset_size)
- if percent == 0:
- percent = 1
- epoch_num -= 1
- date = time.asctime(time.localtime(time.time()))
- print(
- "time: {} local_rank: {}, epoch: {}, step: {}, output is {}, overflow is {}, scale is {}, norm is {}, lr is {}"
- .format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch),
- cb_params.cur_step_num + int(self.has_trained_step),
- cb_params.net_outputs[0].asnumpy(),
- cb_params.net_outputs[1].asnumpy(),
- cb_params.net_outputs[2].asnumpy(),
- cb_params.net_outputs[3].asnumpy(),
- cb_params.net_outputs[4].asnumpy()))
-
-
-
- project_root = os.path.abspath(
- os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
- print('project_root:', project_root)
-
-
- def run_train():
- """train function for GPT"""
- parser = argparse.ArgumentParser(description="GPT training")
- parser.add_argument('--device_id',
- type=int,
- default=0,
- help="Device id, default is 0.")
- parser.add_argument("--device_num",
- type=int,
- default=128,
- help="Use device nums, default is 1.")
- parser.add_argument("--distribute",
- type=str,
- default="true",
- choices=["true", "false"],
- help="Run distribute, default is false.")
- parser.add_argument("--optimizer",
- type=str,
- default="adam",
- choices=["adam", "lamb"],
- help="select which optimizer to be used, default adam")
- parser.add_argument("--epoch_size",
- type=int,
- default=4,
- help="Epoch size, default is 10.")
- parser.add_argument("--warmup_step",
- type=int,
- default=14,
- help="Warmup step, default is 10000.")
- parser.add_argument("--start_lr",
- type=float,
- default="3e-6",# Fix me 2.5e-5
- help="Start learning rate, default is 5e-5.")
- parser.add_argument("--end_lr",
- type=float,
- default="1e-6",
- help="End learning rate, default is 1e-10.")
- parser.add_argument("--sink_size",
- type=int,
- default=2,
- help="Sink size for every iteration, default is 100")
- parser.add_argument('--data_url',
- required=True,
- default=None,
- help='Location of data.')
- parser.add_argument('--train_url',
- required=True,
- default=None,
- help='Location of training outputs.')
- parser.add_argument('--whl_pkg',
- type=str,
- default='',
- help='Location of mindspore whl.')
- parser.add_argument('--bucket_dir',
- type=str,
- default='s3://pcl-verify/yizx/distilPangu/ckpt/pangu26b_teach_368M_pretrain_first300G_oldCode_bs16',
- help='Obs ckpt dir')
- parser.add_argument('--soma_bucket_dir',
- type=str,
- default='s3://pcl-verify/yizx/distilPangu/soma/pangu26b_teach_368M_pretrain_first300G_oldCode_bs16',
- help='Obs soma dir')
- parser.add_argument("--sample_count",
- type=int,
- default=130000,
- help="sample_count, default is 1000000.")
- parser.add_argument("--eod_id",
- type=int,
- default=9,
- help="eod_id.")
- parser.add_argument("--eod_reset",
- type=int,
- default=1,
- help="eod_reset 0/1.")
- parser.add_argument("--full_batch",
- type=int,
- default=1,
- help="full_batch 0/1.")
- parser.add_argument("--save_step",
- type=int,
- default=1000,
- help="a large step")
- parser.add_argument("--reset_data_index",
- type=bool,
- default=0,
- help="start dataset training from zeros")
-
- # TEMP = ['exp61_GPT3_1-3000_2', 'exp61_GPT3_4-5500_2', 'exp61_GPT3_4-15500_2', 'exp61_GPT3_4-25500_2', 'Newexp65_GPT3-16000_2', 'Newexp65_GPT3_2-3494_2']
- # INDEX = 5
-
- # parser.add_argument("--ckpt_path", type=str, default='s3://mindspore-file/huangxinjing/filtered_ckpt/{}/{}part'.format(TEMP[INDEX], TEMP[INDEX]),
- # help="path for saved checkpoint ")
- # parser.add_argument("--word_embedding_path", type=str,
- # default='s3://mindspore-file/huangxinjing/filtered_ckpt/{}/{}_word_embedding.npy'.format(TEMP[INDEX], TEMP[INDEX]),
- # help="path for word embedding")
- # parser.add_argument("--position_embedding_path", type=str,
- # default='s3://mindspore-file/huangxinjing/filtered_ckpt/{}/{}_position_embedding.npy'.format(TEMP[INDEX], TEMP[INDEX]),
- # help="path for position embedding")
- # parser.add_argument("--top_query_embedding_path", type=str,
- # default='s3://mindspore-file/huangxinjing/filtered_ckpt/{}/{}_top_query_embedding.npy'.format(TEMP[INDEX], TEMP[INDEX]),
- # help="path for top_query embedding")
-
- TEMP = ['2_6b_finetuneOnPD_pangu-yizx-5400_2']
- INDEX = 0
-
- parser.add_argument("--ckpt_path", type=str, default='s3://pcl-verify/yizx/distilPangu/merged_ckpt/{}/{}part'.format(TEMP[INDEX], TEMP[INDEX]),
- help="path for saved checkpoint ")
- parser.add_argument("--word_embedding_path", type=str,
- default='s3://pcl-verify/yizx/distilPangu/merged_ckpt/{}/{}_word_embedding.npy'.format(TEMP[INDEX], TEMP[INDEX]),
- help="path for word embedding")
- parser.add_argument("--position_embedding_path", type=str,
- default='s3://pcl-verify/yizx/distilPangu/merged_ckpt/{}/{}_position_embedding.npy'.format(TEMP[INDEX], TEMP[INDEX]),
- help="path for position embedding")
- parser.add_argument("--top_query_embedding_path", type=str,
- default='s3://pcl-verify/yizx/distilPangu/merged_ckpt/{}/{}_top_query_embedding.npy'.format(TEMP[INDEX], TEMP[INDEX]),
- help="path for top_query embedding")
-
- args_opt = parser.parse_args()
- EXEC_PATH = os.path.join(project_root, 'tiny_pangu')
- device_id = int(os.getenv("DEVICE_ID"))
- rank_id_str = os.getenv('RANK_ID', '0')
- rank_id = int(
- rank_id_str[rank_id_str.rfind('-') +
- 1:]) # 'RANK_ID': 'job24535502-job-facereidtome-hn-0/1'
- print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
- device_id = int(os.getenv('DEVICE_ID'))
- local_rank = rank_id
- print('local_rank:{}, device id:{}'.format(local_rank, device_id))
- reset_data_index = bool(args_opt.reset_data_index)
-
- pretrained_strategy_ckpt_path = "/cache/strategy/ckpt_strategy_{}.ckpt".format(local_rank)
- mox.file.copy(src_url="obs://pcl-verify/yizx/distilPangu/strategy_ckpt/pangu26b_finetuneOnPD_PET_saveCKPT_3e6LR_bs30_bt98cktp_strategy.ckpt", dst_url=pretrained_strategy_ckpt_path)
- # mox.file.copy(src_url="obs://mindspore-file/strategy_ckpt/gpt_1024_13b_exp65cktp_strategy.ckpt", dst_url=pretrained_strategy_ckpt_path)
-
- if local_rank % 8 == 0:
- print('MindSpore path:', mindspore)
- print("Modify the time out from 300 to 30000")
- tbe_path = "/usr/local/ma/python3.7/lib/python3.7/site-packages/mindspore" \
- "/_extends/parallel_compile/tbe_compiler/tbe_process.py"
- os.system(
- "sed -i 's/300/30000/g' " + tbe_path
- )
- os.system(
- "sed -i 's/330/33000/g' " + tbe_path
- )
- print("begin download dataset", flush=True)
-
- cache_url = LOCAL_PATH
- if not os.path.exists(LOCAL_PATH):
- Path(LOCAL_PATH).mkdir(parents=True, exist_ok=True)
-
-
- files = os.listdir(LOCAL_PATH)
- data = [
- os.path.join(LOCAL_PATH, name) for name in files
- if not name.endswith(".db")
- ]
- if len(data) == 0:
- print("Start to copy the dataset", flush=True)
- Path(cache_url).mkdir(parents=True, exist_ok=True)
- mox.file.copy_parallel(src_url=BUCKET_DIR, dst_url=LOCAL_PATH)
- print("@@@@@@ Dataset download succeed! @@@@@@@", flush=True)
-
- os.environ['HCCL_CONNECT_TIMEOUT'] = "6000"
- os.system('ulimit -s 102400')
- print(args_opt.ckpt_path)
- ckpt_copy_tar(args_opt.ckpt_path, target_path="/cache/ckpt_files")
- mox.file.copy('obs://pcl-verify/yizx/distilPangu/merged_ckpt/Newexp65_GPT3_2-3494_2.ckpt', '/cache/Newexp65_GPT3_2-3494_2.ckpt')
- mox.file.copy(args_opt.word_embedding_path, '/cache/word_embedding.npy')
- mox.file.copy(args_opt.position_embedding_path, '/cache/position_embedding.npy')
- mox.file.copy(args_opt.top_query_embedding_path, '/cache/top_query_embedding.npy')
-
- print("setting env success.")
- #extra_proxy = '-i http://100.125.33.126:8888/repository/pypi/simple --trusted-host=100.125.33.126'
- #if args_opt.whl_pkg != '':
- # os.system('yes | pip uninstall mindspore-ascend')
- # me_whl = args_opt.whl_pkg
- # me_whl_cache = me_whl.replace('s3://', '/cache/')
- # mox.file.copy_parallel(src_url=me_whl, dst_url=me_whl_cache)
- # os.system('pip install {} {}'.format(me_whl_cache, extra_proxy))
- # print("install mindspore success.")
- # 刷包或下载数据集结束后,写一个文件来表示下载成功
- f = open("%s/install.txt" % (EXEC_PATH), 'w')
- f.close()
- # 此处用于阻塞其他进程,直到刷包以及下载数据集完成为止
-
- while not os.path.exists("%s/install.txt" % (EXEC_PATH)):
- time.sleep(1)
-
- print('local_rank:{}, device id:{} start to run...'.format(
- local_rank, device_id),
- flush=True)
- save_graphs_path = "/cache/" + str(local_rank)
- context.set_context(save_graphs=True,
- save_graphs_path=save_graphs_path,
- mode=context.GRAPH_MODE,
- device_target="Ascend",
- device_id=device_id)
- context.set_context(variable_memory_max_size="31GB")
- full_batch = bool(args_opt.full_batch)
- strategy_file = '/tmp/cktp_strategy.ckpt'
- if args_opt.distribute == "true":
- D.init()
- device_num = D.get_group_size()
- rank = D.get_rank()
- print("device_id is {}, rank_id is {}, device_num is {}".format(
- device_id, rank, device_num))
-
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
- gradients_mean=False,
- device_num=device_num,
- full_batch=full_batch,
- strategy_ckpt_load_file=pretrained_strategy_ckpt_path,
- strategy_ckpt_save_file=strategy_file,
- enable_parallel_optimizer=False)
- # optimizer_weight_shard_size=64, # This will only be used by the 1024 card,
- auto_parallel_context().set_loss_repeated_mean(True)
- set_algo_parameters(elementwise_op_strategy_follow=True)
- _set_multi_subgraphs()
-
- else:
- rank = 0
- device_num = 1
-
- model_parallel_num = 1
- data_parallel_num = int(device_num / model_parallel_num)
- per_batch_size = 8
- batch_size = per_batch_size * data_parallel_num
- teacher_config = GPTConfig(
- data_parallel_num=data_parallel_num,
- model_parallel_num=model_parallel_num,
- batch_size=batch_size,
- seq_length=1024,
- vocab_size=40000,
- embedding_size=2560, #353M 8B 2560
- num_layers=32,# 32
- num_heads=32,# 32
- expand_ratio=4,
- post_layernorm_residual=False,
- dropout_rate=0.1,
- compute_dtype=mstype.float16,
- use_past=False,
- self_layernorm=True,
- use_recompute=True,
- forward_reduce_scatter=True,
- word_emb_dp=True,
- eod_reset=bool(args_opt.eod_reset))
- student_config = GPTConfig(
- data_parallel_num=data_parallel_num,
- model_parallel_num=model_parallel_num,
- batch_size=batch_size,
- seq_length=1024,
- vocab_size=40000,
- embedding_size=1280, #353M 8B 2560
- num_layers=16,# 32
- num_heads=32,# 32
- expand_ratio=4,
- post_layernorm_residual=False,
- dropout_rate=0.1,
- compute_dtype=mstype.float16,
- use_past=False,
- self_layernorm=True,
- use_recompute=True,
- forward_reduce_scatter=True,
- word_emb_dp=True,
- eod_reset=bool(args_opt.eod_reset))
-
- print("===config is: ", teacher_config, flush=True)
- # Define network
- teacher_net = GPT(teacher_config, is_teacher=True)
- student_net = GPT(student_config, is_teacher=False)
- teacher_net.set_train(False)
-
- # eval_teacher_net = EvalNet_p(teacher_net, generate=True)
- # model_tc = Model(eval_teacher_net)
- # fake_input = Tensor(np.ones(shape=(1, teacher_config.seq_length)), mstype.int32)
- # predict_layout = model_tc.infer_predict_layout(fake_input)
-
- ####################################################################
- # ckpt_path = "/cache/ckpt_files"
- # ckpt_file_list = get_ckpt_file_list(ckpt_path)
- # print("##### Start to load PANGU-2.6B distributed checkpoint! #####", flush=True)
- # load_distributed_checkpoint(teacher_net, ckpt_file_list)
- ####################################################################
-
- # save teacher_net ckpt onefile2obs
- # save_checkpoint(teacher_net, "/cache/yizx-{}.ckpt".format('2_6b_finetuneOnPD_pangu-yizx-5400_2'))
- # mox.file.copy("/cache/yizx-{}.ckpt".format('2_6b_finetuneOnPD_pangu-yizx-5400_2'), "obs://pcl-verify/yizx/distilPangu/merged_ckpt/yizx-fintuneOnPD-PET/yizx-{}.ckpt".format('2_6b_finetuneOnPD_pangu-yizx-5400_2'))
-
- # exit()
-
-
- # print("##### Start to load PANGU-teacher-2.6B distributed checkpoint! #####", flush=True)
- # params_dict = load_checkpoint("/cache/Newexp65_GPT3_2-3494_2.ckpt")
- # load_param_into_net(teacher_net, params_dict)
-
-
- tiny_pangu = GPTWithLoss_gd(teacher_net, teacher_config, student_net, student_config)
- gpt_with_loss = tiny_pangu
-
-
- # for item in gpt_with_loss.trainable_params():
- # print(item)
-
- teacher_params = gpt_with_loss.teacher.trainable_params()
- for param in teacher_params:
- param.requires_grad = False
-
-
-
- #args_opt.warmup_step = int(epoch_num * step_per_epoch * 0.01)
- print("=====args_opt is: ", args_opt, flush=True)
- decay_steps_num = int(80*1024*0.6/batch_size)
- print("=====decay_steps_num: ", decay_steps_num)
- lr = LearningRate(learning_rate=args_opt.start_lr,
- end_learning_rate=args_opt.end_lr,
- warmup_steps=args_opt.warmup_step,
- decay_steps=100, # Fix me 76000
- lr_scale=1)
-
- #decay_steps=epoch_num*step_per_epoch)
-
- decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
- params = gpt_with_loss.trainable_params()
- time.sleep(3)
- # exit()
- decay_params = list(filter(decay_filter, params))
- other_params = list(filter(lambda x: not decay_filter(x), params))
- group_params = [{
- 'params': decay_params,
- 'weight_decay': 1e-2
- }, {
- 'params': other_params,
- 'weight_decay': 0.0
- }, {
- 'order_params': params
- }]
- if args_opt.optimizer == "lamb":
- optimizer = nn.Lamb(group_params, learning_rate=lr)
- else:
- optimizer = AdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.999)
-
- # all cards will save ckpt
- params_dict = None
- save_steps = args_opt.save_step
- ckpt_dir = os.path.join("/cache/ckpt/", f"rank_{str(local_rank)}")
- # create dir for ckpt
- if not os.path.exists(ckpt_dir):
- Path(ckpt_dir).mkdir(parents=True, exist_ok=True)
-
- # # whether restore from obs
- # has_trained_epoch = 0
- # has_trained_step = 0
- # data_start_index = 0
- loss_scale_value = math.pow(2, 32)
- bucket_dir = args_opt.bucket_dir
- if not mox.file.exists(bucket_dir):
- mox.file.make_dirs(bucket_dir)
-
- if not mox.file.exists(bucket_dir + '/' + str(local_rank)):
- mox.file.make_dirs(bucket_dir + '/' + str(local_rank))
-
- data_start_index = 0 # 临时使用
- has_trained_epoch = 0
- has_trained_step = 0
- data_start_index = 0
- ds = create_dataset(student_config.batch_size, data_path=LOCAL_PATH, data_start_index=data_start_index, eod_reset=student_config.eod_reset, eod_id=args_opt.eod_id, device_num=device_num, rank=rank, hash_check=True)
-
-
- epoch_num = args_opt.epoch_size
- step_per_epoch = ds.get_dataset_size()
- callback_size = args_opt.sink_size
- actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
- callback = [
- TimeMonitor(callback_size),
- LossCallBack(callback_size, local_rank, has_trained_epoch, has_trained_step)
- ]
-
- obs_uploader = ObsUploader(bucket_dir, max_ckpt=4000, interval_num=128, interval_time=60)
- config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_step,
- keep_checkpoint_max=2,
- integrated_save=False)
- # post_callback_func=obs_uploader.upload_ckpt)
- print(f"data_start_index is {data_start_index}", flush=True)
- ckpoint_cb = ModelCheckpoint(prefix="tinyPangu-yizx",
- directory=ckpt_dir,
- config=config_ck)
- # has_trained_epoch=has_trained_epoch,
- # has_trained_step=has_trained_step,
- # data_start_index=data_start_index)
- ckpt2obs_cb = Ckpt2ObsSummaryCallback(local_ckpt_dir=ckpt_dir,
- local_rank=0,
- has_trained_epoch=has_trained_epoch,
- has_trained_step=has_trained_step,
- bucket=args_opt.bucket_dir + '/' + str(local_rank),
- syn_obs_steps=args_opt.save_step)
- callback.append(ckpoint_cb)
- callback.append(ckpt2obs_cb)
-
- if local_rank == 0:
- sub_dir = args_opt.bucket_dir.split('/')[-1]
- callback.append(LossSummaryCallback(summary_dir="summary",
- local_rank=0,
- has_trained_epoch=has_trained_epoch,
- has_trained_step=has_trained_step,
- bucket='obs://pcl-verify/yizx/distilPangu/summary/' + sub_dir,
- syn_times=40))
- callback.append(StrategySaveCallback(strategy_path=strategy_file,
- local_rank=0,
- has_trained_epoch=has_trained_epoch,
- has_trained_step=has_trained_step,
- bucket='obs://pcl-verify/yizx/distilPangu/strategy_ckpt/' + sub_dir))
-
-
-
- update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value,
- scale_factor=2,
- scale_window=1000)
-
- gpt_with_grads = GPTTrainOneStepWithLossScaleCell(
- gpt_with_loss, optimizer=optimizer, scale_update_cell=update_cell,enable_global_norm=True, config=student_config)
-
-
- model = Model(gpt_with_grads)
- print("\n\n=====dataset size: ", ds.get_dataset_size(), flush=True)
- print("=====actual_epoch_num: ", actual_epoch_num, flush=True)
- print("$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$\n")
- if params_dict:
- model._init(train_dataset=ds, sink_size=callback_size)
- load_param_into_net(gpt, params_dict)
- load_param_into_net(optimizer, params_dict)
- else:
- pass
- # de.config.set_sending_batches(4)
-
- model.train(actual_epoch_num,
- ds,
- callbacks=callback,
- sink_size=callback_size,
- dataset_sink_mode=True)
-
-
- if __name__ == "__main__":
- #set_seed(12315)
- run_train()
|