|
- import os
- import torch
- import stat
- import re
-
- from functools import partial
- from typing import List, Tuple
-
- from SwissArmyTransformer import mpu
- from evaluation.model import batch_filling_sequence
- from generation import BeamSearchStrategy, BaseStrategy
- from SwissArmyTransformer.generation.utils import timed_name, generate_continually
- from initialize import initialize, initialize_model_and_tokenizer
-
-
- def add_generation_specific_args(parser):
- parser.add_argument("--sampling-strategy", type=str, default="BaseStrategy", help="Type of sampling strategy.")
- parser.add_argument("--min-gen-length", type=int, default=0, help="The minimum length each blank should generate.")
- parser.add_argument(
- "--print-all-beams", action="store_true", help="Print all output generated by beam search strategy."
- )
-
-
- def isEnglish(s):
- try:
- s.encode(encoding="utf-8").decode("ascii")
- except UnicodeDecodeError:
- return False
- else:
- return True
-
-
- def get_masks_and_position_ids(seq, mask_position, max_gen_length, gmask=False):
- context_length = seq.shape[1]
- tokens = torch.nn.functional.pad(seq, (0, max_gen_length), mode="constant", value=-1)
- attention_mask = torch.ones((1, tokens.shape[-1], tokens.shape[-1]), device=tokens.device)
- attention_mask.tril_()
- attention_mask[..., : context_length - 1] = 1
- attention_mask.unsqueeze_(1)
- attention_mask = (attention_mask < 0.5).bool()
-
- position_ids = torch.arange(tokens.shape[-1], dtype=torch.long, device=tokens.device)
- if not gmask:
- position_ids[context_length - 1 :] = mask_position
-
- position_ids = position_ids.unsqueeze(0)
-
- return tokens, attention_mask, position_ids
-
-
- def fill_blanks(raw_text: str, model, tokenizer, strategy) -> Tuple[List[str], List[str], List[List[str]]]:
- # add MASK
- generation_mask = "[gMASK]"
- if "[MASK]" in raw_text:
- generation_mask = "[MASK]"
- elif "[sMASK]" in raw_text:
- generation_mask = "[sMASK]"
- use_gmask = "[MASK]" not in raw_text and "[sMASK]" not in raw_text
-
- mask_pattern = r"\[[sg]?MASK\]"
- text_list = re.split(mask_pattern, raw_text)
- pattern_list = re.compile(mask_pattern).findall(raw_text)
- seq = []
- for i in range(len(pattern_list)):
- pattern = pattern_list[i]
- sub_text = text_list[i]
- seq.extend(tokenizer.tokenize(sub_text))
- seq.append(tokenizer.get_command(pattern))
-
- seq.extend(tokenizer.tokenize(text_list[-1]))
-
- if "MASK]" not in raw_text:
- seq += [tokenizer.get_command(generation_mask)]
- raw_text += " " + generation_mask
- if not raw_text.endswith("MASK]"):
- seq = seq + [tokenizer.get_command("eos")]
- if mpu.get_model_parallel_rank() == 0:
- print("\nInput: {}\n".format(raw_text))
- if len(seq) > args.max_sequence_length:
- raise ValueError("text too long.")
-
- # generation
- is_english = isEnglish(raw_text)
- output_list = [seq]
- num_output = args.num_beams if args.sampling_strategy == "BeamSearchStrategy" else 1
- last_pos, answers, answers_with_style, blanks = (
- [0] * num_output,
- ["" for _ in range(num_output)],
- ["" for _ in range(num_output)],
- [[] for _ in range(num_output)],
- )
-
- # continually detect the first mark position
- while True:
- seq = output_list[0]
- # detect mask position
- mask_token = tokenizer.get_command(generation_mask)
- if mask_token not in seq:
- break
- mask_position = seq.index(mask_token)
-
- output_list = []
-
- input_seq = torch.cuda.LongTensor(
- [seq + [tokenizer.get_command("sop")]],
- device=args.device,
- )
- output, _ = batch_filling_sequence(
- model,
- input_seq,
- torch.cuda.LongTensor([input_seq.shape[-1]], device=args.device),
- strategy=strategy,
- get_masks_and_position_ids=partial(
- get_masks_and_position_ids,
- mask_position=mask_position,
- max_gen_length=args.out_seq_length - input_seq.shape[-1],
- gmask=use_gmask,
- ),
- )
- if isinstance(output, torch.Tensor): # different strategies
- output = output.tolist()
- output = output[0] # batch_size = 1
- output_list.extend(output)
-
- # clip -1s and fill back generated things into seq
- for i in range(len(output_list)):
- output = output_list[i].tolist() if isinstance(output_list[i], torch.Tensor) else output_list[i]
- try:
- unfinished = output.index(-1)
- except ValueError:
- unfinished = len(output)
- if output[unfinished - 1] in strategy.end_tokens:
- unfinished -= 1
- bog = output.index(tokenizer.get_command("sop"))
-
- prefix = tokenizer.detokenize(output[last_pos[i] : mask_position])
- blank = tokenizer.detokenize(output[bog + 1 : unfinished])
- answers_with_style[i] += (
- prefix
- + (" " if is_english else "")
- + ("\033[4m" if use_gmask else "\x1b[0;32m\033[4m")
- + blank
- + ("\033[0m" if use_gmask else "\033[0m\x1b[0m")
- + (" " if is_english else "")
- )
- blanks[i].append(blank)
- last_pos[i] = mask_position + unfinished - (bog + 1)
- output_list[i] = output[:mask_position] + output[bog + 1 : unfinished] + output[mask_position + 1 : bog]
-
- for i, output in enumerate(output_list):
- if output[-1] == tokenizer.get_command("eos"):
- output = output[:-1]
- answers_with_style[i] += tokenizer.detokenize(output[last_pos[i] :])
- answers[i] = tokenizer.detokenize(output)
-
- return answers, answers_with_style, blanks
-
-
- def main(args):
- model, tokenizer = initialize_model_and_tokenizer(args)
-
- end_tokens = [tokenizer.get_command("eop"), tokenizer.get_command("eos")]
-
- if args.sampling_strategy == "BaseStrategy":
- strategy = BaseStrategy(
- batch_size=1, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, end_tokens=end_tokens
- )
- elif args.sampling_strategy == "BeamSearchStrategy":
- strategy = BeamSearchStrategy(
- 1,
- args.num_beams,
- length_penalty=args.length_penalty,
- consider_end=True,
- end_tokens=end_tokens,
- no_repeat_ngram_size=args.no_repeat_ngram_size,
- min_gen_length=args.min_gen_length,
- )
- else:
- raise ValueError(f"unknown strategy {args.sampling_strategy}")
-
- def process(raw_text):
- if args.with_id:
- query_id, raw_text = raw_text.split("\t")
-
- answers, answers_with_style, blanks = fill_blanks(raw_text, model, tokenizer, strategy)
-
- # save
- if args.with_id:
- full_path = os.path.join(args.output_path, query_id + ".txt")
- else:
- prefix = raw_text.replace("/", "")[:20]
- full_path = timed_name(prefix, ".txt", args.output_path)
- if mpu.get_model_parallel_rank() == 0:
- if args.print_all_beams and len(answers) > 1:
- for idx, answer_with_style in enumerate(answers_with_style):
- print(f"Output beam {idx}:", answer_with_style) # print the first.
- if len(answer_with_style) > 120:
- print("")
- else:
- print(f"Output:", answers_with_style[0]) # print the first.
- with open(full_path, "w", encoding="utf-8") as fout:
- for answer in answers:
- fout.write(answer + "\n")
-
- os.chmod(full_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU)
-
- os.makedirs(args.output_path, exist_ok=True)
- generate_continually(process, args.input_source)
-
-
- if __name__ == "__main__":
- args = initialize(extra_args_provider=add_generation_specific_args)
-
- with torch.no_grad():
- main(args)
|