|
- import torch
- import torch.nn as nn
- import numpy as np
- from torch.autograd import Variable
- import torch.nn.functional as F
-
- from backbones.resnet import resnet101
- from semantic import semantic
- from hgnn import HGNN_Model
- from classifier_layer import Classifier_Layer
- from six.moves import cPickle
- import time
- import torchvision.models as models
-
- class AdaHGNN(nn.Module):
- def __init__(self, image_feature_dim, output_dim, word_features, num_classes=80, word_feature_dim = 300):
- super(AdaHGNN, self).__init__()
- self.resnet_101 = resnet101()
-
- self.num_classes = num_classes
- self.word_feature_dim = word_feature_dim
- self.image_feature_dim = image_feature_dim
-
- self.word_semantic = semantic(num_classes= self.num_classes,
- image_feature_dim = self.image_feature_dim,
- word_feature_dim=self.word_feature_dim)
-
- self.word_features = word_features
- self._word_features = self.load_features()
-
- self.hgnn_model = HGNN_Model(input_dim=self.image_feature_dim)
-
- self.output_dim = output_dim
- self.fc = nn.Linear(5120, self.output_dim)
- self.classifiers = Classifier_Layer(self.num_classes, self.output_dim)
- # self.test_linear = nn.Linear(2048, self.num_classes)
-
- def forward(self, x):
- batch_size = x.size()[0]
- feature_3, feature_4 = self.resnet_101(x)
- stage_3_input, stage_4_input = self.word_semantic(batch_size,
- feature_3,
- feature_4,
- torch.tensor(self._word_features).cuda())
- hgnn_feature = self.hgnn_model(stage_3_input, stage_4_input)
-
- output = torch.cat((hgnn_feature.view(batch_size*self.num_classes,-1), stage_4_input.view(-1, self.image_feature_dim)), 1)
- output = self.fc(output)
- output = torch.tanh(output)
- output = output.contiguous().view(batch_size, self.num_classes, self.output_dim)
- result = self.classifiers(output)
- return result
-
- def load_features(self):
- return Variable(torch.from_numpy(np.load(self.word_features).astype(np.float32))).cuda()
|