|
- # Copyright 2022 Huawei Technologies Co., Ltd
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- # ============================================================================
- # pylint: disable=C0412
-
- """crf module"""
-
- import mindspore
- from mindspore import nn, ops, Tensor
- from mindspore import Parameter
- from mindspore.common.initializer import initializer, Uniform
- from mindnlp.utils import less_min_pynative_first
- if less_min_pynative_first:
- from mindnlp._legacy.functional import full, arange, where, tensor_split
- else:
- from mindspore.ops import full, arange, where, tensor_split
-
- def sequence_mask(seq_length, max_length, batch_first=False):
- """generate mask matrix by seq_length"""
- range_vector = arange(0, max_length, 1, dtype=seq_length.dtype)
- result = range_vector < seq_length.view(seq_length.shape + (1,))
- if batch_first:
- return result
- return result.swapaxes(0, 1)
-
- class CRF(nn.Cell):
- """Conditional random field.
-
- This module implements a conditional random field [LMP01]_. The forward computation
- of this class computes the log likelihood of the given sequence of tags and
- emission score tensor. This class also has `~CRF.decode` method which finds
- the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
-
- Args:
- num_tags: Number of tags.
- batch_first: Whether the first dimension corresponds to the size of a minibatch.
- reduction: Specifies the reduction to apply to the output:
- ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
- ``sum``: the output will be summed over batches. ``mean``: the output will be
- averaged over batches. ``token_mean``: the output will be averaged over tokens.
-
- Attributes:
- start_transitions (`~Parameter`): Start transition score tensor of size
- ``(num_tags,)``.
- end_transitions (`~Parameter`): End transition score tensor of size
- ``(num_tags,)``.
- transitions (`~Parameter`): Transition score tensor of size
- ``(num_tags, num_tags)``.
-
-
- .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
- "Conditional random fields: Probabilistic models for segmenting and
- labeling sequence data". *Proc. 18th International Conf. on Machine
- Learning*. Morgan Kaufmann. pp. 282–289.
-
- .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
- """
-
- def __init__(self, num_tags: int, batch_first: bool = False, reduction: str = 'sum') -> None:
- super().__init__()
- if num_tags <= 0:
- raise ValueError(f'invalid number of tags: {num_tags}')
- if reduction not in ('none', 'sum', 'mean', 'token_mean'):
- raise ValueError(f'invalid reduction: {reduction}')
- self.num_tags = num_tags
- self.batch_first = batch_first
- self.reduction = reduction
- self.start_transitions = Parameter(initializer(Uniform(0.1), (num_tags,)),
- name='start_transitions')
- self.end_transitions = Parameter(initializer(Uniform(0.1), (num_tags,)),
- name='end_transitions')
- self.transitions = Parameter(initializer(Uniform(0.1), (num_tags, num_tags)),
- name='transitions')
-
- def __repr__(self) -> str:
- return f'{self.__class__.__name__}(num_tags={self.num_tags})'
-
- def construct(self, emissions, tags=None, seq_length=None):
- if tags is None:
- return self._decode(emissions, seq_length)
- return self._construct(emissions, tags, seq_length)
-
- def _construct(self, emissions, tags=None, seq_length=None):
- if self.batch_first:
- batch_size, max_length = tags.shape
- emissions = emissions.swapaxes(0, 1)
- tags = tags.swapaxes(0, 1)
- else:
- max_length, batch_size = tags.shape
-
- if seq_length is None:
- seq_length = full((batch_size,), max_length, dtype=mindspore.int64)
-
- mask = sequence_mask(seq_length, max_length)
-
- # shape: (batch_size,)
- numerator = self._compute_score(emissions, tags, seq_length-1, mask)
- # shape: (batch_size,)
- denominator = self._compute_normalizer(emissions, mask)
- # shape: (batch_size,)
- llh = denominator - numerator
-
- if self.reduction == 'none':
- return llh
- if self.reduction == 'sum':
- return llh.sum()
- if self.reduction == 'mean':
- return llh.mean()
- return llh.sum() / mask.astype(emissions.dtype).sum()
-
- def _decode(self, emissions, seq_length=None):
- """Find the most likely tag sequence using Viterbi algorithm.
-
- Args:
- emissions (`~torch.Tensor`): Emission score tensor of size
- ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
- ``(batch_size, seq_length, num_tags)`` otherwise.
- mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
- if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
-
- Returns:
- List of list containing the best tag sequence for each batch.
- """
- # self._validate(emissions, mask=mask)
- if self.batch_first:
- batch_size, max_length = emissions.shape[:2]
- emissions = emissions.swapaxes(0, 1)
- else:
- max_length, batch_size = emissions.shape[:2]
-
- if seq_length is None:
- seq_length = full((batch_size,), max_length, dtype=mindspore.int64)
-
- mask = sequence_mask(seq_length, max_length)
-
- return self._viterbi_decode(emissions, mask)
-
- def _compute_score(self, emissions, tags, seq_ends, mask):
- # emissions: (seq_length, batch_size, num_tags)
- # tags: (seq_length, batch_size)
- # mask: (seq_length, batch_size)
-
- seq_length, batch_size = tags.shape
- mask = mask.astype(emissions.dtype)
-
- # Start transition score and first emission
- # shape: (batch_size,)
- score = self.start_transitions[tags[0]]
- indices = ops.stack([ops.zeros((batch_size,), mindspore.int64), arange(batch_size), tags[0]])
- # score += emissions[0, arange(batch_size), tags[0]]
- score += ops.gather_nd(emissions, indices.T)
-
- i = Tensor(1, mindspore.int64)
- while i < seq_length:
- # for i in range(1, seq_length):
- # Transition score to next tag, only added if next timestep is valid (mask == 1)
- # shape: (batch_size,)
- t_indices = ops.stack([tags[i - 1], tags[i]])
- # score += self.transitions[tags[i - 1], tags[i]] * mask[i]
- score += ops.gather_nd(self.transitions, t_indices.T) * mask[i]
-
- # Emission score for next tag, only added if next timestep is valid (mask == 1)
- # shape: (batch_size,)
- e_indices = ops.stack([ops.tile(i, (batch_size,)), arange(batch_size), tags[i]])
- score += ops.gather_nd(emissions, e_indices.T) * mask[i]
- i += 1
-
- # End transition score
- # shape: (batch_size,)
- tag_indices = ops.stack([seq_ends, arange(batch_size)])
- # last_tags = tags[seq_ends, arange(batch_size)]
- last_tags = ops.gather_nd(tags, tag_indices.T)
- # shape: (batch_size,)
- score += self.end_transitions[last_tags]
-
- return score
-
- def _compute_normalizer(self, emissions, mask):
- # emissions: (seq_length, batch_size, num_tags)
- # mask: (seq_length, batch_size)
-
- seq_length = emissions.shape[0]
-
- # Start transition score and first emission; score has size of
- # (batch_size, num_tags) where for each batch, the j-th column stores
- # the score that the first timestep has tag j
- # shape: (batch_size, num_tags)
- score = self.start_transitions + emissions[0]
-
- i = Tensor(1, mindspore.int32)
- while i < seq_length:
- # for i in range(1, seq_length):
- # Broadcast score for every possible next tag
- # shape: (batch_size, num_tags, 1)
- broadcast_score = score.expand_dims(2)
-
- # Broadcast emission score for every possible current tag
- # shape: (batch_size, 1, num_tags)
- broadcast_emissions = emissions[i].expand_dims(1)
-
- # Compute the score tensor of size (batch_size, num_tags, num_tags) where
- # for each sample, entry at row i and column j stores the sum of scores of all
- # possible tag sequences so far that end with transitioning from tag i to tag j
- # and emitting
- # shape: (batch_size, num_tags, num_tags)
- next_score = broadcast_score + self.transitions + broadcast_emissions
-
- # Sum over all possible current tags, but we're in score space, so a sum
- # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
- # all possible tag sequences so far, that end in tag i
- # shape: (batch_size, num_tags)
- next_score = ops.logsumexp(next_score, axis=1)
-
- # Set score to the next score if this timestep is valid (mask == 1)
- # shape: (batch_size, num_tags)
- score = where(mask[i].expand_dims(1), next_score, score)
- i += 1
-
- # End transition score
- # shape: (batch_size, num_tags)
- score += self.end_transitions
-
- # Sum (log-sum-exp) over all possible tags
- # shape: (batch_size,)
- return ops.logsumexp(score, axis=1)
-
- def _viterbi_decode(self, emissions, mask):
- # emissions: (seq_length, batch_size, num_tags)
- # mask: (seq_length, batch_size)
-
- seq_length = mask.shape[0]
-
- # Start transition and first emission
- # shape: (batch_size, num_tags)
- score = self.start_transitions + emissions[0]
- history = ()
-
- # score is a tensor of size (batch_size, num_tags) where for every batch,
- # value at column j stores the score of the best tag sequence so far that ends
- # with tag j
- # history saves where the best tags candidate transitioned from; this is used
- # when we trace back the best tag sequence
-
- # Viterbi algorithm recursive case: we compute the score of the best tag sequence
- # for every possible next tag
- i = Tensor(1, mindspore.int32)
- history = ops.zeros(emissions.shape, mindspore.int32)
- while i < seq_length:
- # for i in range(1, seq_length):
- # Broadcast viterbi score for every possible next tag
- # shape: (batch_size, num_tags, 1)
- broadcast_score = score.expand_dims(2)
-
- # Broadcast emission score for every possible current tag
- # shape: (batch_size, 1, num_tags)
- broadcast_emission = emissions[i].expand_dims(1)
-
- # Compute the score tensor of size (batch_size, num_tags, num_tags) where
- # for each sample, entry at row i and column j stores the score of the best
- # tag sequence so far that ends with transitioning from tag i to tag j and emitting
- # shape: (batch_size, num_tags, num_tags)
- next_score = broadcast_score + self.transitions + broadcast_emission
-
- # Find the maximum score over all possible current tag
- # shape: (batch_size, num_tags)
- indices = next_score.argmax(axis=1)
- next_score = next_score.max(axis=1)
- # Set score to the next score if this timestep is valid (mask == 1)
- # and save the index that produces the next score
- # shape: (batch_size, num_tags)
- score = where(mask[i].expand_dims(1), next_score, score)
- history[i - 1] = indices
- i += 1
-
- # End transition score
- # shape: (batch_size, num_tags)
- score += self.end_transitions
-
- return score, history
-
- def post_decode(self, score, history, seq_length):
- """Trace back the best tag sequence based on the score and history tensors."""
- # Now, compute the best path for each sample
- batch_size = seq_length.shape[0]
- seq_ends = seq_length - 1
- # shape: (batch_size,)
- best_tags_list = []
-
- history = tensor_split(history, history.shape[0], 0)
- history = [hist.squeeze() for hist in history]
-
- for idx in range(batch_size):
- # Find the tag which maximizes the score at the last timestep; this is our best tag
- # for the last timestep
- best_last_tag = score[idx].argmax(axis=0)
- best_tags = [best_last_tag]
- # We trace back where the best last tag comes from, append that to our best tag
- # sequence, and trace it back again, and so on
- for hist in reversed(history[:seq_ends[idx]]):
- best_last_tag = hist[idx][best_tags[-1]]
- best_tags.append(best_last_tag)
- # Reverse the order because we start from the last timestep
- best_tags.reverse()
- best_tags_list.append(best_tags)
-
- return best_tags_list
-
- __all__ = ["CRF", "sequence_mask"]
|