|
- # Copyright 2023 PengChengLab, PCL
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import traceback
- from http.server import HTTPServer, BaseHTTPRequestHandler
- from socketserver import ThreadingMixIn
- import threading
- import urllib
- import json
- import os
- import time
- from config import logger
- from serving_7B import load_model, run_predict_serving, hf_project
-
- from src.utils import get_args
- from src.pengcheng_mind_config import set_parse
- from src.generate import generate, generate_increment
-
- ##########################################
-
- args_opt = get_args(True)
- set_parse(args_opt)
- args_opt.distribute = 'false'
- args_opt.pre_trained = 'true'
- model_predict, config = load_model(args_opt)
- generate_func = generate_increment
-
- # Define tokenizer
- from transformers import LlamaTokenizer
- llama_tokenizer_dir = hf_project("tokenizer/llama_vocab/llama_zh_hf")
- print(llama_tokenizer_dir)
- print(os.listdir(llama_tokenizer_dir))
- llama_vocab_path = os.path.join(llama_tokenizer_dir, 'tokenizer_2.model')
- tokenizer = LlamaTokenizer.from_pretrained(llama_vocab_path)
- EOT = tokenizer.eos_token_id
- PAD = tokenizer.unk_token_id
-
- data_list_cache = ["你好呀"]
- gen_len = 50
- top_k = 3
- top_p = 1.0
- tmp = run_predict_serving(model_predict, generate_func, tokenizer, args_opt, data_list_cache, gen_len, top_k, top_p)
-
- class ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
- pass
-
-
- class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
- def _except_response(self):
- self.send_response(204)
- self.send_header("Content-type", "text/html")
- self.end_headers()
-
- self.wfile.write(("request limited, please wait 3 minutes, then send request again").encode())
-
-
- def _normal_response(self, path, args):
- startTime=time.time()
- code=200
- rsp={'code':code}
- if path!="/insert" and path!="/insert/":
- self.send_response(404)
- self.send_header("Content-type", "text/html")
- self.end_headers()
- self.wfile.write(("not found").encode())
- return
-
- try:
- data_list = args['prompts']
- gen_len = args['tokens_to_generate']
- try:
- top_p = args['top_p']
- top_k = args_opt.seq_length
- except:
- top_k = args['top_k']
- top_p = 1.0
- ##### model ifer
- ########
- response = run_predict_serving(model_predict, generate_func, tokenizer, args_opt, data_list, gen_len, top_k, top_p)
- rsp["text"]=[response]
- rsp["score"]=0.98
- rsp["code"] = 200
- except Exception as e:
- rsp["code"]=401
- rsp["text"] = ""
- rsp["score"] = 0.98
- rsp["m"]='服务器错误:'+str(e)+"\n"+traceback.format_exc()
- logger.error(str(e), exc_info=True)
-
- rsp["cost_time"]=float((time.time()-startTime))
- try:
- rsp_str=json.dumps(rsp, ensure_ascii=False)
- logger.info(rsp_str)
- except Exception as e:
- rsp={'code':402,'m':'服务器返回数据错误:'+str(e)+"\n"+traceback.format_exc(),'response':''}
- rsp_str=json.dumps(rsp, ensure_ascii=False)
- logger.info(rsp)
-
- self.send_response(code)
- self.send_header('Content-type', 'text/json; charset=utf-8')
- self.send_header('Access-Control-Allow-Origin', '*')
- self.end_headers()
- self.wfile.write(rsp_str.encode())
-
- def get_phrase(self, args):
- args_prase=urllib.parse.parse_qs(args).items()
- tmp_args=dict([(k,v[0]) for k,v in args_prase])
- return tmp_args
-
- def post_phrase(self, args):
- return json.loads(args)
-
- def do_GET(self):
- #self.send_response(200)
- #self.end_headers()
- #self.wfile.write(b'Hello, world!')
- try:
- path,args=urllib.parse.splitquery(self.path)
- args=self.get_phrase(args)
- thread_num=threading.active_count()
- logger.info(thread_num)
- if thread_num >= 100:
- self._except_response()
- else:
- try:
- self._normal_response(path, args)
- except:
- self.send_response("404")
- self.send_header('Content-type', 'text/json; charset=utf-8')
- self.end_headers()
- self.wfile.write(("目前服务故障,请等待").encode())
- except Exception as e:
- self._except_response()
- logger.error(str(e), exc_info=True)
- #else:
- # self._response(path, args)
-
- def do_POST(self):
- try:
- args = self.rfile.read(int(self.headers['content-length'])).decode("utf-8")
- #args = self.rfile.read(1024).decode("utf-8")
- args=self.post_phrase(args)
- logger.warning(args)
- logger.warning(self.path)
- thread_num=threading.active_count()
- logger.warning(thread_num)
- if thread_num >= 100:
- self._except_response()
- else:
- try:
- self._normal_response(self.path, args)
- except:
- self.send_response(404)
- self.send_header('Content-type', 'text/json; charset=utf-8')
- self.end_headers()
- self.wfile.write(("目前服务故障,请等待").encode())
- except Exception as e:
- self._except_response()
- logger.error(str(e), exc_info=True)
-
-
-
- def main():
- #HOST, PORT = '0.0.0.0', 6868
-
- HOST, PORT = '0.0.0.0', 6900
- #webServer = ThreadingSimpleServer((HOST, PORT), MyServer)
-
- webServer = ThreadingSimpleServer((HOST, PORT), SimpleHTTPRequestHandler)
- #webServer = HTTPServer((HOST, PORT), SimpleHTTPRequestHandler)
- print(">>> start server on: '{}:{}' ".format(HOST, PORT))
- try:
- logger.info("Server started http://%s:%s" % (HOST, PORT))
- webServer.serve_forever()
- except Exception as e:
- logger.error(str(e), exc_info=True)
-
-
- #webServer.server_close()
- #logger.info("Server stopped.")
-
- if __name__== '__main__':
- main()
|