|
- import os
- import sys
- import json
- import csv
- import re
- from model_url import get_model_resp, get_url_tokenizer
-
-
- def general_postprocess(text: str) -> str:
- # Cut off the first newline, period, or comma
- truncated_text = re.split(r'[\n.,]', text, 1)[0]
-
- # Remove punctuation
- no_punctuation = re.sub(r'[^\w\s]', '', truncated_text)
-
- # Remove article
- no_articles = re.sub(r'\b(a|an|the)\b',
- '',
- no_punctuation,
- flags=re.IGNORECASE)
-
- # Remove duplicated blank spaces
- cleaned_text = re.sub(r'\s+', ' ', no_articles).strip()
-
- return cleaned_text
-
-
- def score(predictions, references):
- if len(predictions) != len(references):
- return {
- 'error': 'predictions and references have different '
- 'length'
- }
- processed_predictions = []
- for prediction in predictions:
- prediction = prediction.strip().split('\n')[0].lower()
- if 'answer is' in prediction:
- prediction = prediction.split('answer is')[-1]
- prediction = general_postprocess(prediction)
- processed_predictions.append(prediction)
- processed_answers = [[general_postprocess(j).lower() for j in i]
- for i in references]
-
- details = []
- cnt = 0
- for pred, cand_ans in zip(processed_predictions, processed_answers):
- detail = {'pred': pred, 'answer': cand_ans, 'correct': False}
- cnt += int(any([cand == pred for cand in cand_ans]))
- if int(any([cand == pred for cand in cand_ans])):
- detail['correct'] = True
- details.append(detail)
- score = cnt / len(predictions) * 100
-
- return {'score': score, 'details': details}
-
- def run_predict(url, log_path, few_shot = True):
- """run predict"""
- MAIN_DIR = os.path.dirname(os.path.abspath(__file__))
- tokenizer = get_url_tokenizer()
- predictions = []
- references = []
-
- file_dir = MAIN_DIR + "/task_dataset/nq/nq-test.qa.csv"
- with open(file_dir, "r", encoding='utf-8') as f:
- reader = csv.reader(f, delimiter='\t')
- raw_data = []
- for row in reader:
- assert len(row) == 2
- question = row[0]
- answers = eval(row[1])
- raw_data.append({'question': question, 'answer': answers})
- for data in raw_data:
- question, answer = data['question'], data['answer']
- prompt = f'Answer these questions, your answer should be as simple as possible, start your answer with the prompt \'The answer is \'.\nQ: {question}?'
- model_resp = get_model_resp(url=url, input_str=prompt, tokens_to_generate=100, top_k=1, logprobs=False)
- model_resp = tokenizer.decode(tokenizer.encode(model_resp))
-
- predictions.append(model_resp)
- references.append(answer)
- with open(log_path + '/nq_predictions.json', 'w') as file:
- json.dump(predictions, file)
- with open(log_path + '/nq_references.json', 'w') as file:
- json.dump(references, file)
-
- result = score(predictions, references)
- with open(log_path + '/nq_zeroshot.json', 'w') as file:
- json.dump(result, file)
|