|
- import os
- import torch
- import pickle
- import logging
- import numpy as np
- import torch.nn as nn
- from sklearn.feature_extraction.text import TfidfVectorizer
- from transformers import BertModel, BertTokenizer
- from tqdm import tqdm
-
- from modeling_nezha import NeZhaModel
-
-
- class SparseEncoder(object):
- def __init__(self):
- print("稀疏编码器参数: ngram_range=1")
- self.encoder = TfidfVectorizer(analyzer='char', ngram_range=(1, 1))
-
- def fit(self, train_corpus):
- self.encoder.fit(train_corpus)
- print("稀疏编码器特征个数:", len(self.encoder.get_feature_names()))
- return self
-
- def transform(self, mentions):
- vec = self.encoder.transform(mentions).toarray()
- return vec
-
- def __call__(self, mentions):
- return self.transform(mentions)
-
- def vocab(self):
- return self.encoder.vocabulary_
-
- def save_encoder(self, path):
- with open(path, 'wb') as fout:
- pickle.dump(self.encoder, fout)
- # logging.info("Sparse encoder saved in {}".format(path))
-
- def load_encoder(self, path):
- with open(path, 'rb') as fin:
- self.encoder = pickle.load(fin)
- # logging.info("Sparse encoder loaded from {}".format(path))
-
- return self
-
-
- class DenseEncoder(nn.Module):
- def __init__(self, bert_path, bert_type="bert"):
- super(DenseEncoder, self).__init__()
- if bert_type == "bert":
- self.encoder = BertModel.from_pretrained(bert_path)
- if bert_type == "nezha":
- self.encoder = NeZhaModel.from_pretrained(bert_path)
- print("稠密编码器 模型类型:{} 预训练模型路径:{}".format(bert_type, bert_path))
-
- self.encoder.cuda()
- self.tokenizer = BertTokenizer.from_pretrained(bert_path)
-
- def transform(self, mentions, max_len, desc="bert encoding"):
- if not isinstance(mentions, list):
- mentions = mentions.tolist()
-
- batch_size = 512
- iter_num = len(mentions) // batch_size + 1
-
- # 按长度排序提高计算效率
- # lens = [len(k) for k in mentions]
- # argsort_list = np.argsort(lens)
- # mentions = sorted(mentions, key=lambda x: len(x))
- dense_output = []
-
- with torch.no_grad():
- for i in tqdm(range(iter_num), desc=desc, ncols=100):
- batch_mentions = mentions[batch_size*i: batch_size*(i+1)]
- if max_len != -1:
- batch_input = self.tokenizer(
- batch_mentions,
- padding="max_length",
- max_length=max_len,
- truncation=True
- )
- else:
- batch_input = self.tokenizer(
- batch_mentions,
- padding="longest"
- )
- for k in batch_input.keys():
- batch_input[k] = torch.tensor(batch_input[k]).cuda()
-
- output = self.encoder(**batch_input)
- output = output[0][:, 0, :].squeeze(1).cpu().numpy()
-
- dense_output.append(output)
-
- dense_output = np.concatenate(dense_output, axis=0)
-
- return dense_output
|