|
- # coding=utf-8
- # Copyright 2018 The Google AI Language Team Authors and The HugginFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """BERT finetuning runner."""
- # best acc: 0.680999180999181, loss 0.11498013142432276
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
-
- import csv
- import os
- import logging
- import json
- import argparse
- import random
- from tqdm import tqdm, trange
-
- import numpy as np
- import torch
- from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
- from torch.utils.data.distributed import DistributedSampler
- import re
- from transformers import BertTokenizer
- # from tokenization import BertTokenizer
- from optimization import BertAdam
- from file_utils import PYTORCH_PRETRAINED_BERT_CACHE
- from modeling import BertMultiwayMatch, BertVaillna, BertMultiwayMatch_V0215, Bert_MY
-
- logging.basicConfig(format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
- datefmt='%m/%d/%Y %H:%M:%S',
- level=logging.INFO)
- logger = logging.getLogger(__name__)
-
- os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
- # os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"
-
- def strQ2B(ustring):
- ss = []
- for s in ustring:
- rstring = ""
- for uchar in s:
- inside_code = ord(uchar)
- if inside_code == 12288: # 全角空格直接转换
- inside_code = 32
- elif (inside_code >= 65281 and inside_code <= 65374): # 全角字符(除空格)根据关系转化
- inside_code -= 65248
- rstring += chr(inside_code)
- ss.append(rstring)
- return ss
-
- def nomalize(text):
- text = text.replace("( )。[图1]","").replace("( )。", "").replace("( )","")\
- .replace("image.png( )。", "").replace("( )。", "").replace("( )。","").replace("( )。","").replace("( )", "")\
- .replace("[图1]", "").replace("[图2]", "").replace("( )","").replace("<br>", "").replace("([图1])","").replace("</sub>", "")\
- .replace("<sub>", "").replace("</br>","").replace("<sup>", "").replace("</sup>", "")\
- .replace("\n", " ").replace(")。","").replace("℃","度")\
- .replace("“", "\"").replace("”", "\"").replace("image.png", "")\
- .replace("A.血小板减少B.急性溶血性贫血C.红人综合征D.肌肉震颤E.抗生素相关性腹泻","")\
- .replace("A.还原型谷胱甘肽B.葡醛内酯C.熊去氧胆酸D.多烯磷脂酰胆碱E.甘草酸二铵", "")\
- .replace("A.大环内酯类B.四环素类C.氟喹诺酮类D.酰胺醇类(氯霉素)E.氨基糖苷类", "")\
- .replace("A.卡铂B.氟尿嘧啶C.吉非替尼D.顺铂E.奥沙利铂", "")\
- .replace("A.甲羟孕酮B.左炔诺孕酮C.环丙孕酮D.屈螺酮E.地屈孕酮", "")\
- .replace("A.胶体果胶铋B.复方碳酸钙C.奥美拉唑D.法莫替丁E.替普瑞酮", "")\
- .replace("维C银翅片", "维C银翘片").replace("[67-68]", "")\
- .replace("患者,女,63岁,因心绞痛急诊就医,诊断为ST段抬高心肌梗死。血压","患者,女,63岁,因心绞痛急诊就医,诊断为ST段抬高心肌梗死,血压175/95mmHg,心率110次/分,该患者适宜使用的药物是")\
- .replace("\t", " ")
- text = re.sub(" +", " ", text)
- # text = ''.join(strQ2B(text))
-
- return text.strip()
-
-
-
- # def nomalize(text):
- # text = text.replace("( )。[图1]","").replace("( )。", "").replace("( )","")\
- # .replace("image.png( )。", "").replace("( )。", "").replace("( )。","").replace("( )。","").replace("( )", "")\
- # .replace("[图1]", "").replace("[图2]", "").replace("( )","").replace("<br>", "").replace("([图1])","").replace("</sub>", "")\
- # .replace("<sub>", "").replace("</br>","").replace("<sup>", "").replace("</sup>", "").replace(" ","").strip()
- # return text
- class SwagExample(object):
- """A single training/test example for the SWAG dataset."""
-
- def __init__(self,
- swag_id,
- context_sentence,
- start_ending,
- ending_0,
- ending_1,
- ending_2,
- ending_3,
- ending_4,
- kg_text_0,
- kg_text_1,
- kg_text_2,
- kg_text_3,
- kg_text_4,
- label=None):
- self.swag_id = swag_id
- self.context_sentence = context_sentence
- self.start_ending = start_ending
- self.endings = [
- ending_0,
- ending_1,
- ending_2,
- ending_3,
- ending_4,
- ]
- self.kg_texts = [
- kg_text_0,
- kg_text_1,
- kg_text_2,
- kg_text_3,
- kg_text_4
- ]
- self.label = label
-
- def __str__(self):
- return self.__repr__()
-
- def __repr__(self):
- l = [
- "swag_id: {}".format(self.swag_id),
- "context_sentence: {}".format(self.context_sentence),
- "start_ending: {}".format(self.start_ending),
- "ending_0: {}".format(self.endings[0]),
- "ending_1: {}".format(self.endings[1]),
- "ending_2: {}".format(self.endings[2]),
- "ending_3: {}".format(self.endings[3]),
- "ending_4: {}".format(self.endings[4]),
- "kg_text_0: {}".format(self.kg_texts[0]),
- "kg_text_1: {}".format(self.kg_texts[1]),
- "kg_text_2: {}".format(self.kg_texts[2]),
- "kg_text_3: {}".format(self.kg_texts[3]),
- "kg_text_4: {}".format(self.kg_texts[4]),
- ]
-
- if self.label is not None:
- l.append("label: {}".format(self.label))
-
- return ", ".join(l)
-
-
- class InputExample(object):
- """A single training/test example for simple sequence classification."""
-
- def __init__(self, guid, text_a, text_b=None, label=None):
- """Constructs a InputExample.
-
- Args:
- guid: Unique id for the example.
- text_a: string. The untokenized text of the first sequence. For single
- sequence tasks, only this sequence must be specified.
- text_b: (Optional) string. The untokenized text of the second sequence.
- Only must be specified for sequence pair tasks.
- label: (Optional) string. The label of the example. This should be
- specified for train and dev examples, but not for test examples.
- """
- self.guid = guid
- self.text_a = text_a
- self.text_b = text_b
- self.label = label
-
-
- class InputExampleWithListFourFields(object):
- """A single training/test example for simple multiple choice classification."""
-
- def __init__(self, guid, text_a, text_b, text_c, text_d, label=None):
- """Constructs a InputExample.
-
- Args:
- guid: Unique id for the example.
- text_a: list. A list containing untokenized text
- text_b: list. containing untokenized text associated of the same size as text_A
- text_c: list. containing untokenized text associated of the same size as text_A
- text_d: list. containing untokenized text associated of the same size as text_A
- Only must be specified for multiple choice options.
- label: (Optional) string. The label of the example. This should be
- specified for train and dev examples, but not for test examples.
- """
- assert isinstance(text_a, list)
- assert isinstance(text_b, list)
- assert text_c is None or isinstance(text_c, list)
- assert text_d is None or isinstance(text_d, list)
- assert len(text_a) == len(text_b)
- if text_c is not None:
- assert len(text_c) == len(text_a)
- if text_d is not None:
- assert len(text_d) == len(text_a)
-
- self.guid = guid
- self.text_a = text_a
- self.text_b = text_b
- self.text_c = text_c
- self.text_d = text_d
- self.label = label
-
-
- class InputFeatures(object):
- def __init__(self,
- example_id,
- choices_features,
- label
-
- ):
- self.example_id = example_id
- self.choices_features = [
- {
- 'input_ids': input_ids,
- 'input_mask': input_mask,
- 'segment_ids': segment_ids,
- 'doc_len': doc_len,
- 'kg_len': kg_len,
- 'ques_len': ques_len,
- 'option_len': option_len,
- 'kg_text_ids': kg_text_ids,
- "kg_text_masks": kg_text_masks,
- 'kg_text_segment_ids': kg_text_segment_ids
- }
- for _, input_ids, input_mask, segment_ids, doc_len, kg_len, ques_len, option_len, kg_text_ids,
- kg_text_masks, kg_text_segment_ids
- in choices_features
- ]
- self.label = label
-
-
- class DataProcessor(object):
- """Base class for data converters for sequence classification data sets."""
-
- def get_train_examples(self, data_dir):
- """Gets a collection of `InputExample`s for the train set."""
- raise NotImplementedError()
-
- def get_dev_examples(self, data_dir):
- """Gets a collection of `InputExample`s for the dev set."""
- raise NotImplementedError()
-
- def get_labels(self):
- """Gets the list of labels for this data set."""
- raise NotImplementedError()
-
- @classmethod
- def _read_csv(cls, input_file):
- """Reads a csv file."""
- lines = []
- with open(input_file, 'r') as csv_file:
- reader = csv.DictReader(csv_file)
- for row in reader:
- lines.append(row)
- return lines
-
- @classmethod
- def _read_tsv(cls, input_file, quotechar=None):
- """Reads a tab separated value file."""
- with open(input_file, "r", encoding='utf-8') as f:
- reader = csv.reader(f, delimiter="\t", quotechar=quotechar)
- lines = []
- for line in reader:
- lines.append(line)
- return lines
-
-
- class nlpecProcessor(DataProcessor):
- """Processor for the ANLI data set."""
-
- def get_train_examples(self, data_dir):
- """See base class."""
- '''
- all_not_include_test_evidence_top_10_entities_score.json
- '''
- # train_filter_evidence.json
- return self.read_baidu_exmaple(
- os.path.join(data_dir, 'test.json'), range1=False)
-
- def get_dev_examples(self, data_dir):
- """See base class."""
- return self.read_baidu_exmaple(os.path.join(data_dir, 'dev.json'))
-
- def get_test_examples(self, data_dir):
- """See base class."""
- return self.read_baidu_exmaple(os.path.join(data_dir, 'test.json'))
-
- def get_examples_from_file(self, input_file):
- return self._create_examples(
- self._read_csv(input_file), "to-pred")
-
- def get_labels(self):
- """See base class."""
- return ["0", "1", "2", "3", "4"]
-
- def read_baidu_exmaple(self, input_file, have_answer=True, range1=False):
-
- with open(input_file, "r", encoding="utf-8") as f:
- examples = []
- lines = f.readlines()
- for i, line in enumerate(lines):
-
- try:
- csqa_json = json.loads(line.strip())
- except:
- print(line)
- exit(0)
-
-
- if len(csqa_json["option"]) != 5:
- continue
- flag = False
- for j in range(5):
- if csqa_json['option'][j].strip() == "":
- flag = True
- break
- if csqa_json['context'][j].strip() == "":
- flag = True
- break
- if flag:
- continue
- csqa_json['score_kg_option'] = [0, 0, 0, 0, 0]
- csqa_json['kg_text'] = ["", "", "", "", ""]
- if range1:
- kg_text = kgs[i]
- else:
- kg_text = csqa_json['kg_text']
-
- assert len(kg_text) == 5
-
- label = 0
-
- if have_answer:
- # multi-answer version
- # if "1" in csqa_json["answer"]:
- # print("1")
- # label += 1
- # if "2" in csqa_json["answer"]:
- # print("2")
- # label += 2
- # if "3" in csqa_json["answer"]:
- # print("3")
- # label += 4
- # if "4" in csqa_json["answer"]:
- # print("4")
- # label += 8
- # if '5' in csqa_json["answer"]:
- # print("5")
- # label += 16
- label = int(csqa_json["answer"][0]) - 1
- else:
- label = 0
-
- for jj in range(5):
- csqa_json['context'][jj] = '#'.join(csqa_json['context'][jj].split("######")[0:1])
-
- kg_text[jj] = ";".join(kg_text[jj].split(";")[0:3])
- if kg_text[jj].strip() == "":
- kg_text[jj] = "无。"
- else:
- if kg_text[jj][-1] != "." or kg_text[jj][-1] != "。":
- kg_text[jj] += "。"
-
- csqa_json['questionText'] = nomalize(csqa_json['questionText'])
-
- examples.append(
- SwagExample(
- swag_id=i,
- context_sentence=csqa_json['backgroundText'],
- start_ending=csqa_json['questionText'].strip() + "#",
- ending_0=csqa_json["option"][0].strip() + "###" + csqa_json['context'][0].strip(),
- ending_1=csqa_json["option"][1].strip() + "###" + csqa_json['context'][1].strip(),
- ending_2=csqa_json["option"][2].strip() + "###" + csqa_json['context'][2].strip(),
- ending_3=csqa_json["option"][3].strip() + "###" + csqa_json['context'][3].strip(),
- ending_4=csqa_json["option"][4].strip() + "###" + csqa_json['context'][4].strip(),
- label=label,
- kg_text_0=kg_text[0],
- kg_text_1=kg_text[1],
- kg_text_2=kg_text[2],
- kg_text_3=kg_text[3],
- kg_text_4=kg_text[4]
- ))
-
-
- return examples
-
- def label_field(self):
- return "AnswerRightEnding"
-
-
- def convert_examples_to_features_c_qa(examples, tokenizer, max_seq_length,
- is_training):
- """Loads a data file into a list of `InputBatch`s."""
-
- # Swag is a multiple choice task. To perform this task using Bert,
- # we will use the formatting proposed in "Improving Language
- # Understanding by Generative Pre-Training" and suggested by
- # @jacobdevlin-google in this issue
- # https://github.com/google-research/bert/issues/38.
- #
- # Each choice will correspond to a sample on which we run the
- # inference. For a given Swag example, we will create the 4
- # following inputs:
- # - [CLS] context [SEP] choice_1 [SEP]
- # - [CLS] context [SEP] choice_2 [SEP]
- # - [CLS] context [SEP] choice_3 [SEP]
- # - [CLS] context [SEP] choice_4 [SEP]
- # The model will output a single value for each input. To get the
- # final decision of the model, we will run a softmax over these 4
- # outputs.
- features = []
- for example_index, example in enumerate(examples):
- context_tokens = tokenizer.tokenize(example.context_sentence)
- start_ending_tokens = tokenizer.tokenize(example.start_ending)
-
- choices_features = []
-
- for ending_index, ending in enumerate(example.endings):
- # We create a copy of the context tokens in order to be
- # able to shrink it according to ending_tokens
- '''
- ending 选项
- context_sentence 材料
- start_ending 问题
- opt_context 证据
- '''
- ending, opt_context = ending.split("###")
- context_tokens_choice = context_tokens[:]
- opt_context_tokens = tokenizer.tokenize(opt_context)
- kg_text = example.kg_texts[ending_index]
- kg_text_tokens = tokenizer.tokenize(kg_text)
- evidence_tokens = opt_context_tokens # + ["#"] + kg_text_tokens
- ending_tokens = tokenizer.tokenize(ending)
- option_len = len(ending_tokens)
- ques_len = len(start_ending_tokens)
- ending_tokens1 = start_ending_tokens + ending_tokens
- ending_tokens = ending_tokens1
- context_tokens_choice = context_tokens_choice[:(int((max_seq_length - 3)/ 2) - ques_len - option_len)]
- # Modifies `context_tokens_choice` and `ending_tokens` in
- # place so that the total length is less than the
- # specified length. Account for [CLS], [SEP], [SEP] with "- 3"
- # ending_tokens = start_ending_tokens + ending_tokens
- _truncate_seq_pair(evidence_tokens, context_tokens_choice + ending_tokens, max_seq_length - 3)
- doc_len = len(evidence_tokens)
-
- # ques_len = len(context_tokens_choice) + len(ending_tokens) - option_len
- tokens = ["[CLS]"] + evidence_tokens + ["[SEP]"] + context_tokens_choice + ending_tokens + ["[SEP]"]
- segment_ids = [0] * (len(evidence_tokens) + 2) + [1] * (len(context_tokens_choice) + len(ending_tokens) + 1)
-
- input_ids = tokenizer.convert_tokens_to_ids(tokens)
- input_mask = [1] * len(input_ids)
-
- # Zero-pad up to the sequence length.
- padding = [0] * (max_seq_length - len(input_ids))
- input_ids += padding
- input_mask += padding
- segment_ids += padding
-
- kg_text_tokens = kg_text_tokens[:max_seq_length - 2]
-
- # _truncate_seq_pair(kg_text_tokens, context_tokens_choice + ending_tokens, max_seq_length - 3)
- kg_text_segment_ids = [0] * (len(kg_text_tokens) + 2) #+ [1] * (len(context_tokens_choice) + len(ending_tokens) + 1)
- kg_len = len(kg_text_tokens)
- kg_text_tokens = ["[CLS]"] + kg_text_tokens + ["[SEP]"]
-
- kg_text_input_ids = tokenizer.convert_tokens_to_ids(kg_text_tokens)
- kg_text_input_mask = [1] * len(kg_text_input_ids)
-
- # Zero-pad up to the sequence length.
- kg_text_padding = [0] * (max_seq_length - len(kg_text_input_ids))
- kg_text_input_ids += kg_text_padding
- kg_text_input_mask += kg_text_padding
- kg_text_segment_ids += kg_text_padding
-
- assert len(input_ids) == max_seq_length
- assert len(input_mask) == max_seq_length
- assert len(segment_ids) == max_seq_length
- assert ques_len > 0
- assert option_len > 0
- assert doc_len > 0, print(tokens)
- assert kg_len > 0
- assert len(kg_text_input_ids) == max_seq_length
- assert len(kg_text_input_mask) == max_seq_length
- assert len(kg_text_segment_ids) == max_seq_length, print(len(kg_text_segment_ids))
-
- assert (doc_len + ques_len + option_len) <= max_seq_length, print(doc_len, ques_len, option_len,
- example.start_ending)
-
- choices_features.append((tokens, input_ids, input_mask, segment_ids,
- doc_len, kg_len, ques_len, option_len, kg_text_input_ids,
- kg_text_input_mask, kg_text_segment_ids))
-
- label = int(example.label)
- if example_index < 5 and False:
- logger.info("*** Example ***")
- logger.info(f"swag_id: {example.swag_id}")
- for choice_idx, (tokens, input_ids, input_mask, segment_ids, doc_len,kg_len,
- ques_len, option_len, _, _, _) in enumerate(choices_features):
- logger.info(f"choice: {choice_idx}")
- logger.info(f"tokens: {' '.join(tokens)}")
- logger.info(f"input_ids: {' '.join(map(str, input_ids))}")
- logger.info(f"input_mask: {' '.join(map(str, input_mask))}")
- logger.info(f"segment_ids: {' '.join(map(str, segment_ids))}")
- if is_training:
- logger.info(f"label: {label}")
-
- features.append(
- InputFeatures(
- example_id=example.swag_id,
- choices_features=choices_features,
- label=label
- )
- )
-
- return features
-
-
- def _truncate_seq_pair(tokens_a, tokens_b, max_length):
- """Truncates a sequence pair in place to the maximum length."""
-
- # This is a simple heuristic which will always truncate the longer sequence
- # one token at a time. This makes more sense than truncating an equal percent
- # of tokens from each, since if one sequence is very short then each token
- # that's truncated likely contains more information than a longer sequence.
- while True:
- total_length = len(tokens_a) + len(tokens_b)
- if total_length <= max_length:
- break
- if len(tokens_a) > len(tokens_b):
- tokens_a.pop()
- else:
- tokens_b.pop()
-
-
- def _truncate_sequences(max_length, inputs):
- idx = 0
- for ta, tb in zip(inputs[0], inputs[1]):
- _truncate_seq_pair(ta, tb, max_length)
-
-
- def accuracy(out, labels):
- outputs = np.argmax(out, axis=1)
- return np.sum(outputs == labels)
-
-
- def warmup_linear(x, warmup=0.002):
- if x < warmup:
- return x / warmup
- return 1.0 - x
-
-
- def select_field(features, field):
- return [
- [
- choice[field]
- for choice in feature.choices_features
- ]
- for feature in features
- ]
-
-
- def main():
- parser = argparse.ArgumentParser()
-
- ## Required parameters
- parser.add_argument("--data_dir",
- default=None,
- type=str,
- required=True,
- help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
- parser.add_argument("--bert_model", default=None, type=str, required=True,
- help="Bert pre-trained model selected in the list: bert-base-uncased, "
- "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
- "bert-base-multilingual-cased, bert-base-chinese.")
- parser.add_argument("--task_name",
- default=None,
- type=str,
- required=True,
- help="The name of the task to train.")
- parser.add_argument("--output_dir",
- default=None,
- type=str,
- required=True,
- help="The output directory where the model predictions and checkpoints will be written.")
-
- ## Other parameters
- parser.add_argument("--max_seq_length",
- default=128,
- type=int,
- help="The maximum total input sequence length after WordPiece tokenization. \n"
- "Sequences longer than this will be truncated, and sequences shorter \n"
- "than this will be padded.")
- parser.add_argument("--do_train",
- action='store_true',
- help="Whether to run training.")
- parser.add_argument("--do_eval",
- action='store_true',
- help="Whether to run eval on the dev set.")
- parser.add_argument("--do_lower_case",
- action='store_true',
- help="Set this flag if you are using an uncased model.")
- parser.add_argument("--train_batch_size",
- default=32,
- type=int,
- help="Total batch size for training.")
- parser.add_argument("--eval_batch_size",
- default=8,
- type=int,
- help="Total batch size for eval.")
- parser.add_argument("--learning_rate",
- default=5e-5,
- type=float,
- help="The initial learning rate for Adam.")
- parser.add_argument("--num_train_epochs",
- default=3.0,
- type=float,
- help="Total number of training epochs to perform.")
- parser.add_argument("--warmup_proportion",
- default=0.1,
- type=float,
- help="Proportion of training to perform linear learning rate warmup for. "
- "E.g., 0.1 = 10%% of training.")
- parser.add_argument("--no_cuda",
- action='store_true',
- help="Whether not to use CUDA when available")
- parser.add_argument("--local_rank",
- type=int,
- default=-1,
- help="local_rank for distributed training on gpus")
- parser.add_argument('--seed',
- type=int,
- default=42,
- help="random seed for initialization")
- parser.add_argument('--gradient_accumulation_steps',
- type=int,
- default=1,
- help="Number of updates steps to accumulate before performing a backward/update pass.")
- parser.add_argument('--fp16',
- action='store_true',
- help="Whether to use 16-bit float precision instead of 32-bit")
- parser.add_argument('--early_stop',
- action='store_true',
- help="Whether to use early stop")
- parser.add_argument('--loss_scale',
- type=float, default=0,
- help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
- "0 (default value): dynamic loss scaling.\n"
- "Positive power of 2: static loss scaling value.\n")
- parser.add_argument('--cuda',
- type=str,
- default="",
- help="cuda index")
-
- args = parser.parse_args()
-
- processors = {
- "nlpec": nlpecProcessor,
- }
-
- num_labels_task = {
- "nlpec": 5,
- }
-
- if args.local_rank == -1 or args.no_cuda:
- device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
- n_gpu = torch.cuda.device_count()
- else:
- torch.cuda.set_device(args.local_rank)
- device = torch.device("cuda", args.local_rank)
- n_gpu = 1
- # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
- torch.distributed.init_process_group(backend='nccl')
-
- logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
- device, n_gpu, bool(args.local_rank != -1), args.fp16))
- print("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
- device, n_gpu, bool(args.local_rank != -1), args.fp16))
-
- if args.gradient_accumulation_steps < 1:
- raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
- args.gradient_accumulation_steps))
-
- args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)
-
- random.seed(args.seed)
- np.random.seed(args.seed)
- torch.manual_seed(args.seed)
- if n_gpu > 0:
- torch.cuda.manual_seed_all(args.seed)
-
- if not args.do_train and not args.do_eval:
- raise ValueError("At least one of `do_train` or `do_eval` must be True.")
-
- if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
- raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
- os.makedirs(args.output_dir, exist_ok=True)
-
- task_name = args.task_name.lower()
-
- if task_name not in processors:
- raise ValueError("Task not found: %s" % (task_name))
-
- print("current task is " + str(task_name))
-
- processor = processors[task_name]()
- num_labels = num_labels_task[task_name]
- label_list = processor.get_labels()
-
- # tokenizer = BertTokenizer.from_pretrained(args.bert_model)
- tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
- # tokenizer = BertTokenizer.from_pretrained("../cn_medical_bert/vocab.txt", do_lower_case=args.do_lower_case)
- train_examples = None
- num_train_steps = None
- if args.do_train:
- train_examples = processor.get_train_examples(args.data_dir)
- num_train_steps = int(
- len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
-
- # Prepare model
- model = Bert_MY.from_pretrained(args.bert_model,
- cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
- args.local_rank),
- num_choices=num_labels)
- # model = Bert_MY.from_pretrained("../cn_medical_bert/", num_choices=num_labels)
-
- pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
- logger.info('Parameters size:' + str(pytorch_total_params))
- if args.fp16:
- model.half()
- model.to(device)
- if args.local_rank != -1:
- try:
- from apex.parallel import DistributedDataParallel as DDP
- except ImportError:
- raise ImportError(
- "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
-
- model = DDP(model)
- elif n_gpu > 1:
- model = torch.nn.DataParallel(model)
-
- # Prepare optimizer
- # param_optimizer = filter(lambda p: p.requires_grad, model.parameters())
- param_optimizer = list(model.named_parameters())
- param_optimizer = [(n, p) for n, p in param_optimizer if p.requires_grad]
- no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
- optimizer_grouped_parameters = [
- {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
- {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
- ]
- t_total = num_train_steps
- if args.local_rank != -1:
- t_total = t_total // torch.distributed.get_world_size()
- if args.fp16:
- try:
- from apex.optimizers import FP16_Optimizer
- from apex.optimizers import FusedAdam
- except ImportError:
- raise ImportError(
- "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
-
- optimizer = FusedAdam(optimizer_grouped_parameters,
- lr=args.learning_rate,
- bias_correction=False,
- max_grad_norm=1.0)
- if args.loss_scale == 0:
- optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
- else:
- optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
-
- else:
- optimizer = BertAdam(optimizer_grouped_parameters,
- lr=args.learning_rate,
- warmup=args.warmup_proportion,
- t_total=t_total)
- # t_total = num_train_steps
- # param_optimizer = list(model.named_parameters())
- # # print("#" * 100)
- # # for n, p in param_optimizer:
- # # print(n)
- # # print("#" * 100)
- # no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
- # if n_gpu > 1:
- # name = "module.bert"
- # else:
- # name = "bert"
- # optimizer_grouped_parameters = [
- # {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay) and n.startswith(name)],
- # 'weight_decay': 0.01},
- # {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay) and n.startswith(name)],
- # 'weight_decay': 0.0}
- # ]
- #
- # optimizer_grouped_parameters_not_bert = [
- # {'params': [p for n, p in list(model.named_parameters()) if not n.startswith(name)], 'weight_decay': 0.0}
- # ]
- #
- # # optimizer_bert = Adafactor(optimizer_grouped_parameters, lr=5e-5, warmup_init=True)
- # optimizer_bert = BertAdam(optimizer_grouped_parameters,
- # lr=args.learning_rate,
- # warmup=0.1,
- # t_total=t_total)
- #
- # # optimizer_dec = AdaFactor(optimizer_grouped_parameters_not_bert, lr=0.001, weight_decay=0.)
- # # optimizer_dec = Adafactor(optimizer_grouped_parameters_not_bert, lr=0.0005, weight_decay=0.0)
- #
- # optimizer_other = torch.optim.Adam(optimizer_grouped_parameters_not_bert, lr=0.001, weight_decay=0.0001, amsgrad=True)
- #
- # optimizer = [optimizer_other, optimizer_bert]
-
- global_step = 0
- nb_tr_steps = 0
- tr_loss = 0
-
- best_eval_accuracy = 0.0
-
- if args.do_train:
- train_features = convert_examples_to_features_c_qa(train_examples, tokenizer,
- args.max_seq_length, True)
- logger.info("***** Running training *****")
- logger.info(" Num examples = %d", len(train_examples))
- logger.info(" Batch size = %d", args.train_batch_size)
- logger.info(" Num steps = %d", num_train_steps)
-
- all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
- all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
- all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
- all_doc_len = torch.tensor(select_field(train_features, 'doc_len'), dtype=torch.long)
- all_kg_len = torch.tensor(select_field(train_features, 'kg_len'), dtype=torch.long)
- all_ques_len = torch.tensor(select_field(train_features, 'ques_len'), dtype=torch.long)
- all_option_len = torch.tensor(select_field(train_features, 'option_len'), dtype=torch.long)
- all_kg_text_ids = torch.tensor(select_field(train_features, "kg_text_ids"), dtype=torch.long)
- all_kg_text_masks = torch.tensor(select_field(train_features, "kg_text_masks"), dtype=torch.long)
- all_kg_text_segment_ids = torch.tensor(select_field(train_features, "kg_text_segment_ids"), dtype=torch.long)
- all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
-
- train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label,
- all_doc_len, all_kg_len, all_ques_len, all_option_len, all_kg_text_ids, all_kg_text_masks, all_kg_text_segment_ids)
- if args.local_rank == -1:
- train_sampler = RandomSampler(train_data)
- else:
- train_sampler = DistributedSampler(train_data)
- train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)
-
- model.train()
-
- # Save a trained model
- model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
- output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
- best_epoch_idx = 0
- for epoch_idx in trange(int(args.num_train_epochs), desc="Epoch"):
- tr_loss = 0
- nb_tr_examples, nb_tr_steps = 0, 0
- for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
- batch = tuple(t.to(device) for t in batch)
- input_ids, input_mask, segment_ids, label_ids, doc_len, kg_len, ques_len, option_len, kg_text_input_ids, kg_text_mask, kg_text_segment_ids = batch
- loss, logits = model(input_ids, segment_ids, input_mask, doc_len, kg_len,
- ques_len, option_len, kg_text_input_ids, kg_text_mask, kg_text_segment_ids, label_ids)
- if n_gpu > 1:
- loss = loss.mean() # mean() to average on multi-gpu.
- if args.gradient_accumulation_steps > 1:
- loss = loss / args.gradient_accumulation_steps
-
- if args.fp16:
- optimizer.backward(loss)
- else:
- loss.backward()
-
- tr_loss += loss.item()
- nb_tr_examples += input_ids.size(0)
- nb_tr_steps += 1
- if (step + 1) % args.gradient_accumulation_steps == 0:
- # modify learning rate with special warm up BERT uses
- lr_this_step = args.learning_rate * warmup_linear(global_step / t_total,
- args.warmup_proportion)
- for param_group in optimizer.param_groups:
- param_group['lr'] = lr_this_step
- # for o in optimizer:
- # o.step()
- # o.zero_grad()
- optimizer.step()
- optimizer.zero_grad()
- global_step += 1
-
- if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
- eval_examples = processor.get_dev_examples(args.data_dir)
- eval_features = convert_examples_to_features_c_qa(eval_examples, tokenizer,
- args.max_seq_length, True)
- logger.info("***** Running evaluation *****")
- logger.info(" Num examples = %d", len(eval_examples))
- logger.info(" Batch size = %d", args.eval_batch_size)
- all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
- all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
- all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
- all_doc_len = torch.tensor(select_field(eval_features, 'doc_len'), dtype=torch.long)
- all_kg_len = torch.tensor(select_field(eval_features, 'kg_len'), dtype=torch.long)
- all_ques_len = torch.tensor(select_field(eval_features, 'ques_len'), dtype=torch.long)
- all_option_len = torch.tensor(select_field(eval_features, 'option_len'), dtype=torch.long)
- # 'kg_text_ids': kg_text_ids,
- # "kg_text_masks": kg_text_mask,
- # 'kg_text_segment_ids': kg_text_segment_ids
- all_kg_text_ids = torch.tensor(select_field(eval_features, "kg_text_ids"), dtype=torch.long)
- all_kg_text_masks = torch.tensor(select_field(eval_features, "kg_text_masks"), dtype=torch.long)
- all_kg_text_segment_ids = torch.tensor(select_field(eval_features, "kg_text_segment_ids"), dtype=torch.long)
- all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)
-
- eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label,
- all_doc_len, all_kg_len, all_ques_len, all_option_len, all_kg_text_ids, all_kg_text_masks,
- all_kg_text_segment_ids)
- # Run prediction for full data
- eval_sampler = SequentialSampler(eval_data)
- eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
-
- model.eval()
- eval_loss, eval_accuracy = 0, 0
- nb_eval_steps, nb_eval_examples = 0, 0
-
- for input_ids, input_mask, segment_ids, label_ids, doc_len, kg_len, ques_len, option_len, kg_text_input_ids, kg_text_mask, kg_text_segment_ids \
- in tqdm(eval_dataloader, desc="Evaluating"):
- input_ids = input_ids.to(device)
- input_mask = input_mask.to(device)
- segment_ids = segment_ids.to(device)
- label_ids = label_ids.to(device)
- doc_len = doc_len.to(device)
- kg_len = kg_len.to(device)
- ques_len = ques_len.to(device)
- option_len = option_len.to(device)
- kg_text_input_ids = kg_text_input_ids.to(device)
- kg_text_mask = kg_text_mask.to(device)
- kg_text_segment_ids = kg_text_segment_ids.to(device)
- with torch.no_grad():
- tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask,
- doc_len, kg_len, ques_len, option_len, kg_text_input_ids, kg_text_mask, kg_text_segment_ids, label_ids)
-
- logits = logits.detach().cpu().numpy()
- label_ids = label_ids.to('cpu').numpy()
- tmp_eval_accuracy = accuracy(logits, label_ids)
-
- eval_loss += tmp_eval_loss.mean().item()
- eval_accuracy += tmp_eval_accuracy
-
- nb_eval_examples += input_ids.size(0)
- nb_eval_steps += 1
-
- eval_accuracy = eval_accuracy / nb_eval_examples
- eval_loss = eval_loss / nb_eval_examples
- logger.info(
- "the current eval accuracy is: {}, best acc: {}, loss {}\n".format(eval_accuracy, best_eval_accuracy
- , eval_loss))
- if eval_accuracy > best_eval_accuracy:
-
- best_eval_accuracy = eval_accuracy
- best_epoch_idx = epoch_idx
-
- if args.do_train:
- torch.save(model_to_save.state_dict(), output_model_file)
- if args.early_stop:
- if epoch_idx + 1 - best_epoch_idx >= 5:
- break
- model.train()
-
- # Load a trained model that you have fine-tuned
- output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
- model_state_dict = torch.load(output_model_file)
-
- # model = Bert_MY.from_pretrained("./output_dir/pkukg_ft/",
- # num_choices=num_labels,
- # state_dict=model_state_dict)
- model = Bert_MY.from_pretrained(args.bert_model ,
- cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
- args.local_rank),
- num_choices=num_labels,
- state_dict=model_state_dict)
- # args.bert_model,
- # cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(
- # args.local_rank),
- if args.fp16:
- model.half()
- model.to(device)
-
- if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
- eval_examples = processor.get_dev_examples(args.data_dir)
- eval_features = convert_examples_to_features_c_qa(eval_examples, tokenizer,
- args.max_seq_length, True)
- logger.info("***** Running evaluation *****")
- logger.info(" Num examples = %d", len(eval_examples))
- logger.info(" Batch size = %d", args.eval_batch_size)
- all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
- all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
- all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
- all_doc_len = torch.tensor(select_field(eval_features, 'doc_len'), dtype=torch.long)
- all_kg_len = torch.tensor(select_field(eval_features, 'kg_len'), dtype=torch.long)
- all_ques_len = torch.tensor(select_field(eval_features, 'ques_len'), dtype=torch.long)
- all_option_len = torch.tensor(select_field(eval_features, 'option_len'), dtype=torch.long)
- all_kg_text_ids = torch.tensor(select_field(eval_features, "kg_text_ids"), dtype=torch.long)
- all_kg_text_masks = torch.tensor(select_field(eval_features, "kg_text_masks"), dtype=torch.long)
- all_kg_text_segment_ids = torch.tensor(select_field(eval_features, "kg_text_segment_ids"), dtype=torch.long)
- all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)
-
- eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label,
- all_doc_len, all_kg_len, all_ques_len, all_option_len, all_kg_text_ids, all_kg_text_masks,all_kg_text_segment_ids )
- # Run prediction for full data
- eval_sampler = SequentialSampler(eval_data)
- eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)
-
- model.eval()
- eval_loss, eval_accuracy = 0, 0
- nb_eval_steps, nb_eval_examples = 0, 0
-
- all_pred_labels = []
- all_anno_labels = []
- all_logits = []
-
- for input_ids, input_mask, segment_ids, label_ids, doc_len, kg_len, ques_len, option_len, kg_text_input_ids, kg_text_mask, kg_text_segment_ids \
- in tqdm(eval_dataloader, desc="Evaluating"):
- input_ids = input_ids.to(device)
- input_mask = input_mask.to(device)
- segment_ids = segment_ids.to(device)
- label_ids = label_ids.to(device)
- doc_len = doc_len.to(device)
- kg_len = kg_len.to(device)
- ques_len = ques_len.to(device)
- option_len = option_len.to(device)
- kg_text_input_ids = kg_text_input_ids.to(device)
- kg_text_mask = kg_text_mask.to(device)
- kg_text_segment_ids = kg_text_segment_ids.to(device)
-
- with torch.no_grad():
- tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, doc_len, kg_len,
- ques_len, option_len, kg_text_input_ids, kg_text_mask, kg_text_segment_ids, label_ids)
-
- logits = logits.detach().cpu().numpy()
- label_ids = label_ids.to('cpu').numpy()
-
- output_labels = np.argmax(logits, axis=1)
- all_pred_labels.extend(output_labels.tolist())
- all_logits.extend(list(logits))
- all_anno_labels.extend(list(label_ids))
-
- tmp_eval_accuracy = accuracy(logits, label_ids)
-
- eval_loss += tmp_eval_loss.mean().item()
- eval_accuracy += tmp_eval_accuracy
-
- nb_eval_examples += input_ids.size(0)
- nb_eval_steps += 1
-
- eval_loss = eval_loss / nb_eval_steps
- eval_accuracy = eval_accuracy / nb_eval_examples
- loss = tr_loss / nb_tr_steps if args.do_train else None
- result = {'eval_loss': eval_loss,
- 'eval_accuracy': eval_accuracy,
- 'best_eval_accuracy': best_eval_accuracy,
- 'global_step': global_step,
- 'loss': loss}
-
- output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
- with open(output_eval_file, "w") as writer:
- logger.info("***** Eval results *****")
- for key in sorted(result.keys()):
- logger.info(" %s = %s", key, str(result[key]))
- writer.write("%s = %s\n" % (key, str(result[key])))
- for i in range(len(all_pred_labels)):
- writer.write(str(i) + "\t" + str(all_anno_labels[i]) + "\t" +
- str(all_pred_labels[i]) + "\t" + str(all_logits[i]) + "\n")
-
-
- if __name__ == "__main__":
- main()
|