|
- # Copyright (c) 2021 Baidu.com, Inc. All Rights Reserved
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import argparse
- import os
- import random
- import time
- import json
- import numpy as np
- import paddle
- from paddle.io import DataLoader
- from paddlenlp.transformers import LinearDecayWithWarmup
- from data_loader import DuIEDataset, DataCollator
- from utils import decoding, write_content, write_compare_result, get_precision_recall_f1_v1, get_optimizer, get_model_size
- from RE_Global_Model import RE_Global_Model
- from DefaultLogger import DefaultLogger
- from config import get_config
-
- parser = argparse.ArgumentParser()
- parser.add_argument("--envir", type=str, choices=['local', 'cloud'], default="local")
- parser.add_argument("--run_mode", type=str, choices=['train', 'predict'], default="predict")
- parser.add_argument("--data_name", type=str, choices=['qianyan_60_100', 'qianyan_60_1000',
- 'nyt_deal', 'nyt_star_deal', 'webnlg_deal', 'webnlg_star_deal'], default="webnlg_star_deal")
-
- args = parser.parse_args()
- # 根据两个参数,获取真正的参数
- args = get_config(args.envir, args.run_mode, args.data_name)
- # 定义日志的输出目录
- logger = DefaultLogger(args)
- logger.write_args(args.prn_obj())
-
- def set_random_seed(seed):
- """sets random seed"""
- random.seed(seed)
- np.random.seed(seed)
- paddle.seed(seed)
-
- set_random_seed(args.seed)
-
-
- @paddle.no_grad()
- def evaluate(model, data_loader, file_path, mode, epoch):
- example_all = []
- with open(file_path, "r", encoding="utf-8") as fp:
- for line in fp:
- example_all.append(json.loads(line))
-
- with open(os.path.join(args.data_path, "id2spo.json"), 'r', encoding='utf8') as fp:
- id2spo = json.load(fp)
-
- model.eval()
-
- loss_all = 0
- current_idx = 0
- formatted_outputs = []
- start_time = time.time()
-
- for batch in data_loader:
- input_ids, seq_len, tok_to_orig_start_index, tok_to_orig_end_index, subject_labels, object_labels = batch
- loss, subject_logits, object_logits = model(input_ids=input_ids, subject_labels=subject_labels, object_labels=object_labels)
- loss_all += loss.numpy().item()
- formatted_outputs.extend(decoding(example_all[current_idx: current_idx + len(subject_logits)],
- id2spo,
- subject_logits,
- object_logits,
- seq_len.numpy(),
- tok_to_orig_start_index.numpy(),
- tok_to_orig_end_index.numpy()))
-
- current_idx = current_idx + len(subject_logits)
-
- loss_avg = loss_all / len(data_loader)
- logger.info(f"eval loss: {loss_avg}, all use time: {(time.time() - start_time) / len(data_loader)}")
-
- os.makedirs(os.path.join(args.output_path, "predict"), exist_ok=True)
- predict_file_path = os.path.join(args.output_path, f"predict/{mode}_{epoch}.json")
- compare_file_path = os.path.join(args.output_path, f"predict/{mode}_comp_{epoch}.json")
-
- # 将本次的预测结果写入到文件中
- write_content(formatted_outputs, predict_file_path)
- # 将本次的验证集实际值和预测结果的对比结果写入文件中
- write_compare_result(example_all, formatted_outputs, compare_file_path)
-
- # 基于公开数据集的验证
- precision, recall, f1 = get_precision_recall_f1_v1(compare_file_path, match_pattern=args.match_pattern)
- return precision, recall, f1
-
-
- def do_train():
- paddle.set_device(args.device)
-
- # ========== 定义模型 ==========
- model = RE_Global_Model(args)
- logger.info(get_model_size(model, framework="paddlepaddle"))
-
- # ========== 加载数据集 ==========
- train_file_path = os.path.join(args.data_path, 'train_data.json')
- train_dataset = DuIEDataset.from_file(train_file_path, model.tokenizer, args, True)
- train_batch_sampler = paddle.io.DistributedBatchSampler(train_dataset, batch_size=args.batch_size, shuffle=True, drop_last=True)
- collator = DataCollator()
- train_data_loader = DataLoader(
- dataset=train_dataset,
- batch_sampler=train_batch_sampler,
- collate_fn=collator,
- return_list=True)
- logger.info(f"train_dataset: {len(train_dataset)}")
-
- eval_file_path = os.path.join(args.data_path, 'dev_data.json')
- eval_dataset = DuIEDataset.from_file(eval_file_path, model.tokenizer, args, True)
- eval_batch_sampler = paddle.io.BatchSampler(eval_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
- eval_data_loader = DataLoader(
- dataset=eval_dataset,
- batch_sampler=eval_batch_sampler,
- collate_fn=collator,
- return_list=True)
- logger.info(f"eval_dataset: {len(eval_dataset)}")
-
- # ========== 定义优化器 ==========
- steps_by_epoch = len(train_data_loader)
- num_training_steps = steps_by_epoch * args.num_train_epochs
-
- lr_scheduler_pre = LinearDecayWithWarmup(args.learning_rate_encoder, num_training_steps, args.warmup_ratio)
- optimizer = get_optimizer(model, lr_scheduler_pre, args.weight_decay)
-
- max_f1 = 0.0
- global_step = 0
- logging_steps = 200
- tic_train = time.time()
- for epoch in range(args.num_train_epochs):
- logger.info("\n=====start training of %d epochs=====" % epoch)
- tic_epoch = time.time()
- model.train()
-
- for step, batch in enumerate(train_data_loader):
- optimizer.clear_grad()
-
- input_ids, seq_lens, tok_to_orig_start_index, tok_to_orig_end_index, subject_labels, object_labels = batch
- loss, _, _ = model(input_ids=input_ids, subject_labels=subject_labels, object_labels=object_labels)
-
- loss.backward()
- optimizer.step()
- lr_scheduler_pre.step()
-
- loss_item = loss.numpy().item()
- if global_step % logging_steps == 0:
- logger.info("epoch: %d / %d, steps: %d / %d, loss: %f, speed: %.2f step/s" % (epoch, args.num_train_epochs, step, steps_by_epoch, loss_item, logging_steps / (time.time() - tic_train)))
- tic_train = time.time()
-
- global_step += 1
-
- logger.info("\n=====start evaluating of %d epochs=====" % epoch)
- precision, recall, f1 = evaluate(model, eval_data_loader, eval_file_path, "eval", epoch)
- if f1 >= max_f1:
- max_f1 = f1
- os.makedirs(os.path.join(args.output_path, "checkpoints"), exist_ok=True)
- logger.info(f"saving checkpoint to {os.path.join(args.output_path, f'checkpoints/model_{args.data_name}.pdparams')}")
- paddle.save(model.state_dict(), os.path.join(args.output_path, f"checkpoints/model_{args.data_name}.pdparams"))
-
- logger.info("precision: %.2f\t recall: %.2f\t f1: %.2f\t best f1: %.2f" % (100 * precision, 100 * recall, 100 * f1, 100 * max_f1))
- tic_epoch = time.time() - tic_epoch
- logger.info("epoch time footprint: %d hour %d min %d sec" % (tic_epoch // 3600, (tic_epoch % 3600) // 60, tic_epoch % 60))
-
-
- def valid_(model, testfile_path):
- # ========== 加载数据集 ==========
- test_dataset = DuIEDataset.from_file(testfile_path, model.tokenizer, args, True)
- test_batch_sampler = paddle.io.BatchSampler(test_dataset, batch_size=args.batch_size, shuffle=False, drop_last=False)
- collator = DataCollator()
- test_data_loader = DataLoader(
- dataset=test_dataset,
- batch_sampler=test_batch_sampler,
- collate_fn=collator,
- return_list=True)
- logger.info(f"test_dataset: {len(test_dataset)}")
-
- logger.info(f"\n=====start predicting: {testfile_path}=====")
- precision, recall, f1 = evaluate(model, test_data_loader, testfile_path, "pred", epoch=-1)
- logger.info("precision: %.2f\t recall: %.2f\t f1: %.2f" % (100 * precision, 100 * recall, 100 * f1))
- logger.info(f"=====predicting complete: {testfile_path}=====")
-
-
- def do_predict():
- paddle.set_device(args.device)
-
- # ========== 加载模型 ==========
- model = RE_Global_Model(args)
- state_dict = paddle.load(os.path.join(args.output_path, f"checkpoints/model_{args.data_name}.pdparams"))
- model.set_dict(state_dict)
-
- if "star" in args.data_name:
- for data_sign in ['_1', '_2', '_3', '_4', '_5', '_epo', '_seo', '_normal', '']:
- testfile_path = os.path.join(args.data_path, f'test_triples{data_sign}.json')
- valid_(model, testfile_path)
- else:
- testfile_path = os.path.join(args.data_path, 'test_data.json')
- valid_(model, testfile_path)
-
-
- if __name__ == "__main__":
- if args.run_mode == "train":
- do_train()
- elif args.run_mode == "predict":
- do_predict()
|