#62 add gemma model inference

Open
zoulq wants to merge 1 commits from zoulq/MSAdapterModelZoo-zlq:master into master
  1. +2
    -1
      README.md
  2. +53
    -0
      research/nlp/Gemma/gemma_pytorch/README.md
  3. +14
    -0
      research/nlp/Gemma/gemma_pytorch/gemma/__init__.py
  4. +83
    -0
      research/nlp/Gemma/gemma_pytorch/gemma/config.py
  5. +567
    -0
      research/nlp/Gemma/gemma_pytorch/gemma/model.py
  6. +46
    -0
      research/nlp/Gemma/gemma_pytorch/gemma/tokenizer.py
  7. +2
    -0
      research/nlp/Gemma/gemma_pytorch/requirements.txt
  8. +79
    -0
      research/nlp/Gemma/gemma_pytorch/scripts/run.py
  9. +84
    -0
      research/nlp/Gemma/gemma_pytorch/scripts/run_adapter.py
  10. +60
    -0
      research/nlp/Gemma/gemma_pytorch/setup.py
  11. BIN
      research/nlp/Gemma/gemma_pytorch/tokenizer/tokenizer.model

+ 2
- 1
README.md View File

@@ -8,13 +8,13 @@
```
MSAdapterModelZoo
├── official # official下放置MSAdapter适配的Torch-Code-Style可直接训练模型代码;
│ ├── audio # 音频领域模型;
│ ├── cv # 机器视觉领域模型;
│ ├── nlp # 自然语言处理领域模型;
│ ├── recommend # 推荐领域模型;
├── research # research下放置MSAdapter适配的Torch-Code-Style可推理模型代码;
│ ├── cv # 机器视觉领域模型;
│ ├── nlp # 自然语言处理领域模型;
│ ├── audio # 音频领域模型;
├── example # example下放置MSAdapte适配的不同功能场景下的样例;
│ ├── distributed # 分布式并行训练样例;
└── README.md
@@ -124,6 +124,7 @@ MSAdapterModelZoo
| 图像 | 图像分类 | vit |
| 图像 | 图像分类 | xception |
| 文本 | 自然语言理解 | transformer |
| 文本 | 自然语言理解 | Gemma |
| 音频 | 智能外呼 | conformer |

## 版本说明


+ 53
- 0
research/nlp/Gemma/gemma_pytorch/README.md View File

@@ -0,0 +1,53 @@
# Gemma in PyTorch

**Gemma** is a family of lightweight, state-of-the art open models built from research and technology used to create Google Gemini models. They are text-to-text, decoder-only large language models, available in English, with open weights, pre-trained variants, and instruction-tuned variants. For more details, please check out the following links:

* [Gemma on Google AI](https://ai.google.dev/gemma)
* [Gemma on Kaggle](https://www.kaggle.com/models/google/gemma)
* [Gemma on Vertex AI Model Garden](https://console.cloud.google.com/vertex-ai/publishers/google/model-garden/335)

This is the official PyTorch implementation of Gemma models. We provide model and inference implementations using both PyTorch and PyTorch/XLA, and support running inference on CPU, GPU and TPU.

## Download Gemma model checkpoint

You can find the model checkpoints on Kaggle
[here](https://www.kaggle.com/models/google/gemma/frameworks/pyTorch).

Alternatively, you can find the model checkpoints on the Hugging Face Hub [here](https://huggingface.co/models?other=gemma_torch). To download the models, go the the model repository of the model of interest and click the `Files and versions` tab, and download the model and tokenizer files. For programmatic downloading, if you have `huggingface_hub`
installed, you can also run:

```
huggingface-cli download google/gemma-7b-it-pytorch
```

Note that you can choose between the 2B, 7B, 7B int8 quantized variants.

```
VARIANT=<2b or 7b>
CKPT_PATH=<Insert ckpt path here>
```

## Environment Requirements

Only CPU now, distributed parallelism on GPU/Ascend is to be supported in the future.


## Model Source

In this case, the source of the script corresponding to the implementation of PyTorch is as follows:

https://github.com/google/gemma_pytorch.git


## Quick Start
This case shows the Gemma model implemented by PyTorch and the version converted based on MindTorch, you can start evaluation as follows:
```python
# Install
python setup.py install

