@@ -35,8 +35,8 @@ class ELMoLSTM(nn.Cell): | |||
self.num_directions = 2 | |||
self.batch_first = batch_first | |||
self.support_non_tensor_inputs = True | |||
forward_layers = [] | |||
backward_layers = [] | |||
layers = nn.CellList() | |||
lstm_input_size = input_size | |||
for i in range(num_layers): | |||
forward_cell = LSTMCellWithProjection(lstm_input_size, hidden_size, cell_clip=cell_clip, proj_size=proj_size, proj_clip=proj_clip) | |||
@@ -57,14 +57,12 @@ class ELMoLSTM(nn.Cell): | |||
backward_layer = DynamicRNN(backward_cell) | |||
lstm_input_size = proj_size | |||
forward_layers.append(forward_layer) | |||
backward_layers.append(backward_layer) | |||
layers.append(forward_layer) | |||
layers.append(backward_layer) | |||
self.forward_layers = forward_layers | |||
self.backward_layers = backward_layers | |||
self.layers = layers | |||
self.dropout = nn.Dropout(keep_prob=keep_prob) | |||
self.cast = P.Cast() | |||
@ms_function | |||
def construct(self, x, xr, h=None, seq_length=None): | |||
max_batch_size = x.shape[0] if self.batch_first else x.shape[1] | |||
@@ -72,24 +70,26 @@ class ELMoLSTM(nn.Cell): | |||
h = _init_state(self.num_layers * self.num_directions, max_batch_size, self.hidden_size, self.proj_size, x.dtype) | |||
if self.batch_first: | |||
x = P.Transpose()(x, (1, 0, 2)) | |||
xr = P.Transpose()(xr, (1, 0, 2)) | |||
xr = P.Transpose()(xr, (1, 0, 2)) | |||
x_f, x_b = self._stacked_bi_dynamic_rnn(x, xr, h, seq_length) | |||
if self.batch_first: | |||
x_f = P.Transpose()(x_f, (1, 0, 2)) | |||
x_b = P.Transpose()(x_b, (1, 0, 2)) | |||
return x_f, x_b | |||
def _stacked_bi_dynamic_rnn(self, x, xr, h, seq_length): | |||
def _stacked_bi_dynamic_rnn(self, x, xr, h, seq_length=None): | |||
"""stacked bidirectional dynamic_rnn""" | |||
input_forward = x | |||
input_backward = xr | |||
outputs_f = () | |||
outputs_b = () | |||
for i, (forward_cell, backward_cell) in enumerate(zip(self.forward_layers, self.backward_layers)): | |||
for i in range(self.num_layers): | |||
offset = i * 2 | |||
h_f_i = (P.Squeeze(0)(h[0][offset : offset+1]), P.Squeeze(0)(h[1][offset : offset+1])) | |||
h_b_i = (P.Squeeze(0)(h[0][offset + 1 : offset + 2]), P.Squeeze(0)(h[1][offset+1 : offset + 2])) | |||
forward_cell = self.layers[offset] | |||
backward_cell = self.layers[offset + 1] | |||
output_f, _ = forward_cell(input_forward, h_f_i, seq_length) | |||
output_b, _ = backward_cell(input_backward, h_b_i, seq_length) | |||
if seq_length is None: | |||
@@ -99,5 +99,7 @@ class ELMoLSTM(nn.Cell): | |||
outputs_f += (self.dropout(output_f),) | |||
outputs_b += (self.dropout(output_b),) | |||
input_forward = output_f | |||
input_backward = output_b | |||
return outputs_f[-1], outputs_b[-1] |
@@ -170,9 +170,9 @@ class SampledSoftmaxLoss(_Loss): | |||
sampled_values = self.sampler(labels) | |||
(sampled, true_expected_count, sampled_expected_count) = sampled_values | |||
# sampled = ops.stop_gradient(sampled) | |||
# true_expected_count = ops.stop_gradient(true_expected_count) | |||
# sampled_expected_count = ops.stop_gradient(sampled_expected_count) | |||
sampled = ops.stop_gradient(sampled) | |||
true_expected_count = ops.stop_gradient(true_expected_count) | |||
sampled_expected_count = ops.stop_gradient(sampled_expected_count) | |||
if not sampled.dtype == mstype.int32: | |||
sampled = self.cast(sampled, mstype.int32) | |||
@@ -216,14 +216,14 @@ class SampledSoftmaxLoss(_Loss): | |||
acc_indices_2d = self.reshape(acc_indices[:acc_weights_length], (-1, 1)) | |||
acc_ids_2d_int32 = self.reshape(acc_ids[:acc_weights_length], (-1, 1)) | |||
sparse_indices = self.concat_dim1((acc_indices_2d, acc_ids_2d_int32)) | |||
#sparse_indices = self.cast(sparse_indices, mstype.int32) | |||
sparse_indices = self.cast(sparse_indices, mstype.int32) | |||
# Create sampled_logits_shape = [batch_size, num_sampled] | |||
sampled_logits_shape = sampled_logits.shape | |||
# if self.dtype(sampled_logits) != self.dtype(acc_weights): | |||
# acc_weights = self.cast(acc_weights, self.dtype(sampled_logits)) | |||
if self.dtype(sampled_logits) != self.dtype(acc_weights): | |||
acc_weights = self.cast(acc_weights, self.dtype(sampled_logits)) | |||
sampled_logits += self.sparse_to_dense( | |||
sampled_logits += self.scatter_nd( | |||
sparse_indices, | |||
acc_weights, | |||
sampled_logits_shape) | |||
@@ -7,53 +7,44 @@ from mindspore import context | |||
class TestELMoLSTM(unittest.TestCase): | |||
def test_elmo_lstm(self): | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend') | |||
inputs = Tensor(np.random.randn(3, 10, 10), mindspore.float32) | |||
backward = Tensor(np.random.randn(3, 10, 10), mindspore.float32) | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend',device_id=5) | |||
inputs = Tensor(np.random.randn(10, 3, 10), mindspore.float32) | |||
backward = Tensor(np.random.randn(10, 3, 10), mindspore.float32) | |||
hx = Tensor(np.random.randn(4, 3, 30), mindspore.float32) | |||
cx = Tensor(np.random.randn(4, 3, 20), mindspore.float32) | |||
lstm = ELMoLSTM(10, 20, 30, 2, 0.5, 1.0, 1.0, True, True, True) | |||
lstm = ELMoLSTM(10, 20, 30, 2, 0.5, 1.0, 1.0, True, True, False) | |||
outputs, (hy, cy) = lstm(inputs, backward, (hx, cx)) | |||
outputs_f, outputs_b = lstm(inputs, backward, (hx, cx)) | |||
# (num_layers, batch_size, seq_length, hidden_size) | |||
assert outputs.shape == (2, 3, 10, 60) | |||
# (num_layers, batch_size, hidden_size) | |||
assert hy.shape == (2, 3, 60) | |||
# (num_layers, batch_size, hidden_size) | |||
assert cy.shape == (2, 3, 40) | |||
# (batch_size, seq_length, hidden_size) | |||
assert outputs_f.shape == (10, 3, 30) | |||
assert outputs_b.shape == (10, 3, 30) | |||
def test_elmo_lstm_batch_first(self): | |||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') | |||
context.set_context(mode=context.PYNATIVE_MODE, device_target='Ascend',device_id=5) | |||
inputs = Tensor(np.random.randn(3, 10, 10), mindspore.float32) | |||
backward = Tensor(np.random.randn(3, 10, 10), mindspore.float32) | |||
hx = Tensor(np.random.randn(4, 3, 30), mindspore.float32) | |||
cx = Tensor(np.random.randn(4, 3, 20), mindspore.float32) | |||
lstm = ELMoLSTM(10, 20, 30, 2, 0.5, 1.0, 1.0, True, True, True) | |||
outputs, (hy, cy) = lstm(inputs, backward, (hx, cx)) | |||
outputs_f, outputs_b = lstm(inputs, backward, (hx, cx)) | |||
# (num_layers, batch_size, seq_length, hidden_size) | |||
assert outputs.shape == (2, 3, 10, 60) | |||
# (num_layers, batch_size, hidden_size) | |||
assert hy.shape == (2, 3, 60) | |||
# (num_layers, batch_size, hidden_size) | |||
assert cy.shape == (2, 3, 40) | |||
# (batch_size, seq_length, hidden_size) | |||
assert outputs_f.shape == (3, 10, 30) | |||
assert outputs_b.shape == (3, 10, 30) | |||
def test_elmo_lstm_train_one_step(self): | |||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend') | |||
context.set_context(mode=context.GRAPH_MODE, device_target='Ascend', device_id=5) | |||
inputs = Tensor(np.random.randn(3, 10, 10), mindspore.float32) | |||
backward = Tensor(np.random.randn(3, 10, 10), mindspore.float32) | |||
hx = Tensor(np.random.randn(4, 3, 30), mindspore.float32) | |||
cx = Tensor(np.random.randn(4, 3, 20), mindspore.float32) | |||
lstm = ELMoLSTM(10, 20, 30, 2, 0.5, 1.0, 1.0, True, True, True) | |||
h = (hx, cx) | |||
lstm = ELMoLSTM(10, 20, 30, 1, 0.5, 1.0, 1.0, True, True, True) | |||
outputs, (hy, cy) = lstm(inputs, backward, (hx, cx)) | |||
outputs_f, outputs_b = lstm(inputs, backward, h) | |||
# (num_layers, batch_size, seq_length, hidden_size) | |||
assert outputs.shape == (2, 3, 10, 60) | |||
# (num_layers, batch_size, hidden_size) | |||
assert hy.shape == (2, 3, 60) | |||
# (num_layers, batch_size, hidden_size) | |||
assert cy.shape == (2, 3, 40) | |||
# (batch_size, seq_length, hidden_size) | |||
assert outputs_f.shape == (3, 10, 30) | |||
assert outputs_b.shape == (3, 10, 30) |