|
- <<<<<<< HEAD
- from PIL import Image, ImageOps
-
- # 打开图像,并将RGBA模式转换为RGB模式
- image = Image.open('1.png').convert('RGBA')
- new_image = Image.new("RGBA", image.size, (255, 255, 255, 255))
- new_image.paste(image, (0, 0), image)
-
- # 将图像转换为灰度图像
- gray_image = ImageOps.grayscale(new_image)
-
- # 保存图像
- gray_image.show()
- =======
- import os
- import cv2
- import argparse
- import torch
- import json
- from tqdm import tqdm
- import pickle as pkl
-
- import Levenshtein as L
-
- from utils import load_config, load_checkpoint
- from infer.Backbone import Backbone
- from dataset import Words
- from utils import compute_edit_distance
-
- parser = argparse.ArgumentParser(description='Spatial channel attention')
- parser.add_argument('--config', default='14.yaml', type=str, help='配置文件路径')
- parser.add_argument('--image_path', default='data/CROHME/14_test_images.pkl', type=str, help='测试image路径')
- parser.add_argument('--label_path', default='data/CROHME/14_test_labels.txt', type=str, help='测试label路径')
- args = parser.parse_args()
-
- if not args.config:
- print('请提供config yaml路径!')
- exit(-1)
-
- """加载config文件"""
- params = load_config(args.config)
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- params['device'] = device
-
- words = Words(params['word_path'])
- params['word_num'] = len(words)
- params['struct_num'] = 7
- params['words'] = words
-
- model = Backbone(params)
- model = model.to(device)
-
- load_checkpoint(model, None, params['checkpoint'])
-
- model.eval()
- image = cv2.imread('./data/off_image_test/18_em_0_0.bmp')
- image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
- image = torch.Tensor(image) / 255
- image = image.unsqueeze(0).unsqueeze(0)
- image = image.to(device)
- image_mask = torch.ones(image.shape)
- image, image_mask = image.to(device), image_mask.to(device)
- prediction = model(image, image_mask, 1)
- print(prediction)
- #
- # word_right, node_right, exp_right, length, cal_num, e1, e2, e3 = 0, 0, 0, 0, 0, 0, 0, 0
- #
- # with open(args.label_path) as f:
- # labels = f.readlines()
- #
- # with open(args.image_path, 'rb') as f:
- # images = pkl.load(f)
- #
- # def convert(nodeid, gtd_list):
- # isparent = False
- # child_list = []
- # for i in range(len(gtd_list)):
- # if gtd_list[i][2] == nodeid:
- # isparent = True
- # child_list.append([gtd_list[i][0],gtd_list[i][1],gtd_list[i][3]])
- # if not isparent:
- # return [gtd_list[nodeid][0]]
- # else:
- # if gtd_list[nodeid][0] == '\\frac':
- # return_string = [gtd_list[nodeid][0]]
- # for i in range(len(child_list)):
- # if child_list[i][2] == 'Above':
- # return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- # for i in range(len(child_list)):
- # if child_list[i][2] == 'Below':
- # return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- # for i in range(len(child_list)):
- # if child_list[i][2] == 'Right':
- # return_string += convert(child_list[i][1], gtd_list)
- # for i in range(len(child_list)):
- # if child_list[i][2] not in ['Right','Above','Below']:
- # return_string += ['illegal']
- # else:
- # return_string = [gtd_list[nodeid][0]]
- # for i in range(len(child_list)):
- # if child_list[i][2] in ['l_sup']:
- # return_string += ['['] + convert(child_list[i][1], gtd_list) + [']']
- # for i in range(len(child_list)):
- # if child_list[i][2] == 'Inside':
- # return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- # for i in range(len(child_list)):
- # if child_list[i][2] in ['Sub','Below']:
- # return_string += ['_','{'] + convert(child_list[i][1], gtd_list) + ['}']
- # for i in range(len(child_list)):
- # if child_list[i][2] in ['Sup','Above']:
- # return_string += ['^','{'] + convert(child_list[i][1], gtd_list) + ['}']
- # for i in range(len(child_list)):
- # if child_list[i][2] in ['Right']:
- # return_string += convert(child_list[i][1], gtd_list)
- # return return_string
- #
- #
- # with torch.no_grad():
- # bad_case = {}
- # for item in tqdm(labels):
- # name, *label = item.split()
- # label = ' '.join(label)
- # if name.endswith('.jpg'):
- # name = name.split('.')[0]
- # image = images[name]
- # image = torch.Tensor(image) / 255
- # image = image.unsqueeze(0).unsqueeze(0)
- # image = image.to(device)
- #
- # image_mask = torch.ones(image.shape)
- # image, image_mask = image.to(device), image_mask.to(device)
- #
- # prediction = model(image, image_mask, name)
- # print(prediction)
- # latex_list = convert(1, prediction)
- # latex_string = ' '.join(latex_list)
- # if latex_string == label.strip():
- # exp_right += 1
- # else:
- # bad_case[name] = {
- # 'label': label,
- # 'predi': latex_string,
- # 'list': prediction
- # }
- # distance = compute_edit_distance(latex_string, label)
- # if distance <= 1:
- # e1 += 1
- # if distance <= 2:
- # e2 += 1
- # if distance <= 3:
- # e3 += 1
- # print(exp_right / len(labels))
- # print(e1 / len(labels))
- # print(e2 / len(labels))
- # print(e3 / len(labels))
- #
- # with open('bad_case.json', 'w') as f:
- # json.dump(bad_case, f, ensure_ascii=False)
- #
- #
- #
- #
- #
- #
- #
- #
- #
- #
- #
- #
- #
- #
- >>>>>>> 2f9921dfd1c27f23cc5be35c5e79f70b5f8b77f5
|