|
- import os
- from datasets import load_dataset
- from model_url import get_model_resp, get_url_tokenizer
-
-
- def run_predict(url, log_path, few_shot = True):
- import numpy as np
- tokenizer = get_url_tokenizer()
- MAIN_DIR = os.path.dirname(os.path.abspath(__file__))
- name_list = ['high', 'middle']
- label_convert = {0:"A", 1:"B", 2:"C", 3:"D"}
- def load(name):
- dataset = load_dataset(MAIN_DIR + f"/task_dataset/race/{name}")
- print(dataset)
- def preprocess(x):
- for ans, option in zip(['A', 'B', 'C', 'D'], x['options']):
- x[ans] = option
- del x['options']
- return x
- return dataset.map(preprocess)
- for name in name_list:
- data = load(name)
- validation = data['validation']
- count = 0
- correct_num = 0
- acc = 0
- for info in validation:
- count += 1
- article, A, B, C, D, question, answer = info['article'], \
- info['A'],info['B'],info['C'],info['D'],info['question'],info['answer']
- example = f"'Read the article, and answer the question by replying A, B, C or D.\n{article}\n\nQ: {question}\nA. {A}\nB. {B}\nC. {C}\nD. {D}\nAnswer:"
-
- input_str_one = f"{example}A"
- input_str_two = f"{example}B"
- input_str_thr = f"{example}C"
- input_str_fou = f"{example}D"
-
- input_str = []
- input_str.append(input_str_one)
- input_str.append(input_str_two)
- input_str.append(input_str_thr)
- input_str.append(input_str_fou)
-
- mask_length_list = []
- input_length_list = []
-
- for pred in input_str:
- input_length_list.append(len(tokenizer.encode(pred)))
- mask_length_list.append(len(tokenizer.encode(example)))
-
- model_resp = get_model_resp(url=url, input_str=input_str, tokens_to_generate=0, top_k=1, logprobs=True)
- return_resp = []
- for resp_item, input_length, mask_length in zip(model_resp, input_length_list, mask_length_list):
- assert len(resp_item) == input_length - 1
- item = resp_item[mask_length - 1:input_length - 1]
- return_resp.append(item)
-
- pred_list = [sum(logprobs) / len(logprobs) for logprobs in return_resp]
- answers_pred = int(np.argmax(pred_list))
-
- if label_convert[answers_pred] == answer:
- correct_num += 1
- acc = correct_num / count
- print(f"race-{name}, 准确率Acc:{acc}, number: {count}")
-
- if not few_shot:
- with open(log_path + f'/race_{name}.txt', 'w') as file:
- file.write(f"race-{name}, zero shot , Acc: {acc}, number: {count}")
- else:
- with open(log_path + f'/race_{name}.txt', 'w') as file:
- file.write(f"race-{name}, few shot , Acc: {acc}, number: {count}")
|