|
- import os
- os.system('python -m pip install pcl_pangu')
- os.system('ls /tmp/dataset')
- os.system('nvidia-smi')
- import argparse
- from pcl_pangu.context import set_context
- from pcl_pangu.model import alpha
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument('--model', default='2B6',
- type=str, choices=['350M', '2B6', '13B'],
- help="setting model size from ['350M', '2B6', '13B']")
- parser.add_argument('--data_url', default='/dataset/text_document',
- type=str,
- help="setting bin dataset text_document from: '/dataset'.")
- parser.add_argument('--load', default='/dataset/Pangu-alpha_2.6B_fp16_mgt/Pangu-alpha_2.6B_mgt',
- type=str,
- help="loading pretrained model ckpt, from: '/dataset'.")
- parser.add_argument('--train_url', default='/output/',
- type=str,
- help="save your model to: '/output'.")
- parser.add_argument('--input_user', default='四川的省会是?',
- type=str,
- help="input your answer")
- args = parser.parse_args()
- set_context(backend='pytorch')
- print(args)
-
- model = args.model # 模型类别: 350M, 2B6, 13B
- data_path = args.data_url # 训练文本路径:text_document.bin的存放路径。在/tmp/dataset/里
- load = args.load # 初始模型路径:在/tmp/dataset/里
- save = args.train_url # 训练完的模型保存路径:在/tmp/output/里
- input_user = args.input_user
-
- config = alpha.model_config_gpu(model=model, load=load)
-
- alpha.inference(config, input=input_user)
-
- print()
- pass
|