|
- # Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import collections
- import json
- import os
- from typing import Optional, List, Union, Dict
- from dataclasses import dataclass
-
- import numpy as np
- import paddle
- from tqdm import tqdm
-
- from paddlenlp.transformers import ErnieTokenizer
- from paddlenlp.utils.log import logger
-
- from extract_chinese_and_punct import ChineseAndPunctuationExtractor
- import random
-
- InputFeature = collections.namedtuple("InputFeature", [
- "input_ids", "seq_len", "tok_to_orig_start_index", "tok_to_orig_end_index", "subject_labels", "object_labels"
- ])
-
-
- def parse_label(spo_list, label_map, tokens, tokenizer):
- # R-B、R-I和O,头尾实体都是一样的
- num_labels = 2 * (len(label_map.keys()) - 2) + 1
- seq_len = len(tokens)
- # initialize tag
- subject_labels = [[0] * num_labels for i in range(seq_len)]
- object_labels = [[0] * num_labels for i in range(seq_len)]
- # find all entities and tag them with corresponding "B"/"I" labels
- for spo in spo_list:
- # assign relation label
- # simple relation
- # 因为predicate2id.json中,第一个关系的id为2,我们需要设计第一个为O(id为0),1~rel_num为关系的头,rel_num~最后为I
- label_predicate = label_map[spo['predicate']] - 1
- subject_tokens = tokenizer._tokenize(spo['subject'])
- object_tokens = tokenizer._tokenize(spo['object'])
-
- subject_tokens_len = len(subject_tokens)
- object_tokens_len = len(object_tokens)
-
- forbidden_index = None
- if subject_tokens_len > object_tokens_len:
- for index in range(seq_len - subject_tokens_len + 1):
- if tokens[index:index + subject_tokens_len] == subject_tokens:
- subject_labels[index][label_predicate] = 1
- for i in range(subject_tokens_len - 1):
- subject_labels[index + i + 1][label_predicate + len(label_map.keys()) - 2] = 1
- forbidden_index = index
- break
-
- for index in range(seq_len - object_tokens_len + 1):
- if tokens[index:index + object_tokens_len] == object_tokens:
- if forbidden_index is None:
- object_labels[index][label_predicate] = 1
- for i in range(object_tokens_len - 1):
- object_labels[index + i + 1][label_predicate + len(label_map.keys()) - 2] = 1
- break
- # check if labeled already
- elif index < forbidden_index or index >= forbidden_index + len(subject_tokens):
- object_labels[index][label_predicate] = 1
- for i in range(object_tokens_len - 1):
- object_labels[index + i + 1][label_predicate + len(label_map.keys()) - 2] = 1
- break
- else:
- for index in range(seq_len - object_tokens_len + 1):
- if tokens[index:index + object_tokens_len] == object_tokens:
- object_labels[index][label_predicate] = 1
- for i in range(object_tokens_len - 1):
- object_labels[index + i + 1][label_predicate + len(label_map.keys()) - 2] = 1
- forbidden_index = index
- break
-
- for index in range(seq_len - subject_tokens_len + 1):
- if tokens[index:index + subject_tokens_len] == subject_tokens:
- if forbidden_index is None:
- subject_labels[index][label_predicate] = 1
- for i in range(subject_tokens_len - 1):
- subject_labels[index + i + 1][label_predicate + len(label_map.keys()) - 2] = 1
- break
- elif index < forbidden_index or index >= forbidden_index + len(object_tokens):
- subject_labels[index][label_predicate] = 1
- for i in range(subject_tokens_len - 1):
- subject_labels[index + i + 1][label_predicate + len(label_map.keys()) - 2] = 1
- break
-
- # if token wasn't assigned as any "B"/"I" tag, give it an "O" tag for outside
- for i in range(seq_len):
- if subject_labels[i] == [0] * num_labels:
- subject_labels[i][0] = 1
- if object_labels[i] == [0] * num_labels:
- object_labels[i][0] = 1
-
- return subject_labels, object_labels
-
-
- def convert_example_to_feature(
- example,
- tokenizer,
- chineseandpunctuationextractor: ChineseAndPunctuationExtractor,
- label_map,
- max_length: Optional[int]=512,
- pad_to_max_length: Optional[bool]=None):
-
- spo_list = example['spo_list'] if "spo_list" in example.keys() else None
- text_raw = example['text']
-
- sub_text = []
- buff = ""
- for char in text_raw:
- if chineseandpunctuationextractor.is_chinese_or_punct(char):
- if buff != "":
- sub_text.append(buff)
- buff = ""
- sub_text.append(char)
- else:
- buff += char
- if buff != "":
- sub_text.append(buff)
-
- tok_to_orig_start_index = []
- tok_to_orig_end_index = []
- orig_to_tok_index = []
- tokens = []
- text_tmp = ''
- for (i, token) in enumerate(sub_text):
- orig_to_tok_index.append(len(tokens))
- sub_tokens = tokenizer._tokenize(token)
- text_tmp += token
- for sub_token in sub_tokens:
- tok_to_orig_start_index.append(len(text_tmp) - len(token))
- tok_to_orig_end_index.append(len(text_tmp) - 1)
- tokens.append(sub_token)
- if len(tokens) >= max_length - 2:
- break
- else:
- continue
- break
-
- seq_len = len(tokens)
- # R-B、R-I和O,头尾实体都是一样的
- num_labels = 2 * (len(label_map.keys()) - 2) + 1
- # initialize tag
- subject_labels = [[0] * num_labels for i in range(seq_len)]
- object_labels = [[0] * num_labels for i in range(seq_len)]
- if spo_list is not None:
- subject_labels, object_labels = parse_label(spo_list, label_map, tokens, tokenizer)
-
- # add [CLS] and [SEP] token, they are tagged into "O" for outside
- if seq_len > max_length - 2:
- tokens = tokens[0:(max_length - 2)]
- subject_labels = subject_labels[0:(max_length - 2)]
- object_labels = object_labels[0:(max_length - 2)]
- tok_to_orig_start_index = tok_to_orig_start_index[0:(max_length - 2)]
- tok_to_orig_end_index = tok_to_orig_end_index[0:(max_length - 2)]
-
- tokens = ["[CLS]"] + tokens + ["[SEP]"]
- # "O" tag for [PAD], [CLS], [SEP] token
- outside_label = [[1] + [0] * (num_labels - 1)]
-
- subject_labels = outside_label + subject_labels + outside_label
- object_labels = outside_label + object_labels + outside_label
-
- tok_to_orig_start_index = [-1] + tok_to_orig_start_index + [-1]
- tok_to_orig_end_index = [-1] + tok_to_orig_end_index + [-1]
- if seq_len < max_length:
- tokens = tokens + ["[PAD]"] * (max_length - seq_len - 2)
- subject_labels = subject_labels + outside_label * (max_length - len(subject_labels))
- object_labels = object_labels + outside_label * (max_length - len(object_labels))
-
- tok_to_orig_start_index = tok_to_orig_start_index + [-1] * (max_length - len(tok_to_orig_start_index))
- tok_to_orig_end_index = tok_to_orig_end_index + [-1] * (max_length - len(tok_to_orig_end_index))
-
- token_ids = tokenizer.convert_tokens_to_ids(tokens)
-
- return InputFeature(
- input_ids=np.array(token_ids),
- seq_len=np.array(seq_len),
- tok_to_orig_start_index=np.array(tok_to_orig_start_index),
- tok_to_orig_end_index=np.array(tok_to_orig_end_index),
- subject_labels=np.array(subject_labels),
- object_labels=np.array(object_labels))
-
-
- class DuIEDataset(paddle.io.Dataset):
- """
- Dataset of DuIE.
- """
-
- def __init__(
- self,
- input_ids: List[Union[List[int], np.ndarray]],
- seq_lens: List[Union[List[int], np.ndarray]],
- tok_to_orig_start_index: List[Union[List[int], np.ndarray]],
- tok_to_orig_end_index: List[Union[List[int], np.ndarray]],
- subject_labels: List[Union[List[int], np.ndarray, List[str], List[Dict]]],
- object_labels: List[Union[List[int], np.ndarray, List[str], List[Dict]]]):
- super(DuIEDataset, self).__init__()
-
- self.input_ids = input_ids
- self.seq_lens = seq_lens
- self.tok_to_orig_start_index = tok_to_orig_start_index
- self.tok_to_orig_end_index = tok_to_orig_end_index
- self.subject_labels = subject_labels
- self.object_labels = object_labels
-
- def __len__(self):
- if isinstance(self.input_ids, np.ndarray):
- return self.input_ids.shape[0]
- else:
- return len(self.input_ids)
-
- def __getitem__(self, item):
- return {
- "input_ids": np.array(self.input_ids[item]),
- "seq_lens": np.array(self.seq_lens[item]),
- "tok_to_orig_start_index":
- np.array(self.tok_to_orig_start_index[item]),
- "tok_to_orig_end_index": np.array(self.tok_to_orig_end_index[item]),
- "subject_labels": np.array(self.subject_labels[item], dtype=np.float32),
- "object_labels": np.array(self.object_labels[item], dtype=np.float32)
- }
-
- @classmethod
- def from_file(cls,
- file_path: Union[str, os.PathLike],
- tokenizer,
- args,
- pad_to_max_length: Optional[bool]=None):
- assert os.path.exists(file_path) and os.path.isfile(file_path), f"{file_path} dose not exists or is not a file."
- label_map_path = os.path.join(os.path.dirname(file_path), "predicate2id.json")
- assert os.path.exists(label_map_path) and os.path.isfile(label_map_path), f"{label_map_path} dose not exists or is not a file."
-
- with open(label_map_path, 'r', encoding='utf-8') as fp:
- label_map = json.load(fp)
- chineseandpunctuationextractor = ChineseAndPunctuationExtractor()
-
- input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, subject_labels, object_labels = ([] for _ in range(6))
- logger.info("Preprocessing data, loaded from %s" % file_path)
-
- with open(file_path, "r", encoding="utf-8") as fp:
- if args.envir == "local" and args.run_mode == "train":
- lines = fp.readlines()[:8]
- else:
- lines = fp.readlines()[:]
-
- for line in tqdm(lines, desc="加载数据集:" + file_path):
- example = json.loads(line)
- input_feature = convert_example_to_feature(example, tokenizer, chineseandpunctuationextractor, label_map, args.max_seq_length, pad_to_max_length)
-
- input_ids.append(input_feature.input_ids)
- seq_lens.append(input_feature.seq_len)
- tok_to_orig_start_index.append(input_feature.tok_to_orig_start_index)
- tok_to_orig_end_index.append(input_feature.tok_to_orig_end_index)
- subject_labels.append(input_feature.subject_labels)
- object_labels.append(input_feature.object_labels)
-
- return cls(input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, subject_labels, object_labels)
-
-
- @dataclass
- class DataCollator:
- def __call__(self, examples: List[Dict[str, Union[list, np.ndarray]]]):
- batched_input_ids = np.stack([x['input_ids'] for x in examples])
- seq_lens = np.stack([x['seq_lens'] for x in examples])
- tok_to_orig_start_index = np.stack([x['tok_to_orig_start_index'] for x in examples])
- tok_to_orig_end_index = np.stack([x['tok_to_orig_end_index'] for x in examples])
- subject_labels = np.stack([x['subject_labels'] for x in examples])
- object_labels = np.stack([x['object_labels'] for x in examples])
-
- return (batched_input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, subject_labels, object_labels)
|