|
- import torch
- import numpy as np
- from scipy import ndimage
- import time
-
- _small = 1e-100
-
-
- def _U(x):
- return (x ** 2) * np.where(x < _small, 0, np.log(x))
-
-
- def _interpoint_distances(points):
- xd = np.subtract.outer(points[:, 0], points[:, 0])
- yd = np.subtract.outer(points[:, 1], points[:, 1])
- zd = np.subtract.outer(points[:, 2], points[:, 2])
- return np.sqrt(xd ** 2 + yd ** 2 + zd ** 2)
-
-
- def _calculate_f(coeffs, points, x, y, z):
- w = coeffs[:-4]
- a1, ax, ay, az = coeffs[-4:]
- # The following uses too much RAM:
- distances = _U(np.sqrt((points[:, 0] - x[..., np.newaxis]) ** 2
- + (points[:, 1] - y[..., np.newaxis]) ** 2
- + (points[:, 2] - z[..., np.newaxis]) ** 2
- ))
- summation = (w * distances).sum(axis=-1)
- return a1 + ax * x + ay * y + az * z + summation
-
-
- def _make_L_matrix(points):
- n = len(points)
- K = _U(_interpoint_distances(points))
- P = np.ones((n, 4))
- P[:, 1:] = points
- O = np.zeros((4, 4))
- L = np.asarray(np.bmat([[K, P], [P.transpose(), O]]))
- return L
-
-
- def _make_warp(from_points, to_points, x_vals, y_vals, z_vals):
- from_points, to_points = np.asarray(from_points), np.asarray(to_points)
- err = np.seterr(divide='ignore')
- L = _make_L_matrix(from_points)
- V = np.resize(to_points, (len(to_points) + 4, 3))
- V[-4:, :] = 0
- coeffs = np.dot(np.linalg.pinv(L), V)
- x_warp = _calculate_f(coeffs[:, 0], from_points, x_vals, y_vals, z_vals)
- y_warp = _calculate_f(coeffs[:, 1], from_points, x_vals, y_vals, z_vals)
- z_warp = _calculate_f(coeffs[:, 2], from_points, x_vals, y_vals, z_vals)
-
- np.seterr(**err)
- return [x_warp, y_warp, z_warp]
-
-
- def _make_inverse_warp(from_points, to_points, output_region, approximate_grid):
- x_min, x_max, y_min, y_max, z_min, z_max = output_region
- if approximate_grid is None: approximate_grid = 1
- x_steps = (x_max - x_min) / approximate_grid
- y_steps = (y_max - y_min) / approximate_grid
- z_steps = (z_max - z_min) / approximate_grid
- if approximate_grid == 1:
- x_steps += 1
- y_steps += 1
- z_steps += 1
- x, y, z = np.mgrid[x_min:x_max:x_steps * 1j, y_min:y_max:y_steps * 1j, z_min:z_max:z_steps * 1j]
- # todo
- # make the reverse transform warping from the to_points to the from_points, because we
- # do image interpolation in this reverse fashion
- transform = _make_warp(to_points, from_points, x, y, z)
- # transform = _make_warp(to_points, from_points, x, y, z)
-
- if approximate_grid != 1:
- # linearly interpolate the zoomed transform grid
- new_x, new_y, new_z = np.mgrid[x_min:x_max + 1, y_min:y_max + 1, z_min:z_max + 1]
- x_fracs, x_indices = np.modf((x_steps - 1) * (new_x - x_min) / float(x_max - x_min))
- y_fracs, y_indices = np.modf((y_steps - 1) * (new_y - y_min) / float(y_max - y_min))
- z_fracs, z_indices = np.modf((z_steps - 1) * (new_z - z_min) / float(z_max - z_min))
- x_indices = x_indices.astype(int)
- y_indices = y_indices.astype(int)
- z_indices = z_indices.astype(int)
- x1 = 1 - x_fracs
- y1 = 1 - y_fracs
- z1 = 1 - z_fracs
- ix1 = (x_indices + 1).clip(0, x_steps - 1).astype(int)
- iy1 = (y_indices + 1).clip(0, y_steps - 1).astype(int)
- iz1 = (z_indices + 1).clip(0, z_steps - 1).astype(int)
-
- transform_res = []
- for i in range(len(transform)):
- t000 = transform[i][(x_indices, y_indices, z_indices)]
- t001 = transform[i][(x_indices, y_indices, iz1)]
- t010 = transform[i][(x_indices, iy1, z_indices)]
- t011 = transform[i][(x_indices, iy1, iz1)]
- t100 = transform[i][(ix1, y_indices, z_indices)]
- t101 = transform[i][(ix1, y_indices, iz1)]
- t110 = transform[i][(ix1, iy1, z_indices)]
- t111 = transform[i][(ix1, iy1, iz1)]
- transform_temp = t000 * x1 * y1 * z1 \
- + t001 * x1 * y1 * z_fracs \
- + t010 * x1 * y_fracs * z1 \
- + t011 * x1 * y_fracs * z_fracs \
- + t100 * x_fracs * y1 * z1 \
- + t101 * x_fracs * y1 * z_fracs \
- + t110 * x_fracs * y_fracs * z1 \
- + t111 * x_fracs * y_fracs * z_fracs
- transform_res.append(transform_temp)
- transform = transform_res
- # transform_x.shape = voxel.size
- return transform
-
-
- def warp_voxel(from_points, to_points, voxels, output_region, interpolation_order=1, approximate_grid=10):
- """
-
- Args:
- from_points ():
- to_points ():
- voxels ():
- output_region (): x_min, x_max,y_min, y_max, z_min, z_max
- interpolation_order ():
- approximate_grid ():
- Returns:
-
- """
- transform = _make_inverse_warp(from_points, to_points, output_region, approximate_grid)
- return [ndimage.map_coordinates(np.asarray(voxel), transform, order=interpolation_order, mode='reflect') for
- voxel in voxels]
-
-
- def _thin_plate_spline_warp(voxel, src_points, dst_points, keep_corners=True):
- assert len(voxel.shape) == 3 or 4
- if len(voxel.shape) == 3:
- voxel = np.expand_dims(voxel, axis=0)
- x_max, y_max, z_max = voxel.shape[1:]
- if keep_corners:
- corner_points = np.array(
- [[0, 0, 0], [0, 0, z_max], [0, y_max, 0], [0, y_max, z_max], [x_max, 0, 0], [x_max, 0, z_max],
- [x_max, y_max, 0], [x_max, y_max, z_max]])
- src_points = np.concatenate((src_points, corner_points))
- dst_points = np.concatenate((dst_points, corner_points))
- output_region = [0, x_max - 1, 0, y_max - 1, 0, z_max - 1]
- result = warp_voxel(src_points, dst_points, voxel, output_region, 1, 5)
- result = np.array(result)
- # result = np.squeeze(result)
- return result
-
-
- def tps_warp_3D(voxel, src, dst, keep_corners=True):
- """
-
- Args:
- voxel (): default : 28 * 256 * 256 * 256
- dst ():
- src ():
-
- Returns:
-
- """
- out = _thin_plate_spline_warp(voxel, src, dst, keep_corners)
- return out
-
-
- def create_control_points(grid_size, points_per_dim=3):
- x_max, y_max, z_max = grid_size
- scale = 0.1
- x = np.linspace(1, x_max - 1, int(points_per_dim))
- y = np.linspace(1, y_max - 1, int(points_per_dim))
- z = np.linspace(1, z_max - 1, int(points_per_dim))
- x, y, z = np.meshgrid(x, y, z)
- from_points = np.dstack([x.flat, y.flat, z.flat])[0]
- from_points[:, 0] = np.clip(from_points[:, 0], 1, x_max - 1)
- from_points[:, 1] = np.clip(from_points[:, 1], 1, y_max - 1)
- from_points[:, 2] = np.clip(from_points[:, 2], 1, z_max - 1)
-
- to_points = from_points + np.random.uniform(-scale * grid_size, scale * grid_size, from_points.shape)
- # to_points = from_points - random2
- to_points[:, 0] = np.clip(to_points[:, 0], 1, x_max - 1)
- to_points[:, 1] = np.clip(to_points[:, 1], 1, y_max - 1)
- to_points[:, 2] = np.clip(to_points[:, 2], 1, z_max - 1)
- return from_points, to_points
-
- def get_tps_coeffs(from_points, to_points, voxel_size, save_path=None, keep_corners=True):
- from_points, to_points = np.asarray(from_points), np.asarray(to_points)
- x_max, y_max, z_max = voxel_size
- if keep_corners:
- corner_points = np.array(
- [[0, 0, 0], [0, 0, z_max], [0, y_max, 0], [0, y_max, z_max], [x_max, 0, 0], [x_max, 0, z_max],
- [x_max, y_max, 0], [x_max, y_max, z_max]])
- from_points = np.concatenate((from_points, corner_points))
- to_points = np.concatenate((to_points, corner_points))
- err = np.seterr(divide='ignore')
- def get_coeffs(from_points, to_points):
- L = _make_L_matrix(from_points)
- V = np.resize(to_points, (len(to_points) + 4, 3))
- V[-4:, :] = 0
- coeffs = np.dot(np.linalg.pinv(L), V)
- return coeffs
- coeffs = get_coeffs(to_points, from_points)
- np.seterr(**err)
- if save_path:
- np.savez(save_path, coeffs=coeffs, from_points=from_points,to_points=to_points)
- return coeffs,from_points,to_points
-
-
- def get_tps_deform(pos, tps_params):
- from_points = tps_params['to_points']
- coeffs = tps_params['coeffs']
- err = np.seterr(divide='ignore')
- x_warp = _calculate_f(coeffs[:, 0], from_points, pos[:, 0], pos[:, 1], pos[:, 2])
- y_warp = _calculate_f(coeffs[:, 1], from_points, pos[:, 0], pos[:, 1], pos[:, 2])
- z_warp = _calculate_f(coeffs[:, 2], from_points, pos[:, 0], pos[:, 1], pos[:, 2])
- np.seterr(**err)
- pos[:, 0], pos[:, 1], pos[:, 2] = x_warp, y_warp, z_warp
- return pos
-
- def main():
- grid_size = [3, 3, 3]
- grid_size = np.array(grid_size)
- from_points, to_points = create_control_points(grid_size)
- coeffs,from_points,to_points = get_tps_coeffs(from_points, to_points, grid_size)
- tps_params = {}
- tps_params['from_points'] = from_points
- tps_params['to_points'] = to_points
- tps_params['coeffs'] = coeffs
- pos = np.arange(12).reshape(4,3)
- get_tps_deform(pos, tps_params)
-
- if __name__ == "__main__":
- main()
|