|
- # coding=utf-8
- # Copyright 2020 The HuggingFace Inc. team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- from abc import ABC, abstractmethod
- from collections import UserDict
- from typing import Optional, Tuple
-
- import torch
-
-
-
- PROCESS_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
-
- Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
- details.
-
- `What are input IDs? <../glossary.html#input-ids>`__
- next_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
- Current scores of the top :obj:`2 * num_beams` non-finished beam hypotheses.
- next_tokens (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
- :obj:`input_ids` of the tokens corresponding to the top :obj:`2 * num_beams` non-finished beam hypotheses.
- next_indices (:obj:`torch.LongTensor` of shape :obj:`(batch_size, 2 * num_beams)`):
- Beam indices indicating to which beam hypothesis the :obj:`next_tokens` correspond.
- pad_token_id (:obj:`int`, `optional`):
- The id of the `padding` token.
- eos_token_id (:obj:`int`, `optional`):
- The id of the `end-of-sequence` token.
-
- Return:
- :obj:`UserDict`: A dictionary composed of the fields as defined above:
-
- - **next_beam_scores** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Updated
- scores of all non-finished beams.
- - **next_beam_tokens** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Next tokens
- to be added to the non-finished beam_hypotheses.
- - **next_beam_indices** (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`) -- Beam indices
- indicating to which beam the next tokens shall be added.
-
- """
-
- FINALIZE_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size * num_beams, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
-
- Indices can be obtained using any class inheriting from :class:`~transformers.PretrainedTokenizer`. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
- details.
-
- `What are input IDs? <../glossary.html#input-ids>`__
- final_beam_scores (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
- The final scores of all non-finished beams.
- final_beam_tokens (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
- The last tokens to be added to the non-finished beam_hypotheses.
- final_beam_indices (:obj:`torch.FloatTensor` of shape :obj:`(batch_size * num_beams)`):
- The beam indices indicating to which beam the :obj:`final_beam_tokens` shall be added.
- pad_token_id (:obj:`int`, `optional`):
- The id of the `padding` token.
- eos_token_id (:obj:`int`, `optional`):
- The id of the `end-of-sequence` token.
-
- Return:
- :obj:`torch.LongTensor` of shape :obj:`(batch_size * num_return_sequences, sequence_length)`: The generated
- sequences. The second dimension (sequence_length) is either equal to :obj:`max_length` or shorter if all
- batches finished early due to the :obj:`eos_token_id`.
-
- """
-
-
- class BeamScorer(ABC):
- """
- Abstract base class for all beam scorers that are used for :meth:`~transformers.PretrainedModel.beam_search` and
- :meth:`~transformers.PretrainedModel.beam_sample`.
- """
-
- @abstractmethod
- def process(
- self,
- input_ids: torch.LongTensor,
- next_scores: torch.FloatTensor,
- next_tokens: torch.LongTensor,
- next_indices: torch.LongTensor,
- **kwargs
- ) -> Tuple[torch.Tensor]:
- raise NotImplementedError("This is an abstract method.")
-
- @abstractmethod
- def finalize(
- self,
- input_ids: torch.LongTensor,
- next_scores: torch.FloatTensor,
- next_tokens: torch.LongTensor,
- next_indices: torch.LongTensor,
- **kwargs
- ) -> torch.LongTensor:
- raise NotImplementedError("This is an abstract method.")
-
-
- class BeamSearchScorer(BeamScorer):
- r"""
- :class:`transformers.BeamScorer` implementing standard beam search decoding.
-
- Adapted in part from `Facebook's XLM beam search code
- <https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529>`__.
-
- Args:
- batch_size (:obj:`int`):
- Batch Size of :obj:`input_ids` for which beam search decoding is run in parallel.
- max_length (:obj:`int`):
- The maximum length of the sequence to be generated.
- num_beams (:obj:`int`):
- Number of beams for beam search.
- device (:obj:`torch.device`):
- Defines the device type (*e.g.*, :obj:`"cpu"` or :obj:`"cuda"`) on which this instance of
- :obj:`BeamSearchScorer` will be allocated.
- length_penalty (:obj:`float`, `optional`, defaults to 1.0):
- Exponential penalty to the length. 1.0 means no penalty. Set to values < 1.0 in order to encourage the
- model to generate shorter sequences, to a value > 1.0 in order to encourage the model to produce longer
- sequences.
- do_early_stopping (:obj:`bool`, `optional`, defaults to :obj:`False`):
- Whether to stop the beam search when at least ``num_beams`` sentences are finished per batch or not.
- num_beam_hyps_to_keep (:obj:`int`, `optional`, defaults to 1):
- The number of beam hypotheses that shall be returned upon calling
- :meth:`~transformer.BeamSearchScorer.finalize`.
- """
-
- def __init__(
- self,
- batch_size: int,
- max_length: int,
- num_beams: int,
- device: torch.device,
- length_penalty: Optional[float] = 1.0,
- do_early_stopping: Optional[bool] = False,
- num_beam_hyps_to_keep: Optional[int] = 1,
- ):
- self.max_length = max_length
- self.num_beams = num_beams
- self.device = device
- self.length_penalty = length_penalty
- self.do_early_stopping = do_early_stopping
- self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
-
- self._is_init = False
- self._beam_hyps = [
- BeamHypotheses(
- num_beams=self.num_beams,
- max_length=self.max_length,
- length_penalty=self.length_penalty,
- early_stopping=self.do_early_stopping,
- )
- for _ in range(batch_size)
- ]
- self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
-
- # if not isinstance(num_beams, int) or num_beams <= 1:
- # raise ValueError(
- # f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1, one should make use of `greedy_search` instead."
- # )
-
- @property
- def is_done(self) -> bool:
- return self._done.all()
-
- def process(
- self,
- input_ids: torch.LongTensor,
- next_scores: torch.FloatTensor,
- next_tokens: torch.LongTensor,
- next_indices: torch.LongTensor,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- ) -> Tuple[torch.Tensor]:
- cur_len = input_ids.shape[-1]
- batch_size = len(self._beam_hyps)
- assert batch_size == (input_ids.shape[0] // self.num_beams)
-
- device = input_ids.device
- next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
- next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
- next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
-
- for batch_idx, beam_hyp in enumerate(self._beam_hyps):
- if self._done[batch_idx]:
- assert (
- len(beam_hyp) >= self.num_beams
- ), "Batch can only be done if at least {} beams have been generated".format(self.num_beams)
- assert (
- eos_token_id is not None and pad_token_id is not None
- ), "generated beams >= num_beams -> eos_token_id and pad_token have to be defined"
- # pad the batch
- next_beam_scores[batch_idx, :] = 0
- next_beam_tokens[batch_idx, :] = pad_token_id
- next_beam_indices[batch_idx, :] = 0
- continue
-
- # next tokens for this sentence
- beam_idx = 0
- for beam_token_rank, (next_token, next_score, next_index) in enumerate(
- zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
- ):
- batch_beam_idx = batch_idx * self.num_beams + next_index
- # add to generated hypotheses if end of sentence
- if (eos_token_id is not None) and (next_token.item() == eos_token_id):
- # if beam_token does not belong to top num_beams tokens, it should not be added
- is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
- if is_beam_token_worse_than_top_num_beams:
- continue
- beam_hyp.add(
- input_ids[batch_beam_idx].clone(),
- next_score.item(),
- )
- else:
- # add next predicted token since it is not eos_token
- next_beam_scores[batch_idx, beam_idx] = next_score
- next_beam_tokens[batch_idx, beam_idx] = next_token
- next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
- beam_idx += 1
-
- # once the beam for next step is full, don't add more tokens to it.
- if beam_idx == self.num_beams:
- break
-
- if beam_idx < self.num_beams:
- raise ValueError(
- f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id: {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
- )
-
- # Check if we are done so that we can save a pad step if all(done)
- self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
- next_scores[batch_idx].max().item(), cur_len
- )
-
- return UserDict(
- {
- "next_beam_scores": next_beam_scores.view(-1),
- "next_beam_tokens": next_beam_tokens.view(-1),
- "next_beam_indices": next_beam_indices.view(-1),
- }
- )
-
- def finalize(
- self,
- input_ids: torch.LongTensor,
- final_beam_scores: torch.FloatTensor,
- final_beam_tokens: torch.LongTensor,
- final_beam_indices: torch.LongTensor,
- pad_token_id: Optional[int] = None,
- eos_token_id: Optional[int] = None,
- ) -> torch.LongTensor:
- batch_size = len(self._beam_hyps)
-
- # finalize all open beam hypotheses and add to generated hypotheses
- for batch_idx, beam_hyp in enumerate(self._beam_hyps):
- if self._done[batch_idx]:
- continue
-
- # need to add best num_beams hypotheses to generated hyps
- for beam_id in range(self.num_beams):
- batch_beam_idx = batch_idx * self.num_beams + beam_id
- final_score = final_beam_scores[batch_beam_idx].item()
- final_tokens = input_ids[batch_beam_idx]
- beam_hyp.add(final_tokens, final_score)
-
- # select the best hypotheses
- sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
- best = []
-
- # retrieve best hypotheses
- for i, beam_hyp in enumerate(self._beam_hyps):
- sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
- for j in range(self.num_beam_hyps_to_keep):
- best_hyp = sorted_hyps.pop()[1]
- sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
- best.append(best_hyp)
-
- # prepare for adding eos
- sent_max_len = min(sent_lengths.max().item() + 1, self.max_length)
- decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
- # shorter batches are padded if needed
- if sent_lengths.min().item() != sent_lengths.max().item():
- assert pad_token_id is not None, "`pad_token_id` has to be defined"
- decoded.fill_(pad_token_id)
-
- # fill with hypotheses and eos_token_id if the latter fits in
- for i, hypo in enumerate(best):
- decoded[i, : sent_lengths[i]] = hypo
- if sent_lengths[i] < self.max_length:
- decoded[i, sent_lengths[i]] = eos_token_id
- return decoded
-
-
- class BeamHypotheses:
- def __init__(self, num_beams: int, max_length: int, length_penalty: float, early_stopping: bool):
- """
- Initialize n-best list of hypotheses.
- """
- self.max_length = max_length - 1 # ignoring bos_token
- self.length_penalty = length_penalty
- self.early_stopping = early_stopping
- self.num_beams = num_beams
- self.beams = []
- self.worst_score = 1e9
-
- def __len__(self):
- """
- Number of hypotheses in the list.
- """
- return len(self.beams)
-
- def add(self, hyp: torch.LongTensor, sum_logprobs: float):
- """
- Add a new hypothesis to the list.
- """
- score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
- if len(self) < self.num_beams or score > self.worst_score:
- self.beams.append((score, hyp))
- if len(self) > self.num_beams:
- sorted_next_scores = sorted([(s, idx) for idx, (s, _) in enumerate(self.beams)])
- del self.beams[sorted_next_scores[0][1]]
- self.worst_score = sorted_next_scores[1][0]
- else:
- self.worst_score = min(score, self.worst_score)
-
- def is_done(self, best_sum_logprobs: float, cur_len: int) -> bool:
- """
- If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
- one in the heap, then we are done with this sentence.
- """
-
- if len(self) < self.num_beams:
- return False
- elif self.early_stopping:
- return True
- else:
- cur_score = best_sum_logprobs / cur_len ** self.length_penalty
- ret = self.worst_score >= cur_score
- return ret
|