|
- # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import argparse
- import os
- from typing import Optional
- from functools import partial
- from dataclasses import dataclass, field
-
- import paddle
- from paddlenlp.datasets import load_dataset, MapDataset
- from paddlenlp.transformers import AutoTokenizer, UIEX
- from paddlenlp.metrics import SpanEvaluator
- from paddlenlp.utils.log import logger
- from paddlenlp.utils.ie_utils import unify_prompt_name, get_relation_type_dict, uie_loss_func, compute_metrics
- from paddlenlp.data import DataCollatorWithPadding
- from paddlenlp.trainer import PdArgumentParser, TrainingArguments, Trainer
-
- from utils import convert_example, reader
-
-
- @dataclass
- class DataArguments:
- """
- Arguments pertaining to what data we are going to input our model for evaluation.
- Using `PdArgumentParser` we can turn this class into argparse arguments to be able to
- specify them on the command line.
- """
- test_path: str = field(default=None,
- metadata={"help": "The path of test set."})
-
- schema_lang: str = field(
- default="ch",
- metadata={
- "help": "Select the language type for schema, such as 'ch', 'en'"
- })
-
- max_seq_len: Optional[int] = field(
- default=512,
- metadata={
- "help":
- "The maximum total input sequence length after tokenization. Sequences longer "
- "than this will be truncated, sequences shorter will be padded."
- })
-
- debug: bool = field(
- default=False,
- metadata={"help": "Whether choose debug mode."},
- )
-
-
- @dataclass
- class ModelArguments:
- """
- Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
- """
- model_path: Optional[str] = field(
- default=None,
- metadata={"help": "The path of saved model that you want to load."})
-
-
- def do_eval():
- parser = PdArgumentParser(
- (ModelArguments, DataArguments, TrainingArguments))
- model_args, data_args, training_args = parser.parse_args_into_dataclasses()
-
- # Log model and data config
- training_args.print_config(model_args, "Model")
- training_args.print_config(data_args, "Data")
-
- paddle.set_device(training_args.device)
-
- tokenizer = AutoTokenizer.from_pretrained(model_args.model_path)
- model = UIEX.from_pretrained(model_args.model_path)
-
- test_ds = load_dataset(reader,
- data_path=data_args.test_path,
- max_seq_len=data_args.max_seq_len,
- lazy=False)
- trans_fn = partial(convert_example,
- tokenizer=tokenizer,
- max_seq_len=data_args.max_seq_len)
- if data_args.debug:
- class_dict = {}
- relation_data = []
-
- for data in test_ds:
- class_name = unify_prompt_name(data['prompt'])
- # Only positive examples are evaluated in debug mode
- if len(data['result_list']) != 0:
- p = "的" if data_args.schema_lang == "ch" else " of "
- if p not in data['prompt']:
- class_dict.setdefault(class_name, []).append(data)
- else:
- relation_data.append((data['prompt'], data))
-
- relation_type_dict = get_relation_type_dict(
- relation_data, schema_lang=data_args.schema_lang)
- test_ds = test_ds.map(trans_fn)
-
- trainer = Trainer(
- model=model,
- criterion=uie_loss_func,
- args=training_args,
- eval_dataset=test_ds,
- tokenizer=tokenizer,
- compute_metrics=compute_metrics,
- )
- eval_metrics = trainer.evaluate()
- logger.info("-----Evaluate model-------")
- logger.info("Class Name: ALL CLASSES")
- logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
- (eval_metrics['eval_precision'], eval_metrics['eval_recall'],
- eval_metrics['eval_f1']))
- logger.info("-----------------------------")
- if data_args.debug:
- for key in class_dict.keys():
- test_ds = MapDataset(class_dict[key])
- test_ds = test_ds.map(trans_fn)
- eval_metrics = trainer.evaluate(eval_dataset=test_ds)
-
- logger.info("Class Name: %s" % key)
- logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
- (eval_metrics['eval_precision'],
- eval_metrics['eval_recall'], eval_metrics['eval_f1']))
- logger.info("-----------------------------")
- for key in relation_type_dict.keys():
- test_ds = MapDataset(relation_type_dict[key])
- test_ds = test_ds.map(trans_fn)
- eval_metrics = trainer.evaluate(eval_dataset=test_ds)
- logger.info("-----------------------------")
- if data_args.schema_lang == "ch":
- logger.info("Class Name: X的%s" % key)
- else:
- logger.info("Class Name: %s of X" % key)
- logger.info("Evaluation Precision: %.5f | Recall: %.5f | F1: %.5f" %
- (eval_metrics['eval_precision'],
- eval_metrics['eval_recall'], eval_metrics['eval_f1']))
-
-
- if __name__ == "__main__":
- do_eval()
|