|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- PanguAlpha train script
- """
- import datetime
- import glob
- import os
- import math
- import time
- import random
- import mindspore
- import moxing as mox
- from pathlib2 import Path
- from mindspore import context
- from mindspore.train.model import Model
- import mindspore.communication.management as D
- from mindspore.context import ParallelMode
- import mindspore.nn as nn
- from mindspore.train.callback import TimeMonitor, Callback
- from mindspore.nn.wrap.loss_scale import DynamicLossScaleUpdateCell
- import mindspore.common.dtype as mstype
- from mindspore.parallel import set_algo_parameters
- from mindspore.parallel._cost_model_context import _set_multi_subgraphs
- from mindspore.nn.wrap.cell_wrapper import _VirtualDatasetCell # PipelineCell,
- from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
-
- from dataset_restore_data0 import create_dataset2 as create_dataset
- from src.pangu_alpha_tiny import PanguAlpha, PanguAlphaWithLoss, CrossEntropyLoss, EvalNet_p, generate_samples_cftpd, GPTWithLoss_gd, GPTWithLoss_gd_withCrEN, PanguAlphaWithLoss_tiny
- from src.pangu_alpha_wrapcell import PanguAlphaTrainOneStepWithLossScaleCell, PanguAlphaTrainOneStepWithLossScaleCell_Print2Loss, VirtualDatasetOneInputCell
- from src.pangu_alpha_config import PANGUALPHAConfig, set_parse
- from src.utils import LearningRate, get_args, FP32StateAdamWeightDecay
-
- from mindspore.train.serialization import load_checkpoint, load_param_into_net, build_searched_strategy, merge_sliced_parameter
- import numpy as np
- from mindspore import Tensor
-
- from utils_fix import LossSummaryCallback_Print2Loss, StrategySaveCallback, Ckpt2ObsSummaryCallback, LossSummaryCallback
- from download_dataset import DatasetDownloader
- BUCKET_DIR = 'obs://datasets/V1-sample300-bpe-1024/'
- LOCAL_PATH = "/cache/V1-sample300-bpe-1024/"
-
- # BUCKET_DIR = 'obs://pcl-verify/yizx/distilPangu/datasets/pd_noBlank_0918/'
- # LOCAL_PATH = "/cache/pd_noBlank_0918/"
-
- def ckpt_copy_tar(obs_path, target_path="/cache/ckpt"):
- """
- requires the obs_path to be a complete name
- Copy tar file from the obs to the /cache/
- """
- sub_name_list = ['_0.tar', '_1.tar', '_2.tar', '_3.tar']
- for item in sub_name_list:
- sub_name = obs_path + item
- tmp_name = 'model.tar'
- mox.file.copy(sub_name, os.path.join(target_path, tmp_name))
- os.system('cd {}; tar -xvf {}'.format(target_path, tmp_name))
-
- def get_ckpt_file_list(ckpt_path):
- returned_list = []
- for i in range(0, 16):#512):
- returned_list.append('filerted_{}.ckpt'.format(i))
- returned_list = [os.path.join(ckpt_path, item) for item in returned_list if 'embedding' not in item]
- print("Sorted list", returned_list)
- for item in returned_list:
- fsize = os.path.getsize(item)
- f_gb = fsize / float(1024) / 1024 / 1024
- print(item, " :{:.2f}".format(f_gb))
- return returned_list
-
- class LossCallBack(Callback):
- """
- Monitor the loss in training.
- If the loss in NAN or INF terminating training.
- """
-
- def __init__(self, dataset_size=-1, local_rank=0, has_trained_epoch=0, has_trained_step=0, micro_size=1):
- super(LossCallBack, self).__init__()
- self._dataset_size = dataset_size
- self.local_rank = local_rank
- self.has_trained_epoch = has_trained_epoch
- self.has_trained_step = has_trained_step
- self.micro_size = micro_size
- print("load has trained epoch :{} and step: {}".format(has_trained_epoch, has_trained_step), flush=True)
-
- def step_end(self, run_context):
- """
- Print loss after each step
- """
- cb_params = run_context.original_args()
- if self._dataset_size > 0 and self.local_rank % 8 == 0:
- percent, epoch_num = math.modf(cb_params.cur_step_num /
- self._dataset_size)
- if percent == 0:
- epoch_num -= 1
- date = time.asctime(time.localtime(time.time()))
- total_value = cb_params.net_outputs[0].asnumpy()
- print("time: {} local_rank: {}, epoch: {}, step: {}, total_loss is {}, overflow is {}, scale is {}, lr is {}".
- format(date, int(self.local_rank), int(epoch_num) + int(self.has_trained_epoch),
- cb_params.cur_step_num + int(self.has_trained_step), total_value,
- cb_params.net_outputs[1].asnumpy(), cb_params.net_outputs[2].asnumpy(),
- cb_params.net_outputs[3].asnumpy()
- ))
-
-
- project_root = os.path.abspath(
- os.path.dirname(os.path.realpath(__file__)) + os.path.sep + "..")
- print('project_root:', project_root)
-
-
- def process_example(example, k, keeplinesRatio=1):
- example_sent = ""
- example_lens = len(example)
- random_indexs = [random.randint(0, example_lens - 1) for i in range(k)]
-
- for num_example in random_indexs:
- # example_sent = example_sent + example[num_example]["context"] + example[num_example][
- # "question"] + "请在___处填上内容:" + example[num_example]["answer"] + "\n"
-
- tmp_lines_list = example[num_example]["context"].split('。')
- all_lines = len(tmp_lines_list)
- keep_lins_num = math.floor(all_lines * keeplinesRatio)
-
- # random_line_indexs = [random.randint(0, all_lines - 1) for i in range(keep_lins_num)]
- # for thisRandomLineIndex in random_line_indexs:
- # example_sent = example_sent + tmp_lines_list[thisRandomLineIndex]
-
- for i in range(keep_lins_num):
- example_sent = example_sent + tmp_lines_list[-keep_lins_num + i]
-
- example_sent = example_sent + example[num_example]["question"] + example[num_example]["answer"] + "\n"
- # if len(example_sent) >= MAX_STR:
- # break
-
- return (example_sent)
-
- def process_one_sent_eval(tokenizer, sent, example):
- # it will became list after tokenizer encode
- input_sent = example + sent["context"] + sent["question"] # +"请在___处填上内容:"
- # print(input_sent)
- # print(sent["answer"])
- # input_sent = example + sent["context"] + sent["question"]
- # sent_ids=tokenizer.encode(input_sent)
- # truth_label=tokenizer.encode(sent["answer"])
-
- sent_ids = tokenizer.tokenize(input_sent) # change encode2tokenize
- truth_label = tokenizer.tokenize(sent["answer"])
-
- L = {"prompt": sent_ids,
- "truth": truth_label
- }
-
- return L
-
- def get_cftpd_data(tokenizer, shot, k_num=3):
- if shot == "zero_shot":
- k = 0
- elif shot == "one_shot":
- k = 1
- elif shot == "few_shot":
- k = k_num
-
- all_data = {"contents": []
- ,
- "labels": []
- }
- # load data
- with open("./tasks/cftpd/cft_test_auto.txt", "r", encoding="utf-8") as f:
- data = []
- lines = f.readlines()
- sent_dict = {"context": "", "question": "", "answer": ""}
-
- stop_flag = False
-
- for line in lines:
- line = line.replace(" ", "")
- line = line.replace("\n", "")
- if line.count("|||") == 1:
- if not stop_flag:
- if not line.find("XXXXX") == -1:
- # line = line.replace("XXXXX", "___")
- # line = line.replace("XXXXX", "__")
-
- stop_flag = True # 去处语料中的Question
- else:
- sent_dict["context"] = sent_dict["context"] + line[line.find("|||") + 3:]
-
- ############ 保留语料中Question的下文内容 ################################
- else:
- sent_dict["context"] = sent_dict["context"] + line[line.find("|||") + 3:]
- ##############################################################################
- else:
-
- stop_pos_question = line.find("XXXXX")
- first_pos = line.find("|||") + 3
- second_pos = line[first_pos:].find("|||") + first_pos + 3
- # questions stop at the blank
- sent_dict["question"] = sent_dict["question"] + line[first_pos:stop_pos_question]
- #
- sent_dict["answer"] = sent_dict["answer"] + line[second_pos:]
- data.append(sent_dict)
-
- sent_dict = {"context": "", "question": "", "answer": ""}
-
- # load data
- with open("./tasks/cftpd/cft.test.human", "r", encoding="utf-8") as f:
- examples_data = []
- lines = f.readlines()
- sent_dict_examples = {"context": "", "question": "", "answer": ""}
-
- stop_flag = False
-
- for line in lines:
- line = line.replace(" ", "")
- line = line.replace("\n", "")
-
- if line.count("|||") == 1:
- if not stop_flag:
- if not line.find("XXXXX") == -1:
- # line = line.replace("XXXXX", "")
- # delete question line in context
- # line = ""
- stop_flag = True
- else:
- sent_dict_examples["context"] = sent_dict_examples["context"] + line[line.find("|||") + 3:]
- # else:
- # sent_dict_examples["context"] = sent_dict_examples["context"] + line[line.find("|||") + 3:]
-
- else:
-
- stop_pos_question = line.find("XXXXX")
-
- first_pos = line.find("|||") + 3
- second_pos = line[first_pos:].find("|||") + first_pos + 3
- # questions stop at the blank
-
- # sent_dict_examples["answer"] = sent_dict_examples["answer"] + line[second_pos:]
- # sent_dict_examples["question"] = sent_dict_examples["question"] + line[first_pos:stop_pos_question] + sent_dict_examples["answer"] + line[stop_pos_question + 5: second_pos - 3]
-
- sent_dict_examples["question"] = sent_dict_examples["question"] + line[first_pos:stop_pos_question]
- sent_dict_examples["answer"] = sent_dict_examples["answer"] + line[second_pos:]
-
- # sent_dict_examples["answer"] = ""
- examples_data.append(sent_dict_examples)
-
- sent_dict_examples = {"context": "", "question": "", "answer": ""}
- stop_flag = False
-
- # examples_i = process_example(examples_data, k)
-
- for line in data: # , desc="Preprocessing data"):
- examples_i = process_example(examples_data, k)
- # print(examples_i)
- processed_sent = process_one_sent_eval(tokenizer, line, examples_i)
- all_data["contents"].extend([processed_sent["prompt"]])
- all_data["labels"].append(processed_sent["truth"])
-
- return all_data["contents"], all_data["labels"]
-
- def count_params(net):
- """Count number of parameters in the network
- Args:
- net (mindspore.nn.Cell): Mindspore network instance
- Returns:
- total_params (int): Total number of trainable params
- """
- total_params = 0
- for param in net.trainable_params():
- total_params += np.prod(param.shape)
- return total_params
-
- def run_train(args_opt):
- r"""
- The main training process.
- """
- # Set hccl connect time
- os.environ['HCCL_CONNECT_TIMEOUT'] = "6000"
- EXEC_PATH = os.path.join(project_root, 'tiny_pangu')
- device_id = int(os.getenv("DEVICE_ID"))
- rank_id_str = os.getenv('RANK_ID', '0')
- rank_id = int(
- rank_id_str[rank_id_str.rfind('-') +
- 1:]) # 'RANK_ID': 'job24535502-job-facereidtome-hn-0/1'
- print('rank_id:{}'.format(rank_id), "rank_id str:{}".format(rank_id_str))
- device_id = int(os.getenv('DEVICE_ID'))
- local_rank = rank_id
- print('local_rank:{}, device id:{}'.format(local_rank, device_id))
-
- # copy strategy_ckpt
- pretrained_strategy_ckpt_path = "/cache/strategy/ckpt_strategy_{}.ckpt".format(local_rank)
- mox.file.copy(src_url="obs://pcl-verify/yizx/distilPangu/strategy_ckpt/pangu26b_finetuneOnPD_PET_saveCKPT_3e6LR_bs30_bt98cktp_strategy.ckpt", dst_url=pretrained_strategy_ckpt_path)
- # mox.file.copy(src_url="obs://mindspore-file/strategy_ckpt/gpt_1024_13b_exp65cktp_strategy.ckpt", dst_url=pretrained_strategy_ckpt_path)
-
- # donload dataset
- if local_rank % 8 == 0:
- print('MindSpore path:', mindspore)
- print("Modify the time out from 300 to 30000")
- tbe_path = "/usr/local/ma/python3.7/lib/python3.7/site-packages/mindspore" \
- "/_extends/parallel_compile/tbe_compiler/tbe_process.py"
- os.system(
- "sed -i 's/300/30000/g' " + tbe_path
- )
- os.system(
- "sed -i 's/330/33000/g' " + tbe_path
- )
- print("begin download dataset", flush=True)
-
- cache_url = LOCAL_PATH
- if not os.path.exists(LOCAL_PATH):
- Path(LOCAL_PATH).mkdir(parents=True, exist_ok=True)
-
-
- files = os.listdir(LOCAL_PATH)
- data = [
- os.path.join(LOCAL_PATH, name) for name in files
- if not name.endswith(".db")
- ]
- if len(data) == 0:
- print("Start to copy the dataset", flush=True)
- Path(cache_url).mkdir(parents=True, exist_ok=True)
- mox.file.copy_parallel(src_url=BUCKET_DIR, dst_url=LOCAL_PATH)
- print("@@@@@@ Dataset download succeed! @@@@@@@", flush=True)
-
- os.environ['HCCL_CONNECT_TIMEOUT'] = "6000"
- os.system('ulimit -s 102400')
- print(args_opt.ckpt_path)
- # ckpt_copy_tar(args_opt.ckpt_path, target_path="/cache/ckpt_files")
- mox.file.copy('obs://pcl-verify/yizx/distilPangu/merged_ckpt/Newexp65_GPT3_2-3494_2.ckpt', '/cache/Newexp65_GPT3_2-3494_2.ckpt')
- mox.file.copy(args_opt.word_embedding_path, '/cache/word_embedding.npy')
- mox.file.copy(args_opt.position_embedding_path, '/cache/position_embedding.npy')
- mox.file.copy(args_opt.top_query_embedding_path, '/cache/top_query_embedding.npy')
- print("setting env success.")
- f = open("%s/install.txt" % (EXEC_PATH), 'w')
- f.close()
- # 此处用于阻塞其他进程,直到刷包以及下载数据集完成为止
- while not os.path.exists("%s/install.txt" % (EXEC_PATH)):
- time.sleep(1)
- print('local_rank:{}, device id:{} start to run...'.format(
- local_rank, device_id),
- flush=True)
-
-
- # Set execution mode
- context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.device_target)
- context.set_context(variable_memory_max_size="31GB")
- strategy_file = '/tmp/cktp_strategy.ckpt'
- # Set parallel context
- if args_opt.distribute == "true":
- D.init()
- device_num = D.get_group_size()
- rank = D.get_rank()
- print("rank_id is {}, device_num is {}".format(rank, device_num))
-
-
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL, gradients_mean=False,
- device_num=device_num,
- full_batch=False,
- strategy_ckpt_load_file=pretrained_strategy_ckpt_path,
- enable_parallel_optimizer=False,
- strategy_ckpt_save_file=strategy_file)
- set_algo_parameters(elementwise_op_strategy_follow=True)
- _set_multi_subgraphs()
- else:
- rank = 0
- device_num = 1
- context.set_context(save_graphs=True, save_graphs_path="/cache/" + str(rank))
- # copy data from the cloud to the /cache/Data
- cache_url = '/cache/Data/'
-
-
- # Set model property
- model_parallel_num = 1 #args_opt.op_level_model_parallel_num
- data_parallel_num = int(device_num / model_parallel_num)
- args_opt.per_batch_size = 3
- batch_size = args_opt.per_batch_size * data_parallel_num
- print("@@@@@ batch_size_perDevice is : {} @@@@@".format(batch_size))
-
- teacher_config = PANGUALPHAConfig(
- data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num, batch_size=batch_size,
- seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, embedding_size=args_opt.embedding_size,
- num_layers=args_opt.num_layers, num_heads=args_opt.num_heads, expand_ratio=4, dropout_rate=0.1,
- compute_dtype=mstype.float16, stage_num=args_opt.stage_num, micro_size=args_opt.micro_size,
- eod_reset=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
- param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
- word_emb_dp=bool(args_opt.word_emb_dp))
-
- student_config = PANGUALPHAConfig(
- data_parallel_num=data_parallel_num, model_parallel_num=model_parallel_num, batch_size=batch_size,
- seq_length=args_opt.seq_length, vocab_size=args_opt.vocab_size, embedding_size=1280,
- num_layers=int(args_opt.num_layers / 2), num_heads=args_opt.num_heads, expand_ratio=4, dropout_rate=0.1,
- compute_dtype=mstype.float16, stage_num=args_opt.stage_num, micro_size=args_opt.micro_size,
- eod_reset=bool(args_opt.eod_reset), load_ckpt_path=args_opt.load_ckpt_path,
- param_init_type=mstype.float32 if args_opt.param_init_type == 'fp32' else mstype.float16,
- word_emb_dp=bool(args_opt.word_emb_dp))
-
-
- print("===teacher_config is: ", teacher_config, flush=True)
- print("===student_config is: ", student_config, flush=True)
-
- # Define network
- #teacher_net = PanguAlpha(teacher_config, is_teacher=True)
- student_net = PanguAlpha(student_config, is_teacher=False)
- teacher_net = PanguAlpha(teacher_config, is_teacher=True)
-
- ##############################################################################################
- ### load pretrained_pangu_ckpt
- print("##### start to load pangu 2.6B pretrained-ckpt #####", flush=True)
- # from mindspore.train.serialization import load_distributed_checkpoint
- # #tmp_input = Tensor(np.ones(shape=(1, config.seq_length)), mstype.int32)
- # ##strategy = model.infer_train_layout(train_dataset=ds, sink_size=callback_size)
- # ckpt_file_list = get_ckpt_file_list('/cache/ckpt_files')
- # load_distributed_checkpoint(teacher_net, ckpt_file_list)#, predict_layout)
- params_dict = load_checkpoint("/cache/Newexp65_GPT3_2-3494_2.ckpt")
- load_param_into_net(teacher_net, params_dict)
- print('##### PANGU-2.6B partial parameter size is: {} #####'.format(count_params(teacher_net)))
-
- teacher_net.set_train(False)
- params = teacher_net.trainable_params()
- for param in params:
- param.requires_grad = False
-
- tiny_pangu = GPTWithLoss_gd_withCrEN(teacher_net, teacher_config, student_net, student_config)
- print('##### PANGU teacher trainable parameter size is: {}, student is: {} #####\n'.format(
- count_params(tiny_pangu.teacher), count_params(tiny_pangu.student)))
- pangu_alpha_with_loss = tiny_pangu
-
- # net_params = pangu_alpha_with_loss.parameters_dict()
- # params = [{'name':k, "data":v} for k, v in net_params.items()]
- # for each_param in params:
- # print(each_param)
- # exit()
- ##############################################################################################
- # loss = CrossEntropyLoss(student_config)
- # pangu_alpha_with_loss = PanguAlphaWithLoss_tiny(student_config, student_net, loss, eos_token=args_opt.eod_id)
- # pangu_alpha_with_loss = VirtualDatasetOneInputCell(pangu_alpha_with_loss)
-
-
- # loss = CrossEntropyLoss(student_config)
- # pangu_alpha_with_loss = PanguAlphaWithLoss(student_config, tiny_pangu, loss)
- # pangu_alpha_with_loss = _VirtualDatasetCell(pangu_alpha_with_loss)
-
- print("=====args_opt is: ", args_opt, flush=True)
-
- # Warm-up and cosine decay learning rate
- lr = LearningRate(learning_rate=args_opt.start_lr, end_learning_rate=args_opt.end_lr,
- warmup_steps=args_opt.warmup_step, decay_steps=10000)
-
- # Set weight decay coefficient, zero for bias and layernorm, 1e-1 for rest
- decay_filter = lambda x: 'layernorm' not in x.name.lower() and "bias" not in x.name.lower()
- params = pangu_alpha_with_loss.trainable_params()
- decay_params = list(filter(decay_filter, params))
- other_params = list(filter(lambda x: not decay_filter(x), params))
- group_params = [{
- 'params': decay_params,
- 'weight_decay': 1e-1
- }, {
- 'params': other_params,
- 'weight_decay': 0.0
- }, {
- 'order_params': params
- }]
- if args_opt.optimizer == "lamb":
- optimizer = nn.Lamb(group_params, learning_rate=lr)
- else:
- optimizer = FP32StateAdamWeightDecay(group_params, learning_rate=lr, eps=1e-8, beta1=0.9, beta2=0.95)
- # Initial scaling sens
- loss_scale_value = math.pow(2, 32)
- epoch_num = args_opt.epoch_size
- # Dataset loading mindrecord files
-
- # ds = create_dataset(student_config.batch_size, data_path=LOCAL_PATH,
- # data_start_index=0, eod_reset=student_config.eod_reset, full_batch=bool(args_opt.full_batch),
- # eod_id=args_opt.eod_id, device_num=device_num, rank=rank,
- # column_name=args_opt.data_column_name, epoch=epoch_num)
-
- ds = create_dataset(student_config.batch_size, data_path=LOCAL_PATH, data_start_index=0, eod_reset=student_config.eod_reset, eod_id=args_opt.eod_id, device_num=device_num, rank=rank, hash_check=True)
-
- # ###################### test shape ##################################################################################
- # for item in ds: # [input_ids, position_id, attention_mask]
- # for d in item:
- # print(d, d.shape)
- # break
- # test_inputs, test_position_id, test_attention_mask = item
- # teacher_seq_output, student_seq_output = pangu_alpha_with_loss(test_inputs
- # , test_position_id, test_attention_mask)
- # print(teacher_seq_output, student_seq_output)
- # print(len(teacher_seq_output), len(student_seq_output)) # 31, 15
- # for i in range(len(teacher_seq_output)):
- # print(teacher_seq_output[i].shape, student_seq_output[i].shape)
- # exit()
- # ####################################################################################################
-
- ckpt_dir = os.path.join("/cache/ckpt/", f"rank_{str(local_rank)}")
- # create dir for ckpt
- if not os.path.exists(ckpt_dir):
- Path(ckpt_dir).mkdir(parents=True, exist_ok=True)
-
- step_per_epoch = ds.get_dataset_size()
- callback_size = args_opt.sink_size
- actual_epoch_num = int(epoch_num * step_per_epoch / callback_size)
- callback = [
- TimeMonitor(callback_size),
- LossCallBack(callback_size, rank, 0, 0)
- ]
-
- config_ck = CheckpointConfig(save_checkpoint_steps=args_opt.save_step,
- keep_checkpoint_max=1,
- integrated_save=False)
- ckpoint_cb = ModelCheckpoint(prefix="tinyPangu-yizx",
- directory=ckpt_dir,
- config=config_ck)
- ckpt2obs_cb = Ckpt2ObsSummaryCallback(local_ckpt_dir=ckpt_dir,
- local_rank=0,
- has_trained_epoch=0,
- has_trained_step=0,
- bucket=args_opt.bucket_dir + '/' + str(local_rank),
- syn_obs_steps=args_opt.save_step)
- callback.append(ckpoint_cb)
- callback.append(ckpt2obs_cb)
-
- if local_rank == 0:
- sub_dir = args_opt.bucket_dir.split('/')[-1]
- callback.append(LossSummaryCallback_Print2Loss(summary_dir="summary",
- local_rank=0,
- has_trained_epoch=0,
- has_trained_step=0,
- bucket='obs://pcl-verify/yizx/distilPangu/summary/' + sub_dir,
- syn_times=40))
- callback.append(StrategySaveCallback(strategy_path=strategy_file,
- local_rank=0,
- has_trained_epoch=0,
- has_trained_step=0,
- bucket='obs://pcl-verify/yizx/distilPangu/strategy_ckpt/' + sub_dir))
-
- update_cell = DynamicLossScaleUpdateCell(loss_scale_value=loss_scale_value, scale_factor=2, scale_window=1000)
- pangu_alpha_with_grads = PanguAlphaTrainOneStepWithLossScaleCell_Print2Loss(
- pangu_alpha_with_loss, optimizer=optimizer, scale_update_cell=update_cell, enable_global_norm=True,
- config=student_config)
- model = Model(pangu_alpha_with_grads)
- print("Dataset size: {}, actual_epoch_num: {}".format(ds.get_dataset_size(), actual_epoch_num), flush=True)
- model.train(actual_epoch_num, ds, callbacks=callback, sink_size=callback_size, dataset_sink_mode=True)
-
-
-
- if __name__ == "__main__":
- opt = get_args()
- set_parse(opt)
- if opt.per_batch_size == 0:
- raise ValueError("The per_batch_size has not been configured.")
- run_train(opt)
|