|
- # Copyright 2021 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- """Infer api."""
- import time
-
- import numpy as np
-
- import mindspore.nn as nn
- import mindspore.common.dtype as mstype
- from mindspore.common.tensor import Tensor
- from mindspore.ops import operations as P
- from mindspore import context, Parameter
- from mindspore.train.model import Model
-
- from src.dataset import load_dataset
- from .seq2seq import Seq2seqModel
- from ..utils import zero_weight
- from ..utils.load_weights import load_infer_weights
-
- context.set_context(
- mode=context.GRAPH_MODE,
- save_graphs=False,
- device_target="Ascend",
- reserve_class_name_in_scope=False)
-
-
- class Seq2seqInferCell(nn.Cell):
- """
- Encapsulation class of Seq2seqModel network infer.
-
- Args:
- network (nn.Cell): Seq2seqModel model.
-
- Returns:
- Tuple[Tensor, Tensor], predicted_ids and predicted_probs.
- """
-
- def __init__(self, network):
- super(Seq2seqInferCell, self).__init__(auto_prefix=False)
- self.network = network
-
- def construct(self,
- source_ids,
- source_mask):
- """Defines the computation performed."""
-
- predicted_ids = self.network(source_ids,
- source_mask)
-
- return predicted_ids
-
-
- def seq2seq_infer(config, dataset):
- """
- Run infer with Seq2seqModel.
-
- Args:
- config (Seq2seqConfig): Config.
- dataset (Dataset): Dataset.
-
- Returns:
- List[Dict], prediction, each example has 4 keys, "source",
- "target", "prediction" and "prediction_prob".
- """
- tfm_model = Seq2seqModel(
- config=config,
- is_training=False,
- use_one_hot_embeddings=False)
-
- params = tfm_model.trainable_params()
- weights = load_infer_weights(config)
- for param in params:
- value = param.data
- weights_name = param.name
- if weights_name not in weights:
- raise ValueError(f"{weights_name} is not found in weights.")
- if isinstance(value, Tensor):
- if weights_name in weights:
- assert weights_name in weights
- if isinstance(weights[weights_name], Parameter):
- if param.data.dtype == "Float32":
- param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float32))
- elif param.data.dtype == "Float16":
- param.set_data(Tensor(weights[weights_name].data.asnumpy(), mstype.float16))
-
- elif isinstance(weights[weights_name], Tensor):
- param.set_data(Tensor(weights[weights_name].asnumpy(), config.dtype))
- elif isinstance(weights[weights_name], np.ndarray):
- param.set_data(Tensor(weights[weights_name], config.dtype))
- else:
- param.set_data(weights[weights_name])
- else:
- print("weight not found in checkpoint: " + weights_name)
- param.set_data(zero_weight(value.asnumpy().shape))
-
- print(" | Load weights successfully.")
- tfm_infer = Seq2seqInferCell(tfm_model)
- model = Model(tfm_infer)
-
- predictions = []
- source_sentences = []
-
- shape = P.Shape()
- concat = P.Concat(axis=0)
- batch_index = 1
- pad_idx = 0
- sos_idx = 2
- eos_idx = 3
- source_ids_pad = Tensor(np.tile(np.array([[sos_idx, eos_idx] + [pad_idx] * (config.seq_length - 2)]),
- [config.batch_size, 1]), mstype.int32)
- source_mask_pad = Tensor(np.tile(np.array([[1, 1] + [0] * (config.seq_length - 2)]),
- [config.batch_size, 1]), mstype.int32)
- for batch in dataset.create_dict_iterator():
- source_sentences.append(batch["source_eos_ids"].asnumpy())
- source_ids = Tensor(batch["source_eos_ids"], mstype.int32)
- source_mask = Tensor(batch["source_eos_mask"], mstype.int32)
-
- active_num = shape(source_ids)[0]
- if active_num < config.batch_size:
- source_ids = concat((source_ids, source_ids_pad[active_num:, :]))
- source_mask = concat((source_mask, source_mask_pad[active_num:, :]))
-
- start_time = time.time()
- predicted_ids = model.predict(source_ids, source_mask)
-
- print(f" | BatchIndex = {batch_index}, Batch size: {config.batch_size}, active_num={active_num}, "
- f"Time cost: {time.time() - start_time}.")
- if active_num < config.batch_size:
- predicted_ids = predicted_ids[:active_num, :]
- batch_index = batch_index + 1
- predictions.append(predicted_ids.asnumpy())
-
- output = []
- for inputs, batch_out in zip(source_sentences, predictions):
- for i, _ in enumerate(batch_out):
- if batch_out.ndim == 3:
- batch_out = batch_out[:, 0]
-
- example = {
- "source": inputs[i].tolist(),
- "prediction": batch_out[i].tolist()
- }
- output.append(example)
-
- return output
-
-
- def infer(config):
- """
- Seq2seqModel infer api.
-
- Args:
- config (GNMTConfig): Config.
-
- Returns:
- list, result with
- """
- eval_dataset = load_dataset(data_files=config.test_dataset,
- batch_size=config.batch_size,
- sink_mode=config.dataset_sink_mode,
- drop_remainder=False,
- is_translate=True,
- shuffle=False) if config.test_dataset else None
- prediction = seq2seq_infer(config, eval_dataset)
- return prediction
|