|
- import torch
- from torch.utils.cpp_extension import load
- import os
-
- script_dir = os.path.dirname(__file__)
- sources = [
- os.path.join(script_dir, "chamfer_distance.cpp"),
- os.path.join(script_dir, "chamfer_distance.cu"),
- ]
-
- cd = load(name="cd", sources=sources)
-
-
- class ChamferDistanceFunction(torch.autograd.Function):
- @staticmethod
- def forward(ctx, xyz1, xyz2):
- batchsize, n, _ = xyz1.size()
- _, m, _ = xyz2.size()
- xyz1 = xyz1.contiguous()
- xyz2 = xyz2.contiguous()
- dist1 = torch.zeros(batchsize, n)
- dist2 = torch.zeros(batchsize, m)
-
- idx1 = torch.zeros(batchsize, n, dtype=torch.int)
- idx2 = torch.zeros(batchsize, m, dtype=torch.int)
-
- if not xyz1.is_cuda:
- cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
- else:
- dist1 = dist1.cuda()
- dist2 = dist2.cuda()
- idx1 = idx1.cuda()
- idx2 = idx2.cuda()
- cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2)
-
- ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
-
- return dist1, dist2
-
- @staticmethod
- def backward(ctx, graddist1, graddist2):
- xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
-
- graddist1 = graddist1.contiguous()
- graddist2 = graddist2.contiguous()
-
- gradxyz1 = torch.zeros(xyz1.size())
- gradxyz2 = torch.zeros(xyz2.size())
-
- if not graddist1.is_cuda:
- cd.backward(
- xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
- )
- else:
- gradxyz1 = gradxyz1.cuda()
- gradxyz2 = gradxyz2.cuda()
- cd.backward_cuda(
- xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
- )
-
- return gradxyz1, gradxyz2
-
- chamfer_distance = ChamferDistanceFunction.apply
|