|
- # coding=utf-8
- # Copyright 2020 The HuggingFace Inc. team, Microsoft Corporation.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
-
- import os
- import unittest
-
- from transformers import MPNetTokenizerFast
- from transformers.models.mpnet.tokenization_mpnet import VOCAB_FILES_NAMES, MPNetTokenizer
- from transformers.testing_utils import require_tokenizers, slow
-
- from .test_tokenization_common import TokenizerTesterMixin
-
-
- @require_tokenizers
- class MPNetTokenizerTest(TokenizerTesterMixin, unittest.TestCase):
-
- tokenizer_class = MPNetTokenizer
- rust_tokenizer_class = MPNetTokenizerFast
- test_rust_tokenizer = True
- space_between_special_tokens = True
-
- def setUp(self):
- super().setUp()
-
- vocab_tokens = [
- "[UNK]",
- "[CLS]",
- "[SEP]",
- "[PAD]",
- "[MASK]",
- "want",
- "##want",
- "##ed",
- "wa",
- "un",
- "runn",
- "##ing",
- ",",
- "low",
- "lowest",
- ]
- self.vocab_file = os.path.join(self.tmpdirname, VOCAB_FILES_NAMES["vocab_file"])
- with open(self.vocab_file, "w", encoding="utf-8") as vocab_writer:
- vocab_writer.write("".join([x + "\n" for x in vocab_tokens]))
-
- def get_input_output_texts(self, tokenizer):
- input_text = "UNwant\u00E9d,running"
- output_text = "unwanted, running"
- return input_text, output_text
-
- def test_full_tokenizer(self):
- tokenizer = self.tokenizer_class(self.vocab_file)
-
- tokens = tokenizer.tokenize("UNwant\u00E9d,running")
- self.assertListEqual(tokens, ["un", "##want", "##ed", ",", "runn", "##ing"])
- self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens), [9, 6, 7, 12, 10, 11])
-
- @slow
- def test_sequence_builders(self):
- tokenizer = self.tokenizer_class.from_pretrained("microsoft/mpnet-base")
-
- text = tokenizer.encode("sequence builders", add_special_tokens=False)
- text_2 = tokenizer.encode("multi-sequence build", add_special_tokens=False)
-
- encoded_sentence = tokenizer.build_inputs_with_special_tokens(text)
- encoded_pair = tokenizer.build_inputs_with_special_tokens(text, text_2)
-
- assert encoded_sentence == [0] + text + [2]
- assert encoded_pair == [0] + text + [2] + [2] + text_2 + [2]
|