# PyTorch
python scripts/run.py --ckpt [--save_path]

#MindTorch
python scripts/run_adapter.py --ckpt [--save_path]
```

+ 14
- 0
research/nlp/Gemma/gemma_pytorch/gemma/__init__.py View File

@@ -0,0 +1,14 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


+ 83
- 0
research/nlp/Gemma/gemma_pytorch/gemma/config.py View File

@@ -0,0 +1,83 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Gemma model config."""

import dataclasses
import torch
from typing import Optional


# Keep a mapping from dtype strings to the supported torch dtypes.
_STR_DTYPE_TO_TORCH_DTYPE = dict({
'float16': torch.float16,
'float': torch.float32,
'float32': torch.float32,
'bfloat16': torch.bfloat16,
})


@dataclasses.dataclass
class GemmaConfig:
# The number of tokens in the vocabulary.
vocab_size: int = 256000
# The maximum sequence length that this model might ever be used with.
max_position_embeddings: int = 8192
# The number of blocks in the model.
num_hidden_layers: int = 28
# The number of attention heads used in the attention layers of the model.
num_attention_heads: int = 16
# The number of key-value heads for implementing attention.
num_key_value_heads: int = 16
# The hidden size of the model.
hidden_size: int = 3072
# The dimension of the MLP representations.
intermediate_size: int = 24576
# The number of head dimensions.
head_dim: int = 256
# The epsilon used by the rms normalization layers.
rms_norm_eps: float = 1e-6
# The dtype of the weights.
dtype: str = 'bfloat16'
# Whether a quantized version of the model is used.
quant: bool = False
# The path to the model tokenizer.
tokenizer: Optional[str] = 'tokenizer/tokenizer.model'

def get_dtype(self) -> Optional[torch.dtype]:
"""Gets the torch dtype from the config dtype string."""
return _STR_DTYPE_TO_TORCH_DTYPE.get(self.dtype, None)


def get_config_for_7b() -> GemmaConfig:
return GemmaConfig()


def get_config_for_2b() -> GemmaConfig:
return GemmaConfig(
num_hidden_layers=18,
num_attention_heads=8,
num_key_value_heads=1,
hidden_size=2048,
intermediate_size=16384
)


def get_model_config(variant: str) -> GemmaConfig:
if variant == '7b':
return get_config_for_7b()
elif variant == '2b':
return get_config_for_2b()
return ValueError(f'Invalid variant {variant}. Supported variants are "2b"'
'and "7b"')

+ 567
- 0
research/nlp/Gemma/gemma_pytorch/gemma/model.py View File

@@ -0,0 +1,567 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only Gemma model implementation."""

import re
import torch
from torch import nn
import torch.nn.functional as F
from typing import Any, List, Optional, Sequence, Tuple, Union

from gemma import config as gemma_config
from gemma import tokenizer


class Sampler(nn.Module):

def __init__(self, vocab_size: int):
super().__init__()
self.vocab_size = vocab_size

@torch.no_grad()
def forward(
self,
embedding: torch.Tensor,
hidden_states: torch.Tensor,
output_positions: torch.Tensor,
temperatures: Union[torch.Tensor, None],
top_ps: torch.Tensor,
top_ks: torch.Tensor,
embedding_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
# Select the last element for each sequence.
# (batch_size, input_len, hidden_size) -> (batch_size, hidden_size)
hidden_states = hidden_states.index_select(
1, output_positions).squeeze(dim=1)
logits = torch.matmul(hidden_states, embedding.t())
if embedding_bias is not None:
logits += embedding_bias

if temperatures is None:
return torch.argmax(logits, dim=-1).squeeze(dim=-1)

# Apply temperature scaling.
logits.div_(temperatures.unsqueeze(dim=1))

# Calculate probabilities with softmax.
probs = torch.softmax(logits, dim=-1, dtype=torch.float)
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)

# Apply top-p, top-k.
probs_sum = torch.cumsum(probs_sort, dim=-1)
top_ps_mask = (probs_sum - probs_sort) > top_ps.unsqueeze(dim=1)
probs_sort = torch.where(top_ps_mask, 0, probs_sort)

top_ks_mask = torch.arange(probs_idx.shape[-1],
device=probs_idx.device)
top_ks_mask = top_ks_mask.expand(probs_idx.shape[0], -1)
top_ks_mask = top_ks_mask >= top_ks.unsqueeze(dim=1)
probs_sort = torch.where(top_ks_mask, 0, probs_sort)

# Re-normalization.
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
probs = torch.gather(probs_sort,
dim=-1,
index=torch.argsort(probs_idx, dim=-1))

next_token_ids = torch.multinomial(probs,
num_samples=1,
replacement=True).squeeze(dim=-1)
return next_token_ids


def precompute_freqs_cis(dim: int,
end: int,
theta: float = 10000.0) -> torch.Tensor:
"""Precomputes the frequency cis."""
freqs = 1.0 / (theta**(torch.arange(0, dim, 2)[:(dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
return freqs_cis


def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
"""Applies the rotary embedding to the query and key tensors."""
x_ = torch.view_as_complex(
torch.stack(torch.chunk(x.transpose(1, 2).float(), 2, dim=-1),
dim=-1))
x_out = torch.view_as_real(x_ * freqs_cis).type_as(x)
x_out = torch.cat(torch.chunk(x_out, 2, dim=-1), dim=-2)
x_out = x_out.reshape(x_out.shape[0], x_out.shape[1], x_out.shape[2],
-1).transpose(1, 2)
return x_out


class Linear(nn.Module):

def __init__(self, in_features: int, out_features: int, quant: bool):
super().__init__()
if quant:
self.weight = nn.Parameter(
torch.empty((out_features, in_features), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = nn.Parameter(torch.Tensor(out_features))
else:
self.weight = nn.Parameter(
torch.empty((out_features, in_features)),
requires_grad=False,
)
self.quant = quant

def forward(self, x):
weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output = F.linear(x, weight)
return output


class Embedding(nn.Module):

def __init__(self, num_embeddings: int, embedding_dim: int, quant: bool):
super().__init__()
if quant:
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim), dtype=torch.int8),
requires_grad=False,
)
self.weight_scaler = nn.Parameter(torch.Tensor(num_embeddings))
else:
self.weight = nn.Parameter(
torch.empty((num_embeddings, embedding_dim)),
requires_grad=False,
)
self.quant = quant

def forward(self, x):
weight = self.weight
if self.quant:
weight = weight * self.weight_scaler.unsqueeze(-1)
output = F.embedding(x, weight)
return output


class RMSNorm(torch.nn.Module):

def __init__(
self,
dim: int,
eps: float = 1e-6,
add_unit_offset: bool = True,
):
super().__init__()
self.eps = eps
self.add_unit_offset = add_unit_offset
self.weight = nn.Parameter(torch.zeros(dim))

def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):
x = self._norm(x.float()).type_as(x)
if self.add_unit_offset:
output = x * (1 + self.weight)
else:
output = x * self.weight
return output


class GemmaMLP(nn.Module):

def __init__(
self,
hidden_size: int,
intermediate_size: int,
quant: bool,
):
super().__init__()
self.gate_proj = Linear(hidden_size, intermediate_size, quant)
self.up_proj = Linear(hidden_size, intermediate_size, quant)
self.down_proj = Linear(intermediate_size, hidden_size, quant)

def forward(self, x):
gate = self.gate_proj(x)
gate = F.gelu(gate, approximate="tanh")
up = self.up_proj(x)
fuse = gate * up
outputs = self.down_proj(fuse)
return outputs


class GemmaAttention(nn.Module):

def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
quant: bool,
):
super().__init__()

self.num_heads = num_heads
self.num_kv_heads = num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.hidden_size = hidden_size
self.head_dim = head_dim

self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim

self.scaling = self.head_dim**-0.5

self.qkv_proj = Linear(
self.hidden_size,
(self.num_heads + 2 * self.num_kv_heads) * self.head_dim,
quant=quant)
self.o_proj = Linear(
self.num_heads * self.head_dim,
self.hidden_size,
quant=quant)

def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
hidden_states_shape = hidden_states.shape
assert len(hidden_states_shape) == 3

batch_size, input_len, _ = hidden_states_shape

qkv = self.qkv_proj(hidden_states)
xq, xk, xv = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)

xq = xq.view(batch_size, -1, self.num_heads, self.head_dim)
xk = xk.view(batch_size, -1, self.num_kv_heads, self.head_dim)
xv = xv.view(batch_size, -1, self.num_kv_heads, self.head_dim)

# Positional embedding.
xq = apply_rotary_emb(xq, freqs_cis=freqs_cis)
xk = apply_rotary_emb(xk, freqs_cis=freqs_cis)

# Write new kv cache.
# [batch_size, input_len, n_local_kv_heads, head_dim]
k_cache, v_cache = kv_cache
k_cache.index_copy_(1, kv_write_indices, xk)
v_cache.index_copy_(1, kv_write_indices, xv)

key = k_cache
value = v_cache
if self.num_kv_heads != self.num_heads:
# [batch_size, max_seq_len, n_local_heads, head_dim]
key = torch.repeat_interleave(key, self.num_queries_per_kv, dim=2)
value = torch.repeat_interleave(value,
self.num_queries_per_kv,
dim=2)

# [batch_size, n_local_heads, input_len, head_dim]
q = xq.transpose(1, 2)
# [batch_size, n_local_heads, max_seq_len, head_dim]
k = key.transpose(1, 2)
v = value.transpose(1, 2)

# [batch_size, n_local_heads, input_len, max_seq_len]
scores = torch.matmul(q, k.transpose(2, 3)) * self.scaling
scores = scores + mask
scores = F.softmax(scores.float(), dim=-1).type_as(q)

# [batch_size, n_local_heads, input_len, head_dim]
output = torch.matmul(scores, v)

# [batch_size, input_len, hidden_dim]
output = (output.transpose(1, 2).contiguous().view(
batch_size, input_len, -1))
output = self.o_proj(output)
return output


class GemmaDecoderLayer(nn.Module):

def __init__(
self,
config: gemma_config.GemmaConfig,
):
super().__init__()
self.self_attn = GemmaAttention(
hidden_size=config.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
head_dim=config.head_dim,
quant=config.quant,
)
self.mlp = GemmaMLP(
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant=config.quant,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_cache: Tuple[torch.Tensor, torch.Tensor],
mask: torch.Tensor,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
hidden_states = self.self_attn(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_cache,
mask=mask,
)
hidden_states = residual + hidden_states

# MLP
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states

return hidden_states


class GemmaModel(nn.Module):

def __init__(self, config: gemma_config.GemmaConfig):
super().__init__()
self.config = config
self.vocab_size = config.vocab_size

self.layers = nn.ModuleList()
for _ in range(config.num_hidden_layers):
self.layers.append(GemmaDecoderLayer(config))
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

def forward(
self,
hidden_states: torch.Tensor,
freqs_cis: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
) -> torch.Tensor:
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states = layer(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_cache=kv_caches[i],
mask=mask,
)
hidden_states = self.norm(hidden_states)
return hidden_states


class GemmaForCausalLM(nn.Module):

def __init__(
self,
config: gemma_config.GemmaConfig,
):
super().__init__()
self.config = config
assert config.hidden_size % config.num_attention_heads == 0

max_seq_len = config.max_position_embeddings
head_dim = config.head_dim
vocab_size = config.vocab_size

self.tokenizer = tokenizer.Tokenizer(config.tokenizer)
self.embedder = Embedding(vocab_size, config.hidden_size, config.quant)
self.model = GemmaModel(config)
self.sampler = Sampler(vocab_size)

# Pre-compute rotary embedding table.
rope_theta = getattr(config, 'rope_theta', 10000)
freqs_cis = precompute_freqs_cis(head_dim,
max_seq_len * 2,
theta=rope_theta)
self.register_buffer('freqs_cis', freqs_cis)

@torch.no_grad()
def forward(
self,
input_token_ids: torch.Tensor,
input_positions: torch.Tensor,
kv_write_indices: torch.Tensor,
kv_caches: List[Tuple[torch.Tensor, torch.Tensor]],
mask: torch.Tensor,
output_positions: torch.Tensor,
temperatures: Union[torch.Tensor, None],
top_ps: torch.Tensor,
top_ks: torch.Tensor,
**kwargs,
) -> torch.Tensor:
freqs_cis = self.freqs_cis.index_select(0, input_positions)
kv_write_indices = input_positions

# [batch_size, input_len, hidden_size]
hidden_states = self.embedder(input_token_ids)
# Gemma normalizes the embedding by sqrt(hidden_size).
hidden_states = hidden_states * (self.config.hidden_size**0.5)

hidden_states = self.model(
hidden_states=hidden_states,
freqs_cis=freqs_cis,
kv_write_indices=kv_write_indices,
kv_caches=kv_caches,
mask=mask,
)
embedder_weight = self.embedder.weight
if self.config.quant:
embedder_weight = (
embedder_weight * self.embedder.weight_scaler.unsqueeze(-1))
next_tokens = self.sampler(
embedding=embedder_weight,
hidden_states=hidden_states,
output_positions=output_positions,
temperatures=temperatures,
top_ps=top_ps,
top_ks=top_ks,
)
return next_tokens

def generate(
self,
prompts: Union[str, Sequence[str]],
device: Any,
output_len: int = 100,
temperature: Union[float, None] = 0.95,
top_p: float = 1.0,
top_k: int = 100,
) -> Union[str, Sequence[str]]:
"""Generates responses for given prompts using Gemma model."""
# If a single prompt is provided, treat it as a batch of 1.
is_str_prompt = isinstance(prompts, str)
if is_str_prompt:
prompts = [prompts]

batch_size = len(prompts)
prompt_tokens = [self.tokenizer.encode(prompt) for prompt in prompts]
min_prompt_len = min(len(p) for p in prompt_tokens)
max_prompt_len = max(len(p) for p in prompt_tokens)
max_seq_len = max_prompt_len + output_len
assert max_seq_len <= self.config.max_position_embeddings

# build KV caches
kv_caches = []
for _ in range(self.config.num_hidden_layers):
size = (batch_size, max_seq_len, self.config.num_key_value_heads,
self.config.head_dim)
dtype = self.config.get_dtype()
k_cache = torch.zeros(size=size, dtype=dtype, device=device)
v_cache = torch.zeros(size=size, dtype=dtype, device=device)
kv_caches.append((k_cache, v_cache))

# prepare inputs
token_ids_tensor = torch.full((batch_size, max_seq_len),
self.tokenizer.pad_id, dtype=torch.int64)
input_token_ids_tensor = torch.full((batch_size, min_prompt_len),
self.tokenizer.pad_id,
dtype=torch.int64)
for i, p in enumerate(prompt_tokens):
token_ids_tensor[i, :len(p)] = torch.tensor(p)
input_token_ids_tensor[i, :min_prompt_len] = torch.tensor(
p[:min_prompt_len])
token_ids_tensor = token_ids_tensor.to(device)
input_token_ids_tensor = input_token_ids_tensor.to(device)
prompt_mask_tensor = token_ids_tensor != self.tokenizer.pad_id
input_positions_tensor = torch.arange(0, min_prompt_len,
dtype=torch.int64).to(device)
mask_tensor = torch.full((1, 1, max_seq_len, max_seq_len),
-2.3819763e38).to(torch.float)
mask_tensor = torch.triu(mask_tensor, diagonal=1).to(device)
curr_mask_tensor = mask_tensor.index_select(2, input_positions_tensor)
output_positions_tensor = torch.LongTensor([min_prompt_len - 1]).to(
device)
temperatures_tensor = None if not temperature else torch.FloatTensor(
[temperature] * batch_size).to(device)
top_ps_tensor = torch.FloatTensor([top_p] * batch_size).to(device)
top_ks_tensor = torch.LongTensor([top_k] * batch_size).to(device)
output_index = torch.tensor(min_prompt_len, dtype=torch.int64).to(
device)

# Prefill up to min_prompt_len tokens, then treat other prefill as
# decode and ignore output.
for i in range(max_seq_len - min_prompt_len):
next_token_ids = self(
input_token_ids=input_token_ids_tensor,
input_positions=input_positions_tensor,
kv_write_indices=None,
kv_caches=kv_caches,
mask=curr_mask_tensor,
output_positions=output_positions_tensor,
temperatures=temperatures_tensor,
top_ps=top_ps_tensor,
top_ks=top_ks_tensor,
)

curr_prompt_mask = prompt_mask_tensor.index_select(
1, output_index).squeeze(dim=1)
curr_token_ids = token_ids_tensor.index_select(
1, output_index).squeeze(dim=1)
output_token_ids = torch.where(curr_prompt_mask, curr_token_ids,
next_token_ids).unsqueeze(dim=1)
token_ids_tensor.index_copy_(1, output_index, output_token_ids)

input_token_ids_tensor = output_token_ids
input_positions_tensor = output_index.unsqueeze(dim=-1)
curr_mask_tensor = mask_tensor.index_select(2,
input_positions_tensor)
output_positions_tensor = torch.tensor(0, dtype=torch.int64).to(
device)
output_index = output_index + 1

# Detokenization.
token_ids = token_ids_tensor.tolist()
results = []
for i, tokens in enumerate(token_ids):
trimmed_output = tokens[len(prompt_tokens[i]):len(prompt_tokens[i])
+ output_len]
if self.tokenizer.eos_id in trimmed_output:
eos_index = trimmed_output.index(self.tokenizer.eos_id)
trimmed_output = trimmed_output[:eos_index]
results.append(self.tokenizer.decode(trimmed_output))

# If a string was provided as input, return a string as output.
return results[0] if is_str_prompt else results

def load_weights(self, model_path: str):
self.load_state_dict(
torch.load(
model_path,
# model_path, mmap=True, weights_only=True, # Only for PyTorch new version
)['model_state_dict'],
strict=False,
)

+ 46
- 0
research/nlp/Gemma/gemma_pytorch/gemma/tokenizer.py View File

@@ -0,0 +1,46 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import List, Optional

from sentencepiece import SentencePieceProcessor


class Tokenizer:

def __init__(self, model_path: Optional[str]):
# Reload tokenizer.
assert os.path.isfile(model_path), model_path
self.sp_model = SentencePieceProcessor(model_file=model_path)

# BOS / EOS token IDs.
self.n_words: int = self.sp_model.vocab_size()
self.bos_id: int = self.sp_model.bos_id()
self.eos_id: int = self.sp_model.eos_id()
self.pad_id: int = self.sp_model.pad_id()
assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()

def encode(self, s: str, bos: bool = True, eos: bool = False) -> List[int]:
"""Converts a string into a list of tokens."""
assert isinstance(s, str)
t = self.sp_model.encode(s)
if bos:
t = [self.bos_id] + t
if eos:
t = t + [self.eos_id]
return t

def decode(self, t: List[int]) -> str:
"""Converts a list of tokens into a string."""
return self.sp_model.decode(t)

+ 2
- 0
research/nlp/Gemma/gemma_pytorch/requirements.txt View File

@@ -0,0 +1,2 @@
numpy == 1.24.4
sentencepiece == 0.1.99

+ 79
- 0
research/nlp/Gemma/gemma_pytorch/scripts/run.py View File

@@ -0,0 +1,79 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import random

import numpy as np
import torch

from gemma import config
from gemma import model as gemma_model


@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)


def main(args):
# Construct the model config.
model_config = config.get_model_config(args.variant)
model_config.dtype = "float32" if args.device == "cpu" else "float16"
model_config.quant = args.quant

# Seed random.
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# Create the model and load the weights.
device = torch.device(args.device)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma_model.GemmaForCausalLM(model_config)
model.load_weights(args.ckpt)
model = model.to(device).eval()
print("Model loading done")

# Generate the response.
result = model.generate(args.prompt, device, output_len=args.output_len)

# Print the prompts and results.
print('======================================')
print(f'PROMPT: {args.prompt}')
print(f'RESULT: {result}')
print('======================================')


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "7b"])
parser.add_argument("--device",
type=str,
default="cpu",
choices=["cpu", "cuda"])
parser.add_argument("--output_len", type=int, default=100)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()

main(args)

+ 84
- 0
research/nlp/Gemma/gemma_pytorch/scripts/run_adapter.py View File

@@ -0,0 +1,84 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import contextlib
import random

import numpy as np

from mindtorch.tools import mstorch_enable
import mindspore as ms
ms.set_context(device_target="CPU")

import torch

from gemma import config
from gemma import model as gemma_model


@contextlib.contextmanager
def _set_default_tensor_type(dtype: torch.dtype):
"""Sets the default torch dtype to the given dtype."""
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(torch.float)


def main(args):
# Construct the model config.
model_config = config.get_model_config(args.variant)
model_config.dtype = "float32" if args.device == "cpu" else "float16"
model_config.quant = args.quant

# Seed random.
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)

# Create the model and load the weights.
device = torch.device(args.device)
with _set_default_tensor_type(model_config.get_dtype()):
model = gemma_model.GemmaForCausalLM(model_config)
model.load_weights(args.ckpt)
model = model.to(device).eval()
print("Model loading done")

# Generate the response.
result = model.generate(args.prompt, device, output_len=args.output_len)

# Print the prompts and results.
print('======================================')
print(f'PROMPT: {args.prompt}')
print(f'RESULT: {result}')
print('======================================')


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, required=True)
parser.add_argument("--variant",
type=str,
default="2b",
choices=["2b", "7b"])
parser.add_argument("--device",
type=str,
default="cpu",
choices=["cpu", "cuda"])
parser.add_argument("--output_len", type=int, default=100)
parser.add_argument("--seed", type=int, default=12345)
parser.add_argument("--quant", action='store_true')
parser.add_argument("--prompt", type=str, default="The meaning of life is")
args = parser.parse_args()

main(args)

+ 60
- 0
research/nlp/Gemma/gemma_pytorch/setup.py View File

@@ -0,0 +1,60 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os
from typing import List

import setuptools

ROOT_DIR = os.path.dirname(__file__)


def get_path(*filepath) -> str:
return os.path.join(ROOT_DIR, *filepath)


def read_readme() -> str:
"""Read the README file."""
return io.open(get_path("README.md"), "r", encoding="utf-8").read()


def get_requirements() -> List[str]:
"""Get Python package dependencies from requirements.txt."""
with open(get_path("requirements.txt")) as f:
requirements = f.read().strip().split("\n")
return requirements


setuptools.setup(
name="gemma",
version="0.1",
author="Gemma contributors",
license="Apache 2.0",
description=("Gemma model implementation"),
long_description=read_readme(),
long_description_content_type="text/markdown",
classifiers=[
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"License :: OSI Approved :: Apache Software License",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
],
packages=setuptools.find_packages(exclude=("benchmarks", "docs",
"examples", "tests")),
python_requires=">=3.8",
install_requires=get_requirements(),
)

BIN
research/nlp/Gemma/gemma_pytorch/tokenizer/tokenizer.model View File


Loading…
Cancel
Save