Skip to content

Commit 37bd280

Browse files
Darijan Gudeljfacebook-github-bot
Darijan Gudelj
authored andcommitted
load whole dataset in train loop
Summary: Loads the whole dataset and moves it to the device and sends it to for sampling to enable full dataset heterogeneous raysampling. Reviewed By: bottler Differential Revision: D39263009 fbshipit-source-id: c527537dfc5f50116849656c9e171e868f6845b1
1 parent c311a4c commit 37bd280

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

projects/implicitron_trainer/experiment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def run(self) -> None:
222222
train_loader=train_loader,
223223
val_loader=val_loader,
224224
test_loader=test_loader,
225+
# pyre-ignore[6]
225226
train_dataset=datasets.train,
226227
model=model,
227228
optimizer=optimizer,

projects/implicitron_trainer/impl/training_loop.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
)
2323
from pytorch3d.implicitron.tools.stats import Stats
2424
from pytorch3d.renderer.cameras import CamerasBase
25-
from torch.utils.data import DataLoader
25+
from torch.utils.data import DataLoader, Dataset
2626

2727
from .utils import seed_all_random_engines
2828

@@ -44,6 +44,7 @@ def run(
4444
train_loader: DataLoader,
4545
val_loader: Optional[DataLoader],
4646
test_loader: Optional[DataLoader],
47+
train_dataset: Dataset,
4748
model: ImplicitronModelBase,
4849
optimizer: torch.optim.Optimizer,
4950
scheduler: Any,
@@ -116,6 +117,7 @@ def run(
116117
train_loader: DataLoader,
117118
val_loader: Optional[DataLoader],
118119
test_loader: Optional[DataLoader],
120+
train_dataset: Dataset,
119121
model: ImplicitronModelBase,
120122
optimizer: torch.optim.Optimizer,
121123
scheduler: Any,

pytorch3d/implicitron/models/generic_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,8 @@ def safe_slice_targets(
389389
)
390390

391391
# (1) Sample rendering rays with the ray sampler.
392-
ray_bundle: ImplicitronRayBundle = self.raysampler( # pyre-fixme[29]
392+
# pyre-ignore[29]
393+
ray_bundle: ImplicitronRayBundle = self.raysampler(
393394
target_cameras,
394395
evaluation_mode,
395396
mask=mask_crop[:n_targets]

0 commit comments

Comments
 (0)