|
- # !/usr/bin/env python
- # coding=UTF-8
- """
- @Author: WEN Hao
- @LastEditors: WEN Hao
- @Description:
- @Date: 2021-12-02
- @LastEditTime: 2023-05-08
-
- 加载数据集命令行参数
-
- """
-
- import importlib
- import warnings
- from argparse import ArgumentParser
- from dataclasses import dataclass
- from pathlib import Path
- from typing import Any, Dict
-
- from datasets import list_datasets
-
- from ..const import BUILTIN_DATASETS, DATASET_NAME_MAPPING, MODEL_DEFAULT_DATASET
- from ..Datasets import NLPDataset
- from ..utils.misc import module_name
- from ..utils.strings import LANGUAGE, normalize_language # noqa: F401
-
- __all__ = [
- "DatasetArgs",
- ]
-
-
- @dataclass
- class DatasetArgs:
- """ """
-
- __name__ = "DatasetArgs"
-
- dataset: str = None
- subset: str = None
- max_len: int = 512
-
- @classmethod
- def _add_parser_args(cls, parser: ArgumentParser) -> ArgumentParser:
- """ """
- default_obj = cls()
- ds_group = parser.add_argument_group()
- ds_group.add_argument(
- "-d",
- "--dataset",
- type=str,
- help="模型评测数据集名称或者路径",
- default=default_obj.dataset,
- dest="dataset",
- choices=list(BUILTIN_DATASETS.keys()),
- )
- ds_group.add_argument(
- "--subset",
- type=str,
- help="模型评测数据集子集名称",
- default=default_obj.subset,
- dest="subset",
- choices=["train", "test"],
- )
- ds_group.add_argument(
- "--max-len",
- type=int,
- help="对抗样本字符数目上限",
- default=default_obj.max_len,
- dest="max_len",
- )
- return parser
-
- @classmethod
- def _create_dataset_from_args(cls, args: Dict) -> Any:
- """ """
- obj = cls()
- obj.dataset = args.get("dataset")
- obj.subset = args.get("subset")
- obj.max_len = args.get("max_len")
- language = normalize_language(args.get("language"))
- try:
- hugggingface_datasets = list_datasets()
- except Exception:
- # in case of network error
- hugggingface_datasets = []
- if obj.dataset is None:
- # if language == LANGUAGE.CHINESE:
- # obj.dataset = "jd_binary"
- # elif language == LANGUAGE.ENGLISH:
- # obj.dataset = "sst"
- # load dataset according to model
- model_name = args.get("model")
- assert model_name is not None, "模型名称与数据集名称不能同时为空"
- obj.dataset = MODEL_DEFAULT_DATASET.get(model_name)
- assert obj.dataset is not None
- print(f"未指定数据集, 将加载默认数据集 {obj.dataset}")
- if obj.dataset in BUILTIN_DATASETS:
- ds_cls = getattr(
- importlib.import_module(f"{module_name}.Datasets.{DATASET_NAME_MAPPING[obj.dataset]}"),
- BUILTIN_DATASETS[obj.dataset],
- )
- ds = ds_cls(subsets=obj.subset, max_len=obj.max_len)
- elif Path(obj.dataset).exists():
- ds = NLPDataset.from_table(
- obj.dataset,
- subsets=obj.subset,
- max_len=obj.max_len,
- language=language,
- dataset_name="custom",
- )
- elif obj.dataset in hugggingface_datasets:
- warnings.warn(
- f"将从 Huggingface datasets 加载数据集 {obj.dataset}, " "过程中会从 raw.githubusercontent.com 下载部分代码,可能会比较慢,甚至失败",
- RuntimeWarning,
- )
- ds = NLPDataset.from_huggingface_dataset(obj.dataset, split=obj.subset, max_len=obj.max_len, language=language)
- else:
- raise ValueError(
- f"未知数据集 {obj.dataset}"
- "若要使用自定义数据集,请指定正确的数据集路径;"
- "若要使用 Huggingface datasets, 请指定正确的数据集名称;"
- "若要使用内置数据集,请指定正确的数据集名称。"
- )
-
- return ds
|