|
- import ms_adapter.pytorch.nn as nn
- from ms_adapter.pytorch.nn.modules.pooling import MaxPool2d
-
-
- def test_add_modules():
- layer1 = MaxPool2d(1, 2)
- layer2 = MaxPool2d(1, 2)
- layer1.add_module("layer2", layer2)
-
-
- def test_named_children1():
- layer1 = MaxPool2d(1, 2)
- layer2 = MaxPool2d(1, 2)
- layer1.add_module("layer2", layer2)
- aa = layer1.named_children()
-
-
- def test_modules():
- layer1 = MaxPool2d(1, 2)
- layer2 = MaxPool2d(1, 2)
- layer1.add_module("layer2", layer2)
- aa = layer1.modules()
-
-
- def test_named_parameters():
- layer1 = MaxPool2d(1, 2)
- layer2 = MaxPool2d(1, 2)
- layer1.add_module("layer2", layer2)
- aa = layer1.named_parameters()
- for k, v in aa:
- print(k, v)
-
- def test_named_children2():
- layer1 = MaxPool2d(1, 2)
- layer2 = MaxPool2d(1, 2)
- layer1.add_module("layer2", layer2)
- for k, v in layer1.named_children():
- print(k, v)
-
- def test_modulelist():
- conv = nn.Conv2d(100, 20, 3)
- bn = nn.BatchNorm2d(20)
- relu = nn.ReLU()
- modulelist = nn.ModuleList([bn])
- modulelist.insert(0, conv)
- modulelist.append(relu)
- modulelist.extend([relu, relu])
- print(modulelist)
-
-
- test_named_children1()
- test_add_modules()
- test_named_children2()
- test_modules()
- test_named_parameters()
- test_modulelist()
-
-
- # save state_dict
- # import torch
- import ms_adapter.pytorch as torch
-
- class MyNet(torch.nn.Module):
- def __init__(self):
- super(MyNet, self).__init__()
- self.conv1 = torch.nn.Conv2d(3,32,3,1,1)
- self.relu1 = torch.nn.ReLU()
- self.max_pooling1 = torch.nn.MaxPool2d(2,1)
-
- self.conv2 = torch.nn.Conv2d(3,32,2,1,1)
- self.relu2 = torch.nn.ReLU()
- self.max_pooling2 = torch.nn.MaxPool2d(2,1)
-
- self.dense1 = torch.nn.Linear(32*3*3,128)
- self.dense2 = torch.nn.Linear(128,10)
-
- def forward(self,x):
- x = self.conv1(x)
- x = self.relu1(x)
- x = self.max_pooling1(x)
-
- x = self.conv2(x)
- x = self.relu2(x)
- x = self.max_pooling2(x)
- x = self.dense1(x)
- x = self.dense2(x)
- return x
-
- model = MyNet()
- print(model.state_dict())
- for param_tensor in model.state_dict():
- print(param_tensor)
|