#1 master

Merged
zhangyh02 merged 6 commits from PCL-Platform.Intelligence/PanGu-Alpha-GPU:master into master 2 years ago
  1. +17
    -0
      README-en.md
  2. +2
    -2
      README.md
  3. +1
    -3
      inference_mindspore_gpu/3-minus-inference.md
  4. +362
    -0
      inference_mindspore_gpu/README-en.md
  5. +231
    -0
      panguAlpha_pytorch/README-en.md
  6. +1
    -1
      panguAlpha_pytorch/README.md
  7. +8
    -8
      panguAlpha_pytorch/examples/finetune_pangu_distributed.sh
  8. +1
    -1
      panguAlpha_pytorch/examples/generate_text.sh
  9. +52
    -0
      panguAlpha_pytorch/examples/pretrain_pangu_distributed_2.6B.sh
  10. +1
    -1
      panguAlpha_pytorch/megatron/arguments.py
  11. +2
    -3
      panguAlpha_pytorch/megatron/model/language_model.py
  12. +8
    -6
      panguAlpha_pytorch/megatron/training.py
  13. +2
    -4
      panguAlpha_pytorch/pretrain_gpt2.py
  14. +5
    -7
      panguAlpha_pytorch/tools/generate_samples_Pangu.py

+ 17
- 0
README-en.md View File

@@ -0,0 +1,17 @@
# PanGu-Alpha-GPU

English|[中文](README.md)

### Description

This project is the GPU version of [Pangu-alpha](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha), please check the original project for the details of Pangu-alpha. The main purpose of this project is to enable Pangu-alpha models to be inferred and trained on GPU, so that more people can experience the charm of big models. The purpose of openness is to gather ideas and explore the potential applications of large model , as well as to identify problems that can guide our future innovative research and breakthroughs.


