|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """
- PanGu predict run
- """
- import os
- import numpy as np
- import mindspore.common.dtype as mstype
- import mindspore.communication.management as D
- from mindspore import context, Tensor
- from mindspore import export
- from mindspore.context import ParallelMode
- from mindspore.parallel import set_algo_parameters
- from mindspore.parallel._cost_model_context import _set_multi_subgraphs
- from mindspore.train.model import Model
- # from mindspore.train.serialization import load_distributed_checkpoint
- from mindspore.parallel.nn.transformer import TransformerOpParallelConfig
- from mindspore.nn.transformer.transformer import TransformerRecomputeConfig
- from src.pangu_alpha import EvalNet, PanguAlphaModel
- from src.pangu_alpha_config import set_parse, PanguAlphaConfig
- from src.utils_competition_ar_exp22 import get_args, ckpt_copy_tar_new, get_ckpt_file_list
- # from mindspore.nn.transformer import layers
-
- import moxing as mox
- import time
-
- def load_model(args_opt):
- r"""
- The main function for load model
- """
- # Set execution mode
- context.set_context(save_graphs=False,
- mode=context.GRAPH_MODE,
- device_target=args_opt.device_target)
- context.set_context(variable_memory_max_size="30GB")
- # Set parallel context
- if args_opt.distribute == "true":
- D.init()
- device_num = D.get_group_size()
- rank = D.get_rank()
- print("rank_id is {}, device_num is {}".format(rank, device_num))
- ############################################################
- local_ckpt_path = "/cache/ckpt.ckpt"
- if rank % 8 == 0:
- os.system('ulimit -s 102400')
- # obs://research-my2/taoht-13b/filtered_ckpt/mPanGu_53-26b-128k-m1d128-22m1d24/mPanGu_53-26b-128k-m1d128-22m3d26-exp4/integrated_save/mPanGu_Alpha-53_exp4-54000.ckpt
- mox.file.copy(src_url="obs://research-my2/taoht-13b/filtered_ckpt/integrated_save/mPanGu_Alpha-53_exp22-30508-ar_fp16.ckpt",
- dst_url=local_ckpt_path)
- print("setting env success.")
- # 下载模型文件结束后,写一个文件来表示下载成功
- f = open("/tmp/download_ckpt.txt", 'w')
- f.close()
- # 此处用于阻塞其他进程,直到刷包以及下载数据集完成为止
- while not os.path.exists("/tmp/download_ckpt.txt"):
- time.sleep(1)
- print("\n\n************Checkpoint download succeed!*************\n\n", flush=True)
- if rank % 8 == 0:
- print(os.listdir(args_opt.load_ckpt_local_path), flush=True)
- ############################################################
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL,
- gradients_mean=False,
- full_batch=True,
- loss_repeated_mean=True,
- enable_parallel_optimizer=False,
- pipeline_stages=args_opt.stage_num)
- set_algo_parameters(elementwise_op_strategy_follow=True)
- _set_multi_subgraphs()
-
- else:
- rank = 0
- device_num = 1
- context.reset_auto_parallel_context()
- context.set_auto_parallel_context(
- strategy_ckpt_load_file=args_opt.strategy_load_ckpt_path)
-
- use_past = (args_opt.use_past == "true")
- print('local_rank:{}, start to run...'.format(rank), flush=True)
- if args_opt.export:
- use_past = True
- # Set model property
- model_parallel_num = args_opt.op_level_model_parallel_num
- data_parallel_num = int(device_num / model_parallel_num)
-
- recompute_config = TransformerRecomputeConfig(recompute=True)
- parallel_config = TransformerOpParallelConfig(data_parallel=data_parallel_num,
- model_parallel=model_parallel_num,
- pipeline_stage=args_opt.stage_num,
- micro_batch_num=args_opt.micro_size,
- recompute=recompute_config)
-
- per_batch_size = args_opt.per_batch_size
- batch_size = per_batch_size * data_parallel_num
- # Now only support single batch_size for predict
- if args_opt.run_type == "predict":
- batch_size = 1
- config = PanguAlphaConfig(
- batch_size=batch_size,
- seq_length=args_opt.seq_length,
- vocab_size=args_opt.vocab_size,
- hidden_size=args_opt.embedding_size,
- num_layers=args_opt.num_layers,
- num_heads=args_opt.num_heads,
- post_layernorm_residual=False,
- dropout_rate=0.0,
- ffn_hidden_size=args_opt.embedding_size * 4,
- use_past=use_past,
- eod_reset=False,
- parallel_config=parallel_config,
- load_ckpt_path=None,
- run_type=args_opt.run_type,
- param_init_type=mstype.float16)
- print("===config is: ", config, flush=True)
- print("=====args_opt is: ", args_opt, flush=True)
-
- # Define network
- pangu_alpha = PanguAlphaModel(config)
- eval_net = EvalNet(pangu_alpha, pad_token=args_opt.padding_id)
- eval_net.set_train(False)
- model_predict = Model(eval_net)
- # Compile network and obtain tensor layout for loading ckpt
- inputs_np = Tensor(np.ones(shape=(config.batch_size, config.seq_length)), mstype.int32)
- current_index = Tensor(np.array([0]), mstype.int32)
-
- if args_opt.distribute == "false":
- predict_layout = None
- elif config.use_past:
- batch_valid_length = Tensor(np.array([0]), mstype.int32)
- init_true = Tensor([True], mstype.bool_)
- inputs_np_1 = Tensor(np.ones(shape=(config.batch_size, 1)), mstype.int32)
- model_predict.predict_network.add_flags_recursive(is_first_iteration=True)
- predict_layout = model_predict.infer_predict_layout(inputs_np, current_index, init_true, batch_valid_length)
- model_predict.predict_network.add_flags_recursive(is_first_iteration=False)
- _ = model_predict.infer_predict_layout(inputs_np_1, current_index, init_true, batch_valid_length)
- else:
- predict_layout = model_predict.infer_predict_layout(inputs_np, current_index)
- ##------------------------------------------------------------------------------------------------------
- print("======start load_distributed checkpoint", flush=True)
- # Load checkpoint files
- from src.serialization2 import load_checkpoint, load_param_into_net
- param_dict = load_checkpoint(local_ckpt_path)
- load_param_into_net(eval_net, param_dict, strict_load=False)
- print("================load param ok=================", flush=True)
- return pangu_alpha, model_predict, config
-
-
- def run_predict(network, model_predict, config, args_opt):
- """run predict"""
- from src.generate import generate_increment_re as generate_func
- D.init()
- rank = D.get_rank()
-
- # Define tokenizer
- work_dir = '/home/work/user-job-dir/pangu_alpha-r1.7/tokenizer/spm_13w'
- from tokenizer.spm_13w.tokenizer import SpmTokenizer, langs_ID, translate_ID
- vocab_file = work_dir + '/spm.128k.model.1'
- tokenizer = SpmTokenizer(vocab_file)
-
- # Tokenize input sentence to ids
- input_txt = "你今天中午吃的什么?"
- input_ids = tokenizer.tokenize(input_txt)
- src_langs = 'zh'
- tag_langs = 'ar'
- # Call inference
- src_input_ids = [langs_ID[src_langs], langs_ID[src_langs], langs_ID[src_langs]] + \
- input_ids + \
- [translate_ID, translate_ID, translate_ID] + \
- [langs_ID[tag_langs], langs_ID[tag_langs], langs_ID[tag_langs]]
-
- tag_output_max_length = min(len(input_ids) * 2 + 100, 512)
- output_ids = generate_func(model_predict, np.array([src_input_ids]), args_opt, dynamic_generate_length=tag_output_max_length)
- # Decode output ids to sentence
- output_txt = tokenizer.detokenize(output_ids[len(src_input_ids):].tolist())
- if tag_langs == 'zh':
- output_txt = output_txt.replace(" ", '')
- print('Output is:', output_txt, flush=True)
-
-
- def main():
- """Main process for predict or export model"""
- opt = get_args(True)
- set_parse(opt)
- network, model_predict, config = load_model(opt)
-
- run_predict(network, model_predict, config, opt)
-
-
- if __name__ == "__main__":
- main()
|