|
- from json import dumps
- from os import getenv
- from typing import Dict
-
- from fastapi import FastAPI
- from fastapi.responses import HTMLResponse, JSONResponse, StreamingResponse
- from tiktoken import get_encoding
- from uvicorn import run
-
- from utils import *
-
- # 包括对话模型/补全模型/嵌入模型,路径参考配置文件
- # 对话模型类别填写chat,补全模型类别填写completion,嵌入模型类别填写embedding
- chat_model: ChatModel = init_llm(task="chat")
- completion_model: CompletionModel = init_llm(task="completion")
- embedding_model: EmbeddingModel = init_llm(task="embedding")
-
- # 初始化接口服务,指定路由前缀
- # AI协作平台上从环境变量中获取路由前缀,其他情况下忽略
- prefix = getenv(key="OPENI_GRADIO_URL", default="") # noqa
- app = FastAPI()
-
-
- def sse(line: Union[str, Dict]) -> str:
- """Server Sent Events for stream"""
- return "data: {}\n\n".format(dumps(obj=line, ensure_ascii=False) if isinstance(line, dict) else line)
-
-
- @app.get(path=prefix + "/", response_class=HTMLResponse)
- def homepage():
- """接口服务首页"""
- return open(file="templates/Infinity.html", mode="r", encoding="utf-8").read()
-
-
- @app.post(path=prefix + "/v1/chat/completions", response_model=None)
- def chat(args: Dict) -> Union[StreamingResponse, Dict]:
- """Chat接口"""
- if chat_model is None:
- return {}
- req = ChatRequestSchema().load(args)
- if req["stream"]:
- # 流式响应
- return StreamingResponse(content=chat_stream(req=req), media_type="text/event-stream")
- else:
- # 非流式响应
- return chat_generate(req=req)
-
-
- @app.post(path=prefix + "/v1/completions", response_model=None)
- def completions(args: Dict) -> Union[StreamingResponse, Dict]:
- """Completions接口"""
- if completion_model is None:
- return {}
- req = CompletionsRequestSchema().load(args)
- if req["stream"]:
- # 流式响应
- return StreamingResponse(content=completions_stream(req=req), media_type="text/event-stream")
- else:
- # 非流式响应
- return completions_generate(req=req)
-
-
- @app.post(path=prefix + "/v1/embeddings", response_class=JSONResponse)
- def embeddings(args: Dict) -> Dict:
- """Embeddings接口"""
- if embedding_model is None:
- return {}
- req = EmbeddingsRequestSchema().load(args)
- inputs = req["input"] if isinstance(req["input"], list) else [req["input"]]
- encoding = get_encoding(encoding_name="cl100k_base")
- data = [{"index": index, "embedding": embedding_model.embedding(sentence=text)} for index, text in enumerate(inputs)]
- usage = {
- "prompt_tokens": sum(len(text.split()) for text in inputs),
- "total_tokens": sum(len(encoding.encode(text=text)) for text in inputs)
- }
- return EmbeddingsResponseSchema().dump({"model": embedding_model.name, "data": data, "usage": usage})
-
-
- def chat_generate(req: Dict) -> Dict:
- """输出模型回答"""
- message = ChatMessageSchema().dump({"role": "assistant", "content": chat_model.generate(conversation=req["messages"])})
- choice = ChatChoiceSchema().dump({"index": 0, "message": message, "finish_reason": "stop"})
- return ChatResponseSchema().dump({"model": chat_model.name, "choices": [choice]})
-
-
- def chat_stream(req: Dict):
- """流式输出模型回答"""
- delta = ChatMessageSchema().dump({"role": "assistant", "content": ""})
- choice = ChatChoiceChunkSchema().dump({"index": 0, "delta": delta, "finish_reason": None})
- yield sse(line=ChatResponseChunkSchema().dump({"model": chat_model.name, "choices": [choice]}))
- # 多轮对话,字符型流式输出
- for index, answer in enumerate(chat_model.stream(conversation=req["messages"])):
- delta = ChatMessageSchema().dump({"role": "assistant", "content": answer})
- choice = ChatChoiceChunkSchema().dump({"index": index, "delta": delta, "finish_reason": None})
- yield sse(line=ChatResponseChunkSchema().dump({"model": chat_model.name, "choices": [choice]}))
- choice = ChatChoiceChunkSchema().dump({"index": 0, "delta": {}, "finish_reason": "stop"})
- yield sse(line=ChatResponseChunkSchema().dump({"model": chat_model.name, "choices": [choice]}))
- yield sse(line="[DONE]")
-
-
- def completions_generate(req: Dict) -> Dict:
- """输出模型补全"""
- prompts = req["prompt"] if isinstance(req["prompt"], list) else [req["prompt"]]
- choices = []
- for index, prompt in enumerate(prompts):
- choice = CompletionsChoiceSchema().dump({
- "index": index,
- "text": completion_model.generate(prefix=prompt),
- "finish_reason": "length"}
- )
- choices.append(choice)
- return CompletionsResponseSchema().dump({"model": completion_model.name, "choices": choices})
-
-
- def completions_stream(req: Dict):
- """流式输出模型补全"""
- prompts = req["prompt"] if isinstance(req["prompt"], list) else [req["prompt"]]
- choices = []
- for index, prompt in enumerate(prompts):
- choice = CompletionsChoiceSchema().dump({"index": index, "text": "\n\n", "finish_reason": None})
- yield sse(line=CompletionsResponseSchema().dump({"model": completion_model.name, "choices": [choice]}))
- for answer in completion_model.stream(prefix=prompt):
- choice = CompletionsChoiceSchema().dump({"index": index, "text": answer, "finish_reason": None})
- yield sse(line=CompletionsResponseSchema().dump({"model": completion_model.name, "choices": [choice]}))
- choice = CompletionsChoiceSchema().dump({"index": index, "text": "", "finish_reason": "length"})
- yield sse(line=CompletionsResponseSchema().dump({"model": completion_model.name, "choices": [choice]}))
- choices.append(choice)
- yield sse(line=CompletionsResponseSchema().dump({"model": completion_model.name, "choices": choices}))
- yield sse(line="[DONE]")
-
-
- # AI协作平台不适用main空间执行,需要返回FastAPI对象
- if __name__ == "__main__":
- run(app=app, host=appHost, port=appPort)
|