|
- import re
- import math
- import json
- import argparse
- import sys
- import time
-
- import matplotlib
- from tqdm import tqdm
-
- import torch
- import torch.nn as nn
- import torch.nn.functional as F
- from torch.utils.data import TensorDataset, DataLoader
- import torchtext
- import os
- import numpy as np
-
- from transformers import AdamW, get_linear_schedule_with_warmup
- from transformers import AutoTokenizer, AutoModel
-
- from sklearn import metrics
- from sklearn.preprocessing import MultiLabelBinarizer
- from torchmetrics import Precision
- import numpy as np
- import matplotlib.pyplot as plt
-
- from sentence_transformers import SentenceTransformer
- from nltk.tokenize import sent_tokenize
- import wikipediaapi
-
- from nltk.corpus import stopwords
- from nltk.tokenize import word_tokenize
-
- import warnings
-
- warnings.filterwarnings("ignore")
-
-
- def seed_all(seed=42):
- import torch, random, os, numpy
-
- if not seed:
- seed = 10
-
- print("[ Using Seed : ", seed, " ]")
-
- os.environ['PYTHONHASHSEED'] = str(seed)
- torch.manual_seed(seed)
- torch.cuda.manual_seed_all(seed)
- torch.cuda.manual_seed(seed)
- numpy.random.seed(seed)
- random.seed(seed)
- torch.backends.cudnn.deterministic = True
- torch.backends.cudnn.benchmark = False
-
-
- # log
- class Logger(object):
- def __init__(self, fileN=None):
- self.terminal = sys.stdout
- self.filename = fileN
-
- def write(self, message):
- with open(self.filename, 'a+') as log:
- self.terminal.write(message)
- log.write(message)
-
- def flush(self):
- pass
-
-
- def clean_string(string):
- string = re.sub(r"[^A-Za-z0-9,!?]", " ", string)
- string = re.sub(r"\s{2,}", " ", string)
- # 去停用词、截断
- stop_words = set(stopwords.words('english'))
- word_tokens = word_tokenize(string)
- filtered_sentence = [w for w in word_tokens if not w in stop_words]
- filtered_sentence = ' '.join(filtered_sentence)
- filtered_sentence = re.sub(r"\s{2,}", " ", filtered_sentence)
- return filtered_sentence.lower().strip()
-
-
- def get_embedding_from_ggnews(text, ggnews):
- text = text.lower()
- text = text.replace('/', ',')
- try:
- if len(text.split(',')) == 1:
- return torch.tensor(ggnews[text]).detach().cpu()
- embed = []
- for t in text.split(','):
- embed.append(ggnews[t])
- embed = torch.stack(embed)
- return torch.mean(embed, dim=0).detach().cpu()
- except:
- return torch.rand((1, 300)).detach().cpu()
-
-
- def get_embedding_from_glove(text):
- text = text.lower()
- text = text.replace('/', ',')
- glove = torchtext.vocab.GloVe(name='6B', dim=300)
- if len(text.split(',')) == 1:
- return glove[text]
- embed = []
- for t in text.split(','):
- embed.append(glove[t])
- embed = torch.stack(embed)
- return torch.mean(embed, dim=0)
-
-
- def get_embedding_from_fasttext(text):
- text = text.lower()
- text = text.replace('/', ',')
- fasttext = torchtext.vocab.FastText('simple')
- try:
- if len(text.split(',')) == 1:
- return torch.Tensor(fasttext[text])
- embed = []
- for t in text.split(','):
- embed.append(fasttext[t])
- embed = torch.stack(embed)
- return torch.mean(embed, dim=0)
- except:
- return torch.rand((1, 300))
-
-
- def get_paragraph_from_wiki(text, n_sent=2):
- text = text.lower()
- wiki_wiki = wikipediaapi.Wikipedia('en')
- text = text.replace('/', ',')
- text = text.split(',')
- all_label_paragraph = []
- for t in text:
- paragr = []
- page = wiki_wiki.page(t)
- paragraph = sent_tokenize(page.summary)
- if len(paragraph) == 0:
- paragr.extend(t)
- elif len(paragraph) <= n_sent:
- paragr.extend(paragraph)
- else:
- paragr.extend(paragraph[:n_sent])
- all_label_paragraph.append(' '.join(paragr))
- return all_label_paragraph
-
-
- def get_embedding_from_wiki(text, sbert, n_sent=2):
- all_label_paragraph = get_paragraph_from_wiki(text, n_sent)
- embedding = []
- for text in all_label_paragraph:
- embedding_one = sbert.encode(text, convert_to_tensor=True)
- embedding.append(embedding_one)
- embedding = torch.stack(embedding)
- # 按动态权重得到embedding
- W = embedding.sum(1) / 768
- Weighted = torch.zeros(W.size(0), 1)
- sum_num = sum([math.exp(i) for i in W])
- for i, t in enumerate(W):
- Weighted[i] = math.exp(t) / sum_num
- embedding = torch.matmul(Weighted.permute(1, 0), embedding)
- return embedding
-
-
- def PMI_create_features(mlb, embedding_type, threshold):
- label2id = {v: k for k, v in enumerate(mlb.classes_)}
-
- if embedding_type == 'random':
- features = torch.rand(101, 768)
- elif embedding_type == 'wiki':
- # 时间太长,尝试
- if os.path.exists("./embedding_type/wiki_embedding_type_train_embedding4.pt"):
- features = torch.load('./embedding_type/wiki_embedding_type_train_embedding4.pt')
- else:
- features = torch.zeros(len((label2id)), 768)
- sbert = SentenceTransformer('paraphrase-distilroberta-base-v1', device='cpu')
- for label, id in tqdm(label2id.items()):
- features[id] = get_embedding_from_wiki(label, sbert, n_sent=10)
- torch.save(features, './embedding_type/wiki_embedding_type_train_embedding4.pt')
- else:
- features = torch.zeros(len((label2id)), 300)
- if embedding_type == 'fasttext':
- for label, id in tqdm(label2id.items()):
- features[id] = get_embedding_from_fasttext(label)
- elif embedding_type == 'ggnews':
- import gensim
- ggnews = gensim.models.KeyedVectors.load_word2vec_format('./GoogleNews-vectors-negative300.bin',
- binary=True)
- for label, id in tqdm(label2id.items()):
- features[id] = get_embedding_from_ggnews(label, ggnews)
- elif embedding_type == 'glove':
- for label, id in tqdm(label2id.items()):
- features[id] = get_embedding_from_glove(label)
- return features
-
-
- def read_dataset(train_path, val_path, embedding_type, threshold):
- train = json.load(open(train_path))
- val = json.load(open(val_path))
-
- train_sents = [clean_string(text) for text in train['text']]
- val_sents = [clean_string(text) for text in val['text']]
-
- # 查看长度
- # A = np.array([len(t.split()) for t in train_sents])
- # B = np.array([len(t.split()) for t in val_sents])
- # print(np.where(A<300)[0].shape)
- # 画图
- # matplotlib.pyplot.bar(np.array([i for i in range(len(train_sents))]), A)
- # matplotlib.pyplot.bar(np.array([i for i in range(len(val_sents))]), B)
- # plt.show()
- mlb = MultiLabelBinarizer()
- train_labels = mlb.fit_transform(train['label'])
- val_labels = mlb.transform(val['label'])
-
- label_features = PMI_create_features(mlb, embedding_type, threshold)
-
- return train_sents, train_labels, val_sents, val_labels, label_features
-
-
- def load_data(args, train_path, val_path, max_length, batch_size, device, embedding_type, threshold):
- tokenizer = AutoTokenizer.from_pretrained('pre_model/' + args.pre_model_type)
- train_sents, train_labels, val_sents, val_labels, label_features = read_dataset(train_path, val_path,
- embedding_type, threshold)
- X_train = tokenizer.batch_encode_plus(train_sents, padding=True, truncation=True, max_length=max_length,
- return_tensors='pt')
- y_train = torch.tensor(train_labels)
- X_val = tokenizer.batch_encode_plus(val_sents, padding=True, truncation=True, max_length=max_length,
- return_tensors='pt')
- y_val = torch.tensor(val_labels)
-
- train_tensor = TensorDataset(X_train['input_ids'].to(device), X_train['attention_mask'].to(device),
- X_train['token_type_ids'].to(device), y_train.to(device))
- train_loader = DataLoader(train_tensor, batch_size=batch_size, shuffle=True)
-
- val_tensor = TensorDataset(X_val['input_ids'].to(device), X_val['attention_mask'].to(device),
- X_val['token_type_ids'].to(device), y_val.to(device))
- val_loader = DataLoader(val_tensor, batch_size=batch_size, shuffle=False)
-
- return train_loader, val_loader, label_features
-
-
- class EarlyStopping:
- def __init__(self, patience=5, delta=0):
- self.patience = patience
- self.counter = 0
- self.best_score = None
- self.early_stop = False
- self.val_loss_min = np.Inf
- self.delta = delta
-
- def __call__(self, val_loss):
-
- score = -val_loss
- if self.best_score is None:
- self.best_score = score
- elif score < self.best_score + self.delta:
- self.counter += 1
- if self.counter >= self.patience:
- self.early_stop = True
- else:
- self.best_score = score
- self.counter = 0
-
-
- class BertAttention(nn.Module):
- def __init__(self, features, args):
- super(BertAttention, self).__init__()
- self.bert = AutoModel.from_pretrained('pre_model/' + args.pre_model_type)
- self.dropout = nn.Dropout(0.2)
- self.args = args
-
- self.label_features = features
-
- self.linear = nn.Linear(self.bert.config.hidden_size, features.size(0))
-
- def forward(self, input_ids, attention_mask, token_type_ids):
- bert_output = self.bert(input_ids, token_type_ids, attention_mask)
- bert_output_last_hidden = bert_output['last_hidden_state']
- bert_output_last_hidden = self.dropout(bert_output_last_hidden) # [32, 300, 768]
- bert_pooler_output = bert_output['pooler_output']
- bert_pooler_output = self.dropout(bert_pooler_output) # [32, 768]
- # 维度变换
- bert_pooler_output = bert_pooler_output.unsqueeze(1) # [32, 1, 768]
- dm0 = self.label_features.size(0)
- dm1 = self.label_features.size(1)
- label_embedding = self.label_features.expand(self.args.batch_size, dm0, dm1) # [32, 101, 768]
- # 获得相似性矩阵
- Label_similarity_matrix = similarity_matrix(label_embedding) # [32, 101, 101]
- Text_similarity_matrix = similarity_matrix(bert_output_last_hidden) # [32, 300, 300]
- # label Attention
- socre = torch.bmm(bert_pooler_output, label_embedding.permute(0, 2, 1)) # [32, 1, 101]
- socre = torch.bmm(socre, Label_similarity_matrix)
- socre = F.softmax(socre, dim=-1) / math.sqrt(label_embedding.size(-1))
- label_reps = torch.bmm(socre, label_embedding) # [32, 1, 768]
- # text Attention
- socre = torch.bmm(label_reps, bert_output_last_hidden.permute(0, 2, 1)) # [32, 1, 300]
- socre = torch.bmm(socre, Text_similarity_matrix)
- socre = F.softmax(socre, dim=-1) / math.sqrt(label_embedding.size(-1))
- text_reps = torch.bmm(socre, bert_output_last_hidden) # [32, 1, 768]
- # 预测输出层
- output = torch.bmm(text_reps, label_embedding.permute(0, 2, 1)).squeeze(1)
-
- return output
-
- def similarity_matrix(features):
- dim0 = features.size(0)
- dim1 = features.size(1)
- similarity_matrix = torch.zeros(dim0, dim1, dim1)
- for idx in range(dim0):
- rep = torch.index_select(input=features, dim=0, index=torch.LongTensor([idx]).to(features.device)).squeeze(0)
- for i, row in enumerate(rep):
- for j, col in enumerate(rep):
- similarity_matrix[idx][i][j] = F.cosine_similarity(row, col, dim=0)
- return similarity_matrix
-
-
- # Init model and optimizer & schedule
- def initialize_model(label_features, device, len_trainloader, epochs, args, lr=3e-5):
- model = BertAttention(label_features.to(device), args)
- print(model)
- model.to(device)
-
- no_decay = ['bias', 'LayerNorm.weight']
- param_optimizer = [[name, para] for name, para in model.named_parameters() if para.requires_grad]
- optimizer_grouped_parameters = [
- {'params': [param for name, param in param_optimizer if not any(nd in name for nd in no_decay)],
- 'weight_decay': 0.01},
- {'params': [param for name, param in param_optimizer if any(nd in name for nd in no_decay)],
- 'weight_decay': 0.0}
- ]
-
- optimizer = AdamW(optimizer_grouped_parameters, lr=lr)
- n_steps = len_trainloader * epochs
- scheduler = get_linear_schedule_with_warmup(optimizer=optimizer, num_training_steps=n_steps, num_warmup_steps=100)
- criterion = nn.BCEWithLogitsLoss()
-
- return model, optimizer, scheduler, criterion
-
-
- def step(model, optimizer, scheduler, criterion, batch):
- input_ids, attention_mask, token_type_ids, label = batch
- optimizer.zero_grad()
- y_pred = model.forward(input_ids, attention_mask, token_type_ids)
-
- loss = criterion(y_pred, label.float())
- loss.backward()
-
- optimizer.step()
- scheduler.step()
-
- return loss.item()
-
-
- def validate(model, criterion, val_loader):
- print("Evaluating...")
- model.eval()
- with torch.no_grad():
- running_loss = 0.0
- pred_labels, targets = list(), list()
-
- for _, batch in enumerate(tqdm(val_loader)):
- input_ids, attention_mask, token_type_ids, y_true = batch
- output = model(input_ids, attention_mask, token_type_ids)
- loss = criterion(output, y_true.float())
-
- running_loss += loss.item()
-
- pred_labels.extend(torch.sigmoid(output).detach().cpu().numpy())
- targets.extend(y_true.detach().cpu().numpy())
-
- val_loss = running_loss / len(val_loader)
-
- pred_labels, targets = np.array(pred_labels), np.array(targets)
- accuracy = metrics.accuracy_score(targets, pred_labels.round())
- micro_f1 = metrics.f1_score(targets, pred_labels.round(), average='micro')
- macro_f1 = metrics.f1_score(targets, pred_labels.round(), average='macro')
-
- ndcg1 = metrics.ndcg_score(targets, pred_labels, k=1)
- ndcg3 = metrics.ndcg_score(targets, pred_labels, k=3)
- ndcg5 = metrics.ndcg_score(targets, pred_labels, k=5)
-
- p1 = Precision(num_classes=101, top_k=1)(torch.tensor(pred_labels), torch.tensor(targets))
- p3 = Precision(num_classes=101, top_k=3)(torch.tensor(pred_labels), torch.tensor(targets))
- p5 = Precision(num_classes=101, top_k=5)(torch.tensor(pred_labels), torch.tensor(targets))
-
- return val_loss, accuracy, micro_f1, macro_f1, ndcg1, ndcg3, ndcg5, p1, p3, p5
-
-
- def train(model, optimizer, scheduler, criterion, train_loader, val_loader, checkpoint, epochs=20):
- early_stopping = EarlyStopping(delta=1e-5, patience=10)
- train_losses, val_losses, val_accs = [], [], []
-
- for epoch in range(epochs):
- running_loss = 0.0
- model.train()
- print("-------------Epoch: {}/{}---------------".format(epoch + 1, int(epochs)))
- for i, batch in enumerate(tqdm(train_loader, desc="Train-Batch Progress")):
- loss = step(model, optimizer, scheduler, criterion, batch)
- running_loss += loss
- if (i + 1) % 100 == 0 or i == 0:
- print("Epoch: {} - iter: {}/{} - train_loss: {}".format(epoch + 1, i + 1, len(train_loader),
- running_loss / (i + 1)))
- else:
- print("Epoch: {} - iter: {}/{} - train_loss: {}".format(epoch + 1, i + 1, len(train_loader),
- running_loss / len(train_loader)))
- val_loss, accuracy, micro_f1, macro_f1, ndcg1, ndcg3, ndcg5, p1, p3, p5 = validate(model, criterion,
- val_loader)
-
- train_losses.append(running_loss / (i + 1))
- val_losses.append(val_loss), val_accs.append(accuracy)
- print("Val_loss: {} - Accuracy: {} - Micro-F1: {} - Macro-F1: {}".format(val_loss, accuracy, micro_f1,
- macro_f1))
- print(
- "nDCG1: {} - nDCG@3: {} - nDCG@5: {} - P@1: {} - P@3: {} - P@5: {}".format(ndcg1, ndcg3, ndcg5, p1, p3,
- p5))
-
- early_stopping(val_loss)
- if early_stopping.early_stop:
- print('Early stoppping. Previous model saved in: ', checkpoint)
- train_losses, val_losses, val_accs = np.array(train_losses).reshape(-1, 1), np.array(
- val_losses).reshape(-1, 1), np.array(val_accs).reshape(-1, 1)
- np.savetxt(os.path.join(checkpoint, 'log.txt'), np.hstack((train_losses, val_losses, val_accs)),
- delimiter='#')
- break
- torch.save({
- 'epoch': epoch + 1,
- 'model_state_dict': model.state_dict(),
- # 'optimizer_state_dict': optimizer.state_dict(),
- # 'scheduler': scheduler.state_dict(),
- 'val_loss': val_loss
- }, os.path.join(checkpoint, 'cp' + str(epoch + 1) + '.pt'))
-
- train_losses, val_losses, val_accs = np.array(train_losses).reshape(-1, 1), np.array(val_losses).reshape(-1,
- 1), np.array(
- val_accs).reshape(-1, 1)
- np.savetxt(os.path.join(checkpoint, 'log.txt'), np.hstack((train_losses, val_losses, val_accs)), delimiter='#')
-
-
- def main():
- parser = argparse.ArgumentParser(description='RCV1 Classification')
-
- parser.add_argument('--model_name', type=str, default='train_embedding5', help='model name')
- parser.add_argument('--train_data', type=str, default='data/rcv1_train_data.json',
- help='The train dataset directory.')
- parser.add_argument('--val_data', type=str, default='data/rcv1_val_data.json', help='The val dataset directory')
- parser.add_argument('--lr', type=float, default=5e-5, help='learning rate')
- parser.add_argument('--batch_size', type=int, default=2, help='batch size')
- parser.add_argument('--epochs', type=int, default=10, help='number epochs')
- parser.add_argument('--max_length', type=int, default=300, help='max sequence length')
- parser.add_argument('--embedding_type', type=str, default='wiki',
- help='type of the word embeding: wiki, random, glove, fasttext, ggnews')
- parser.add_argument('--checkpoint', type=str, default='checkpoint', help='check point')
- # parser.add_argument('--resume', type=int, default=0, help='resume train model from checkpoint')
- # parser.add_argument('--graph_feature', type=str, default='./data/graph_feature.pth',
- # help='path to feature of graph: adjacency, node feature')
- parser.add_argument('--threshold', type=float, default=0.0)
- parser.add_argument('--pre_model_type', type=str, default='bert-base-uncased',
- help='type of the model: roberta-base, bert-base-uncased')
- parser.add_argument('--device', type=str, default='cpu', help='cuda or cpu')
- parser.add_argument('--log', type=str, default='True', help='True or False')
- args = parser.parse_args()
-
- args.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
-
- print("表征获取:C。利用wiki百科获得标签句子,每一个标签中的所有句子拼接,然后输入到BERT模型中去再按动态权重得到标签Embedding # /n"
- "融合标签信息提取机制:A")
-
- # log
- time_now = time.strftime('%Y-%m-%d_%H-%M-%S', time.localtime(time.time()))
- if args.log == 'True':
- sys.stdout = Logger(args.checkpoint + '/' + args.model_name + '+' + time_now + '.txt')
-
- seed_all()
-
- # read dataset
- print('reading dataset...')
- train_loader, val_loader, label_features = load_data(args=args, train_path=args.train_data, val_path=args.val_data,
- max_length=args.max_length, batch_size=args.batch_size,
- device=args.device, embedding_type=args.embedding_type,
- threshold=args.threshold)
-
- print('initialize model')
- model, optimizer, scheduler, criterion = initialize_model(label_features=label_features, device=args.device,
- len_trainloader=len(train_loader), epochs=args.epochs,
- args=args, lr=args.lr)
-
- print('training model...')
- train(model=model, optimizer=optimizer, scheduler=scheduler, criterion=criterion, train_loader=train_loader,
- val_loader=val_loader, checkpoint=args.checkpoint, epochs=args.epochs)
-
-
- if __name__ == '__main__':
- main()
-
- #
-
|