|
- import os
- import cv2
- import argparse
- import torch
- import json
- from tqdm import tqdm
- import pickle as pkl
- import time
- import Levenshtein as L
-
- from utils import load_config, load_checkpoint
- from infer.Backbone import Backbone
- from dataset import Words
- from utils import compute_edit_distance
-
- parser = argparse.ArgumentParser(description='Spatial channel attention')
- parser.add_argument('--config', default='14.yaml', type=str, help='配置文件路径')
- parser.add_argument('--image_path', default='data/CROHME/19_test_images.pkl', type=str, help='测试image路径')
- parser.add_argument('--label_path', default='data/CROHME/19_test_labels.txt', type=str, help='测试label路径')
- args = parser.parse_args()
-
- if not args.config:
- print('请提供config yaml路径!')
- exit(-1)
-
- """加载config文件"""
- params = load_config(args.config)
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- params['device'] = device
-
- words = Words(params['word_path'])
- params['word_num'] = len(words)
- params['struct_num'] = 7
- params['words'] = words
-
- model = Backbone(params)
- model = model.to(device)
- model_time = 0
- load_checkpoint(model, None, params['checkpoint'])
-
- model.eval()
-
- word_right, node_right, exp_right, length, cal_num, e1, e2, e3 = 0, 0, 0, 0, 0, 0, 0, 0
-
- with open(args.label_path) as f:
- labels = f.readlines()
-
- with open(args.image_path, 'rb') as f:
- images = pkl.load(f)
-
- def convert(nodeid, gtd_list):
- isparent = False
- child_list = []
- for i in range(len(gtd_list)):
- if gtd_list[i][2] == nodeid:
- isparent = True
- child_list.append([gtd_list[i][0],gtd_list[i][1],gtd_list[i][3]])
- if not isparent:
- return [gtd_list[nodeid][0]]
- else:
- if gtd_list[nodeid][0] == '\\frac':
- return_string = [gtd_list[nodeid][0]]
- for i in range(len(child_list)):
- if child_list[i][2] == 'Above':
- return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] == 'Below':
- return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] == 'Right':
- return_string += convert(child_list[i][1], gtd_list)
- for i in range(len(child_list)):
- if child_list[i][2] not in ['Right','Above','Below']:
- return_string += ['illegal']
- else:
- return_string = [gtd_list[nodeid][0]]
- for i in range(len(child_list)):
- if child_list[i][2] in ['l_sup']:
- return_string += ['['] + convert(child_list[i][1], gtd_list) + [']']
- for i in range(len(child_list)):
- if child_list[i][2] == 'Inside':
- return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] in ['Sub','Below']:
- return_string += ['_','{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] in ['Sup','Above']:
- return_string += ['^','{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] in ['Right']:
- return_string += convert(child_list[i][1], gtd_list)
- return return_string
-
-
- with torch.no_grad():
- bad_case = {}
- for item in tqdm(labels):
- name, *label = item.split()
- label = ' '.join(label)
- if name.endswith('.jpg'):
- name = name.split('.')[0]
- image = images[name]
- image = torch.Tensor(image) / 255
- image = image.unsqueeze(0).unsqueeze(0)
- image = image.to(device)
-
- image_mask = torch.ones(image.shape)
- image, image_mask = image.to(device), image_mask.to(device)
- a = time.time()
- prediction = model(image, image_mask, name)
- model_time = model_time + time.time() - a
- latex_list = convert(1, prediction)
- latex_string = ' '.join(latex_list)
- if latex_string == label.strip():
- exp_right += 1
- else:
- bad_case[name] = {
- 'label': label,
- 'predi': latex_string,
- 'list': prediction
- }
- distance = compute_edit_distance(latex_string, label)
- if distance <= 1:
- e1 += 1
- if distance <= 2:
- e2 += 1
- if distance <= 3:
- e3 += 1
- print(f'ExpRate: {exp_right / len(labels)}')
- print(f'e1: {e1 / len(labels)}')
- print(f'e2: {e2 / len(labels)}')
- print(f'e3: {e3 / len(labels)}')
- print(f'fps:{len(labels) / model_time}')
- with open('bad_case.json', 'w') as f:
- json.dump(bad_case, f, ensure_ascii=False)
-
-
-
-
-
-
-
-
-
-
-
-
-
|