|
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
-
- import json
- import os
-
- import numpy as np
- import paddle
-
- from paddlenlp.utils.log import logger
-
-
- def read_local_dataset(data_path, data_file=None, is_test=False):
- """
- Load datasets with one example per line, formated as:
- {"text_a": X, "text_b": X, "question": X, "choices": [A, B], "labels": [0, 1]}
- """
- if data_file is not None:
- file_paths = [os.path.join(data_path, fname) for fname in os.listdir(data_path) if fname.endswith(data_file)]
- else:
- file_paths = [data_path]
- skip_count = 0
- for file_path in file_paths:
- with open(file_path, "r", encoding="utf-8") as fp:
- for example in fp:
- example = json.loads(example.strip())
- if len(example["choices"]) < 2 or not isinstance(example["text_a"], str) or len(example["text_a"]) < 3:
- skip_count += 1
- continue
- if "text_b" not in example:
- example["text_b"] = ""
- if not is_test or "labels" in example:
- if not isinstance(example["labels"], list):
- example["labels"] = [example["labels"]]
- one_hots = np.zeros(len(example["choices"]), dtype="float32")
- for x in example["labels"]:
- one_hots[x] = 1
- example["labels"] = one_hots.tolist()
-
- if is_test:
- yield example
- continue
- std_keys = ["text_a", "text_b", "question", "choices", "labels"]
- std_example = {k: example[k] for k in std_keys if k in example}
- yield std_example
- logger.warning(f"Skip {skip_count} examples.")
-
-
- class UTCLoss(object):
- def __call__(self, logit, label):
- return self.forward(logit, label)
-
- def forward(self, logit, label):
- logit = (1.0 - 2.0 * label) * logit
- logit_neg = logit - label * 1e12
- logit_pos = logit - (1.0 - label) * 1e12
- zeros = paddle.zeros_like(logit[..., :1])
- logit_neg = paddle.concat([logit_neg, zeros], axis=-1)
- logit_pos = paddle.concat([logit_pos, zeros], axis=-1)
- label = paddle.concat([label, zeros], axis=-1)
- logit_neg[label == -100] = -1e12
- logit_pos[label == -100] = -1e12
- neg_loss = paddle.logsumexp(logit_neg, axis=-1)
- pos_loss = paddle.logsumexp(logit_pos, axis=-1)
- loss = (neg_loss + pos_loss).mean()
- return loss
|