Skip to content

Commit e9f4e0d

Browse files
bottlerfacebook-github-bot
authored andcommitted
PLY color scaling
Summary: When a PLY file contains colors in byte format, these are now scaled from 0..255 to [0,1], as they should be Reviewed By: gkioxari Differential Revision: D27765254 fbshipit-source-id: 526b5f5149d5e8cbffd7412b411be52c935fa4ad
1 parent 6c3fe95 commit e9f4e0d

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

pytorch3d/io/ply_io.py

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from collections import namedtuple
1515
from io import BytesIO, TextIOBase
1616
from pathlib import Path
17-
from typing import List, Optional, Tuple, Union
17+
from typing import List, Optional, Tuple, Union, cast
1818

1919
import numpy as np
2020
import torch
@@ -780,10 +780,24 @@ def _load_ply_raw(f, path_manager: PathManager) -> Tuple[_PlyHeader, dict]:
780780

781781
def _get_verts_column_indices(
782782
vertex_head: _PlyElementType,
783-
) -> Tuple[List[int], Optional[List[int]]]:
783+
) -> Tuple[List[int], Optional[List[int]], float]:
784784
"""
785785
Get the columns of verts and verts_colors in the vertex
786-
element of a parsed ply file.
786+
element of a parsed ply file, together with a color scale factor.
787+
When the colors are in byte format, they are scaled from 0..255 to [0,1].
788+
Otherwise they are not scaled.
789+
790+
For example, if the vertex element looks as follows:
791+
792+
element vertex 892
793+
property double x
794+
property double y
795+
property double z
796+
property uchar red
797+
property uchar green
798+
property uchar blue
799+
800+
then the return value will be ([0,1,2], [6,7,8], 1.0/255)
787801
788802
Args:
789803
vertex_head: as returned from load_ply_raw.
@@ -792,6 +806,7 @@ def _get_verts_column_indices(
792806
point_idxs: List[int] of 3 point columns.
793807
color_idxs: List[int] of 3 color columns if they are present,
794808
otherwise None.
809+
color_scale: value to scale colors by.
795810
"""
796811
point_idxs: List[Optional[int]] = [None, None, None]
797812
color_idxs: List[Optional[int]] = [None, None, None]
@@ -806,9 +821,17 @@ def _get_verts_column_indices(
806821
color_idxs[j] = i
807822
if None in point_idxs:
808823
raise ValueError("Invalid vertices in file.")
809-
if None in color_idxs:
810-
return point_idxs, None
811-
return point_idxs, color_idxs
824+
color_scale = 1.0
825+
if all(
826+
idx is not None and _PLY_TYPES[vertex_head.properties[idx].data_type].size == 1
827+
for idx in color_idxs
828+
):
829+
color_scale = 1.0 / 255
830+
return (
831+
point_idxs,
832+
None if None in color_idxs else cast(List[int], color_idxs),
833+
color_scale,
834+
)
812835

813836

814837
def _get_verts(
@@ -831,7 +854,7 @@ def _get_verts(
831854
if not isinstance(vertex, list):
832855
raise ValueError("Invalid vertices in file.")
833856
vertex_head = next(head for head in header.elements if head.name == "vertex")
834-
point_idxs, color_idxs = _get_verts_column_indices(vertex_head)
857+
point_idxs, color_idxs, color_scale = _get_verts_column_indices(vertex_head)
835858

836859
# Case of no vertices
837860
if vertex_head.count == 0:
@@ -856,7 +879,9 @@ def _get_verts(
856879
# so it was read as a single array and we can index straight into it.
857880
verts = torch.tensor(vertex[0][:, point_idxs], dtype=torch.float32)
858881
if color_idxs is not None:
859-
vertex_colors = torch.tensor(vertex[0][:, color_idxs], dtype=torch.float32)
882+
vertex_colors = color_scale * torch.tensor(
883+
vertex[0][:, color_idxs], dtype=torch.float32
884+
)
860885
else:
861886
# The vertex element is heterogeneous. It was read as several arrays,
862887
# part by part, where a part is a set of properties with the same type.
@@ -887,6 +912,7 @@ def _get_verts(
887912
for color in range(3):
888913
partnum, col = prop_to_partnum_col[color_idxs[color]]
889914
vertex_colors.numpy()[:, color] = vertex[partnum][:, col]
915+
vertex_colors *= color_scale
890916

891917
return verts, vertex_colors
892918

tests/test_io_ply.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def test_load_cloudcompare_pointcloud(self):
448448
torch.FloatTensor([0, 1, 2]) + 7 * torch.arange(8)[:, None],
449449
)
450450
self.assertClose(
451-
pointcloud.features_padded()[0],
451+
pointcloud.features_padded()[0] * 255,
452452
torch.FloatTensor([3, 4, 5]) + 7 * torch.arange(8)[:, None],
453453
)
454454

@@ -518,7 +518,7 @@ def test_load_pointcloud_bad_order(self):
518518
self.assertEqual(pointcloud_gpu.device, torch.device("cuda:0"))
519519
pointcloud = pointcloud_gpu.to(torch.device("cpu"))
520520
expected_points = torch.tensor([[[2, 5, 3]]], dtype=torch.float32)
521-
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32)
521+
expected_features = torch.tensor([[[4, 1, 6]]], dtype=torch.float32) / 255.0
522522
self.assertClose(pointcloud.points_padded(), expected_points)
523523
self.assertClose(pointcloud.features_padded(), expected_features)
524524

0 commit comments

Comments
 (0)