|
- import time
- import importlib
-
- from os.path import join, isdir, isfile, relpath
- from glob import glob
-
- from evaluation import BaseConfig, ModelForEvaluation, DEFAULT_CLASS, print_rank_0
- from initialize import initialize, initialize_model_and_tokenizer
-
-
- def add_evaluation_specific_args(parser):
- """Arguments for evaluation"""
- group = parser.add_argument_group("evaluation", "Evaluation configurations")
-
- # Task
- group.add_argument("--task", nargs="+", default=[], help="All task config to evaluation")
- group.add_argument("--data-path", type=str, required=True, help="Data dir path for all tasks")
- return parser
-
-
- def find_all_tasks(all_task_config_path):
- tasks = []
- for task in all_task_config_path:
- if isdir(task):
- tasks += [relpath(path, ".") for path in glob(join(task, "**/*.yaml"), recursive=True)]
- elif isfile(task):
- tasks.append(task)
- return tasks
-
-
- def evaluate_all_tasks(data_path, model, tokenizer, all_task_config_path, task_classes):
- for config_path, task_class in zip(all_task_config_path, task_classes):
- config = task_class.config_class().from_yaml_file(config_path)
- config.path = join(data_path, config.path)
- task = task_class(model, tokenizer, config)
- task.evaluate()
-
-
- def main():
- args = initialize(extra_args_provider=add_evaluation_specific_args)
- args.task = find_all_tasks(args.task)
-
- task_classes = []
- print_rank_0("> Loading task configs")
- for task_config_path in args.task:
- config = BaseConfig.from_yaml_file(task_config_path)
- if config.module:
- path = ".".join(config.module.split(".")[:-1])
- module = importlib.import_module(path)
- class_name = config.module.split(".")[-1]
- task_class = getattr(module, class_name)
- task_classes.append(task_class)
- else:
- task_classes.append(DEFAULT_CLASS[config.type])
- print_rank_0(f" Task {config.name} loaded from config {task_config_path}")
- print_rank_0(f"> Successfully load {len(task_classes)} task{'s' if len(task_classes) > 1 else ''}")
-
- model, tokenizer = initialize_model_and_tokenizer(args)
- model = ModelForEvaluation(model)
-
- start = time.time()
- evaluate_all_tasks(args.data_path, model, tokenizer, args.task, task_classes)
- print_rank_0(f"Finish {len(task_classes)} task{'s' if len(task_classes) > 1 else ''} in {time.time() - start:.1f}s")
-
-
- if __name__ == "__main__":
- main()
|