-
Notifications
You must be signed in to change notification settings - Fork 52
/
inverse_warp.py
executable file
·110 lines (87 loc) · 4.22 KB
/
inverse_warp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
from __future__ import division
import torch
from torch.autograd import Variable
pixel_coords = None
def set_id_grid(depth):
global pixel_coords
b, h, w = depth.size()
i_range = Variable(torch.arange(0, h).view(1, h, 1).expand(1,h,w)).type_as(depth) # [1, H, W]
j_range = Variable(torch.arange(0, w).view(1, 1, w).expand(1,h,w)).type_as(depth) # [1, H, W]
ones = Variable(torch.ones(1,h,w)).type_as(depth)
pixel_coords = torch.stack((j_range, i_range, ones), dim=1) # [1, 3, H, W]
def check_sizes(input, input_name, expected):
condition = [input.ndimension() == len(expected)]
for i,size in enumerate(expected):
if size.isdigit():
condition.append(input.size(i) == int(size))
assert(all(condition)), "wrong size for {}, expected {}, got {}".format(input_name, 'x'.join(expected), list(input.size()))
def pixel2cam(depth, intrinsics_inv):
global pixel_coords
"""Transform coordinates in the pixel frame to the camera frame.
Args:
depth: depth maps -- [B, H, W]
intrinsics_inv: intrinsics_inv matrix for each element of batch -- [B, 3, 3]
Returns:
array of (u,v,1) cam coordinates -- [B, 3, H, W]
"""
b, h, w = depth.size()
if (pixel_coords is None) or pixel_coords.size(2) < h:
set_id_grid(depth)
current_pixel_coords = pixel_coords[:,:,:h,:w].expand(b,3,h,w).contiguous().view(b, 3, -1).cuda() # [B, 3, H*W]
cam_coords = intrinsics_inv.bmm(current_pixel_coords).view(b, 3, h, w)
return cam_coords * depth.unsqueeze(1)
def cam2pixel(cam_coords, proj_c2p_rot, proj_c2p_tr, padding_mode):
"""Transform coordinates in the camera frame to the pixel frame.
Args:
cam_coords: pixel coordinates defined in the first camera coordinates system -- [B, 4, H, W]
proj_c2p_rot: rotation matrix of cameras -- [B, 3, 4]
proj_c2p_tr: translation vectors of cameras -- [B, 3, 1]
Returns:
array of [-1,1] coordinates -- [B, 2, H, W]
"""
b, _, h, w = cam_coords.size()
cam_coords_flat = cam_coords.view(b, 3, -1) # [B, 3, H*W]
if proj_c2p_rot is not None:
pcoords = proj_c2p_rot.bmm(cam_coords_flat)
else:
pcoords = cam_coords_flat
if proj_c2p_tr is not None:
pcoords = pcoords + proj_c2p_tr # [B, 3, H*W]
X = pcoords[:, 0]
Y = pcoords[:, 1]
Z = pcoords[:, 2].clamp(min=1e-3)
X_norm = 2*(X / Z)/(w-1) - 1 # Normalized, -1 if on extreme left, 1 if on extreme right (x = w-1) [B, H*W]
Y_norm = 2*(Y / Z)/(h-1) - 1 # Idem [B, H*W]
if padding_mode == 'zeros':
X_mask = ((X_norm > 1)+(X_norm < -1)).detach()
X_norm[X_mask] = 2 # make sure that no point in warped image is a combinaison of im and gray
Y_mask = ((Y_norm > 1)+(Y_norm < -1)).detach()
Y_norm[Y_mask] = 2
pixel_coords = torch.stack([X_norm, Y_norm], dim=2) # [B, H*W, 2]
return pixel_coords.view(b,h,w,2)
def inverse_warp(feat, depth, pose, intrinsics, intrinsics_inv, padding_mode='zeros'):
"""
Inverse warp a source image to the target image plane.
Args:
feat: the source feature (where to sample pixels) -- [B, CH, H, W]
depth: depth map of the target image -- [B, H, W]
pose: 6DoF pose parameters from target to source -- [B, 6]
intrinsics: camera intrinsic matrix -- [B, 3, 3]
intrinsics_inv: inverse of the intrinsic matrix -- [B, 3, 3]
Returns:
Source image warped to the target image plane
"""
check_sizes(depth, 'depth', 'BHW')
check_sizes(pose, 'pose', 'B34')
check_sizes(intrinsics, 'intrinsics', 'B33')
check_sizes(intrinsics_inv, 'intrinsics', 'B33')
assert(intrinsics_inv.size() == intrinsics.size())
batch_size, _, feat_height, feat_width = feat.size()
cam_coords = pixel2cam(depth, intrinsics_inv)
pose_mat = pose
pose_mat = pose_mat.cuda()
# Get projection matrix for tgt camera frame to source pixel frame
proj_cam_to_src_pixel = intrinsics.bmm(pose_mat) # [B, 3, 4]
src_pixel_coords = cam2pixel(cam_coords, proj_cam_to_src_pixel[:,:,:3], proj_cam_to_src_pixel[:,:,-1:], padding_mode) # [B,H,W,2]
projected_feat = torch.nn.functional.grid_sample(feat, src_pixel_coords, padding_mode=padding_mode)
return projected_feat