|
- import os
- import cv2
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- import time
- import math
- import copy
- from models.densenet import DenseNet
- from models.attention import Attention
- from models.decoder import PositionEmbeddingSine
- from models.counting import CountingDecoder as counting_decoder
- from counting_utils import gen_counting_label
- from utils import draw_attention_map, draw_counting_map
-
-
- class Inference(nn.Module):
- def __init__(self, params=None, draw_map=False):
- super(Inference, self).__init__()
- self.params = params
- self.draw_map = draw_map
- self.use_label_mask = params['use_label_mask']
- self.encoder = DenseNet(params=self.params)
- self.in_channel = params['counting_decoder']['in_channel']
- self.out_channel = params['counting_decoder']['out_channel']
- self.counting_decoder1 = counting_decoder(self.in_channel, self.out_channel, 3)
- self.counting_decoder2 = counting_decoder(self.in_channel, self.out_channel, 5)
- self.device = params['device']
- self.decoder = decoder_dict[params['decoder']['net']](params=self.params)
-
- """经过cnn后 长宽与原始尺寸比缩小的比例"""
- self.ratio = params['densenet']['ratio']
-
- with open(params['word_path']) as f:
- words = f.readlines()
- print(f'共 {len(words)} 类符号。')
- self.words_index_dict = {i: words[i].strip() for i in range(len(words))}
- self.cal_mae = nn.L1Loss(reduction='mean')
- self.cal_mse = nn.MSELoss(reduction='mean')
-
- def forward(self, images, labels, name, is_train=False):
- cnn_features = self.encoder(images)
- batch_size, _, height, width = cnn_features.shape
- counting_preds1, counting_maps1 = self.counting_decoder1(cnn_features, None)
- counting_preds2, counting_maps2 = self.counting_decoder2(cnn_features, None)
- counting_preds = (counting_preds1 + counting_preds2) / 2
- counting_maps = (counting_maps1 + counting_maps2) / 2
-
- mae = self.cal_mae(counting_preds, gen_counting_label(labels, self.out_channel, True)).item()
- mse = math.sqrt(self.cal_mse(counting_preds, gen_counting_label(labels, self.out_channel, True)).item())
-
- word_probs, word_alphas = self.decoder(cnn_features, counting_preds, is_train=is_train)
-
- if self.draw_map:
- if not os.path.exists(os.path.join(self.params['attention_map_vis_path'], name)):
- os.makedirs(os.path.join(self.params['attention_map_vis_path'], name), exist_ok=True)
- if not os.path.exists(os.path.join(self.params['counting_map_vis_path'], name)):
- os.makedirs(os.path.join(self.params['counting_map_vis_path'], name), exist_ok=True)
- for i in range(images.shape[0]):
- img = images[i][0].detach().cpu().numpy() * 255
- # draw attention_map
- for step in range(len(word_probs)):
- word_atten = word_alphas[step][0].detach().cpu().numpy()
- word_heatmap = draw_attention_map(img, word_atten)
- cv2.imwrite(os.path.join(self.params['attention_map_vis_path'], name, f'word_{step}.jpg'),
- word_heatmap)
- # draw counting_map
- for idx in range(self.out_channel):
- counting_map = counting_maps[0].permute(1, 2, 0)[:, :, idx].detach().cpu()
- counting_heatmap = draw_counting_map(img, counting_map)
- img_name = 'symbol_' + self.words_index_dict[idx] + '_map.jpg'
- cv2.imwrite(os.path.join(self.params['counting_map_vis_path'], name, img_name), counting_heatmap)
-
- return word_probs, word_alphas, mae, mse
-
-
- class AttDecoder(nn.Module):
- def __init__(self, params):
- super(AttDecoder, self).__init__()
- self.params = params
- self.input_size = params['decoder']['input_size']
- self.hidden_size = params['decoder']['hidden_size']
- self.out_channel = params['encoder']['out_channel']
- self.attention_dim = params['attention']['attention_dim']
- self.dropout_prob = params['dropout']
- self.device = params['device']
- self.word_num = params['word_num']
- self.ratio = params['densenet']['ratio']
-
- self.init_weight = nn.Linear(self.out_channel, self.hidden_size)
- self.embedding = nn.Embedding(self.word_num, self.input_size)
- self.word_input_gru = nn.GRUCell(self.input_size, self.hidden_size)
- self.encoder_feature_conv = nn.Conv2d(self.out_channel, self.attention_dim, kernel_size=1)
- self.word_attention = Attention(params)
-
- self.word_state_weight = nn.Linear(self.hidden_size, self.hidden_size)
- self.word_embedding_weight = nn.Linear(self.input_size, self.hidden_size)
- self.word_context_weight = nn.Linear(self.out_channel, self.hidden_size)
- self.counting_context_weight = nn.Linear(self.word_num, self.hidden_size)
- self.word_convert = nn.Linear(self.hidden_size, self.word_num)
-
- if params['dropout']:
- self.dropout = nn.Dropout(params['dropout_ratio'])
-
- def forward(self, cnn_features, counting_preds, is_train=False):
- batch_size, _, height, width = cnn_features.shape
- image_mask = torch.ones((batch_size, 1, height, width)).to(self.device)
-
- cnn_features_trans = self.encoder_feature_conv(cnn_features)
- position_embedding = PositionEmbeddingSine(256, normalize=True)
- pos = position_embedding(cnn_features_trans, image_mask[:, 0, :, :])
- cnn_features_trans = cnn_features_trans + pos
- word_alpha_sum = torch.zeros((batch_size, 1, height, width)).to(device=self.device)
- hidden = self.init_hidden(cnn_features, image_mask)
- word_embedding = self.embedding(torch.ones([batch_size]).long().to(device=self.device))
-
- counting_context_weighted = self.counting_context_weight(counting_preds)
- word_probs = []
- word_probs1 = []
- word_alphas = []
- i = 0
-
- hidden = self.word_input_gru(word_embedding, hidden)
- hidden1 = hidden
- word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden,
- word_alpha_sum, image_mask)
-
- word_alpha_sum1 = word_alpha_sum
-
- current_state = self.word_state_weight(hidden)
-
- word_weighted_embedding = self.word_embedding_weight(word_embedding)
- word_context_weighted = self.word_context_weight(word_context_vec)
-
- if self.params['dropout']:
- word_out_state = self.dropout(
- current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted)
- else:
- word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
-
- word_prob = self.word_convert(word_out_state)
- first_prob, word = word_prob.topk(2, dim=1, largest=True)
- word = word.reshape(2, 1)
- first_prob = first_prob.reshape(2, 1)
- prob0, prob1 = predict(first_prob[0], word_prob), predict(first_prob[1], word_prob)
- word_embedding = self.embedding(word[0])
- word_embedding1 = self.embedding(word[1])
- start, start1 = [], []
- start.append(word[0])
- start1.append(word[1])
- beam0, beam1 = beam(prob0, start, word[0]), beam(prob1, start1, word[1])
- if word[0].item() == 0:
- return word_probs, word_alphas
- word_alphas.append(word_alpha)
- word_probs.append(word[0])
- beam_list = [beam0, beam1]
- while i < 199:
- beam0, beam1 = beam_list[0], beam_list[1]
- hidden = self.word_input_gru(word_embedding, hidden)
- hidden1 = self.word_input_gru(word_embedding1, hidden1)
- word_context_vec, word_alpha, word_alpha_sum = self.word_attention(cnn_features, cnn_features_trans, hidden,
- word_alpha_sum, image_mask)
- word_context_vec1, word_alpha1, word_alpha_sum1 = self.word_attention(cnn_features, cnn_features_trans,
- hidden1,
- word_alpha_sum1, image_mask)
- current_state = self.word_state_weight(hidden)
- current_state1 = self.word_state_weight(hidden1)
- word_weighted_embedding = self.word_embedding_weight(word_embedding)
- word_weighted_embedding1 = self.word_embedding_weight(word_embedding1)
- word_context_weighted = self.word_context_weight(word_context_vec)
- word_context_weighted1 = self.word_context_weight(word_context_vec)
-
- if self.params['dropout']:
- word_out_state = self.dropout(
- current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted)
- word_out_state1 = self.dropout(
- current_state1 + word_weighted_embedding1 + word_context_weighted1 + counting_context_weighted)
- else:
- word_out_state = current_state + word_weighted_embedding + word_context_weighted + counting_context_weighted
- word_out_state1 = current_state1 + word_weighted_embedding1 + word_context_weighted1 + counting_context_weighted
-
- word_prob = self.word_convert(word_out_state)
- word_prob1 = self.word_convert(word_out_state1)
- prob_next1, word = word_prob.topk(2, dim=1, largest=True)
- prob_next2, word1 = word_prob1.topk(2, dim=1, largest=True)
- word = word.reshape(2, 1)
- word1 = word1.reshape(2, 1)
- prob_next1 = prob_next1.reshape(2, 1)
- prob_next2 = prob_next2.reshape(2, 1)
- beam2, beam3 = beam0, beam1
- beam2 = copy.deepcopy(beam0)
- prob_next1 = [predict(prob_next1[0], word_prob), predict(prob_next1[1], word_prob)]
- prob_next2 = [predict(prob_next2[0], word_prob1), predict(prob_next2[1], word_prob1)]
- beam0.prob, beam1.prob, beam2.prob, beam3.prob = prob_next1[0].item() * prob0, prob_next1[1].item() * prob0 \
- , prob_next2[0].item() * prob1, prob_next2[1].item() * prob1
- beam0.hidden, beam1.hidden, beam2.hidden, beam3.hidden = hidden, hidden1, hidden, hidden1
- beam0.word, beam1.word, beam2.word, beam3.word = word[0], word[1], word1[0], word1[1]
-
- beam_list = [beam2, beam0, beam1, beam3]
-
- beam_list.sort(reverse=True)
- beam0 = beam_list[0]
- beam1 = beam_list[1]
- print(prob_next1[0].item(), prob_next1[1].item())
-
- word_embedding = self.embedding(beam0.word)
- word_embedding1 = self.embedding(beam1.word)
- hidden = beam0.hidden
- hidden1 = beam1.hidden
- for i in range(2):
- if beam_list[i].word.item() == 0:
- word_probs1.append(beam_list[i])
- if len(word_probs1) > 1:
- break
- else:
- beam_list[i].word_probs.append(beam_list[i].word)
-
- if len(word_probs1) > 1:
- word_probs1.sort(reverse=True)
- return word_probs1[0].word_probs, word_alphas
- if len(word_probs1) > 0 and word_probs1[0] == beam_list[0]:
- return word_probs1[0].word_probs
- word_alphas.append(word_alpha)
- word_probs.append(word)
- i += 1
- print(beam_list[0].word)
- if len(word_probs1) > 0:
- return word_probs1[0].word_probs
- return beam_list[0].word_probs, word_alphas
-
- def init_hidden(self, features, feature_mask):
- average = (features * feature_mask).sum(-1).sum(-1) / feature_mask.sum(-1).sum(-1)
- average = self.init_weight(average)
- return torch.tanh(average)
-
-
- decoder_dict = {
- 'AttDecoder': AttDecoder
- }
-
-
- class beam:
- def __init__(self, prob, word_probs, word, hidden=1):
- self.word = word
- self.prob = prob
- self.word_probs = word_probs
- self.hidden = hidden
-
- def __lt__(self, other):
- return self.prob < other.prob
-
-
- def predict(prob,probs):
- return prob.exp() / probs.exp().sum(1)
|