|
- import gc
- import json
- import os
- from typing import Dict, Union, Optional, Tuple
-
- import torch
- from torch.nn import Module
- from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer
- from transformers.generation.logits_process import LogitsProcessor
-
-
- def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
- # transformer.word_embeddings 占用1层
- # transformer.final_layernorm 和 lm_head 占用1层
- # transformer.layers 占用 28 层
- # 总共30层分配到num_gpus张卡上
- num_trans_layers = 28
- per_gpu_layers = 30 / num_gpus
-
- # bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
- # windows下 model.device 会被设置成 transformer.word_embeddings.device
- # linux下 model.device 会被设置成 lm_head.device
- # 在调用chat或者stream_chat时,input_ids会被放到model.device上
- # 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
- # 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
- # 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
- # 仅此处做少许修改以支持ChatGLM3
- device_map = {
- 'transformer.embedding.word_embeddings': 0,
- 'transformer.encoder.final_layernorm': 0,
- 'transformer.output_layer': 0,
- 'transformer.rotary_pos_emb': 0,
- 'lm_head': 0
- }
-
- used = 2
- gpu_target = 0
- for i in range(num_trans_layers):
- if used >= per_gpu_layers:
- gpu_target += 1
- used = 0
- assert gpu_target < num_gpus
- device_map[f'transformer.encoder.layers.{i}'] = gpu_target
- used += 1
-
- return device_map
-
-
- def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
- device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
- if num_gpus < 2 and device_map is None:
- model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
- else:
- from accelerate import dispatch_model
-
- model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
-
- if device_map is None:
- device_map = auto_configure_device_map(num_gpus)
-
- model = dispatch_model(model, device_map=device_map)
-
- return model
-
-
- class InvalidScoreLogitsProcessor(LogitsProcessor):
- def __call__(
- self, input_ids: torch.LongTensor, scores: torch.FloatTensor
- ) -> torch.FloatTensor:
- if torch.isnan(scores).any() or torch.isinf(scores).any():
- scores.zero_()
- scores[..., 5] = 5e4
- return scores
-
-
- def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
- content = ""
- for response in output.split("<|assistant|>"):
- metadata, content = response.split("\n", maxsplit=1)
- if not metadata.strip():
- content = content.strip()
- content = content.replace("[[训练时间]]", "2023年")
- else:
- if use_tool:
- content = "\n".join(content.split("\n")[1:-1])
-
- def tool_call(**kwargs):
- return kwargs
-
- parameters = eval(content)
- content = {
- "name": metadata.strip(),
- "arguments": json.dumps(parameters, ensure_ascii=False)
- }
- else:
- content = {
- "name": metadata.strip(),
- "content": content
- }
- return content
-
-
- @torch.inference_mode()
- def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
- messages = params["messages"]
- functions = params["functions"]
- temperature = float(params.get("temperature", 1.0))
- repetition_penalty = float(params.get("repetition_penalty", 1.0))
- top_p = float(params.get("top_p", 1.0))
- max_new_tokens = int(params.get("max_tokens", 256))
- max_length = params.get("max_length", None)
- echo = params.get("echo", True)
-
- messages = process_chatglm_messages(messages, functions=functions)
- query, role = messages[-1]["content"], messages[-1]["role"]
-
- inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
- inputs = inputs.to(model.device)
- input_echo_len = len(inputs["input_ids"][0])
-
- if input_echo_len >= model.config.seq_length:
- print(f"Input length larger than {model.config.seq_length}")
-
- if max_length is None:
- max_length = min(max_new_tokens + input_echo_len, model.config.seq_length)
-
- eos_token_id = [
- tokenizer.eos_token_id,
- tokenizer.get_command("<|user|>"),
- ]
-
- gen_kwargs = {
- "max_length": max_length,
- "do_sample": True if temperature > 1e-5 else False,
- "top_p": top_p,
- "repetition_penalty": repetition_penalty,
- "logits_processor": [InvalidScoreLogitsProcessor()],
- }
- if temperature > 1e-5:
- gen_kwargs["temperature"] = temperature
-
- total_len = 0
- for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):
- total_ids = total_ids.tolist()[0]
- total_len = len(total_ids)
- if echo:
- output_ids = total_ids[:-1]
- else:
- output_ids = total_ids[input_echo_len:-1]
-
- response = tokenizer.decode(output_ids)
- if response and response[-1] != "�":
- response, stop_found = apply_stopping_strings(response, ["<|observation|>"])
-
- yield {
- "text": response,
- "usage": {
- "prompt_tokens": input_echo_len,
- "completion_tokens": total_len - input_echo_len,
- "total_tokens": total_len,
- },
- "finish_reason": "function_call" if stop_found else None,
- }
-
- if stop_found:
- break
-
- # Only last stream result contains finish_reason, we set finish_reason as stop
- ret = {
- "text": response,
- "usage": {
- "prompt_tokens": input_echo_len,
- "completion_tokens": total_len - input_echo_len,
- "total_tokens": total_len,
- },
- "finish_reason": "stop",
- }
- yield ret
-
- gc.collect()
- torch.cuda.empty_cache()
-
-
- def process_chatglm_messages(messages, functions=None):
- _messages = messages
- messages = []
-
- if functions:
- messages.append(
- {
- "role": "system",
- "content": "Answer the following questions as best as you can. You have access to the following tools:",
- "tools": functions
- }
- )
-
- for m in _messages:
- role, content, func_call = m.role, m.content, m.function_call
- if role == "function":
- messages.append(
- {
- "role": "observation",
- "content": content
- }
- )
-
- elif role == "assistant" and func_call is not None:
- for response in content.split("<|assistant|>"):
- metadata, sub_content = response.split("\n", maxsplit=1)
- messages.append(
- {
- "role": role,
- "metadata": metadata,
- "content": sub_content.strip()
- }
- )
- else:
- messages.append({"role": role, "content": content})
- return messages
-
-
- def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
- for response in generate_stream_chatglm3(model, tokenizer, params):
- pass
- return response
-
-
- def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]:
- stop_found = False
- for string in stop_strings:
- idx = reply.find(string)
- if idx != -1:
- reply = reply[:idx]
- stop_found = True
- break
-
- if not stop_found:
- # If something like "\nYo" is generated just before "\nYou: is completed, trim it
- for string in stop_strings:
- for j in range(len(string) - 1, 0, -1):
- if reply[-j:] == string[:j]:
- reply = reply[:-j]
- break
- else:
- continue
-
- break
-
- return reply, stop_found
|