|
- # import tensorflow.compat.v1 as tf
- import tensorlayer as tl
- from utils import color_space #无框架
- from tensorlayer.layers import Module
-
- GEO_DIM = 3
- data_format = 'channels_last'
- channels_axis = 1 if data_format == 'channels_last' else 0
- points_axis = 0 if data_format == 'channels_last' else 1
-
- # Specialized Conv1D for stride 1 and no padding, equivalent of MLP for each sample
- class Conv1D(Module):
- def __init__(self, input_dim, output_dim, activation=tl.ops.relu, use_bias=True, name='conv1d', data_format='channels_last'):
- super(Conv1D, self).__init__()
- assert data_format in ['channels_last', 'channels_first']
- W_shape = (input_dim, output_dim) if data_format == 'channels_last' else (output_dim, input_dim)
- b_shape = (1, output_dim) if data_format == 'channels_last' else (output_dim, 1)
- self.data_format = data_format
- self.use_bias = use_bias
- self.name = name
- self.W = self._get_weights(name+'_W', shape=W_shape,
- init=tl.initializers.XavierUniform(), #和'glorot_uniform'一样)
- trainable=True)
- if use_bias:
- self.b = self._get_weights(name+'_b', shape=b_shape,
- init=tl.initializers.zeros(),
- trainable=True)
- self.activation = activation
-
- def forward(self, x):
- if self.data_format == 'channels_first':
- x = self.W @ x #@是Python 3.5之后加入的矩阵乘法运算符
- else:
- x = x @ self.W
- if self.use_bias:
- x = x + self.b
- if self.activation is not None:
- x = self.activation(x)
- return x
-
- # Sequence of 1D convolutions
- class Conv1DSequence(Module):
- def __init__(self, feature_dims, activation=tl.ops.relu, use_bias=True, name='conv1d_sequence', data_format=data_format):
- super(Conv1DSequence, self).__init__()
- self.name = name
- self.convs = []
- for i in range(len(feature_dims) - 1):
- self.convs.append(Conv1D(feature_dims[i], feature_dims[i+1], name=f'conv1d_{i}',
- activation=activation, use_bias=use_bias, data_format=data_format))
- self.seq_convs = tl.layers.SequentialLayer(self.convs)
-
- def forward(self, x):
- x = self.seq_convs(x)
- # for c in self.convs:
- # x = c(x)
- return x
-
- class Encoder(Module):
- def __init__(self, feature_dims, name='encoder'):
- super(Encoder, self).__init__()
- self.name = name
- self.convs = Conv1DSequence(feature_dims, activation=tl.ops.relu, use_bias=True, data_format=data_format)
-
- def forward(self, x):
- y = self.convs(x)
- y = tl.ops.reduce_max(y, axis=points_axis)
- y = tl.ops.expand_dims(y, points_axis)
- return y
-
- class FoldingLayer(Module):
- def __init__(self, input_dim, filters, final_dim, use_bias=False, name='folding'):
- super(FoldingLayer, self).__init__()
- self.name = name
- self.convs = Conv1DSequence((input_dim, filters, final_dim),
- activation=tl.ops.LeakyReLU(0.1), use_bias=use_bias, data_format=data_format)
-
- def forward(self, x, y):
- x = tl.ops.concat([x, y], axis=channels_axis) #[解码后点的个数,128+3]
- x = self.convs(x)
- return x
-
- class Decoder(Module):
- def __init__(self, input_dim, filters, y_dim, final_dim, name='decoder'):#3,64,128,3
- super(Decoder, self).__init__()
- self.name = name
- self.foldlayer1 = FoldingLayer(input_dim + y_dim, filters, filters, name=f'folding_1')
- self.foldlayer2 = FoldingLayer(filters + y_dim, filters, final_dim, name=f'folding_2')
- # self.layers = [
- # FoldingLayer(input_dim + y_dim, filters, filters, name=f'folding_1'),
- # FoldingLayer(filters + y_dim, filters, final_dim, name=f'folding_2')]
- # self.n_foldings = len(self.layers)
-
- def forward(self, y, grid):
- tile_multiples = [1, 1]
- tile_multiples[points_axis] = grid.shape[points_axis] #[解码后点的个数,1]
- y = tl.ops.tile(y, tile_multiples) #复制多个 [解码后点的个数,128]
-
- # self.xs = [grid] #[解码后点的个数,3] 相当于格点坐标,格点数量=解码后点的个数
- # for i in range(self.n_foldings):
- # self.xs.append(self.layers[i](self.xs[i], self.y))
- # self.x_tilde = self.xs[-1]
- x = self.foldlayer1(grid,y) #x.shape: (解码后点的个数, 64) 相当于在每个格点都同样放128个编码数据,即y,然后卷积。可以理解为同样的128个数据在不同的位置进行了演化,使得不同的位置会有不同的坐标值,得到的x在每个维度上是不一样的,没有重复关系
- x_tilde = self.foldlayer2(x,y) #x_tilde.shape: (解码后点的个数, 3) 128个同样的编码数据和不同的64个数据合并,再卷积。可以理解为128个整体特征+不同位置的演化结果再次演化,得到解码后的点坐标
- return x_tilde
-
- class Model(Module):
- def __init__(self, points, grid, filter_sizes=(128, 128, 128, 128), filters=64, ndim=3):
- super(Model, self).__init__()
- self.encoder = Encoder((GEO_DIM,) + filter_sizes)
- self.decoder = Decoder(ndim, filters, filter_sizes[-1], ndim)
- assert points.shape[0] == 1
- self.x = points[0, :, :GEO_DIM] #(37183, 3)
- self.ori_colors = points[0, :, GEO_DIM:]
- self.grid = grid #[解码后点的个数,3] 相当于格点坐标,格点数量=解码后点的个数
-
- def forward(self,x):
- x = tl.convert_to_tensor(x,dtype=tl.float32)
- y = self.encoder(x) #y.shape: (1, 128) 这128个数可以理解为对这些点坐标的整体特征概括
- x_tilde = self.decoder(y, self.grid)
- return x_tilde
|