Skip to content

Commit 4e44534

Browse files
Update rerender_a_video.py fix dtype error (#10451)
Update rerender_a_video.py
1 parent a17832b commit 4e44534

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

examples/community/rerender_a_video.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -782,7 +782,7 @@ def __call__(
782782
self.attn_state.reset()
783783

784784
# 4.1 prepare frames
785-
image = self.image_processor.preprocess(frames[0]).to(dtype=torch.float32)
785+
image = self.image_processor.preprocess(frames[0]).to(dtype=self.dtype)
786786
first_image = image[0] # C, H, W
787787

788788
# 4.2 Prepare controlnet_conditioning_image
@@ -926,8 +926,8 @@ def __call__(
926926
prev_image = frames[idx - 1]
927927
control_image = control_frames[idx]
928928
# 5.1 prepare frames
929-
image = self.image_processor.preprocess(image).to(dtype=torch.float32)
930-
prev_image = self.image_processor.preprocess(prev_image).to(dtype=torch.float32)
929+
image = self.image_processor.preprocess(image).to(dtype=self.dtype)
930+
prev_image = self.image_processor.preprocess(prev_image).to(dtype=self.dtype)
931931

932932
warped_0, bwd_occ_0, bwd_flow_0 = get_warped_and_mask(
933933
self.flow_model, first_image, image[0], first_result, False, self.device

0 commit comments

Comments
 (0)