# mindspore Inference、Finetune、Pre-training
:
1. [Please check](inference_mindspore_gpu/README-en.md):This part of the code only supports inference, so if you just want to experience Pangu-Alpha, we recommend using the "Three minutes to implement inference tutorial" under this page.
2. [Please check](https://gitee.com/mindspore/models/tree/master/official/nlp/pangu_alpha ):If you want to develop on Pangu-Alpha, we recommend using the training and inference code provided by mindspore. Model_zoo on the official website of mindspore provides inference, Finetune, and pre-training full process.

# pytorch Inference、Finetune、Pre-training

[Please check](panguAlpha_pytorch/README-en.md):The full process of inference, Finetune, and pre-training of Pangu-Alpha developed based on Megatron-1.1.

+ 2
- 2
README.md View File

@@ -1,6 +1,6 @@
# PanGu-Alpha-GPU

[英文](README-en.md)|中文

### 描述

@@ -11,7 +11,7 @@
# mindspore 推理、Finetune、预训练

1. [请查看](inference_mindspore_gpu/README.md):该部分代码只支持推理,如果只想体验一下盘古α的话,推荐使用这个页面下的《三分钟实现推理教程》。
2. [请查看](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/nlp/pangu_alpha):如果想在盘古α上开发的话,推荐使用 mindspore 提供的训练和推理代码。mindspore 官网的 model_zoo 提供了推理、Finetune、预训练全流程。
2. [请查看](https://gitee.com/mindspore/models/tree/master/official/nlp/pangu_alpha ):如果想在盘古α上开发的话,推荐使用 mindspore 提供的训练和推理代码。mindspore 官网的 model_zoo 提供了推理、Finetune、预训练全流程。

# pytorch 推理、Finetune、预训练



+ 1
- 3
inference_mindspore_gpu/3-minus-inference.md View File

@@ -1,6 +1,4 @@
# 3分钟实现盘古α模型推理
《目前启智社区的服务器内存有限(只有16g),所以暂时推理不了。不过后续会开放内存更大的服务器,敬请关注!!》


启智社区中项目下的云脑页面可以申请一张 T4 GPU,没有服务器的个人或学生可以在启智社区上快速免费体验盘古α。本教程将手把手教你3分钟实现盘古α的推理流程。

@@ -28,7 +26,7 @@
- 在`数据集存放路径`那一栏选 'PanguAlpha_2.6B_fp16_mindspore.zip'

![img.png](images/choice-dataset.png)
- 选择内存大于等于 32g 的服务器规格
- 点击`新建任务`

- 等待一段时间就会出现 `调试` 按钮,点击进去就是熟悉的 jupyter 调试界面了


+ 362
- 0
inference_mindspore_gpu/README-en.md View File

@@ -0,0 +1,362 @@
# PanGu-Alpha-GPU



### Description

This project is a GPU inference version of [Pangu-alpha](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha), for information about [Pangu-alpha](https://git.openi.org.cn /Intelligence/Pangu-Alpha), please see the original project for information on the principles, datasets, etc. The current phase of the project focuses on enabling Pangu-alpha models to be inferred and trained on GPUs, so that more people can experience the appeal of large models. The purpose of openness is to gather ideas, draw inspiration, and explore the potential of large model applications, as well as to identify problems that can guide our future innovative research and breakthroughs.



### model ckpt

| model | MD5 | fp |
| ------------------------------------------------------------ | -------------------------------- | ---- |
| [Pangu-alpha_2.6B.ckpt](https://git.openi.org.cn/attachments/27234961-4d2c-463b-9052-0240cc7ff29b?type=0) | da404a985671f1b5ad913631a4e52219 | fp32 |
| [ PanguAlpha_13b_fp16.ckpt](https://git.openi.org.cn/attachments/650711d6-6310-4dc2-90f8-153552e59c7a?type=0) | f2734649b9b859ff4cf62d496291249a | fp16 |
| [PanguAlpha_2.6B_fp16.ckpt](https://git.openi.org.cn/attachments/7ff30c2f-e9e4-44be-8eaa-23c9d617b781?type=0) | 3a14e8bf50548a717160e89df7c14b63 | fp16 |

[Pangu-alpha_2.6B.ckpt](https://git.openi.org.cn/attachments/27234961-4d2c-463b-9052-0240cc7ff29b?type=0) Can be used for loading 2.6B models of `fp16` and `fp32`, since the precision conversion is performed during the model loading phase

[ PanguAlpha_13b_fp16.ckpt](https://git.openi.org.cn/attachments/650711d6-6310-4dc2-90f8-153552e59c7a?type=0) Can only be used for loading 13B models of `fp16`

[PanguAlpha_2.6B_fp16.ckpt](https://git.openi.org.cn/attachments/7ff30c2f-e9e4-44be-8eaa-23c9d617b781?type=0) can be used for loading 2.6B models of `fp16`, is the same as [Pangu-alpha_2.6B.ckpt](https://git.openi.org.cn/attachments/27234961-4d2c-463b-9052-0240cc7ff29b?type=0), but this ckpt consumes less memory, about 20g.

### Graphics memory usage

| model | Graphics memory |
| --------- | --------- |
| 2.6B_fp16 | 6728 MiB |
| 2.6B_fp32 | 17214 MiB |
| 13B_fp16 | 26430 MiB |

Different models can be run depending on the video memory size of the card

The `2.6B_fp16` model should work on most graphics cards

Already running `2.6B_fp16` model successfully on T4 and `2.6B_fp16`, `2.6B_fp32` and `13B_fp16` models on v100


### Reasoning

###** Want a quick experience? Check out the [3-Minute Tutorial on inference](3-minus-inference.md)! You can use the T4 server for free!!! **


##### Environment

You can use a ready-made docker image directly

```
docker pull yands/mindspore_pangu-alpha:1.2.0
```
python path `/usr/local/bin/python`

If you don't like to use this image, you can also use the `mindspore:1.2.0` version, which allows you to run the `2.6B_fp32` model directly. There are several mindspore source code changes needed to run the `fp16` model, please see [appendix](#source code changes).

##### dependencies

```
pip install jieba
pip install sentencepiece==0.1.94
```

##### run

```
python run_inference.py --model=13B_fp16 --load_ckpt_path=/xxx/PanguAlpha_13b_fp16.ckpt
python run_inference.py --model=2B6_fp16 --load_ckpt_path=/xxx/PanguAlpha_2_6b.ckpt
python run_inference.py --model=2B6 --load_ckpt_path=/xxx/PanguAlpha_2_6b.ckpt
```

##### result

Note: The result limits the output tokens length to 50, without post-processing for different tasks

```
Input is: 上联:瑞风播福泽,事业具昌盛千家乐

2.6B_fp16 : 求下联
上联:瑞风播福泽,事业具昌盛千家乐求对下联:瑞风播德泽,事业成伟业万家乐

2.6B_fp32 :;下联:瑞雪迎瑞雪,事业兴盛万家乐。
横批:瑞雪兆丰年。
【瑞雪】
上联:瑞雪兆丰年,事业具昌盛千家乐

13B_fp16 :下联怎么对
上联:瑞风播福泽,事业具昌盛千家乐下联:福泽播善政,福泽流天下万民欢
```
```
Input is: 四川的省会是?

2.6B_fp16 : 省会是成都,四川省的省会是成都市,简称:成 都,是中华人民共和国四川省的首府,位于四川盆地西南边缘,成都平原中部,成都市区东南部,东与德阳市、绵阳

2.6B_fp32 :成都。。。。。。
好的,谢谢啦

13B_fp16 :成都。
```
```
Input is: 上联:春雨润人间,社会和谐万象新

2.6B_fp16 : 下联:春风暖人间,社会和谐万象新。横批:春雨润大地

2.6B_fp32 :下联:春晖照世界,世界和平百花盛开。
春雨滋润着大地,社会和谐万象更新,春晖照耀着世界。

13B_fp16 :下联:春风吹大地,大地春意盎然
```
```
Input is: 书生:羌笛何须怨杨柳,春风不度玉门关。
飞云:(这诗怎么这么耳熟?且过去跟他聊聊如何。)
书生:小兄弟,要不要一起喝一杯?
飞云:你请我呀?你若是请我,我便和你喝一杯;你若不请我,我便一个人去喝。
书生:小兄弟,看你年纪轻轻,不至于这么势利吧?
飞云:

2.6B_fp16 : 你怎么知道我不势利,你若是不势利,为什么要请我喝酒?
书生:小兄弟,你这话问的好奇怪哦!
飞云:你说什么?
书生:

2.6B_fp32 :我这人最看重的便是信誉。
书生:那是自然。
(这是在下和小兄弟的对话,小兄弟可不要介意。)
飞云:你看我年纪轻轻,不

13B_fp16 :你不知道吗,我这可是为了我自己,才这样做的。
书生:你这不是为了我自己,是为了你自己的前途呀!
飞云:你这是什么话?
书生:
```
```
Input is: 张无忌拿出屠龙宝刀,手起刀落,周芷若掉了一颗门牙,身旁的赵敏喜极而泣,

2.6B_fp16 : 周芷若却只是微微一笑,并不言语。张无忌又拿出倚天剑和屠龙刀,倚天剑和屠龙刀都是他亲手打造,刀刀都锋利无比,倚天剑和屠龙刀的刀身都是

2.6B_fp32 :张无忌又道:“这刀名叫倚天,你拿去给人做鞘,这倚天剑我送给你了。”周芷若喜道:“谢大侠。”张无忌道:“那倚天剑

13B_fp16 :周芷若也流下了幸福的眼泪。 张无忌和周芷若回到了峨嵋派。峨嵋弟子正准备迎接新掌门,却发现掌门被一个蒙面人所杀。张无忌和赵敏急忙赶到峨嵋
```
```
Input is: 人工智能成为国际竞争的新焦点。人工智能是引领未来的战略性技术,世界主要发达国家把发展人工智能作为提升国家竞争力、维护国家安全的重大战略,加紧出台规划和政策,围绕核心技术、顶尖人才、标准规范等强化部署,力图在新一轮国际科技竞争中掌握主导权。当前,

2.6B_fp16 : 人工智能技术已经成为全球创新最活跃、应用最广泛、产业链最长的技术之一。
在我国,人工智能技术也已进入快速发展的新阶段,人工智能技术在教育、医疗、交通、金融

2.6B_fp32 :人工智能正从理论研究向产业应用加速推进,从基础研究向技术研发和产品开发加速推进,从产业应用向民生服务加速推进。人工智能已成为全球创新最活跃、产业发展最具活力、国际竞争

13B_fp16 :人工智能发展呈现出深度感知、人机协同、群智开放、自主操控等趋势和特点,正在深刻改变人类的生产生活方式,成为推动经济社会发展的新引擎。
“人工智能”的英文单词是
```
```
Input is: 中国和美国和日本和法国和加拿大和澳大利亚的首都分别是哪里?

2.6B_fp16 : 中国:北京、上海、广州、香港、重庆、成都。

2.6B_fp32 :北京 (Beijing);上海 (Shanghai);广州 (Guangzhou);重庆 (Chongqing);成都 ( Chengdu);

13B_fp16 :中国的首都是北京,美国的首都是纽约,日本的首都是东京,法国的首都是巴黎。
```

### error

- ​ Sometimes the following two errors are reported, if they occur, re-run

```
FileNotFoundError: [Errno 2] No such file or directory: '/userhome/pclproject/gpt/PanGu-Alpha-GPU/inference/cuda_meta_105367/NotEqual_9416851550248092999.json'
```

```
ModuleNotFoundError: No module named 'tvm'
```



### Appendix

#### Source Code Modification

1、 the function `load_param_into_net()` on mindspore/train/serialization.py

```
def load_param_into_net(net, parameter_dict, strict_load=False,):
"""
Loads parameters into network.

Args:
net (Cell): Cell network.
parameter_dict (dict): Parameter dictionary.
strict_load (bool): Whether to strict load the parameter into net. If False, it will load parameter
in the param_dict into net with the same suffix. Default: False

Raises:
TypeError: Argument is not a Cell, or parameter_dict is not a Parameter dictionary.

Examples:
>>> net = Net()
>>> ckpt_file_name = "./checkpoint/LeNet5-1_32.ckpt"
>>> param_dict = load_checkpoint(ckpt_file_name, filter_prefix="conv1")
>>> param_not_load = load_param_into_net(net, param_dict)
>>> print(param_not_load)
['conv1.weight']
"""
if not isinstance(net, nn.Cell):
logger.error("Failed to combine the net and the parameters.")
msg = ("Argument net should be a Cell, but got {}.".format(type(net)))
raise TypeError(msg)

if not isinstance(parameter_dict, dict):
logger.error("Failed to combine the net and the parameters.")
msg = ("Argument parameter_dict should be a dict, but got {}.".format(type(parameter_dict)))
raise TypeError(msg)

strict_load = Validator.check_bool(strict_load)
logger.info("Execute the process of loading parameters into net.")
net.init_parameters_data()
param_not_load = []
for _, param in net.parameters_and_names():
if param.name in parameter_dict:
new_param = parameter_dict[param.name]
new_param = Parameter(Tensor(new_param.asnumpy(), param.dtype), name=param.name)
if not isinstance(new_param, Parameter):
logger.error("Failed to combine the net and the parameters.")
msg = ("Argument parameter_dict element should be a Parameter, but got {}.".format(type(new_param)))
raise TypeError(msg)
_update_param(param, new_param)
else:
param_not_load.append(param.name)

if param_not_load and not strict_load:
_load_dismatch_prefix_params(net, parameter_dict, param_not_load)

logger.debug("Params not matched(in net but not in parameter_dict):")
for param_name in param_not_load:
logger.debug("%s", param_name)

logger.info("Loading parameters into net is finished.")
if param_not_load:
logger.warning("{} parameters in the net are not loaded.".format(len(param_not_load)))
return param_not_load

```



2、class Dense() of mindspore/nn/layer/basic.py

```
class Dense(Cell):
r"""
The dense connected layer.

Applies dense connected layer for the input. This layer implements the operation as:

.. math::
\text{outputs} = \text{activation}(\text{inputs} * \text{kernel} + \text{bias}),

where :math:`\text{activation}` is the activation function passed as the activation
argument (if passed in), :math:`\text{kernel}` is a weight matrix with the same
data type as the inputs created by the layer, and :math:`\text{bias}` is a bias vector
with the same data type as the inputs created by the layer (only if has_bias is True).

Args:
in_channels (int): The number of channels in the input space.
out_channels (int): The number of channels in the output space.
weight_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable weight_init parameter. The dtype
is same as input x. The values of str refer to the function `initializer`. Default: 'normal'.
bias_init (Union[Tensor, str, Initializer, numbers.Number]): The trainable bias_init parameter. The dtype is
same as input x. The values of str refer to the function `initializer`. Default: 'zeros'.
has_bias (bool): Specifies whether the layer uses a bias vector. Default: True.
activation (Union[str, Cell, Primitive]): activate function applied to the output of the fully connected layer,
eg. 'ReLU'.Default: None.

Inputs:
- **input** (Tensor) - Tensor of shape :math:`(*, in\_channels)`.

Outputs:
Tensor of shape :math:`(*, out\_channels)`.

Raises:
TypeError: If `in_channels` or `out_channels` is not an int.
TypeError: If `has_bias` is not a bool.
TypeError: If `activation` is not one of str, Cell, Primitive, None.
ValueError: If length of shape of `weight_init` is not equal to 2 or shape[0] of `weight_init`
is not equal to `out_channels` or shape[1] of `weight_init` is not equal to `in_channels`.
ValueError: If length of shape of `bias_init` is not equal to 1
or shape[0] of `bias_init` is not equal to `out_channels`.

Supported Platforms:
``Ascend`` ``GPU`` ``CPU``

Examples:
>>> input = Tensor(np.array([[180, 234, 154], [244, 48, 247]]), mindspore.float32)
>>> net = nn.Dense(3, 4)
>>> output = net(input)
>>> print(output.shape)
(2, 4)
"""

@cell_attr_register(attrs=['has_bias', 'activation'])
def __init__(self,
in_channels,
out_channels,
weight_init='normal',
bias_init='zeros',
has_bias=True,
activation=None,
dtype=mstype.float32):
super(Dense, self).__init__()
self.in_channels = Validator.check_positive_int(in_channels)
self.out_channels = Validator.check_positive_int(out_channels)
self.has_bias = Validator.check_bool(has_bias)
self.reshape = P.Reshape()
self.shape_op = P.Shape()


if isinstance(weight_init, Tensor):
if weight_init.ndim != 2 or weight_init.shape[0] != out_channels or \
weight_init.shape[1] != in_channels:
raise ValueError("Weight init shape error.")
self.weight = Parameter(initializer(weight_init, [out_channels, in_channels], dtype), name="weight")

self.bias = None
if self.has_bias:
if isinstance(bias_init, Tensor):
if bias_init.ndim != 1 or bias_init.shape[0] != out_channels:
raise ValueError("Bias init shape error.")
self.bias = Parameter(initializer(bias_init, [out_channels], dtype), name="bias")
self.bias_add = P.BiasAdd()

self.matmul = P.MatMul(transpose_b=True)
self.activation = get_activation(activation) if isinstance(activation, str) else activation
if activation is not None and not isinstance(self.activation, (Cell, Primitive)):
raise TypeError("The activation must be str or Cell or Primitive,"" but got {}.".format(activation))
self.activation_flag = self.activation is not None

def construct(self, x):
x_shape = self.shape_op(x)
check_dense_input_shape(x_shape)
if len(x_shape) != 2:
x = self.reshape(x, (-1, x_shape[-1]))
x = self.matmul(x, self.weight)
if self.has_bias:
x = self.bias_add(x, self.bias)
if self.activation_flag:
x = self.activation(x)
if len(x_shape) != 2:
out_shape = x_shape[:-1] + (-1,)
x = self.reshape(x, out_shape)
return x

def extend_repr(self):
s = 'input_channels={}, output_channels={}'.format(self.in_channels, self.out_channels)
if self.has_bias:
s += ', has_bias={}'.format(self.has_bias)
if self.activation_flag:
s += ', activation={}'.format(self.activation)
return s
```




+ 231
- 0
panguAlpha_pytorch/README-en.md View File

@@ -0,0 +1,231 @@
This is a pytorch implementation of the Pangu alpha model. It can be inferred, trained, and finetune on the pytorch framework.

Starting point: Mindspore is a new deep learning framework that many people have not used, so converting mindspore models to pytorch models will allow more people to use our Pangu models and allow users to not only experience our large models, but also finetune our models.

Megatron is a large, powerful transformer algorithm library developed by NVIDIA's deep learning applications research team. This port is based on Megatron, and the main work includes converting model files, adding query layer, and modifying model slicing strategy.

# Environments

Supports python >= 3.6, pytorch >= 1.5, cuda >= 10, and nccl >= 2.6 versions.

The official NVIDIA docker image `docker pull nvcr.io/nvidia/pytorch:20.03-py3` is recommended. You need to install [NLTK](https://www.nltk.org/install.html).

You can also download the paired image directly at

```bash
docker pull yands/pangu-alpha-megatron-lm-nvidia-pytorch:20.03.2
```
Using`/opt/conda/bin/python`。

# Model File Download

| Model File | Md5 | Size | Parameter Configuration |
| ------------------------------------------------------------ | -------------------------------- | ---- | ------------------------------------------------------------ |
| [Pangu-alpha_2.6B_fp16_mgt.zip](https://git.openi.org.cn/attachments/72aec03d-6bdb-4652-ac2a-8099db4b0bed) | 28f6dd2ec5d1df2fd22ec5f4a66f51e7 | 4.6G | num-layers : 31<br />hidden-size : 2560<br />num-attention-heads : 32 |
| [Pangu-alpha_13B_fp16_mgt.zip](https://git.openi.org.cn/attachments/937b3e2d-98fb-4871-9691-b32afb5a4d79?type=0) | e6f7a05cbdf8ba8d69e6786e48344f6f | 22G | num-layers : 39<br />hidden-size : 5120<br />num-attention-heads : 40 |

**Note:`num-layers` is equal to `num-layers - 1` in [Pangu](https://git.openi.org.cn/PCL-Platform.Intelligence/PanGu-Alpha)**

Model file directory structure.
```txt
Pangu-alpha_2.6B_fp16_mgt #model directory, --load parameter needs to fill in the path
-- iter_0001000 # iteration number directory
--mp_rank_00 # directory for each GPU when the model is parallel
--model_optim_rng.pt #model file
--latest_checkpointed_iteration.txt #file of iterations of ckpt
```
# Accuracy
There are some differences in the results of the `mean` operator between the two frameworks, resulting in inconsistent output of the `LayerNorm` layer, so the generated results are not exactly consistent. Not solved yet, looking for a solution :-).

On the iflytek task, the few-shot accuracy of pytorch's 2.6b_fp16 model is 0.78929, which is 2 points down from the paper's 0.81.
# Inference

###** Want a quick experience? Check out the [3-Minute Tutorial on inference](3-minus-inference.md)! You can white-knuckle the T4 server!!! **

Currently only the inference script for generating text is available, as follows.

Requires the following configuration parameters.

`-out-seq-length`: the maximum number of tokens to generate

`--top_k`: the larger the value of k, the higher the diversity of generated samples. Different k can be tried.

```bash
python tools/generate_samples_Pangu.py \
--model-parallel-size 1 \
--num-layers 31 \
--hidden-size 2560 \
--load /dataset/Pangu-alpha_2.6B_mgt/ \
--num-attention-heads 32 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--batch-size 1 \
--seq-length 1024 \
--out-seq-length 50 \
--temperature 1.0 \
--vocab-file megatron/tokenizer/bpe_4w_pcl/vocab \
--num-samples 0 \
--top_k 2 \
--finetune
```
Examples:

k=1
```txt
Input is: 书生:羌笛何须怨杨柳,春风不度玉门关。
飞云:(这诗怎么这么耳熟?且过去跟他聊聊如何。)
书生:小兄弟,要不要一起喝一杯?
飞云:你请我呀?你若是请我,我便和你喝一杯;你若不请我,我便一个人去喝。
书生:小兄弟,看你年纪轻轻,不至于这么势利吧?
飞云:
Output is: 我是个读书人,不势利,我只想和你喝酒。
书生:小兄弟,你是不是有什么心事?

Input is: 张无忌拿出屠龙宝刀,手起刀落,周芷若掉了一颗门牙,身旁的赵敏喜极而泣,
Output is: 张无忌道:“你这是做甚么?”赵敏道:“我要你做我的丈夫。”张无忌道:“我不配。”赵敏道:“你不配,我也不配

Input is: 四川的省会是?
Output is:
成都,简称蓉,别称“锦城”,是四川省省会,也是中国西部地区重要的中心城市,中国西部地区重要的经济中心、科技中心、金融中心和交通枢纽,中国西部地区重要的
```
k=2
```txt
Input is: 书生:羌笛何须怨杨柳,春风不度玉门关。
飞云:(这诗怎么这么耳熟?且过去跟他聊聊如何。)
书生:小兄弟,要不要一起喝一杯?
飞云:你请我呀?你若是请我,我便和你喝一杯;你若不请我,我便一个人去喝。
书生:小兄弟,看你年纪轻轻,不至于这么势利吧?
飞云:
Output is: 你这是在质疑本公子的人格?你若是不请我喝酒我就不去喝。你若不请我喝,那我就一个人走路。你要是不请我喝,那你便
```
k = 3
```txt
Input is: 中国和美国和日本和法国和加拿大和澳大利亚的首都分别是哪里?
Output is: 分别在哪个省的哪个市
中国的北京,美国纽约,加拿大多伦多,日本的大阪,法国的里昂,澳大利亚的墨尔本,新西兰的基督城,澳大利亚首都堪
```


# Finetune
Currently only finetune is provided without changing the model model structure and data format, i.e. continue pre-training.
##### 1. Preparing training data

Refer to [data](#data) section

##### 2. Model cutting

The model downloaded above is a single machine inference model, so you need to cut the model first when finetune is performed, and cut it into model parallel models.

Parameters.

`-model-parallel-size`: the number of slices of the original model, here is 1

`--num-mp-model`: the number of models after slicing

`--mp-model-save`: the path to save the model after slicing

```bash
python tools/split_full_model_into_mp_model.py \
--model-parallel-size 1 \
--num-mp-model 2 \
--num-layers 31 \
--hidden-size 2560 \
--load /**ful model path**/ \
--mp-model-save /**mp model save path**/ \
--num-attention-heads 32 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--batch-size 1 \
--seq-length 1024 \
--model-type Pangu \
--vocab-file megatron/tokenizer/bpe_4w_pcl/vocab \
--finetune
```
##### 3. Training

Run the script:

```examples/finetune_pangu_distributed.sh```

##### 4. Model merging

The finished model of finetune is fragmented, so if you want to do single card inference, you need to merge the model first.

Merge script.

`--mp-model-parallel-size`: the number of model slices

`--load`: model save directory

```bash
python tool/merge_mp_partitions.py \
--model-parallel-size 1 \
--mp-model-parallel-size 2 \
--num-layers 31 \
--hidden-size 2560 \
--load /full model ckpt dir/ \
--num-attention-heads 32 \
--max-position-embeddings 1024 \
--tokenizer-type GPT2BPETokenizer \
--fp16 \
--batch-size 1 \
--seq-length 1024 \
--model-type Pangu \
--vocab-file megatron/tokenizer/bpe_4w_pcl/vocab \
--reset-attention-mask \
--finetune \
```



# Training

Reference Script

```bash
examples/pretrain_pangu_distributed_2.6B.sh
```



# Data

##### Generate training data

Reference script: `/tools/preprocess_data_pangu.py`

Store multiple `xxx.txt` files in the train_dataset directory, if there are more training data, it is better to have a uniform file size for each `txt` and separate multiple `txt`s, the size can be 10M a file. If there is traditional text that needs to be converted to simplified, you can use `zhconv`.

The format of each `txt` text is (need blank lines to split different samples)
```txt
sample 1 ***
***
***

sample 2 ***
***
***

sample 2 ***
***
***
```
```bash
python /tools/preprocess_data_pangu.py \
--input /train_dataset/*.txt \
--output-prefix /megatron/dataset/ \
--vocab-file /megatron/tokenizer/bpe_4w_pcl/vocab \
--dataset-impl mmap \
--append-eod
```

The files /path/to/dataset/xxx.idx and /path/to/dataset/xxx.bin will be generated.

Finetune and pre-training require the parameter: `-data-path=/path/to/dataset/xxx`






+ 1
- 1
panguAlpha_pytorch/README.md View File

@@ -186,7 +186,7 @@ python tool/merge_mp_partitions.py \
参考脚本

```bash
examples/pretrain_gpt2_distributed_2.6B.sh
examples/pretrain_pangu_distributed_2.6B.sh
```




+ 8
- 8
panguAlpha_pytorch/examples/finetune_pangu_distributed.sh View File

@@ -1,9 +1,9 @@
#! /bin/bash

export CUDA_VISIBLE_DEVICES="0,1"
#export CUDA_VISIBLE_DEVICES="0,1"
# Runs the "345M" parameter model

GPUS_PER_NODE=2
GPUS_PER_NODE=8
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
@@ -11,17 +11,18 @@ NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

DATA_PATH=/userhome/dataset/megatron/zhanghan/sample_100G_policy_3/Sample100GPolicy3_text_document
#DATA_PATH=/userhome/dataset/megatron/test_vocab4w/text_document
#DATA_PATH=/userhome/dataset/megatron/zhanghan/sample_100G_policy_3/Sample100GPolicy3_text_document
DATA_PATH=/ghome/yands/dataset/megatron/test_vocab4w/text_document
#CHECKPOINT_PATH=/ghome/yands/model/checkPoints/megatron-1.1-pangu
CHECKPOINT_PATH=/userhome/model/checkPoints/megatron-1.1-pangu-2.6B/merged_split
CHECKPOINT_PATH=/root/pangu/model/pangu_fp16_8mp_2b6
#CHECKPOINT_PATH=/userhome/model/checkPoints/megatron-1.1-pangu-2.6B/merged_split
#CHECKPOINT_PATH=/userhome/model/panguAlpha_2.6b_fp16_NumpyCkpt/merged/

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

/opt/conda/bin/python -u -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \
--model-parallel-size 2 \
--model-parallel-size 8 \
--num-layers 31 \
--hidden-size 2560 \
--num-attention-heads 32 \
@@ -33,7 +34,7 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--vocab-file /userhome/pclproject/gpt/Megatron-LM-1.1-Pangu/megatron/tokenizer/bpe_4w_pcl/vocab \
--vocab-file /userhome/pclproject/gpt/Megatron-LM-1.1/megatron/tokenizer/bpe_4w_pcl/vocab \
--merge-file gpt2-merges.txt \
--data-impl mmap \
--split 949,50,1 \
@@ -51,7 +52,6 @@ DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $
--eval-iters 10 \
--attention-dropout 0.1 \
--hidden-dropout 0.1 \
--fp16 \
--reset-attention-mask \
--finetune



+ 1
- 1
panguAlpha_pytorch/examples/generate_text.sh View File

@@ -1,6 +1,6 @@
#!/bin/bash

export CUDA_VISIBLE_DEVICES="4,5,6,7"
export CUDA_VISIBLE_DEVICES="4"

CHECKPOINT_PATH=/userhome/model/checkPoints/megatron-1.1-pangu-2.6B/merged/



+ 52
- 0
panguAlpha_pytorch/examples/pretrain_pangu_distributed_2.6B.sh View File

@@ -0,0 +1,52 @@
#! /bin/bash

export CUDA_VISIBLE_DEVICES="2,3,4,5"

GPUS_PER_NODE=1
# Change for multinode config
MASTER_ADDR=localhost
MASTER_PORT=6000
NNODES=1
NODE_RANK=0
WORLD_SIZE=$(($GPUS_PER_NODE*$NNODES))

DATA_PATH=/userhome/dataset/megatron/zhanghan/sample_100G_policy_3/Sample100GPolicy3_text_document
CHECKPOINT_PATH=/userhome/model/checkPoints/megatron-1.1-pangu-2.6B/merged

DISTRIBUTED_ARGS="--nproc_per_node $GPUS_PER_NODE --nnodes $NNODES --node_rank $NODE_RANK --master_addr $MASTER_ADDR --master_port $MASTER_PORT"

python -m torch.distributed.launch $DISTRIBUTED_ARGS \
pretrain_gpt2.py \
--model-parallel-size 1 \
--num-layers 31 \
--hidden-size 2560 \
--num-attention-heads 32 \
--batch-size 4 \
--seq-length 1024 \
--max-position-embeddings 1024 \
--train-iters 500000 \
--lr-decay-iters 320000 \
--save $CHECKPOINT_PATH \
--load $CHECKPOINT_PATH \
--data-path $DATA_PATH \
--vocab-file /userhome/pclproject/gpt/Megatron-LM-1.1/megatron/tokenizer/bpe_4w_pcl/vocab \
--merge-file gpt2-merges.txt \
--data-impl mmap \
--split 949,50,1 \
--distributed-backend nccl \
--lr 0.00015 \
--lr-decay-style cosine \
--min-lr 1.0e-5 \
--weight-decay 1e-2 \
--clip-grad 1.0 \
--warmup .01 \
--log-interval 100 \
--save-interval 300 \
--eval-interval 1000 \
--eval-iters 10 \
--reset-attention-mask \
--checkpoint-activations



set +x

+ 1
- 1
panguAlpha_pytorch/megatron/arguments.py View File

@@ -165,7 +165,7 @@ def _add_network_size_args(parser):
group.add_argument('--max-position-embeddings', type=int, default=None,
help='Maximum number of position embeddings to use. '
'This is the size of position embedding.')
group.add_argument('--make-vocab-size-divisible-by', type=int, default=128,
group.add_argument('--make-vocab-size-divisible-by', type=int, default=1,
help='Pad the vocab size to be divisible by this value.'
'This is added for computational efficieny reasons.')
group.add_argument('--layernorm-epsilon', type=float, default=1e-5,


+ 2
- 3
panguAlpha_pytorch/megatron/model/language_model.py View File

@@ -122,6 +122,7 @@ class Embedding(MegatronModule):
self.word_embeddings = mpu.VocabParallelEmbedding(
vocab_size, self.hidden_size, init_method=self.init_method)
self._word_embeddings_key = 'word_embeddings'
self.vocab_size = vocab_size

# Position embedding (serial).
self.position_embeddings = torch.nn.Embedding(
@@ -163,11 +164,9 @@ class Embedding(MegatronModule):
self.init_method(self.tokentype_embeddings.weight)

def forward(self, input_ids, position_ids, tokentype_ids=None):

# Embeddings.
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)

embeddings = words_embeddings + position_embeddings
if tokentype_ids is not None:
assert self.tokentype_embeddings is not None
@@ -210,6 +209,7 @@ class Embedding(MegatronModule):
if 'word_embeddings' in key:
state_dict_[key.split('word_embeddings.')[1]] \
= state_dict[key]
state_dict_["weight"] = state_dict_["weight"][:self.vocab_size]
self.word_embeddings.load_state_dict(state_dict_, strict=strict)

# Position embedding.
@@ -447,7 +447,6 @@ class TransformerLanguageModel(MegatronModule):
queryEmbedding_out = self.topQueryEmbedding(position_ids,
tokentype_ids=tokentype_ids)


# Transformer.
transformer_output = self.transformer(embedding_output,
queryEmbedding_out,


+ 8
- 6
panguAlpha_pytorch/megatron/training.py View File

@@ -68,8 +68,7 @@ def pretrain(train_valid_test_dataset_provider, model_provider,
"""

# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron(extra_args_provider=extra_args_provider,
args_defaults=args_defaults)
initialize_megatron(extra_args_provider=extra_args_provider, args_defaults=args_defaults)

args = get_args()
timers = get_timers()
@@ -164,7 +163,7 @@ def get_optimizer(model):
param.model_parallel = False

# Use Adam.
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay)
optimizer = Adam(param_groups, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.95), eps=1e-8)

# Wrap into fp16 optimizer.
if args.fp16:
@@ -237,7 +236,10 @@ def backward_step(optimizer, model, loss):

# Backward pass.
timers('backward-backward').start()
optimizer.zero_grad(set_grads_to_None=True)
if args.fp16:
optimizer.zero_grad(set_grads_to_None=True)
else:
optimizer.zero_grad()
if args.fp16:
optimizer.backward(loss, update_master_grads=False)
else:
@@ -421,8 +423,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler,
torch.distributed.barrier()
time_str = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
rank = torch.distributed.get_rank()
print_rank_0('rank: {} | time: {} | exiting the program at '
'iteration {}'.format(rank, time_str, iteration))
print_rank_0('rank: {} | time: {} | exiting the program at iteration {}'.format(rank, time_str, iteration))
sys.exit()

return iteration, skipped_iters
@@ -532,6 +533,7 @@ def build_train_valid_test_data_iterators(
torch.distributed.broadcast(flags,
mpu.get_model_parallel_src_rank(),
group=mpu.get_model_parallel_group())
args.do_train = flags[0].item()
args.do_valid = flags[1].item()
args.do_test = flags[2].item()


+ 2
- 4
panguAlpha_pytorch/pretrain_gpt2.py View File

@@ -76,8 +76,7 @@ def forward_step(data_iterator, model):

# Get the batch.
timers('batch generator').start()
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(
data_iterator)
tokens, labels, loss_mask, attention_mask, position_ids = get_batch(data_iterator)
timers('batch generator').stop()
# Forward model.
losses = model(tokens, position_ids, attention_mask, labels=labels)
@@ -94,8 +93,7 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
args = get_args()

print_rank_0('> building train, validation, and test datasets '
'for GPT2 ...')
print_rank_0('> building train, validation, and test datasets for GPT2 ...')
train_ds, valid_ds, test_ds = build_train_valid_test_datasets(
data_prefix=args.data_path,
data_impl=args.data_impl,


+ 5
- 7
panguAlpha_pytorch/tools/generate_samples_Pangu.py View File

@@ -140,11 +140,11 @@ def main():
'四川的省会是?',
'上联:春雨润人间,社会和谐万象新',
'''书生:羌笛何须怨杨柳,春风不度玉门关。
飞云:(这诗怎么这么耳熟?且过去跟他聊聊如何。)
书生:小兄弟,要不要一起喝一杯?
飞云:你请我呀?你若是请我,我便和你喝一杯;你若不请我,我便一个人去喝。
书生:小兄弟,看你年纪轻轻,不至于这么势利吧?
飞云:''',
飞云:(这诗怎么这么耳熟?且过去跟他聊聊如何。)
书生:小兄弟,要不要一起喝一杯?
飞云:你请我呀?你若是请我,我便和你喝一杯;你若不请我,我便一个人去喝。
书生:小兄弟,看你年纪轻轻,不至于这么势利吧?
飞云:''',
'张无忌拿出屠龙宝刀,手起刀落,周芷若掉了一颗门牙,身旁的赵敏喜极而泣,',
'人工智能成为国际竞争的新焦点。人工智能是引领未来的战略性技术,世界主要发达国家把发展人工智能作为提升国家竞争力、维护国家安全的重大战略,加紧出台规划和政策,围绕核心技术、顶尖人才、标准规范等强化部署,力图在新一轮国际科技竞争中掌握主导权。当前,',
'中国和美国和日本和法国和加拿大和澳大利亚的首都分别是哪里?']
@@ -174,8 +174,6 @@ def main():
print('Output is:', output_samples[len(sample):], flush=True)
print('@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@@')

return


if __name__ == "__main__":



Loading…
Cancel
Save