|
- import argparse
- import numpy as np
- import os
- import random
- import math
- import numpy as np
- import mindspore as ms
- import mindspore.nn as nn
- from mindspore import context
- from mindspore.context import ParallelMode
- import mindspore.dataset as ds
- from mindspore import save_checkpoint
- import mindspore.ops as ops
- from mindspore.communication.management import init, get_rank
- from mindspore import Tensor
-
-
- data_dict = np.load('/mnt/cloud_disk/ssk/dvc_p/vgg19.npy', allow_pickle=True, encoding='latin1').item()
-
- def get_conv(C_in, C_out, layername):
- conv = nn.Conv2d(C_in, C_out, kernel_size=3, has_bias=True)
- weight, bias = data_dict[layername][0], data_dict[layername][1]
- conv.weight.set_data(Tensor.from_numpy(weight.transpose(3,2,0,1)))
- conv.bias.set_data(Tensor.from_numpy(bias))
- return nn.SequentialCell(conv, nn.ReLU())
-
-
- class Vgg19():
- def __init__(self):
- self.VGG_MEAN = [103.939, 116.779, 123.68]
-
- self.conv1_1 = get_conv(3, 64, 'conv1_1')
- self.conv1_2 = get_conv(64, 64, 'conv1_2')
- self.conv2_1 = get_conv(64, 128, 'conv2_1')
- self.conv2_2 = get_conv(128, 128, 'conv2_2')
- self.conv3_1 = get_conv(128, 256, 'conv3_1')
- self.conv3_2 = get_conv(256, 256, 'conv3_2')
- self.conv3_3 = get_conv(256, 256, 'conv3_3')
- self.conv3_4 = get_conv(256, 256, 'conv3_4')
- self.conv4_1 = get_conv(256, 512, 'conv4_1')
- self.conv4_2 = get_conv(512, 512, 'conv4_2')
- self.conv4_3 = get_conv(512, 512, 'conv4_3')
- self.conv4_4 = get_conv(512, 512, 'conv4_4')
- self.conv5_1 = get_conv(512, 512, 'conv5_1')
- self.conv5_2 = get_conv(512, 512, 'conv5_2')
- self.conv5_3 = get_conv(512, 512, 'conv5_3')
- self.conv5_4 = get_conv(512, 512, 'conv5_4')
-
- self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
- self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
- self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
- self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
- self.split = ops.Split(1, 3)
-
-
- def extract_feature(self, x):
- x = x * 255.0
-
- # Convert RGB to BGR
- R, G, B = self.split(x)
- x = ops.concat([B - self.VGG_MEAN[0], G - self.VGG_MEAN[1], R - self.VGG_MEAN[2]], axis=1)
-
- x = self.conv1_1(x)
- x = self.conv1_2(x)
- x = self.pool1(x)
- x = self.conv2_1(x)
- x = self.conv2_2(x)
- x = self.pool2(x)
- x = self.conv3_1(x)
- x = self.conv3_2(x)
- x = self.conv3_3(x)
- x = self.conv3_4(x)
- x = self.pool3(x)
- x = self.conv4_1(x)
- x = self.conv4_2(x)
- x = self.conv4_3(x)
- x = self.conv4_4(x)
- x = self.pool4(x)
- x = self.conv5_1(x)
- x = self.conv5_2(x)
- x = self.conv5_3(x)
- x = self.conv5_4(x)
- return x
|