|
- # coding=utf-8
- # Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """Sample Generate GPT2"""
-
- import os
- import sys
- MAIN_DIR = os.path.dirname(
- os.path.abspath(__file__)
- )
- print(MAIN_DIR)
- sys.path.insert(0, MAIN_DIR)
-
- import time
-
- def get_model_resp_one_item(url, input_str, tokens_to_generate, top_k=3, logprobs=False):
-
- return_response = None
- if url == 'MODEL_BAICHUAN':
- from model_baichuan import get_local_model_resp_one_item
- return_response = get_local_model_resp_one_item(input_str, tokens_to_generate, top_k, logprobs)
- elif url == 'MODEL_200B':
- from model_200B import get_local_model_resp_one_item
- return_response = get_local_model_resp_one_item(input_str, tokens_to_generate, top_k, logprobs)
- else:
- from model_url import get_url_model_resp_one_item
- return_response = get_url_model_resp_one_item(url, input_str, tokens_to_generate, top_k, logprobs)
- return return_response
-
- def get_model_resp(url, input_str, tokens_to_generate, top_k=3, logprobs=False):
- if isinstance(input_str, str):
- return get_model_resp_one_item(url, input_str, tokens_to_generate, top_k, logprobs)
- else:
- assert isinstance(input_str, list)
- return_resp_list = []
- for input_str_one_item in input_str:
- return_resp = get_model_resp_one_item(url, input_str_one_item, tokens_to_generate, top_k, logprobs)
- return_resp_list.append(return_resp)
-
- for return_resp_item in return_resp_list:
- if return_resp_item is None:
- return None
-
- return return_resp_list
-
-
- def get_tokenizer(url):
- tokenizer = None
- if url == 'MODEL_BAICHUAN':
- from model_baichuan import get_local_tokenizer
- tokenizer = get_local_tokenizer()
- elif url == 'MODEL_200B':
- from model_200B import get_local_tokenizer
- tokenizer = get_local_tokenizer()
- else:
- from model_url import get_url_tokenizer
- tokenizer = get_url_tokenizer()
- return tokenizer
-
-
- def do_eval(url, task_processor):
- MAIN_RANK_FLAG = False
- MODEL_200B_AND_LOGPROBS_FLAG = True if (url == 'MODEL_200B' and task_processor.logprobs) else False
-
- print("Eval Task {} start!".format(task_processor.task_name))
-
- task_processor.init_example_list()
- print("example_list: {}".format(task_processor.example_list))
-
- for shot in ["few_shot", "zero_shot"]:
- if shot == 'few_shot' and (task_processor.task_name in ['z_bench']):
- continue
- if shot == 'zero_shot' and (task_processor.task_name in ['samples_fdd']):
- continue
- start_time = time.time()
-
- eval_data_iter = task_processor.get_eval_data_iter(shot)
- task_processor.init_eval_results()
- for input_items, groud_truth_items in eval_data_iter:
- example_str = task_processor.get_shot_example_str()
- input_str, useful_info = task_processor.get_input_str(input_items, example_str)
- if MODEL_200B_AND_LOGPROBS_FLAG:
- input_length_list, mask_length_list = useful_info
- input_str = list(zip(input_str, input_length_list, mask_length_list))
- model_resp = get_model_resp(url, input_str, task_processor.tokens_to_generate, task_processor.top_k, task_processor.logprobs)
- if model_resp is not None:
- MAIN_RANK_FLAG = True
- if not MODEL_200B_AND_LOGPROBS_FLAG:
- model_resp = task_processor.process_model_resp(model_resp, useful_info)
- task_processor.evaluate_result(model_resp, groud_truth_items)
- if MAIN_RANK_FLAG:
- task_processor.save_eval_results()
-
- end_time = time.time()
- print("Eval Task {} {} end! Cost Time: {}".format(task_processor.task_name, shot, end_time - start_time))
-
-
-
- if __name__ == "__main__":
- base_path = MAIN_DIR + "/dataset"
- url = sys.argv[1]
- model_token_max_len = int(sys.argv[2])
- task_name = str(sys.argv[3])
- os.environ['TOKEN_VERSION'] = str(sys.argv[4])
- os.environ['URL_VERSION'] = str(sys.argv[5])
- # url = 'http://1.14.200.123:5010/api'
- # model_token_max_len = 2000
- # task_name = 'mmlu'
- # os.environ['TOKEN_VERSION'] = 'v2'
- # os.environ['URL_VERSION'] = 'v1'
- tokenizer = get_tokenizer(url)
- if task_name == 'webqa':
- from dataset_processor.webqa_processor import WebqaProcessor as MainProcessor
- elif task_name == 'cmnli':
- from dataset_processor.cmnli_processor import CmnliProcessor as MainProcessor
- elif task_name == 'c3_m':
- from dataset_processor.c3_m_processor import C3MProcessor as MainProcessor
- elif task_name == 'cmrc':
- from dataset_processor.cmrc_processor import CMRCProcessor as MainProcessor
- elif task_name == 'siqa':
- from dataset_processor.siqa_processor import SIQAProcessor as MainProcessor
- elif task_name == 'sst2':
- from dataset_processor.sst2_processor import SST2Processor as MainProcessor
- elif task_name == 'winogrande':
- from dataset_processor.winogrande_processor import WinoGrandeProcessor as MainProcessor
- elif task_name == 'iflytek':
- from dataset_processor.iflytek_processor import IflytekProcessor as MainProcessor
- elif task_name == 'dureader':
- from dataset_processor.dureader_processor import DuReaderProcessor as MainProcessor
- elif task_name == 'z_bench':
- from dataset_processor.z_bench_processor import ZBenchProcessor as MainProcessor
- elif task_name == 'c_eval':
- from dataset_processor.c_eval_processor import CEvalProcessor as MainProcessor
- elif task_name == 'gaokao':
- from dataset_processor.gaokao_processor import GaoKaoProcessor as MainProcessor
- elif task_name == 'agi_eval':
- from dataset_processor.agi_eval_processor import AGIEvalProcessor as MainProcessor
- elif task_name == 'mmlu':
- from dataset_processor.mmlu_processor import MMLUProcessor as MainProcessor
- elif task_name == 'c_eval_test':
- from dataset_processor.c_eval_test_processor import CEvalTestProcessor as MainProcessor
- elif task_name == 'samples_fdd':
- from dataset_processor.samples_for_data_distribution import SamplesFDPProcessor as MainProcessor
- os.environ['TEST_JSON_LIST_RANK'] = str(sys.argv[6])
-
- task_processor = MainProcessor(task_name, base_path, tokenizer, model_token_max_len)
- if task_processor.logprobs:
- os.environ['MODEL_PROBS'] = 'TRUE'
- else:
- os.environ['MODEL_PROBS'] = 'FALSE'
-
- do_eval(url, task_processor)
|