|
- import typer
- from pathlib import Path
- from datasets import load_dataset
- from transformers import AutoTokenizer
- from rich import print
- from typer import Argument, Option
- from typing import Optional
-
-
- def prepare_data(dataset_save_name: str = Argument(..., help='数据集保存名称'),
- tokenizer_name_or_path: str = Argument(..., help='tokenizer名称或路径'),
- cache_dir: str = Argument(..., help='缓存目录,用于保存处理后的数据集'),
- corpus_dir: str = Argument(..., help='原始语料目录,将会读取目录中所有的jsonl文件'),
- num_proc: int = Option(20, help='处理进程数,默认为20'),
- use_prompt: Optional[bool] = Option(None, help='是否使用prompt,默认为None,即使用prompt'),
- max_length: int = Option(512, help='文本处理最大长度,默认为512')):
- """数据预处理
-
- """
-
- print('load tokenizer')
- tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path)
- if not tokenizer.pad_token:
- tokenizer.pad_token_id = tokenizer.eos_token_id
- raw_dir = Path(corpus_dir)
- data_files = []
- for p in raw_dir.iterdir():
- if p.is_file() and str(p).endswith('jsonl'):
- data_files.append(str(p))
-
- print(f'all data files: {data_files}')
- ds = load_dataset('json', data_files=data_files, cache_dir=cache_dir)
-
- def tokenize_with_prompt(examples):
- texts = [pre_prompt + content + post_prompt for pre_prompt, content, post_prompt in zip(examples['pre-prompt'], examples['content'], examples['post-prompt'])]
- inputs = tokenizer(texts,
- return_overflowing_tokens=True,
- max_length=max_length,
- truncation=True)
- return {'input_ids': inputs['input_ids']}
-
- def tokenize_no_prompt(examples):
- texts = [pre_prompt + content + post_prompt for pre_prompt, content, post_prompt in zip(examples['pre-prompt'], examples['content'], examples['post-prompt'])]
- inputs = tokenizer(texts,
- return_overflowing_tokens=True,
- max_length=max_length,
- truncation=True)
- return {'input_ids': inputs['input_ids']}
-
-
- print('preprocess dataset')
- if use_prompt:
- ds = ds.map(tokenize_with_prompt, batched=True, num_proc=num_proc, remove_columns=ds['train'].column_names)
- else:
- ds = ds.map(tokenize_no_prompt, batched=True, num_proc=num_proc, remove_columns=ds['train'].column_names)
-
- print('save dataset')
- save_path = Path(cache_dir, dataset_save_name)
- ds.save_to_disk(save_path)
-
-
- if __name__ == "__main__":
- typer.run(prepare_data)
|