|
|
@@ -2,7 +2,7 @@ import mindspore |
|
|
|
import mindspore.nn as nn |
|
|
|
import mindspore.ops as P |
|
|
|
import numpy as np |
|
|
|
from mindspore import Tensor, Parameter |
|
|
|
from mindspore import Tensor, Parameter, ms_function |
|
|
|
from elmo.ops.sampled_softmax_loss import SampledSoftmaxLoss |
|
|
|
from mindspore.common.initializer import initializer, Normal, Zero |
|
|
|
|
|
|
@@ -28,6 +28,8 @@ class LossCell(nn.Cell): |
|
|
|
self.sparse_softmax_cross_entropy_with_logits = nn.SoftmaxCrossEntropyWithLogits(sparse=True) |
|
|
|
self.matmul = nn.MatMul(False, True) |
|
|
|
self.reduce_mean = P.ReduceMean() |
|
|
|
|
|
|
|
#@ms_function |
|
|
|
def construct(self, lstm_outputs, next_ids): |
|
|
|
total_loss = [] |
|
|
|
for lstm_output, next_token_id in zip(lstm_outputs, next_ids): |
|
|
@@ -42,4 +44,4 @@ class LossCell(nn.Cell): |
|
|
|
loss = self.sparse_softmax_cross_entropy_with_logits(output_scores, next_token_id_flat) |
|
|
|
total_loss.append(self.reduce_mean(loss)) |
|
|
|
|
|
|
|
return 0.5 * (total_loss[0] + total_loss[1]) |
|
|
|
return 0.5 * (total_loss[0] + total_loss[1]) * 20 |