Skip to content

Commit 54eb76d

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Loosening the checks in eval script for CO3Dv2 style eval
Summary: V2 dataset does not have the concept of known/unseen frames. Test-time conditining is done with train-set frames, which violates the previous check. Also fixing a corner case in VideoWriter. Reviewed By: bottler Differential Revision: D42706976 fbshipit-source-id: d43be3dd3060d18cb9f46d5dcf6252d9f084110f
1 parent 9dc28f5 commit 54eb76d

File tree

2 files changed

+8
-12
lines changed

2 files changed

+8
-12
lines changed

pytorch3d/implicitron/evaluation/evaluate_new_view_synthesis.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -219,17 +219,10 @@ def eval_batch(
219219
frame_type = [frame_type]
220220

221221
is_train = is_train_frame(frame_type)
222-
if not (is_train[0] == is_train).all():
223-
raise ValueError("All frames in the eval batch have to be either train/test.")
224-
225-
# pyre-fixme[16]: `Optional` has no attribute `device`.
226-
is_known = is_known_frame(frame_type, device=frame_data.image_rgb.device)
227-
228-
if not ((is_known[1:] == 1).all() and (is_known[0] == 0).all()):
222+
if len(is_train) > 1 and (is_train[1] != is_train[1:]).any():
229223
raise ValueError(
230-
"For evaluation the first element of the batch has to be"
231-
+ " a target view while the rest should be source views."
232-
) # TODO: do we need to enforce this?
224+
"All (conditioning) frames in the eval batch have to be either train/test."
225+
)
233226

234227
for k in [
235228
"depth_map",
@@ -362,7 +355,7 @@ def eval_batch(
362355

363356
results["meta"] = {
364357
# store the size of the batch (corresponds to n_src_views+1)
365-
"batch_size": int(is_known.numel()),
358+
"batch_size": len(frame_type),
366359
# store the type of the target frame
367360
# pyre-fixme[16]: `None` has no attribute `__getitem__`.
368361
"frame_type": str(frame_data.frame_type[0]),

pytorch3d/implicitron/tools/video_writer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,11 @@ def get_video(self, quiet: bool = True) -> str:
124124
quiet: If `True`, suppresses logging messages.
125125
126126
Returns:
127-
video_path: The path to the generated video.
127+
video_path: The path to the generated video if any frames were added.
128+
Otherwise returns an empty string.
128129
"""
130+
if self.frame_num == 0:
131+
return ""
129132

130133
regexp = os.path.join(self.cache_dir, self.regexp)
131134

0 commit comments

Comments
 (0)