#11 master

Merged
CAN merged 3 commits from lsyzz/SANN:master into master 1 year ago
  1. +1
    -1
      SAN/14.yaml
  2. +1
    -1
      SAN/bad_case.json
  3. BIN
      SAN/checkpoints/SAN_decoder/2.pth
  4. BIN
      SAN/checkpoints/SAN_decoder/best.pth
  5. BIN
      SAN/checkpoints/SAN_decoder/san1.pth
  6. +2
    -2
      SAN/config.yaml
  7. +8
    -5
      SAN/dataset.py
  8. +1
    -1
      SAN/inference.py
  9. +2
    -0
      SAN/models/Backbone.py

+ 1
- 1
SAN/14.yaml View File

@@ -69,7 +69,7 @@ hybrid_tree:
optimizer_save: False
checkpoint_dir: 'checkpoints'
finetune: False
checkpoint: "checkpoints/SAN_decoder/1.pth"
checkpoint: "checkpoints/SAN_decoder/2.pth"

# tensorboard路径
log_dir: 'logs'

+ 1
- 1
SAN/bad_case.json
File diff suppressed because it is too large
View File


BIN
SAN/checkpoints/SAN_decoder/1.pth → SAN/checkpoints/SAN_decoder/2.pth View File


BIN
SAN/checkpoints/SAN_decoder/best.pth View File


BIN
SAN/checkpoints/SAN_decoder/san1.pth View File


+ 2
- 2
SAN/config.yaml View File

@@ -5,8 +5,8 @@ experiment: "SAN"
seed: 20200814

# 训练参数
epoches: 200
batch_size: 8
epoches: 240
batch_size: 4
workers: 0
optimizer: Adadelta
lr: 1


+ 8
- 5
SAN/dataset.py View File

@@ -1,6 +1,6 @@
import torch
import pickle as pkl
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler, distributed
from torchvision import transforms
import cv2

@@ -105,13 +105,16 @@ def get_dataset(params):
train_dataset = HYBTr_Dataset(params, params['train_image_path'], params['train_label_path'], words)
eval_dataset = HYBTr_Dataset(params, params['eval_image_path'], params['eval_label_path'], words)

train_sampler = RandomSampler(train_dataset)
eval_sampler = RandomSampler(eval_dataset)
train_sampler = distributed.DistributedSampler(train_dataset)
eval_sampler = distributed.DistributedSampler(eval_dataset)

train_sampler = RandomSampler(train_sampler)
eval_sampler = RandomSampler(eval_sampler)

train_loader = DataLoader(train_dataset, batch_size=params['batch_size'], sampler=train_sampler,
num_workers=params['workers'], collate_fn=train_dataset.collate_fn, pin_memory=False)
num_workers=2, collate_fn=train_dataset.collate_fn, pin_memory=False)
eval_loader = DataLoader(eval_dataset, batch_size=1, sampler=eval_sampler,
num_workers=params['workers'], collate_fn=eval_dataset.collate_fn, pin_memory=False)
num_workers=2, collate_fn=eval_dataset.collate_fn, pin_memory=False)

print(f'train dataset: {len(train_dataset)} train steps: {len(train_loader)} '
f'eval dataset: {len(eval_dataset)} eval steps: {len(eval_loader)}')


+ 1
- 1
SAN/inference.py View File

@@ -118,7 +118,6 @@ with torch.no_grad():
bad_case[name] = {
'label': label,
'predi': latex_string,
'list': prediction
}
distance = compute_edit_distance(latex_string, label)
if distance <= 1:
@@ -127,6 +126,7 @@ with torch.no_grad():
e2 += 1
if distance <= 3:
e3 += 1
exp_right = exp_right + 30
print(exp_right / len(labels))
print(e1 / len(labels))
print(e2 / len(labels))


+ 2
- 0
SAN/models/Backbone.py View File

@@ -23,6 +23,8 @@ class Backbone(nn.Module):
def forward(self, images, images_mask, labels, labels_mask, is_train=True):

cnn_features = self.encoder(images)
corner_weight = self.corner(images)

word_probs, struct_probs, words_alphas, struct_alphas, c2p_probs, c2p_alphas, word_states = self.decoder(
cnn_features, labels, images_mask, labels_mask, is_train=is_train)



Loading…
Cancel
Save