|
- import torch
- from torch import nn
- import torchvision
-
-
- class Model(nn.Module):
- def __init__(self):
- super().__init__()
-
- # self.resnet = nn.Sequential(torchvision.models.resnet34(pretrained=True))
-
- self.conv1 = nn.Sequential(
- nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), padding=3),
- nn.LeakyReLU(),
- nn.MaxPool2d(kernel_size=2, stride=2),
-
- nn.Conv2d(in_channels=64, out_channels=192, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
- # nn.MaxPool2d(kernel_size=2, stride=2),
- )
- self.conv2 = nn.Sequential(
- nn.Conv2d(in_channels=192, out_channels=256, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
- nn.MaxPool2d(kernel_size=2, stride=2),
- )
- self.conv3 = nn.Sequential(
- nn.Conv2d(in_channels=256, out_channels=128, kernel_size=(1, 1)),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=128, out_channels=256, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=256, out_channels=256, kernel_size=(1, 1)),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
- nn.MaxPool2d(kernel_size=2, stride=2),
- )
- self.conv4 = nn.Sequential(
- nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1)),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
-
- nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1)),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
-
- nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1)),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
-
- nn.Conv2d(in_channels=512, out_channels=256, kernel_size=(1, 1)),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=256, out_channels=512, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
-
- nn.Conv2d(in_channels=512, out_channels=512, kernel_size=(1, 1)),
- nn.LeakyReLU(),
-
- nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
-
- nn.MaxPool2d(kernel_size=2, stride=2),
- )
- self.conv5 = nn.Sequential(
- nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=(1, 1)),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
-
- nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=(1, 1)),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
-
- nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3, 3), stride=(2, 2), padding=1),
- nn.LeakyReLU(),
- )
- self.conv6 = nn.Sequential(
- nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3, 3), padding=1),
- nn.LeakyReLU(),
- nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=(3, 3), padding=1),
- # nn.BatchNorm2d(1024),
- nn.LeakyReLU(),
- )
- # 最后的两个全连接层
- self.conn_layer = nn.Sequential(
- # nn.Linear(in_features=7 * 7 * 1024, out_features=4096),
- nn.Linear(in_features=7 * 7 * 1024, out_features=2048), # 减小参数,防止显存不足
- nn.LeakyReLU(),
- # nn.Linear(in_features=4096, out_features=7*7*30),
- nn.Linear(in_features=2048, out_features=7 * 7 * 30), # 减小参数,防止显存不足
- nn.Softmax() # 不确定论文中是否是softmax, 增加防止梯度爆炸
- )
-
- pass
-
- def forward(self, X):
- X = self.conv1(X)
-
- X = self.conv2(X)
-
- X = self.conv3(X)
-
- X = self.conv4(X)
- X = self.conv5(X)
- X = self.conv6(X)
-
- # print(X.size())
- # X = X.view(X.size()[0], -1)
- # X = X.view(-1, 7 * 7 * 1024)
- X = X.view(-1, 7 * 7 * 1024)
-
- X = self.conn_layer(X)
- NUM_BBOX = 2
- CLASSES = 1
- # return X.reshape(-1, (5 * NUM_BBOX + len(CLASSES)), 7, 7) # 记住最后要reshape一下输出数据
- return X.reshape(-1, 20+2+8, 7, 7) # 记住最后要reshape一下输出数据
-
- # return X
-
-
- if __name__ == '__main__':
- net = Model()
- trainer = torch.optim.SGD(net.parameters(), lr=0.2, weight_decay=5e-4)
-
- net.train()
- X = torch.normal(0, 1, (1, 3, 448, 448))
- Y = net(X)
-
- print(Y)
- print(Y.size())
- print(Y.max())
-
|