|
- import os
- import sys
- import json
- from model_url import get_model_resp, get_url_tokenizer
- import pandas as pd
-
- def remove_boxed(s):
- left = '\\boxed{'
- try:
- assert s[:len(left)] == left
- assert s[-1] == '}'
- return s[len(left):-1]
- except Exception:
- return None
-
-
- def last_boxed_only_string(string):
- idx = string.rfind('\\boxed')
- if idx < 0:
- idx = string.rfind('\\fbox')
- if idx < 0:
- return None
-
- i = idx
- right_brace_idx = None
- num_left_braces_open = 0
- while i < len(string):
- if string[i] == '{':
- num_left_braces_open += 1
- if string[i] == '}':
- num_left_braces_open -= 1
- if num_left_braces_open == 0:
- right_brace_idx = i
- break
- i += 1
-
- if right_brace_idx is None:
- retval = None
- else:
- retval = string[idx:right_brace_idx + 1]
-
- return retval
-
-
- def score(predictions, references):
- if len(predictions) != len(references):
- return {
- 'error': 'predictions and references have different '
- 'length'
- }
- correct = 0
- count = 0
- details = []
- for i, j in zip(predictions, references):
- detail = {'pred': i, 'answer': j, 'correct': False}
- count += 1
- if is_equiv(i, j):
- correct += 1
- detail['correct'] = True
- details.append(detail)
- result = {'accuracy': 100 * correct / count, 'details': details}
- return result
-
- def _fix_fracs(string):
- substrs = string.split('\\frac')
- new_str = substrs[0]
- if len(substrs) > 1:
- substrs = substrs[1:]
- for substr in substrs:
- new_str += '\\frac'
- if substr[0] == '{':
- new_str += substr
- else:
- try:
- assert len(substr) >= 2
- except AssertionError:
- return string
- a = substr[0]
- b = substr[1]
- if b != '{':
- if len(substr) > 2:
- post_substr = substr[2:]
- new_str += '{' + a + '}{' + b + '}' + post_substr
- else:
- new_str += '{' + a + '}{' + b + '}'
- else:
- if len(substr) > 2:
- post_substr = substr[2:]
- new_str += '{' + a + '}' + b + post_substr
- else:
- new_str += '{' + a + '}' + b
- string = new_str
- return string
-
- def _fix_a_slash_b(string):
- if len(string.split('/')) != 2:
- return string
- a = string.split('/')[0]
- b = string.split('/')[1]
- try:
- a = int(a)
- b = int(b)
- assert string == '{}/{}'.format(a, b)
- new_string = '\\frac{' + str(a) + '}{' + str(b) + '}'
- return new_string
- except AssertionError:
- return string
-
- def _remove_right_units(string):
- # "\\text{ " only ever occurs (at least in the val set) when describing
- # units
- if '\\text{ ' in string:
- splits = string.split('\\text{ ')
- assert len(splits) == 2
- return splits[0]
- else:
- return string
-
- def _fix_sqrt(string):
- if '\\sqrt' not in string:
- return string
- splits = string.split('\\sqrt')
- new_string = splits[0]
- for split in splits[1:]:
- if split[0] != '{':
- a = split[0]
- new_substr = '\\sqrt{' + a + '}' + split[1:]
- else:
- new_substr = '\\sqrt' + split
- new_string += new_substr
- return new_string
-
- def _strip_string(string):
- # linebreaks
- string = string.replace('\n', '')
-
- # remove inverse spaces
- string = string.replace('\\!', '')
-
- # replace \\ with \
- string = string.replace('\\\\', '\\')
-
- # replace tfrac and dfrac with frac
- string = string.replace('tfrac', 'frac')
- string = string.replace('dfrac', 'frac')
-
- # remove \left and \right
- string = string.replace('\\left', '')
- string = string.replace('\\right', '')
-
- # Remove circ (degrees)
- string = string.replace('^{\\circ}', '')
- string = string.replace('^\\circ', '')
-
- # remove dollar signs
- string = string.replace('\\$', '')
-
- # remove units (on the right)
- string = _remove_right_units(string)
-
- # remove percentage
- string = string.replace('\\%', '')
- string = string.replace('\%', '') # noqa: W605
-
- # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively,
- # add "0" if "." is the start of the string
- string = string.replace(' .', ' 0.')
- string = string.replace('{.', '{0.')
- # if empty, return empty string
- if len(string) == 0:
- return string
- if string[0] == '.':
- string = '0' + string
-
- # to consider: get rid of e.g. "k = " or "q = " at beginning
- if len(string.split('=')) == 2:
- if len(string.split('=')[0]) <= 2:
- string = string.split('=')[1]
-
- # fix sqrt3 --> sqrt{3}
- string = _fix_sqrt(string)
-
- # remove spaces
- string = string.replace(' ', '')
-
- # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works
- # with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b}
- string = _fix_fracs(string)
-
- # manually change 0.5 --> \frac{1}{2}
- if string == '0.5':
- string = '\\frac{1}{2}'
-
- # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix
- # in case the model output is X/Y
- string = _fix_a_slash_b(string)
-
- return string
-
- def is_equiv( str1, str2, verbose=False):
- if str1 is None and str2 is None:
- print('WARNING: Both None')
- return True
- if str1 is None or str2 is None:
- return False
-
- try:
- ss1 = _strip_string(str1)
- ss2 = _strip_string(str2)
- if verbose:
- print(ss1, ss2)
- return ss1 == ss2
- except: # noqa
- return str1 == str2
-
- def run_predict(url, log_path, few_shot = True):
- tokenizer = get_url_tokenizer()
- MAIN_DIR = os.path.dirname(os.path.abspath(__file__))
- file_dir = MAIN_DIR + "/task_dataset/math/math.json"
-
- pre_question = [
- "Problem:\nFind the domain of the expression $\\frac{{\sqrt{{x-2}}}}{{\sqrt{{5-x}}}}$.}}\nSolution:",
- "Problem:\nIf $\det \mathbf{{A}} = 2$ and $\det \mathbf{{B}} = 12,$ then find $\det (\mathbf{{A}} \mathbf{{B}}).$\nSolution:",
- "Problem:\nTerrell usually lifts two 20-pound weights 12 times. If he uses two 15-pound weights instead, how many times must Terrell lift them in order to lift the same total weight?\nSolution:",
- "Problem:\nIf the system of equations: \\begin{{align*}} 6x-4y&=a,\\\\ 6y-9x &=b. \end{{align*}}has a solution $(x, y)$ where $x$ and $y$ are both nonzero, find $\\frac{{a}}{{b}},$ assuming $b$ is nonzero.\nSolution:",
- ]
-
- pre_answer=[
- "The expressions inside each square root must be non-negative. Therefore, $x-2 \ge 0$, so $x\ge2$, and $5 - x \ge 0$, so $x \le 5$. Also, the denominator cannot be equal to zero, so $5-x>0$, which gives $x<5$. Therefore, the domain of the expression is $\\boxed{{[2,5)}}$.\nFinal Answer: The final answer is $[2,5)$. I hope it is correct.\n",
- "We have that $\det (\mathbf{{A}} \mathbf{{B}}) = (\det \mathbf{{A}})(\det \mathbf{{B}}) = (2)(12) = \\boxed{{24}}.$\nFinal Answer: The final answer is $24$. I hope it is correct.\n",
- "If Terrell lifts two 20-pound weights 12 times, he lifts a total of $2\cdot 12\cdot20=480$ pounds of weight. If he lifts two 15-pound weights instead for $n$ times, he will lift a total of $2\cdot15\cdot n=30n$ pounds of weight. Equating this to 480 pounds, we can solve for $n$: \\begin{{align*}} 30n&=480\\\\ \Rightarrow\qquad n&=480/30=\\boxed{{16}} \end{{align*}}\nFinal Answer: The final answer is $16$. I hope it is correct.\n",
- "If we multiply the first equation by $-\\frac{{3}}{{2}}$, we obtain $$6y-9x=-\\frac{{3}}{{2}}a.$$Since we also know that $6y-9x=b$, we have $$-\\frac{{3}}{{2}}a=b\Rightarrow\\frac{{a}}{{b}}=\\boxed{{-\\frac{{2}}{{3}}}}.$$\nFinal Answer: The final answer is $-\\frac{{2}}{{3}}$. I hope it is correct.\n",
- ]
- example = ""
- if few_shot:
- for pre_q, pre_a in zip(pre_question, pre_answer):
- example += f"{pre_q}\n{pre_a}"
-
- predictions = []
- references = []
- tokens_to_generate = 100
- with open(file_dir, 'r', encoding="utf-8") as fl:
- data = json.load(fl)
- raw_data = []
- for i in data.keys():
- raw_data.append({
- 'problem':
- data[i]['problem'],
- 'solution':
- remove_boxed(last_boxed_only_string(data[i]['solution']))
- })
-
- for info in raw_data:
- question = info['problem']
- answer = info['solution'].strip()
- sample = f"{example}Problem:\n{question}\nSolution:\n"
-
- input_token_ids = tokenizer.encode(sample)
-
- if len(input_token_ids) + tokens_to_generate > 2048:
- input_token_ids = input_token_ids[-(2048 - tokens_to_generate):]
- sample = tokenizer.decode(input_token_ids)
- sample = sample.strip()
- # Tokenize input sentence to ids
- output_samples = get_model_resp(url=url, input_str=sample, tokens_to_generate=tokens_to_generate, top_k=1, logprobs=False)
- output_samples = tokenizer.decode(tokenizer.encode(output_samples))
-
- predictions.append(output_samples)
- references.append(answer)
-
- with open(log_path + '/math_predictions.json', 'w') as file:
- json.dump(predictions, file)
- with open(log_path + '/math_references.json', 'w') as file:
- json.dump(references, file)
-
- result = score(predictions, references)
- with open(log_path + '/math_4shot.json', 'w') as file:
- json.dump(result, file)
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
|