Open
Description
Describe the bug
The MOEFeedForward for HiDream has the auxiliary loss commented out in the upstream prototype code.
Additionally, the MoEGate has a memory leak in it.
Reproduction
- Train HiDream
- Observe outOfMemory on backward pass
- Resolve MoEGate OOM by implementing gradient checkpointing
- Observe extraordinarily high loss values
Logs
2025-04-12 10:03:59,851 [INFO] cls: <class 'helpers.training.optimizers.adamw_bfloat16.AdamWBF16'>, settings: {'betas': (0.9, 0.999), 'weight_decay': 0.01, 'eps': 1e-06}
2025-04-12 10:03:59,855 [INFO] Optimizer arguments={'lr': 4e-05, 'betas': (0.9, 0.999), 'weight_decay': 0.01, 'eps': 1e-06}
2025-04-12 10:03:59,855 [INFO] Loading constant learning rate scheduler with 100 warmup steps
2025-04-12 10:03:59,855 [INFO] Using generic 'constant' learning rate scheduler.
2025-04-12 10:03:59,857 [INFO] Preparing models..
2025-04-12 10:03:59,858 [INFO] Loading our accelerator...
2025-04-12 10:03:59,875 [INFO] Resuming from checkpoint checkpoint-8000
2025-04-12 10:04:00,033 [INFO] Previous checkpoint had 0 exhausted buckets.
2025-04-12 10:04:00,034 [INFO] Previous checkpoint was on epoch 471.
2025-04-12 10:04:00,034 [INFO] Previous checkpoint had 10 seen images.
2025-04-12 10:04:00,034 [INFO] Resuming from global_step 8000.
2025-04-12 10:04:00,034 [INFO]
(Rank: 0) -> Number of seen images: 10
(Rank: 0) -> Number of unseen images: 7
(Rank: 0) -> Current Bucket: None
(Rank: 0) -> 1 Buckets: ['1.0']
(Rank: 0) -> 0 Exhausted Buckets: []
2025-04-12 10:04:00,093 [INFO]
***** Running training *****
- Num batches = 17
- Num Epochs = 589
- Current Epoch = 471
- Total train batch size (w. parallel, distributed & accumulation) = 1
- Instantaneous batch size per device = 1
- Gradient Accumulation steps = 1
- Total optimization steps = 10000
- Steps completed: 8000
- Total optimization steps remaining = 2000
Epoch 478/589, Steps: 81%|████████████▏ | 8114/10000 [03:30<57:25, 1.83s/it, grad_absmax=0.00149, lr=4e-5, step_loss=1.13]
System Info
Diffusers git main
Who can help?
No response