|
- import torch.nn as nn
- from collections import OrderedDict
-
-
- class C1(nn.Module):
- def __init__(self):
- super(C1, self).__init__()
-
- self.c1 = nn.Sequential(OrderedDict([
- ('c1', nn.Conv2d(1, 6, kernel_size=(5, 5))),
- ('relu1', nn.ReLU()),
- ('s1', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
- ]))
-
- def forward(self, img):
- output = self.c1(img)
- return output
-
-
- class C2(nn.Module):
- def __init__(self):
- super(C2, self).__init__()
-
- self.c2 = nn.Sequential(OrderedDict([
- ('c2', nn.Conv2d(6, 16, kernel_size=(5, 5))),
- ('relu2', nn.ReLU()),
- ('s2', nn.MaxPool2d(kernel_size=(2, 2), stride=2))
- ]))
-
- def forward(self, img):
- output = self.c2(img)
- return output
-
-
- class C3(nn.Module):
- def __init__(self):
- super(C3, self).__init__()
-
- self.c3 = nn.Sequential(OrderedDict([
- ('c3', nn.Conv2d(16, 120, kernel_size=(5, 5))),
- ('relu3', nn.ReLU())
- ]))
-
- def forward(self, img):
- output = self.c3(img)
- return output
-
-
- class F4(nn.Module):
- def __init__(self):
- super(F4, self).__init__()
-
- self.f4 = nn.Sequential(OrderedDict([
- ('f4', nn.Linear(120, 84)),
- ('relu4', nn.ReLU())
- ]))
-
- def forward(self, img):
- output = self.f4(img)
- return output
-
-
- class F5(nn.Module):
- def __init__(self):
- super(F5, self).__init__()
-
- self.f5 = nn.Sequential(OrderedDict([
- ('f5', nn.Linear(84, 10)),
- ('sig5', nn.LogSoftmax(dim=-1))
- ]))
-
- def forward(self, img):
- output = self.f5(img)
- return output
-
-
- class LeNet5(nn.Module):
- """
- Input - 1x32x32
- Output - 10
- """
- def __init__(self):
- super(LeNet5, self).__init__()
-
- self.c1 = C1()
- self.c2_1 = C2()
- self.c2_2 = C2()
- self.c3 = C3()
- self.f4 = F4()
- self.f5 = F5()
-
- def forward(self, img):
- output = self.c1(img)
- x = self.c2_1(output)
- output = self.c2_2(output)
- output += x
-
- output = self.c3(output)
- output = output.view(img.size(0), -1)
- output = self.f4(output)
- output = self.f5(output)
- return output
|