|
- import tensorflow as tf
- from tensorflow.python.ops import rnn
- #import my_rnn
- import pdb
-
- eps = 1e-6
- # For version compatibility
- def reduce_sum(input_tensor, axis=None, keepdims=False):
- try:
- return tf.reduce_sum(input_tensor, axis=axis, keepdims=keepdims)
- except:
- return tf.reduce_sum(input_tensor, axis=axis, keep_dims=keepdims)
-
-
- # For version compatibility
- def softmax(logits, axis=None):
- try:
- return tf.nn.softmax(logits, axis=axis)
- except:
- return tf.nn.softmax(logits, dim=axis)
- def get_shape(inputs, name=None):
- name = "shape" if name is None else name
- with tf.name_scope(name):
- static_shape = inputs.get_shape().as_list()
- dynamic_shape = tf.shape(inputs)
- shape = []
- for i, dim in enumerate(static_shape):
- dim = dim if dim is not None else dynamic_shape[i]
- shape.append(dim)
- return(shape)
-
- def routing(input, b_IJ, num_outputs=10, num_dims=16, iter_routing = 3):
- ''' The routing algorithm.
- Args:
- input: A Tensor with [batch_size, num_caps_l=1152, 1, length(u_i)=8, 1]
- shape, num_caps_l meaning the number of capsule in the layer l.
- num_outputs: the number of output capsules.
- num_dims: the number of dimensions for output capsule.
- Returns:
- A Tensor of shape [batch_size, num_caps_l_plus_1, length(v_j)=16, 1]
- representing the vector output `v_j` in the layer l+1
- Notes:
- u_i represents the vector output of capsule i in the layer l, and
- v_j the vector output of capsule j in the layer l+1.
- '''
-
- # W: [1, num_caps_i, num_caps_j * len_v_j, len_u_j, 1] (1, 10, 1000, 100, 1)
- input_shape = get_shape(input)#batch_size, Mp, dim, 1
- W = tf.get_variable('Weight', shape=[1, input_shape[1], num_dims * num_outputs] + input_shape[-2:],
- dtype=tf.float32, initializer=tf.random_normal_initializer(stddev=0.01))
- #print (W.shape)
- biases = tf.get_variable('bias', shape=(1, 1, num_outputs, num_dims, 1))
-
- # Eq.2, calc u_hat
- # Since tf.matmul is a time-consuming op,
- # A better solution is using element-wise multiply, reduce_sum and reshape
- # ops instead. Matmul [a, b] x [b, c] is equal to a series ops as
- # element-wise multiply [a*c, b] * [a*c, b], reduce_sum at axis=1 and
- # reshape to [a, c]
- input = tf.tile(input, [1, 1, num_dims * num_outputs, 1, 1])
- # assert input.get_shape() == [batch_size, 1152, 160, 8, 1]
-
- u_hat = reduce_sum(W * input, axis=3, keepdims=True)
- u_hat = tf.reshape(u_hat, shape=[-1, input_shape[1], num_outputs, num_dims, 1])
- # assert u_hat.get_shape() == [batch_size, 1152, 10, 16, 1]
-
- # In forward, u_hat_stopped = u_hat; in backward, no gradient passed back from u_hat_stopped to u_hat
- u_hat_stopped = tf.stop_gradient(u_hat, name='stop_gradient')
-
- # line 3,for r iterations do
- for r_iter in range(iter_routing):
- with tf.variable_scope('iter_' + str(r_iter)):
- # line 4:
- # => [batch_size, 1152, 10, 1, 1]
- c_IJ = softmax(b_IJ, axis=2)
-
- # At last iteration, use `u_hat` in order to receive gradients from the following graph
- if r_iter == iter_routing - 1:
- # line 5:
- # weighting u_hat with c_IJ, element-wise in the last two dims
- # => [batch_size, 1152, 10, 16, 1]
- s_J = tf.multiply(c_IJ, u_hat)
- # then sum in the second dim, resulting in [batch_size, 1, 10, 16, 1]
- s_J = reduce_sum(s_J, axis=1, keepdims=True) + biases
- # assert s_J.get_shape() == [batch_size, 1, num_outputs, num_dims, 1]
-
- # line 6:
- # squash using Eq.1,
- v_J = squash(s_J)
- # assert v_J.get_shape() == [batch_size, 1, 10, 16, 1]
- elif r_iter < iter_routing - 1: # Inner iterations, do not apply backpropagation
- s_J = tf.multiply(c_IJ, u_hat_stopped)
- s_J = reduce_sum(s_J, axis=1, keepdims=True) + biases
- v_J = squash(s_J)
-
- # line 7:
- # reshape & tile v_j from [batch_size ,1, 10, 16, 1] to [batch_size, 1152, 10, 16, 1]
- # then matmul in the last tow dim: [16, 1].T x [16, 1] => [1, 1], reduce mean in the
- # batch_size dim, resulting in [1, 1152, 10, 1, 1]
- v_J_tiled = tf.tile(v_J, [1, input_shape[1], 1, 1, 1])
- u_produce_v = reduce_sum(u_hat_stopped * v_J_tiled, axis=3, keepdims=True)
- # assert u_produce_v.get_shape() == [batch_size, 1152, 10, 1, 1]
-
- # b_IJ += tf.reduce_sum(u_produce_v, axis=0, keep_dims=True)
- b_IJ += u_produce_v
-
- return(v_J)
-
-
- def squash(vector):
- '''Squashing function corresponding to Eq. 1
- Args:
- vector: A tensor with shape [batch_size, 1, num_caps, vec_len, 1] or [batch_size, num_caps, vec_len, 1].
- Returns:
- A tensor with the same shape as vector but squashed in 'vec_len' dimension.
- '''
- vec_squared_norm = reduce_sum(tf.square(vector), -2, keepdims=True)
- scalar_factor = vec_squared_norm / (1 + vec_squared_norm) / tf.sqrt(vec_squared_norm + eps)
- vec_squashed = scalar_factor * vector # element-wise
- return(vec_squashed)
-
-
-
|