|
- # -*- coding = utf-8 -*-
- '''
- # @time:2023/2/18 18:49
- # Author:DFTL
- # @File:Fusion_block.py
- '''
-
- import torch
- import torch.nn as nn
-
- class Fusion_block_1(nn.Module):
- def __init__(self):
- super(Fusion_block_1, self).__init__()
-
- self.conv1 = nn.Conv1d(2,1,kernel_size=3,padding=1)
-
- self.linear = nn.Linear(64,12)
-
- self.flatten = nn.Flatten()
-
-
- def forward(self,image_,knowledge_): #shape:(1,1,64)
-
- # print(image_.shape)
- # print(knowledge_.shape)
- knowledge_ = torch.unsqueeze(knowledge_,0)
-
-
- cat_ = torch.cat([image_,knowledge_],1)
-
- output = self.conv1(cat_)
-
- output = self.flatten(output)
- # print(output.shape)
-
- output = self.linear(output)
-
- return output
-
- class Fusion_Module(nn.Module):
-
- def __init__(self):
- super(Fusion_Module, self).__init__()
- self.relu = nn.ReLU()
-
- self.linear1 = nn.Linear(64, 64)
- self.linear2 = nn.Linear(64, 64)
- self.linear3 = nn.Linear(64, 64)
- self.linear4 = nn.Linear(64, 64)
-
- self.linear5 = nn.Linear(64*4,12)
-
- '''input1为图像语义特征,input2、3分别为两条知识向量,input1.shape=(batch,l),input2.shape=(batch,l)...'''
- def forward(self,input1,input2=None,input3=None,input4=None):
-
- input1 = self.linear1(input1)
- input1 = self.relu(input1)
- input2 = self.linear2(input2)
- input2 = self.relu(input2)
- input3 = self.linear3(input3)
- input3 = self.relu(input3)
- input4 = self.linear4(input4)
- input4 = self.relu(input4)
-
- output = torch.cat([input1,input2,input3,input4],dim=1) #8*(64*4)
-
- output = self.linear5(output)
-
- return output
-
- if __name__ == '__main__':
- i = torch.randn([8,64])
- t = torch.randn([8,64])
|