|
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
-
- import uuid
- from typing import Dict, Optional
-
- from torch import Tensor
-
-
- class FairseqIncrementalState(object):
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.init_incremental_state()
-
- def init_incremental_state(self):
- self._incremental_state_id = str(uuid.uuid4())
-
- def _get_full_incremental_state_key(self, key: str) -> str:
- return "{}.{}".format(self._incremental_state_id, key)
-
- def get_incremental_state(
- self,
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
- key: str,
- ) -> Optional[Dict[str, Optional[Tensor]]]:
- """Helper for getting incremental state for an nn.Module."""
- full_key = self._get_full_incremental_state_key(key)
- if incremental_state is None or full_key not in incremental_state:
- return None
- return incremental_state[full_key]
-
- def set_incremental_state(
- self,
- incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]],
- key: str,
- value: Dict[str, Optional[Tensor]],
- ) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]:
- """Helper for setting incremental state for an nn.Module."""
- if incremental_state is not None:
- full_key = self._get_full_incremental_state_key(key)
- incremental_state[full_key] = value
- return incremental_state
-
-
- def with_incremental_state(cls):
- cls.__bases__ = (FairseqIncrementalState,) + tuple(
- b for b in cls.__bases__ if b != FairseqIncrementalState
- )
- return cls
|