@@ -16,8 +16,8 @@ eps: 1e-6 | |||||
weight_decay: 1e-4 | weight_decay: 1e-4 | ||||
beta: 0.9 | beta: 0.9 | ||||
image_resize: True | image_resize: True | ||||
image_width: 3200 | |||||
image_height: 400 | |||||
image_width: 1600 | |||||
image_height: 320 | |||||
image_channel: 1 | image_channel: 1 | ||||
dropout: True | dropout: True | ||||
dropout_ratio: 0.5 | dropout_ratio: 0.5 | ||||
@@ -3,8 +3,8 @@ import models | |||||
from infer.san_decoder import SAN_decoder | from infer.san_decoder import SAN_decoder | ||||
import torch.nn.functional as F | import torch.nn.functional as F | ||||
from models.CAN.counting import CountingDecoder as counting_decoder | from models.CAN.counting import CountingDecoder as counting_decoder | ||||
import matplotlib.pyplot as plt | |||||
import numpy as np | |||||
class Backbone(nn.Module): | class Backbone(nn.Module): | ||||
def __init__(self, params=None): | def __init__(self, params=None): | ||||
super(Backbone, self).__init__() | super(Backbone, self).__init__() | ||||
@@ -22,9 +22,10 @@ class Backbone(nn.Module): | |||||
self.ratio = params['densenet']['ratio'] if params['encoder']['net'] == 'DenseNet' else 16 * params['resnet'][ | self.ratio = params['densenet']['ratio'] if params['encoder']['net'] == 'DenseNet' else 16 * params['resnet'][ | ||||
'conv1_stride'] | 'conv1_stride'] | ||||
def forward(self, images, images_mask): | |||||
def forward(self, images, images_mask, name): | |||||
counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio] | counting_mask = images_mask[:, :, ::self.ratio, ::self.ratio] | ||||
cnn_features = self.encoder(images) | cnn_features = self.encoder(images) | ||||
# visulize_all_channel_into_one(cnn_features, name) | |||||
counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask) | counting_preds1, _ = self.counting_decoder1(cnn_features, counting_mask) | ||||
counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask) | counting_preds2, _ = self.counting_decoder2(cnn_features, counting_mask) | ||||
counting_preds = (counting_preds1 + counting_preds2) / 2 | counting_preds = (counting_preds1 + counting_preds2) / 2 | ||||
@@ -32,6 +33,22 @@ class Backbone(nn.Module): | |||||
return prediction | return prediction | ||||
def visulize_all_channel_into_one(feature_map, i): | |||||
output = feature_map | |||||
output = output.data.squeeze() | |||||
output = output.cpu().numpy() | |||||
output = np.mean(output, axis=0) | |||||
height, width = 320, 1600 | |||||
times = height / float(width) | |||||
plt.rcParams["figure.figsize"] = (1, times) | |||||
plt.axis('off') | |||||
plt.imshow(output, cmap='jet', interpolation='bilinear') | |||||
plt.savefig('vis/{}.png'.format(i), dpi=3 * height) | |||||
class SupConHead(nn.Module): | class SupConHead(nn.Module): | ||||
"""backbone + projection head""" | """backbone + projection head""" | ||||
@@ -15,8 +15,8 @@ from utils import compute_edit_distance | |||||
parser = argparse.ArgumentParser(description='Spatial channel attention') | parser = argparse.ArgumentParser(description='Spatial channel attention') | ||||
parser.add_argument('--config', default='14.yaml', type=str, help='配置文件路径') | parser.add_argument('--config', default='14.yaml', type=str, help='配置文件路径') | ||||
parser.add_argument('--image_path', default='data/CROHME/19_test_images.pkl', type=str, help='测试image路径') | |||||
parser.add_argument('--label_path', default='data/CROHME/19_test_labels.txt', type=str, help='测试label路径') | |||||
parser.add_argument('--image_path', default='data/CROHME/14_test_images.pkl', type=str, help='测试image路径') | |||||
parser.add_argument('--label_path', default='data/CROHME/14_test_labels.txt', type=str, help='测试label路径') | |||||
args = parser.parse_args() | args = parser.parse_args() | ||||
if not args.config: | if not args.config: | ||||
@@ -108,8 +108,8 @@ with torch.no_grad(): | |||||
image_mask = torch.ones(image.shape) | image_mask = torch.ones(image.shape) | ||||
image, image_mask = image.to(device), image_mask.to(device) | image, image_mask = image.to(device), image_mask.to(device) | ||||
prediction = model(image, image_mask) | |||||
prediction = model(image, image_mask, name) | |||||
print(prediction) | |||||
latex_list = convert(1, prediction) | latex_list = convert(1, prediction) | ||||
latex_string = ' '.join(latex_list) | latex_string = ' '.join(latex_list) | ||||
if latex_string == label.strip(): | if latex_string == label.strip(): | ||||
@@ -41,7 +41,7 @@ class Backbone(nn.Module): | |||||
counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, | counting_loss = self.counting_loss(counting_preds1, counting_labels) + self.counting_loss(counting_preds2, | ||||
counting_labels) \ | counting_labels) \ | ||||
+ self.counting_loss(counting_preds, counting_labels) | + self.counting_loss(counting_preds, counting_labels) | ||||
counting_loss = 0.1 * counting_loss | |||||
word_probs, struct_probs, words_alphas, struct_alphas, c2p_probs, c2p_alphas, word_states, c2p_out_states = self.decoder( | word_probs, struct_probs, words_alphas, struct_alphas, c2p_probs, c2p_alphas, word_states, c2p_out_states = self.decoder( | ||||
cnn_features, labels, images_mask, labels_mask, counting_preds, is_train=is_train) | cnn_features, labels, images_mask, labels_mask, counting_preds, is_train=is_train) | ||||
@@ -239,7 +239,7 @@ class SupConLoss(nn.Module): | |||||
loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos | loss = - (self.temperature / self.base_temperature) * mean_log_prob_pos | ||||
loss = loss.view(anchor_count, batch_size).mean() | loss = loss.view(anchor_count, batch_size).mean() | ||||
loss = 0.1 * loss | |||||
loss = 0.05 * loss | |||||
return loss | return loss | ||||
@@ -13,14 +13,18 @@ label_path = './data/train_caption.txt' | |||||
# for item in labels: | # for item in labels: | ||||
# name, *labels = item.split() | # name, *labels = item.split() | ||||
# label = ' '.join(labels) | # label = ' '.join(labels) | ||||
# if len(labels) > 15: | |||||
# if len(labels) > 25: | |||||
# continue | |||||
# if'limits' in label or len(labels) == 0: | |||||
# continue | # continue | ||||
# inp.append(label) | # inp.append(label) | ||||
# gt.append(label) | # gt.append(label) | ||||
# | # | ||||
# train_voc = 'language_model_train.csv' | # train_voc = 'language_model_train.csv' | ||||
# pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(train_voc, index=None, sep='\t') | # pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(train_voc, index=None, sep='\t') | ||||
# | |||||
# 生成测试集 | |||||
word_path = './data/word.txt' | word_path = './data/word.txt' | ||||
new_word_path = './data/word1.txt' | new_word_path = './data/word1.txt' | ||||
with open(word_path) as f: | with open(word_path) as f: | ||||
@@ -65,16 +69,18 @@ inp,gt = [], [] | |||||
c = 0 | c = 0 | ||||
for item in f: | for item in f: | ||||
name, *label = item.split() | name, *label = item.split() | ||||
if len(label) > 15: | |||||
if len(label) > 25: | |||||
continue | continue | ||||
label1 = ' '.join(disturb(label, 1, labels)) | label1 = ' '.join(disturb(label, 1, labels)) | ||||
label2 = ' '.join(label) | label2 = ' '.join(label) | ||||
if 'limits' in label1: | |||||
continue | |||||
inp.append(' '.join(disturb(label, 1, labels))) | inp.append(' '.join(disturb(label, 1, labels))) | ||||
gt.append(' '.join(label)) | gt.append(' '.join(label)) | ||||
if label1 == label2: | if label1 == label2: | ||||
c = c+1 | c = c+1 | ||||
# eval_voc = 'language_model_eval.csv' | |||||
# pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(eval_voc, index=None, sep='\t') | |||||
inp = inp[:5000] | |||||
gt = gt[:5000] | |||||
eval_voc = 'language_model_eval.csv' | |||||
pd.DataFrame({'inp':inp, 'gt':gt}).to_csv(eval_voc, index=None, sep='\t') |
Dear OpenI User
Thank you for your continuous support to the Openl Qizhi Community AI Collaboration Platform. In order to protect your usage rights and ensure network security, we updated the Openl Qizhi Community AI Collaboration Platform Usage Agreement in January 2024. The updated agreement specifies that users are prohibited from using intranet penetration tools. After you click "Agree and continue", you can continue to use our services. Thank you for your cooperation and understanding.
For more agreement content, please refer to the《Openl Qizhi Community AI Collaboration Platform Usage Agreement》