Skip to content

Commit 70acb3e

Browse files
bottlerfacebook-github-bot
authored andcommitted
new tests demonstrating pixel matching
Summary: Demonstrate current behavior of pixels with new tests of all renderers. Reviewed By: gkioxari Differential Revision: D32651141 fbshipit-source-id: 3ca30b4274ed2699bc5e1a9c6437eb3f0b738cbf
1 parent bf3bc6f commit 70acb3e

File tree

1 file changed

+256
-0
lines changed

1 file changed

+256
-0
lines changed

tests/test_camera_pixels.py

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import unittest
8+
9+
import torch
10+
from common_testing import TestCaseMixin
11+
from pytorch3d.renderer import (
12+
MeshRasterizer,
13+
NDCGridRaysampler,
14+
PerspectiveCameras,
15+
PointsRasterizationSettings,
16+
PointsRasterizer,
17+
PulsarPointsRenderer,
18+
RasterizationSettings,
19+
)
20+
from pytorch3d.structures import Meshes, Pointclouds
21+
22+
23+
"""
24+
PyTorch3D renderers operate in an align_corners=False manner.
25+
This file demonstrates the pixel-perfect calculation by very simple
26+
examples.
27+
"""
28+
29+
30+
class _CommonData:
31+
"""
32+
Contains data for all these tests.
33+
34+
- Firstly, a non-square at the origin specified in ndc space and
35+
screen space. Principal point is in the center of the image.
36+
Focal length is 1.0 in world space.
37+
This camera has the identity as its world to view transformation, so
38+
it is facing down the positive z axis with y being up and x being left.
39+
A point on the z=1.0 focal plane has its x,y world coordinate equal to
40+
its NDC.
41+
42+
- Secondly, batched together with that, is a camera with the same
43+
focal length facing in the same direction but located so that it faces
44+
the corner of the corner pixel of the first image, with its principal
45+
point located at its corner, so that it maps the z=1 plane to NDC just
46+
like the first.
47+
48+
- a single point self.point in world space which is located on a plane 1.0
49+
in front from the camera which is located exactly in the center
50+
of a known pixel (self.x, self.y), specifically with negative x and slightly
51+
positive y, so it is in the top right quadrant of the image.
52+
53+
- A second batch of cameras defined in screen space which exactly match the
54+
first ones.
55+
56+
So that this data can be copied for making demos, it is easiest to leave
57+
it as a freestanding class.
58+
"""
59+
60+
def __init__(self):
61+
self.H, self.W = 249, 125
62+
self.image_size = (self.H, self.W)
63+
self.camera_ndc = PerspectiveCameras(
64+
focal_length=1.0,
65+
image_size=(self.image_size,),
66+
in_ndc=True,
67+
T=torch.tensor([[0.0, 0.0, 0.0], [-1.0, self.H / self.W, 0.0]]),
68+
principal_point=((-0.0, -0.0), (1.0, -self.H / self.W)),
69+
)
70+
# Note how principal point is specifiied
71+
self.camera_screen = PerspectiveCameras(
72+
focal_length=self.W / 2.0,
73+
principal_point=((self.W / 2.0, self.H / 2.0), (0.0, self.H)),
74+
image_size=(self.image_size,),
75+
T=torch.tensor([[0.0, 0.0, 0.0], [-1.0, self.H / self.W, 0.0]]),
76+
in_ndc=False,
77+
)
78+
79+
# 81 is more than half of 125, 113 is a bit less than half of 249
80+
self.x, self.y = 81, 113
81+
self.point = [-0.304, 0.176, 1]
82+
# The point is in the center of pixel (81, 113)
83+
# where pixel (0,0) is the top left.
84+
# 81 is 38/2 pixels over the midpoint (125-1)/2=62
85+
# and 38/125=0.304
86+
# 113 is 22/2 pixels under the midpoint (249-1)/2=124
87+
# and 22/125=0.176
88+
89+
90+
class TestPixels(TestCaseMixin, unittest.TestCase):
91+
def test_mesh(self):
92+
data = _CommonData()
93+
# Three points on the plane at unit 1 from the camera in
94+
# world space, whose mean is the known point.
95+
verts = torch.tensor(
96+
[[-0.288, 0.192, 1], [-0.32, 0.192, 1], [-0.304, 0.144, 1]]
97+
)
98+
self.assertClose(verts.mean(0), torch.tensor(data.point))
99+
faces = torch.LongTensor([[0, 1, 2]])
100+
# A mesh of one triangular face whose centroid is the known point
101+
# duplicated so it can be rendered from two cameras.
102+
meshes = Meshes(verts=[verts], faces=[faces]).extend(2)
103+
faces_per_pixel = 2
104+
for camera in (data.camera_ndc, data.camera_screen):
105+
rasterizer = MeshRasterizer(
106+
cameras=camera,
107+
raster_settings=RasterizationSettings(
108+
image_size=data.image_size, faces_per_pixel=faces_per_pixel
109+
),
110+
)
111+
barycentric_coords_found = rasterizer(meshes).bary_coords
112+
self.assertTupleEqual(
113+
barycentric_coords_found.shape,
114+
(2,) + data.image_size + (faces_per_pixel, 3),
115+
)
116+
# We see that the barycentric coordinates at the expected
117+
# pixel are (1/3, 1/3, 1/3), indicating that this pixel
118+
# hits the centroid of the triangle.
119+
self.assertClose(
120+
barycentric_coords_found[:, data.y, data.x, 0],
121+
torch.full((2, 3), 1 / 3.0),
122+
atol=1e-5,
123+
)
124+
125+
def test_pointcloud(self):
126+
data = _CommonData()
127+
clouds = Pointclouds(points=torch.tensor([[data.point]])).extend(2)
128+
colorful_cloud = Pointclouds(
129+
points=torch.tensor([[data.point]]), features=torch.ones(1, 1, 3)
130+
).extend(2)
131+
points_per_pixel = 2
132+
# for camera in [data.camera_screen]:
133+
for camera in (data.camera_ndc, data.camera_screen):
134+
rasterizer = PointsRasterizer(
135+
cameras=camera,
136+
raster_settings=PointsRasterizationSettings(
137+
image_size=data.image_size,
138+
radius=0.0001,
139+
points_per_pixel=points_per_pixel,
140+
),
141+
)
142+
# when rasterizing we expect only one pixel to be occupied
143+
rasterizer_output = rasterizer(clouds).idx
144+
self.assertTupleEqual(
145+
rasterizer_output.shape, (2,) + data.image_size + (points_per_pixel,)
146+
)
147+
found = torch.nonzero(rasterizer_output != -1)
148+
self.assertTupleEqual(found.shape, (2, 4))
149+
self.assertListEqual(found[0].tolist(), [0, data.y, data.x, 0])
150+
self.assertListEqual(found[1].tolist(), [1, data.y, data.x, 0])
151+
152+
if camera.in_ndc():
153+
# Pulsar not currently working in screen space.
154+
pulsar_renderer = PulsarPointsRenderer(rasterizer=rasterizer)
155+
pulsar_output = pulsar_renderer(
156+
colorful_cloud, gamma=(0.1, 0.1), znear=(0.1, 0.1), zfar=(70, 70)
157+
)
158+
self.assertTupleEqual(
159+
pulsar_output.shape, (2,) + data.image_size + (3,)
160+
)
161+
# Look for points rendered in the red channel only, expecting our one.
162+
# Check the first batch element only.
163+
# TODO: Something is odd with the second.
164+
found = torch.nonzero(pulsar_output[0, :, :, 0])
165+
self.assertTupleEqual(found.shape, (1, 2))
166+
self.assertListEqual(found[0].tolist(), [data.y, data.x])
167+
# Should be:
168+
# found = torch.nonzero(pulsar_output[:, :, :, 0])
169+
# self.assertTupleEqual(found.shape, (2, 3))
170+
# self.assertListEqual(found[0].tolist(), [0, data.y, data.x])
171+
# self.assertListEqual(found[1].tolist(), [1, data.y, data.x])
172+
173+
def test_raysampler(self):
174+
data = _CommonData()
175+
gridsampler = NDCGridRaysampler(
176+
image_width=data.W,
177+
image_height=data.H,
178+
n_pts_per_ray=2,
179+
min_depth=1.0,
180+
max_depth=2.0,
181+
)
182+
for camera in (data.camera_ndc, data.camera_screen):
183+
bundle = gridsampler(camera)
184+
self.assertTupleEqual(bundle.xys.shape, (2,) + data.image_size + (2,))
185+
self.assertTupleEqual(
186+
bundle.directions.shape, (2,) + data.image_size + (3,)
187+
)
188+
self.assertClose(
189+
bundle.xys[:, data.y, data.x],
190+
torch.tensor(data.point[:2]).expand(2, -1),
191+
)
192+
# We check only the first batch element.
193+
# Second element varies because of camera location.
194+
self.assertClose(
195+
bundle.directions[0, data.y, data.x],
196+
torch.tensor(data.point),
197+
)
198+
199+
def test_camera(self):
200+
data = _CommonData()
201+
# Our point, plus the image center, and a corner of the image.
202+
# Located at the focal-length distance away
203+
points = torch.tensor([data.point, [0, 0, 1], [1, data.H / data.W, 1]])
204+
for cameras in (data.camera_ndc, data.camera_screen):
205+
ndc_points = cameras.transform_points_ndc(points)
206+
screen_points = cameras.transform_points_screen(points)
207+
camera_points = cameras.transform_points(points)
208+
for batch_idx in range(2):
209+
# NDC space agrees with the original
210+
self.assertClose(ndc_points[batch_idx], points, atol=1e-5)
211+
# First point in screen space is the center of our expected pixel
212+
self.assertClose(
213+
screen_points[batch_idx][0],
214+
torch.tensor([data.x + 0.5, data.y + 0.5, 1.0]),
215+
atol=1e-5,
216+
)
217+
# Second point in screen space is the center of the screen
218+
self.assertClose(
219+
screen_points[batch_idx][1],
220+
torch.tensor([data.W / 2.0, data.H / 2.0, 1.0]),
221+
atol=1e-5,
222+
)
223+
# Third point in screen space is the corner of the screen
224+
# (corner of corner pixels)
225+
self.assertClose(
226+
screen_points[batch_idx][2],
227+
torch.tensor([0.0, 0.0, 1.0]),
228+
atol=1e-5,
229+
)
230+
231+
if cameras.in_ndc():
232+
self.assertClose(camera_points[batch_idx], ndc_points[batch_idx])
233+
else:
234+
# transform_points does something strange for screen cameras
235+
if batch_idx == 0:
236+
wanted = torch.stack(
237+
[
238+
data.W - screen_points[batch_idx, :, 0],
239+
data.H - screen_points[batch_idx, :, 1],
240+
torch.ones(3),
241+
],
242+
dim=1,
243+
)
244+
else:
245+
wanted = torch.stack(
246+
[
247+
-screen_points[batch_idx, :, 0],
248+
2 * data.H - screen_points[batch_idx, :, 1],
249+
torch.ones(3),
250+
],
251+
dim=1,
252+
)
253+
254+
print(wanted)
255+
print(camera_points[batch_idx])
256+
self.assertClose(camera_points[batch_idx], wanted)

0 commit comments

Comments
 (0)