|
- import os
- import sys
- import json
- import jpype
- import time
- from tqdm import tqdm
-
- from tqdm._tqdm import trange
- jpath = "-Djava.class.path=./search/Search.jar:./search/lib/lucene-core-8.3.1.jar" \
- ":./search/lib/lucene-queryparser-8.3.1.jar:./search/lib/IKAnalyzer-5.0.jar:./search/lib/mysql-connector-java-8.0.18.jar"
- if not jpype.isJVMStarted():
- jpype.startJVM(jpype.getDefaultJVMPath(), '-ea', jpath, convertStrings=False)
- QSearch = jpype.JClass('demo.TextIndex')
- qs = QSearch()
-
- '''
- all_not_include_test_evidence_top_10.json.x.json
- right: 1203
- cnt: 2635
- ACC: 0.45654648956356736
- all_evidence_top_10.json.x.json
-
- '''
- def score(q, topk=10):
- result = qs.getText(topk, q)
- text = []
- scores = []
- touch = []
- part_text = []
- for bean in result:
- text.append(bean.getText())
- scores.append(bean.getScore())
- x = bean.getTouch()
- touch.append(x)
- # part_text.append(qs.getChapContent(x))
- return scores, text, touch, None
-
- def main(topk=10):
- cnt = 0
- # 0.4070638986249663 new_append_data
- #
- with open("./split_data_mc/test.json", "r") as f, \
- open("./split_data_mc/test_evidence.json", "w") as nf:
- lines = f.readlines()
- right = 0
-
- for ind in trange(len(lines)):
-
-
- line = lines[ind]
-
- '''
- questionType, questionId, questionText, optionImg, questionImg, backgroundText, answer: []
- audiourl, subject, option
- '''
- instance = json.loads(line.strip())
- instance['q_s'] = ""
- instance['context_s'] = []
- instance['option_s'] = []
-
- if instance["answer"][0].strip() == "":
- continue
- # 20200413 modified version
- # if len(instance["answer"]) == 1:
- # continue
- if len(instance["option"]) != 5:
- continue
- instance['context'] = []
- option = '####'.join(instance['option']).replace("\n", "")
- q = instance['backgroundText'] + instance['questionText']
- q = q.replace("\n", " ")
- # print(q)
- # print(instance['questionId'])
- # if q != "下列中药与化学药联合应用中,不存在重复用药的是":
- # continue
-
- # 20200413 modified version
- # if instance['questionType'] != "多项选择题":
- # continue
-
- # if '多项' not in instance['questionType']:
- # continue
- tuples = []
-
-
- for i, opt in enumerate(instance['option']):
-
- query = q + opt
- # print(query)
- start = time.time()
- try:
- output = score(query, topk=1)
- # print("time: ", time.time() - start)
- tuples.append((i, output[0][0], output[1][0], output[2][0], opt))
- xxstr = []
- for j in range(len(output[1])):
- xxstr.append(str(output[1][j]))
- instance['context'].append("######".join(xxstr))
- # print(instance['context'][-1])
- except:
- instance['context'].append("")
- import traceback
- traceback.print_exc()
-
- assert len(instance['option']) == 5
- assert len(instance['context']) == 5
- reverse = True
- # if options_to_logic[option] == "逆向":
- # reverse = False
- tuples = sorted(tuples, key=lambda x: x[1], reverse=reverse)
-
- nf.write(json.dumps(instance, ensure_ascii=False)+"\n")
- if len(tuples) == 0:
- continue
- if tuples[0][0] == int(instance['answer'][0]) - 1:
-
- right += 1
- print("right")
- else:
- # print(q)
- # print(instance['option'])
- # print(instance['option'][int(instance['answer'][0]) - 1])
- # print(tuples[:10])
- print("wrong")
- cnt += 1
-
- print("right: ", right)
- print("cnt: ", cnt)
- print("ACC: ", 1.0*right/cnt)
-
- if __name__ == "__main__":
- # main()
- topk = 10
- output = score("混合易发生爆炸的是高锰酸钾与甘油", topk=topk)
- print(output)
- tuples = [] ################葡萄糖注射液双歧三联活菌制剂
- for i in range(topk):
- tuples.append((i, output[0][i], output[1][i], output[2][i]))
- tuples = sorted(tuples, key=lambda x: x[1], reverse=True)
- for i in range(topk):
- print(i)
- print(output[0][i])
- print(output[1][i])
- print(output[2][i])
- # print(output[3][i])
- print()
|