|
- # Copyright (C) 2021 NVIDIA Corporation. All rights reserved.
- # Licensed under The MIT License (MIT)
- # Permission is hereby granted, free of charge, to any person obtaining a copy of
- # this software and associated documentation files (the "Software"), to deal in
- # the Software without restriction, including without limitation the rights to
- # use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
- # the Software, and to permit persons to whom the Software is furnished to do so,
- # subject to the following conditions:
- # The above copyright notice and this permission notice shall be included in all
- # copies or substantial portions of the Software.
- # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
- # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
- # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
- # COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
- # IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
- # CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
-
- import pickle
-
- import torch
- from torch import distributed as dist
-
-
- def get_rank():
- if not dist.is_available():
- return 0
-
- if not dist.is_initialized():
- return 0
-
- return dist.get_rank()
-
-
- def synchronize():
- if not dist.is_available():
- return
-
- if not dist.is_initialized():
- return
-
- world_size = dist.get_world_size()
-
- if world_size == 1:
- return
-
- dist.barrier()
-
-
- def get_world_size():
- if not dist.is_available():
- return 1
-
- if not dist.is_initialized():
- return 1
-
- return dist.get_world_size()
-
-
- def reduce_sum(tensor):
- if not dist.is_available():
- return tensor
-
- if not dist.is_initialized():
- return tensor
-
- tensor = tensor.clone()
- dist.all_reduce(tensor, op=dist.ReduceOp.SUM)
-
- return tensor
-
-
- def gather_grad(params):
- world_size = get_world_size()
-
- if world_size == 1:
- return
-
- for param in params:
- if param.grad is not None:
- dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM)
- param.grad.data.div_(world_size)
-
-
- def all_gather(data):
- world_size = get_world_size()
-
- if world_size == 1:
- return [data]
-
- buffer = pickle.dumps(data)
- storage = torch.ByteStorage.from_buffer(buffer)
- tensor = torch.ByteTensor(storage).to('cuda')
-
- local_size = torch.IntTensor([tensor.numel()]).to('cuda')
- size_list = [torch.IntTensor([0]).to('cuda') for _ in range(world_size)]
- dist.all_gather(size_list, local_size)
- size_list = [int(size.item()) for size in size_list]
- max_size = max(size_list)
-
- tensor_list = []
- for _ in size_list:
- tensor_list.append(torch.ByteTensor(size=(max_size,)).to('cuda'))
-
- if local_size != max_size:
- padding = torch.ByteTensor(size=(max_size - local_size,)).to('cuda')
- tensor = torch.cat((tensor, padding), 0)
-
- dist.all_gather(tensor_list, tensor)
-
- data_list = []
-
- for size, tensor in zip(size_list, tensor_list):
- buffer = tensor.cpu().numpy().tobytes()[:size]
- data_list.append(pickle.loads(buffer))
-
- return data_list
-
-
- def reduce_loss_dict(loss_dict):
- world_size = get_world_size()
-
- if world_size < 2:
- return loss_dict
-
- with torch.no_grad():
- keys = []
- losses = []
-
- for k in sorted(loss_dict.keys()):
- keys.append(k)
- losses.append(loss_dict[k])
-
- losses = torch.stack(losses, 0)
- dist.reduce(losses, dst=0)
-
- if dist.get_rank() == 0:
- losses /= world_size
-
- reduced_losses = {k: v for k, v in zip(keys, losses)}
-
- return reduced_losses
|