|
- # Copyright (c) Facebook, Inc. and its affiliates.
- #
- # This source code is licensed under the MIT license found in the
- # LICENSE file in the root directory of this source tree.
-
- import logging
-
- import torch
-
-
- logger = logging.getLogger(__name__)
-
-
- class NanDetector:
- """
- Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name
- """
-
- def __init__(self, model, forward=True, backward=True):
- self.bhooks = []
- self.fhooks = []
- self.forward = forward
- self.backward = backward
- self.named_parameters = list(model.named_parameters())
- self.reset()
-
- for name, mod in model.named_modules():
- mod.__module_name = name
- self.add_hooks(mod)
-
- def __enter__(self):
- return self
-
- def __exit__(self, exc_type, exc_value, exc_traceback):
- # Dump out all model gnorms to enable better debugging
- norm = {}
- gradients = {}
- for name, param in self.named_parameters:
- if param.grad is not None:
- grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32)
- norm[name] = grad_norm.item()
- if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any():
- gradients[name] = param.grad.data
- if len(gradients) > 0:
- logger.info("Detected nan/inf grad norm, dumping norms...")
- logger.info(f"norms: {norm}")
- logger.info(f"gradients: {gradients}")
-
- self.close()
-
- def add_hooks(self, module):
- if self.forward:
- self.fhooks.append(module.register_forward_hook(self.fhook_fn))
- if self.backward:
- self.bhooks.append(module.register_backward_hook(self.bhook_fn))
-
- def reset(self):
- self.has_printed_f = False
- self.has_printed_b = False
-
- def _detect(self, tensor, name, backward):
- err = None
- if (
- torch.is_floating_point(tensor)
- # single value tensors (like the loss) will not provide much info
- and tensor.numel() >= 2
- ):
- with torch.no_grad():
- if torch.isnan(tensor).any():
- err = "NaN"
- elif torch.isinf(tensor).any():
- err = "Inf"
- if err is not None:
- err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}"
- return err
-
- def _apply(self, module, inp, x, backward):
- if torch.is_tensor(x):
- if isinstance(inp, tuple) and len(inp) > 0:
- inp = inp[0]
- err = self._detect(x, module.__module_name, backward)
- if err is not None:
- if torch.is_tensor(inp) and not backward:
- err += (
- f" input max: {inp.max().item()}, input min: {inp.min().item()}"
- )
-
- has_printed_attr = "has_printed_b" if backward else "has_printed_f"
- logger.warning(err)
- setattr(self, has_printed_attr, True)
- elif isinstance(x, dict):
- for v in x.values():
- self._apply(module, inp, v, backward)
- elif isinstance(x, list) or isinstance(x, tuple):
- for v in x:
- self._apply(module, inp, v, backward)
-
- def fhook_fn(self, module, inp, output):
- if not self.has_printed_f:
- self._apply(module, inp, output, backward=False)
-
- def bhook_fn(self, module, inp, output):
- if not self.has_printed_b:
- self._apply(module, inp, output, backward=True)
-
- def close(self):
- for hook in self.fhooks + self.bhooks:
- hook.remove()
|