|
- # Copyright 2023 PengChengLab, PCL
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- import traceback
- from http.server import HTTPServer, BaseHTTPRequestHandler
- from socketserver import ThreadingMixIn
- import threading
- import urllib
- import json
- import requests
- from config import logger
- import time
-
- from auto_curl_postprocess_with_token import refresh_user_token
-
- def modelarts_api_post(input, top_k, top_p, max_generate_length):
- now_token = refresh_user_token()
-
- code = 200
- rsp = {'code': code}
- # 不同AICC,部署modelartsAPI 在线服务的api地址不同,需手动修改下面的url
- url = 'https://d4d470159a804e1ebae8299ce1d2ca8a.infer.syaicenter.com/v1/infers/3925bb8d-8d90-42de-96ec-d189f65b9a4c'
- heads = {
- "Content-Type": "application/json",
- "Connection": "keep-alive",
- "X-Auth-Token": "{}".format(now_token)
- }
-
- post_data = {
- "input": input,
- "top_k": top_k,
- "top_p": top_p,
- "max_generate_length": max_generate_length
- }
- startTime = time.time()
- r = requests.post(url, headers=heads, data=json.dumps(post_data), verify=False)
- if r.status_code == 200:
- rsp["text"] = [r.text]
- else:
- rsp["code"] = 401
- rsp["text"] = ""
- rsp["score"] = 0.98
- rsp["cost_time"] = float((time.time() - startTime))
- try:
- rsp_str = json.dumps(rsp, ensure_ascii=False)
- logger.info(rsp_str)
- except Exception as e:
- rsp = {'code': 402, 'm': '服务器返回数据错误:' + str(e) + "\n" + traceback.format_exc(), 'response': ''}
-
- return rsp
-
-
- class ThreadingSimpleServer(ThreadingMixIn, HTTPServer):
- pass
-
-
- class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
- def _except_response(self):
- self.send_response(204)
- self.send_header("Content-type", "text/html")
- self.end_headers()
-
- self.wfile.write(("request limited, please wait 3 minutes, then send request again").encode())
-
-
- def _normal_response(self, path, args):
- startTime=time.time()
- code=200
- rsp={'code':code}
- if path!="/insert" and path!="/insert/":
- self.send_response(404)
- self.send_header("Content-type", "text/html")
- self.end_headers()
- self.wfile.write(("not found").encode())
- return
-
- try:
- data_list = args['prompts']
- gen_len = args['tokens_to_generate']
- try:
- top_p = args['top_p']
- if top_p is None:
- top_p = 1.0
- except:
- top_p = 1.0
- try:
- top_k = args['top_k']
- if top_k is None:
- top_k = 2048
- except:
- top_k = 10
- ##### model ifer
- ########
- rsp = modelarts_api_post(data_list[0], top_k, top_p, gen_len)
- except Exception as e:
- rsp["code"]=401
- rsp["text"] = ""
- rsp["score"] = 0.98
- rsp["m"]='服务器错误:'+str(e)+"\n"+traceback.format_exc()
- logger.error(str(e), exc_info=True)
-
- rsp["cost_time"]=float((time.time()-startTime))
- try:
- rsp_str=json.dumps(rsp, ensure_ascii=False)
- logger.info(rsp_str)
- except Exception as e:
- rsp={'code':402,'m':'服务器返回数据错误:'+str(e)+"\n"+traceback.format_exc(),'response':''}
- rsp_str=json.dumps(rsp, ensure_ascii=False)
- logger.info(rsp)
-
- self.send_response(code)
- self.send_header('Content-type', 'text/json; charset=utf-8')
- self.send_header('Access-Control-Allow-Origin', '*')
- self.end_headers()
- self.wfile.write(rsp_str.encode())
-
- def get_phrase(self, args):
- args_prase=urllib.parse.parse_qs(args).items()
- tmp_args=dict([(k,v[0]) for k,v in args_prase])
- return tmp_args
-
- def post_phrase(self, args):
- return json.loads(args)
-
- def do_GET(self):
- #self.send_response(200)
- #self.end_headers()
- #self.wfile.write(b'Hello, world!')
- try:
- path,args=urllib.parse.splitquery(self.path)
- args=self.get_phrase(args)
- thread_num=threading.active_count()
- logger.info(thread_num)
- if thread_num >= 100:
- self._except_response()
- else:
- try:
- self._normal_response(path, args)
- except:
- self.send_response("404")
- self.send_header('Content-type', 'text/json; charset=utf-8')
- self.end_headers()
- self.wfile.write(("目前服务故障,请等待").encode())
- except Exception as e:
- self._except_response()
- logger.error(str(e), exc_info=True)
- #else:
- # self._response(path, args)
-
- def do_POST(self):
- try:
- args = self.rfile.read(int(self.headers['content-length'])).decode("utf-8")
- #args = self.rfile.read(1024).decode("utf-8")
- args=self.post_phrase(args)
- logger.warning(args)
- logger.warning(self.path)
- thread_num=threading.active_count()
- logger.warning(thread_num)
- if thread_num >= 100:
- self._except_response()
- else:
- try:
- self._normal_response(self.path, args)
- except:
- self.send_response(404)
- self.send_header('Content-type', 'text/json; charset=utf-8')
- self.end_headers()
- self.wfile.write(("目前服务故障,请等待").encode())
- except Exception as e:
- self._except_response()
- logger.error(str(e), exc_info=True)
-
-
-
- def main():
- #HOST, PORT = '0.0.0.0', 6868
-
- HOST, PORT = '0.0.0.0', 6996
- #webServer = ThreadingSimpleServer((HOST, PORT), MyServer)
-
- webServer = ThreadingSimpleServer((HOST, PORT), SimpleHTTPRequestHandler)
- #webServer = HTTPServer((HOST, PORT), SimpleHTTPRequestHandler)
- try:
- logger.info("Server started http://%s:%s" % (HOST, PORT))
- webServer.serve_forever()
- except Exception as e:
- logger.error(str(e), exc_info=True)
-
-
- #webServer.server_close()
- #logger.info("Server stopped.")
-
- if __name__== '__main__':
- main()
|