|
- from flask import Flask, request
- from infer.Backbone import Backbone
- import cv2
- import argparse
- from utils import load_config, load_checkpoint
- import torch
- from dataset import Words
- import numpy as np
- from flask_cors import CORS
-
- app = Flask(__name__)
- CORS(app)
-
- def convert(nodeid, gtd_list):
- isparent = False
- child_list = []
- for i in range(len(gtd_list)):
- if gtd_list[i][2] == nodeid:
- isparent = True
- child_list.append([gtd_list[i][0],gtd_list[i][1],gtd_list[i][3]])
- if not isparent:
- return [gtd_list[nodeid][0]]
- else:
- if gtd_list[nodeid][0] == '\\frac':
- return_string = [gtd_list[nodeid][0]]
- for i in range(len(child_list)):
- if child_list[i][2] == 'Above':
- return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] == 'Below':
- return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] == 'Right':
- return_string += convert(child_list[i][1], gtd_list)
- for i in range(len(child_list)):
- if child_list[i][2] not in ['Right','Above','Below']:
- return_string += ['illegal']
- else:
- return_string = [gtd_list[nodeid][0]]
- for i in range(len(child_list)):
- if child_list[i][2] in ['l_sup']:
- return_string += ['['] + convert(child_list[i][1], gtd_list) + [']']
- for i in range(len(child_list)):
- if child_list[i][2] == 'Inside':
- return_string += ['{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] in ['Sub','Below']:
- return_string += ['_','{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] in ['Sup','Above']:
- return_string += ['^','{'] + convert(child_list[i][1], gtd_list) + ['}']
- for i in range(len(child_list)):
- if child_list[i][2] in ['Right']:
- return_string += convert(child_list[i][1], gtd_list)
- return return_string
-
-
- @app.route('/')
- def hello_world():
- return 'Hello, World!'
-
- @app.route('/model', methods=['POST'])
- def model():
- parser = argparse.ArgumentParser(description='Spatial channel attention')
- parser.add_argument('--config', default='14.yaml', type=str, help='配置文件路径')
- args = parser.parse_args()
- params = load_config(args.config)
-
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
- params['device'] = device
-
- words = Words(params['word_path'])
- params['word_num'] = len(words)
- params['struct_num'] = 7
- params['words'] = words
- print(params)
- model = Backbone(params)
- model = model.to(device)
- load_checkpoint(model, None, params['checkpoint'])
- model.eval()
-
-
- file = request.files['file']
- gray_img = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_COLOR)
- print(gray_img.shape)
- if len(gray_img.shape) == 2:
- gray_img = torch.Tensor(gray_img) / 255
- gray_img = gray_img.unsqueeze(0).unsqueeze(0)
- if gray_img.shape[2] == 3:
- gray_img = cv2.cvtColor(gray_img, cv2.COLOR_RGBA2GRAY)
- # cv2.imshow('gray',gray_img)
- # cv2.waitKey(0)
- gray_img = torch.Tensor(gray_img) / 255
- gray_img = gray_img.unsqueeze(0).unsqueeze(0)
-
- gray_img = gray_img.to(device)
- image_mask = torch.ones(gray_img.shape)
- image, image_mask = gray_img.to(device), image_mask.to(device)
- prediction = model(image, image_mask, '1')
- print(prediction)
- ans = convert(1, prediction)
- ans = ' '.join(ans)
- ans1 = ans
- ans = ans.replace('\\','\\\\\\')
- ans = "$$" + ans + "$$"
- print(ans)
-
- return {'ans':ans, 'tree':prediction, "latex":ans1}
-
- if __name__ == '__main__':
- app.run()
|