|
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
-
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- from typing import Any, Callable, Dict, List, NewType, Optional, Tuple, Union
-
- import numpy as np
- import paddle
- import random
- import warnings
- from dataclasses import dataclass
- from ..transformers.tokenizer_utils_base import BatchEncoding, PretrainedTokenizerBase, PaddingStrategy
-
- __all__ = [
- "DataCollatorWithPadding",
- "default_data_collator",
- "DataCollator",
- "DefaultDataCollator",
- "DataCollatorForTokenClassification",
- "DataCollatorForSeq2Seq",
- ]
-
- InputDataClass = NewType("InputDataClass", Any)
- """
- A DataCollator is a function that takes a list of samples from a Dataset and collate them into a batch, as a dictionary
- of PyTorch/TensorFlow tensors or NumPy arrays.
- """
- DataCollator = NewType("DataCollator", Callable[[List[InputDataClass]], Dict[str, Any]])
-
-
- class DataCollatorMixin:
- def __call__(self, features, return_tensors=None):
- if return_tensors is None:
- return_tensors = self.return_tensors
- if return_tensors == "pd":
- return self.paddle_call(features)
- elif return_tensors == "np":
- return self.numpy_call(features)
- else:
- raise ValueError(f"Framework '{return_tensors}' not recognized!")
-
-
- def default_data_collator(features: List[InputDataClass], return_tensors="pd") -> Dict[str, Any]:
- """
- Very simple data collator that simply collates batches of dict-like objects and performs special handling for
- potential keys named:
-
- - `label`: handles a single value (int or float) per object
- - `label_ids`: handles a list of values per object
-
- Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
- to the model. See glue and ner for example of how it's useful.
- """
-
- # In this function we'll make the assumption that all `features` in the batch
- # have the same attributes.
- # So we will look at the first element as a proxy for what attributes exist
- # on the whole batch.
-
- if return_tensors == "pd":
- return paddle_default_data_collator(features)
- elif return_tensors == "np":
- return numpy_default_data_collator(features)
-
-
- def paddle_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
- if not isinstance(features[0], (dict, BatchEncoding)):
- features = [vars(f) for f in features]
- first = features[0]
- batch = {}
-
- # Special handling for labels.
- # Ensure that tensor is created with the correct type
- # (it should be automatically the case, but let's make sure of it.)
- if "label" in first and first["label"] is not None:
- label = first["label"].item() if isinstance(first["label"], paddle.Tensor) else first["label"]
- dtype = "int64" if isinstance(label, int) else "float32"
- batch["labels"] = paddle.to_tensor([f["label"] for f in features], dtype=dtype)
- elif "label_ids" in first and first["label_ids"] is not None:
- if isinstance(first["label_ids"], paddle.Tensor):
- batch["labels"] = paddle.stack([f["label_ids"] for f in features])
- else:
- dtype = "int64" if type(first["label_ids"][0]) is int else "float32"
- batch["labels"] = paddle.to_tensor([f["label_ids"] for f in features], dtype=dtype)
-
- # Handling of all other possible keys.
- # Again, we will use the first element to figure out which key/values are not None for this model.
- for k, v in first.items():
- if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
- if isinstance(v, paddle.Tensor):
- batch[k] = paddle.stack([f[k] for f in features])
- else:
- batch[k] = paddle.to_tensor([f[k] for f in features])
-
- return batch
-
-
- def numpy_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
-
- if not isinstance(features[0], (dict, BatchEncoding)):
- features = [vars(f) for f in features]
- first = features[0]
- batch = {}
-
- # Special handling for labels.
- # Ensure that tensor is created with the correct type
- # (it should be automatically the case, but let's make sure of it.)
- if "label" in first and first["label"] is not None:
- label = first["label"].item() if isinstance(first["label"], np.ndarray) else first["label"]
- dtype = np.int64 if isinstance(label, int) else np.float32
- batch["labels"] = np.array([f["label"] for f in features], dtype=dtype)
- elif "label_ids" in first and first["label_ids"] is not None:
- if isinstance(first["label_ids"], np.ndarray):
- batch["labels"] = np.stack([f["label_ids"] for f in features])
- else:
- dtype = np.int64 if type(first["label_ids"][0]) is int else np.float32
- batch["labels"] = np.array([f["label_ids"] for f in features], dtype=dtype)
-
- # Handling of all other possible keys.
- # Again, we will use the first element to figure out which key/values are not None for this model.
- for k, v in first.items():
- if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
- if isinstance(v, np.ndarray):
- batch[k] = np.stack([f[k] for f in features])
- else:
- batch[k] = np.array([f[k] for f in features])
-
- return batch
-
-
- @dataclass
- class DefaultDataCollator(DataCollatorMixin):
- """
- Very simple data collator that simply collates batches of dict-like objects and performs special handling for
- potential keys named:
- - `label`: handles a single value (int or float) per object
- - `label_ids`: handles a list of values per object
- Does not do any additional preprocessing: property names of the input object will be used as corresponding inputs
- to the model. See glue and ner for example of how it's useful.
- This is an object (like other data collators) rather than a pure function like default_data_collator. This can be
- helpful if you need to set a return_tensors value at initialization.
- Args:
- return_tensors (`bool`):
- Return Tensor or numpy array.
- """
-
- return_tensors: str = "pd"
-
- def __call__(self, features: List[Dict[str, Any]], return_tensors=None) -> Dict[str, Any]:
- if return_tensors is None:
- return_tensors = self.return_tensors
- return default_data_collator(features, return_tensors)
-
-
- @dataclass
- class DataCollatorWithPadding:
- """
- Data collator that will dynamically pad the inputs to the longest sequence in the batch.
-
- Args:
- tokenizer (`paddlenlp.transformers.PretrainedTokenizer`):
- The tokenizer used for encoding the data.
- """
-
- tokenizer: PretrainedTokenizerBase
- padding: Union[bool, str, PaddingStrategy] = True
- max_length: Optional[int] = None
- pad_to_multiple_of: Optional[int] = None
- return_tensors: str = "pd"
- return_attention_mask: Optional[bool] = None
-
- def __call__(self, features: List[Dict[str, Any]]) -> Dict[str, Any]:
- batch = self.tokenizer.pad(
- features,
- padding=self.padding,
- max_length=self.max_length,
- pad_to_multiple_of=self.pad_to_multiple_of,
- return_tensors=self.return_tensors,
- return_attention_mask=self.return_attention_mask,
- )
- if "label" in batch:
- batch["labels"] = batch["label"]
- del batch["label"]
- if "label_ids" in batch:
- batch["labels"] = batch["label_ids"]
- del batch["label_ids"]
- return batch
-
-
- @dataclass
- class DataCollatorForTokenClassification(DataCollatorMixin):
- """
- Data collator that will dynamically pad the inputs received, as well as the labels.
-
- Args:
- tokenizer ([`PretrainedTokenizer`] or [`PretrainedFasterTokenizer`]):
- The tokenizer used for encoding the data.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
- among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
- is provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
- 7.5 (Volta).
- label_pad_token_id (`int`, *optional*, defaults to -100):
- The id to use when padding the labels (-100 will be automatically ignore by PyTorch loss functions).
- return_tensors (`str`):
- The type of Tensor to return. Allowable values are "np", "pt" and "tf".
- """
-
- tokenizer: PretrainedTokenizerBase
- padding: Union[bool, str, PaddingStrategy] = True
- max_length: Optional[int] = None
- pad_to_multiple_of: Optional[int] = None
- label_pad_token_id: int = -100
- return_tensors: str = "pd"
-
- def paddle_call(self, features):
- label_name = "label" if "label" in features[0].keys() else "labels"
- labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
- batch = self.tokenizer.pad(
- features,
- padding=self.padding,
- max_length=self.max_length,
- pad_to_multiple_of=self.pad_to_multiple_of,
- # Conversion to tensors will fail if we have labels as they are not of the same length yet.
- return_tensors="pd" if labels is None else None,
- )
-
- if labels is None:
- return batch
-
- sequence_length = paddle.to_tensor(batch["input_ids"]).shape[1]
- padding_side = self.tokenizer.padding_side
- if padding_side == "right":
- batch[label_name] = [
- list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
- ]
- else:
- batch[label_name] = [
- [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
- ]
-
- batch = {k: paddle.to_tensor(v, dtype="int64") for k, v in batch.items()}
- return batch
-
- def numpy_call(self, features):
- label_name = "label" if "label" in features[0].keys() else "labels"
- labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
- batch = self.tokenizer.pad(
- features,
- padding=self.padding,
- max_length=self.max_length,
- pad_to_multiple_of=self.pad_to_multiple_of,
- # Conversion to tensors will fail if we have labels as they are not of the same length yet.
- return_tensors="np" if labels is None else None,
- )
-
- if labels is None:
- return batch
-
- sequence_length = np.array(batch["input_ids"]).shape[1]
- padding_side = self.tokenizer.padding_side
- if padding_side == "right":
- batch["labels"] = [
- list(label) + [self.label_pad_token_id] * (sequence_length - len(label)) for label in labels
- ]
- else:
- batch["labels"] = [
- [self.label_pad_token_id] * (sequence_length - len(label)) + list(label) for label in labels
- ]
-
- batch = {k: np.array(v, dtype=np.int64) for k, v in batch.items()}
- return batch
-
-
- @dataclass
- class DataCollatorForSeq2Seq:
- """
- Data collator that will dynamically pad the inputs received, as well as the labels.
-
- Args:
- tokenizer ([`PretrainedTokenizer`] or [`PretrainedFasterTokenizer`]):
- The tokenizer used for encoding the data.
- model ([`PreTrainedModel`]):
- The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
- prepare the *decoder_input_ids*
-
- This is useful when using *label_smoothing* to avoid calculating loss twice.
- padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
- Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
- among:
-
- - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single sequence
- is provided).
- - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
- acceptable input length for the model if that argument is not provided.
- - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
- lengths).
- max_length (`int`, *optional*):
- Maximum length of the returned list and optionally padding length (see above).
- pad_to_multiple_of (`int`, *optional*):
- If set will pad the sequence to a multiple of the provided value.
-
- This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
- 7.5 (Volta).
- label_pad_token_id (`int`, *optional*, defaults to -100):
- The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
- return_tensors (`str`):
- The type of Tensor to return. Allowable values are "np", "pt" and "tf".
- """
-
- tokenizer: PretrainedTokenizerBase
- model: Optional[Any] = None
- padding: Union[bool, str, PaddingStrategy] = True
- max_length: Optional[int] = None
- pad_to_multiple_of: Optional[int] = None
- label_pad_token_id: int = -100
- return_tensors: str = "pd"
-
- def __call__(self, features, return_tensors=None):
- if return_tensors is None:
- return_tensors = self.return_tensors
- labels = [feature["labels"] for feature in features] if "labels" in features[0].keys() else None
- # We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
- # same length to return tensors.
- if labels is not None:
- max_label_length = max(len(l) for l in labels)
- if self.pad_to_multiple_of is not None:
- max_label_length = (
- (max_label_length + self.pad_to_multiple_of - 1)
- // self.pad_to_multiple_of
- * self.pad_to_multiple_of
- )
-
- padding_side = self.tokenizer.padding_side
- for feature in features:
- remainder = [self.label_pad_token_id] * (max_label_length - len(feature["labels"]))
- if isinstance(feature["labels"], list):
- feature["labels"] = (
- feature["labels"] + remainder if padding_side == "right" else remainder + feature["labels"]
- )
- elif padding_side == "right":
- feature["labels"] = np.concatenate([feature["labels"], remainder]).astype(np.int64)
- else:
- feature["labels"] = np.concatenate([remainder, feature["labels"]]).astype(np.int64)
-
- features = self.tokenizer.pad(
- features,
- padding=self.padding,
- max_length=self.max_length,
- pad_to_multiple_of=self.pad_to_multiple_of,
- return_tensors=return_tensors,
- )
-
- # prepare decoder_input_ids
- if (
- labels is not None
- and self.model is not None
- and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
- ):
- decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=features["labels"])
- features["decoder_input_ids"] = decoder_input_ids
-
- return features
|