|
- # coding=utf-8
- # Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- """Sample Generate GPT2"""
-
- import os
- import nltk
- import random
- import numpy as np
- import torch
- import torch.nn.functional as F
- import argparse
- import time
- from datetime import datetime
- from arguments import get_args
- from utils import Timers, set_random_seed
- from utils import load_checkpoint, get_checkpoint_iteration
- from data_utils import make_tokenizer
- from configure_data import configure_data
- import mpu
-
- from fp16 import FP16_Module
- from model import GPT2Model
- from utils import print_rank_0
-
- USE_TORCH_DDP = True
-
-
- def get_model(args):
- """Build the model."""
-
- print_rank_0('building GPT2 model ...')
- model = GPT2Model(num_layers=args.num_layers,
- vocab_size=args.vocab_size,
- hidden_size=args.hidden_size,
- num_attention_heads=args.num_attention_heads,
- embedding_dropout_prob=args.hidden_dropout,
- attention_dropout_prob=args.attention_dropout,
- output_dropout_prob=args.hidden_dropout,
- max_sequence_length=args.max_position_embeddings,
- max_memory_length=args.mem_length,
- checkpoint_activations=args.checkpoint_activations,
- checkpoint_num_layers=args.checkpoint_num_layers,
- parallel_output=False,
- relative_encoding=args.transformer_xl)
-
- if mpu.get_data_parallel_rank() == 0:
- print(' > number of parameters on model parallel rank {}: {}'.format(
- mpu.get_model_parallel_rank(),
- sum([p.nelement() for p in model.parameters()])), flush=True)
-
- # To prevent OOM for model sizes that cannot fit in GPU memory in full precision
- if hasattr(args, "deepspeed") and args.deepspeed and args.fp16:
- model.half()
-
- # GPU allocation.
- model.cuda(torch.cuda.current_device())
-
- # Fp16 conversion.
- if args.fp16:
- model = FP16_Module(model)
-
- # Wrap model for distributed training.
- if USE_TORCH_DDP:
- from model import PyTorchDistributedDataParallel as DDP
- i = torch.cuda.current_device()
- model = DDP(model, device_ids=[i], output_device=i,
- process_group=mpu.get_data_parallel_group())
- else:
- from model import DistributedDataParallel as DDP
- model = DDP(model)
-
- return model
-
-
- def get_masks_and_position_ids(data,
- eod_token,
- reset_position_ids,
- reset_attention_mask,
- loss_mask=None,
- attention_mask=None,
- transformer_xl=False,
- mem_length=None):
- # Extract batch size and sequence length.
- batch_size, seq_length = data.size()
-
- # Attention mask (lower triangular).
- if transformer_xl:
- if attention_mask is None:
- attention_mask = torch.ones((1, seq_length, seq_length + mem_length), device=data.device)
- attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + mem_length), mem_length)
- else:
- if reset_attention_mask:
- att_mask_batch = batch_size
- else:
- att_mask_batch = 1
- if attention_mask is None:
- attention_mask = torch.ones((att_mask_batch, seq_length, seq_length), device=data.device)
- attention_mask = torch.tril(attention_mask)
- attention_mask = attention_mask.unsqueeze(1)
-
- # Loss mask.
- if loss_mask is None:
- loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
-
- # Position ids.
- position_ids = torch.arange(seq_length, dtype=torch.long,
- device=data.device)
- position_ids = position_ids.unsqueeze(0).expand_as(data)
- if not transformer_xl:
- loss_mask[data == eod_token] = 0.0
- # We need to clone as the ids will be modifed based on batch index.
- if reset_position_ids:
- position_ids = position_ids.clone()
-
- if reset_position_ids or reset_attention_mask:
- # Loop through the batches:
- for b in range(batch_size):
-
- # Find indecies where EOD token is.
- eod_index = position_ids[b, data[b] == eod_token]
- # Detach indecies from positions if going to modify positions.
- if reset_position_ids:
- eod_index = eod_index.clone()
-
- # Loop through EOD indecies:
- prev_index = 0
- for j in range(eod_index.size()[0]):
- i = eod_index[j]
- # Mask attention loss.
- if reset_attention_mask:
- attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
- # Reset positions.
- if reset_position_ids:
- position_ids[b, (i + 1):] -= (i + 1 - prev_index)
- prev_index = i + 1
-
- return attention_mask, loss_mask, position_ids
-
-
- def initialize_distributed(args):
- """Initialize torch.distributed."""
-
- # Manually set the device ids.
- device = args.rank % torch.cuda.device_count()
- if args.local_rank is not None:
- device = args.local_rank
- torch.cuda.set_device(device)
- # Call the init process
- init_method = 'tcp://'
- master_ip = os.getenv('MASTER_ADDR', 'localhost')
- master_port = os.getenv('MASTER_PORT', '6000')
- init_method += master_ip + ':' + master_port
- torch.distributed.init_process_group(
- backend=args.distributed_backend,
- world_size=args.world_size, rank=args.rank,
- init_method=init_method)
-
- # Set the model-parallel / data-parallel communicators.
- mpu.initialize_model_parallel(args.model_parallel_size)
-
- # Optional DeepSpeed Activation Checkpointing Features
- #
- if hasattr(args, "deepspeed") and args.deepspeed and args.deepspeed_activation_checkpointing:
- set_deepspeed_activation_checkpointing(args)
-
-
- def setup_model(args):
- """Setup model and optimizer."""
-
- model = get_model(args)
-
- # if args.deepspeed:
- # print_rank_0("DeepSpeed is enabled.")
- #
- # model, _, _, _ = deepspeed.initialize(
- # model=model,
- # model_parameters=model.parameters(),
- # args=args,
- # mpu=mpu,
- # dist_init_required=False
- # )
- if args.load is not None:
- if args.deepspeed:
- iteration, release, success = get_checkpoint_iteration(args)
- path = os.path.join(args.load, str(iteration), "mp_rank_00_model_states.pt")
- checkpoint = torch.load(path)
- model.load_state_dict(checkpoint["module"])
- print(f"Load model file {path}")
- else:
- _ = load_checkpoint(
- model, None, None, args, load_optimizer_states=False)
- # if args.deepspeed:
- # model = model.module
-
- return model
-
-
- def get_batch(context_tokens, device, args):
- tokens = context_tokens
- tokens = tokens.view(args.batch_size, -1).contiguous()
- tokens = tokens.to(device)
-
- # Get the masks and postition ids.
- attention_mask, loss_mask, position_ids = get_masks_and_position_ids(
- tokens,
- args.eod_token,
- reset_position_ids=False,
- reset_attention_mask=False,
- transformer_xl=args.transformer_xl,
- mem_length=args.mem_length)
-
- return tokens, attention_mask, position_ids
-
-
- def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
- # This function has been mostly taken from huggingface conversational ai code at
- # https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
-
- if top_k > 0:
- # Remove all tokens with a probability less than the last token of the top-k
- indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
- logits[indices_to_remove] = filter_value
-
- if top_p > 0.0:
- # convert to 1D
- logits = logits.view(logits.size()[1]).contiguous()
- sorted_logits, sorted_indices = torch.sort(logits, descending=True)
- cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
-
- # Remove tokens with cumulative probability above the threshold
- sorted_indices_to_remove = cumulative_probs > top_p
- # Shift the indices to the right to keep also the first token above the threshold
- sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
- sorted_indices_to_remove[..., 0] = 0
- indices_to_remove = sorted_indices[sorted_indices_to_remove]
- logits[indices_to_remove] = filter_value
- # going back to 2D
- logits = logits.view(1, -1).contiguous()
-
- return logits
-
-
- def sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args, device, mems=None, end_token=None):
- tokens, attention_mask, position_ids = get_batch(context_tokens_tensor, device, args)
-
- counter = 0
- if mems is None:
- mems = []
- if end_token is None:
- end_token = args.eod_token
- org_context_length = context_length
- while counter < (args.out_seq_length - org_context_length):
- if counter == 0:
- logits, *mems = model(tokens, position_ids, attention_mask, *mems)
- else:
- index = org_context_length + counter
- logits, *mems = model(tokens[:, index - 1: index], tokens.new_ones((1, 1)) * (index - 1),
- tokens.new_ones(1, 1, 1, args.mem_length + 1, device=tokens.device,
- dtype=torch.float), *mems)
- logits = logits[:, -1]
- logits /= args.temperature
- logits = top_k_logits(logits, top_k=args.top_k, top_p=args.top_p)
- log_probs = F.softmax(logits, dim=-1)
- prev = torch.multinomial(log_probs, num_samples=1)[0]
- is_end = prev == end_token
- if is_end:
- break
- tokens = torch.cat((tokens, prev.view(1, 1)), dim=1)
- context_length += 1
- counter += 1
- if mpu.get_model_parallel_rank() == 0 and counter % 16 == 0:
- output_tokens_list = tokens.view(-1).contiguous()
- decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
- if mpu.get_model_parallel_rank() == 0 and (counter % 128 == 0 or is_end):
- os.system('clear')
- trim_decode_tokens = decode_tokens
- print(trim_decode_tokens, flush=True)
- output_tokens_list = tokens.view(-1).contiguous()
- return output_tokens_list, mems
-
-
- def read_context(tokenizer, args, output):
- terminate_runs, skip_run = 0, 0
- if mpu.get_model_parallel_rank() == 0:
- while True:
- raw_text = input("\nContext prompt (stop to exit) >>> ")
- if not raw_text:
- print('Prompt should not be empty!')
- continue
- if raw_text == "stop":
- terminate_runs = 1
- break
- output.write(raw_text)
- context_tokens = tokenizer.EncodeAsIds(raw_text).tokenization
- context_length = len(context_tokens)
-
- if context_length >= args.seq_length:
- print("\nContext length", context_length,
- "\nPlease give smaller context than the window length!")
- continue
- break
- else:
- context_length = 0
-
- terminate_runs_tensor = torch.cuda.LongTensor([terminate_runs])
- torch.distributed.broadcast(terminate_runs_tensor, mpu.get_model_parallel_src_rank(),
- group=mpu.get_model_parallel_group())
- terminate_runs = terminate_runs_tensor[0].item()
-
- if terminate_runs == 1:
- return terminate_runs, None, None, None
-
- context_length_tensor = torch.cuda.LongTensor([context_length])
-
- torch.distributed.broadcast(context_length_tensor, mpu.get_model_parallel_src_rank(),
- group=mpu.get_model_parallel_group())
- context_length = context_length_tensor[0].item()
- if mpu.get_model_parallel_rank() == 0:
- context_tokens_tensor = torch.cuda.LongTensor(context_tokens)
- else:
- context_tokens_tensor = torch.cuda.LongTensor([0] * context_length)
- torch.distributed.broadcast(context_tokens_tensor, mpu.get_model_parallel_src_rank(),
- group=mpu.get_model_parallel_group())
- if mpu.get_model_parallel_rank() != 0:
- raw_text = tokenizer.DecodeIds(context_tokens_tensor.tolist())
- return terminate_runs, raw_text, context_tokens_tensor, context_length
-
-
- def generate_samples(model, tokenizer, args, device):
- model.eval()
- output_path = "./samples"
- if not os.path.exists(output_path):
- os.makedirs(output_path)
- output_path = os.path.join(output_path, f"sample-{datetime.now().strftime('%m-%d-%H-%M')}.txt")
- with torch.no_grad(), open(output_path, "w") as output:
- while True:
- torch.distributed.barrier(group=mpu.get_model_parallel_group())
-
- terminate_runs, raw_text, context_tokens_tensor, context_length = read_context(tokenizer, args, output)
- if terminate_runs == 1:
- return
- start_time = time.time()
- output_tokens_list, _ = sample_sequence(model, tokenizer, context_tokens_tensor, context_length, args,
- device)
- if mpu.get_model_parallel_rank() == 0:
- os.system('clear')
- print("\nTaken time {:.2f}\n".format(time.time() - start_time), flush=True)
- print("\nContext:", raw_text, flush=True)
- decode_tokens = tokenizer.DecodeIds(output_tokens_list.tolist())
- trim_decode_tokens = decode_tokens[len(raw_text):]
- print("\nGPT2:", trim_decode_tokens, flush=True)
- output.write(trim_decode_tokens + "\n")
-
- torch.distributed.barrier(group=mpu.get_model_parallel_group())
-
-
- def prepare_tokenizer(args):
- tokenizer_args = {
- 'tokenizer_type': args.tokenizer_type,
- 'corpus': None,
- 'model_path': args.tokenizer_path,
- 'vocab_size': args.vocab_size,
- 'model_type': args.tokenizer_model_type,
- 'cache_dir': args.cache_dir,
- 'add_eop': args.hierarchical}
- tokenizer = make_tokenizer(**tokenizer_args)
-
- num_tokens = tokenizer.num_tokens
- before = num_tokens
- after = before
- multiple = args.make_vocab_size_divisible_by
- while (after % multiple) != 0:
- after += 1
- print_rank_0('> padded vocab (size: {}) with {} dummy '
- 'tokens (new size: {})'.format(
- before, after - before, after))
-
- args.tokenizer_num_tokens = after
- args.tokenizer_num_type_tokens = tokenizer.num_type_tokens
- args.eod_token = tokenizer.get_command('eos').Id
-
- # after = tokenizer.num_tokens
- # while after % mpu.get_model_parallel_world_size() != 0:
- # after += 1
-
- args.vocab_size = after
- print("prepare tokenizer done", flush=True)
-
- return tokenizer
-
-
- def main():
- """Main training program."""
-
- print('Generate Samples')
-
- # Disable CuDNN.
- torch.backends.cudnn.enabled = False
-
- # Arguments.
- args = get_args()
- args.deepspeed = False
- args.mem_length = args.seq_length + args.mem_length - 1
-
- # Pytorch distributed.
- initialize_distributed(args)
-
- # Random seeds for reproducability.
- set_random_seed(args.seed)
-
- # get the tokenizer
- tokenizer = prepare_tokenizer(args)
-
- # Model, optimizer, and learning rate.
- model = setup_model(args)
-
- # setting default batch size to 1
- args.batch_size = 1
-
- # generate samples
- generate_samples(model, tokenizer, args, torch.cuda.current_device())
-
-
- if __name__ == "__main__":
- main()
|