Skip to content

HiDream auxiliary loss for MoE experts not tied to computation graph #11301

Open
@bghira

Description

@bghira

Describe the bug

The MOEFeedForward for HiDream has the auxiliary loss commented out in the upstream prototype code.

Additionally, the MoEGate has a memory leak in it.

Reproduction

  • Train HiDream
  • Observe outOfMemory on backward pass
    • Resolve MoEGate OOM by implementing gradient checkpointing
  • Observe extraordinarily high loss values

Logs

2025-04-12 10:03:59,851 [INFO] cls: <class 'helpers.training.optimizers.adamw_bfloat16.AdamWBF16'>, settings: {'betas': (0.9, 0.999), 'weight_decay': 0.01, 'eps': 1e-06}
2025-04-12 10:03:59,855 [INFO] Optimizer arguments={'lr': 4e-05, 'betas': (0.9, 0.999), 'weight_decay': 0.01, 'eps': 1e-06}
2025-04-12 10:03:59,855 [INFO] Loading constant learning rate scheduler with 100 warmup steps
2025-04-12 10:03:59,855 [INFO] Using generic 'constant' learning rate scheduler.
2025-04-12 10:03:59,857 [INFO] Preparing models..
2025-04-12 10:03:59,858 [INFO] Loading our accelerator...
2025-04-12 10:03:59,875 [INFO] Resuming from checkpoint checkpoint-8000
2025-04-12 10:04:00,033 [INFO] Previous checkpoint had 0 exhausted buckets.
2025-04-12 10:04:00,034 [INFO] Previous checkpoint was on epoch 471.
2025-04-12 10:04:00,034 [INFO] Previous checkpoint had 10 seen images.
2025-04-12 10:04:00,034 [INFO] Resuming from global_step 8000.
2025-04-12 10:04:00,034 [INFO] 
(Rank: 0)     -> Number of seen images: 10
(Rank: 0)     -> Number of unseen images: 7
(Rank: 0)     -> Current Bucket: None
(Rank: 0)     -> 1 Buckets: ['1.0']
(Rank: 0)     -> 0 Exhausted Buckets: []
2025-04-12 10:04:00,093 [INFO] 
***** Running training *****
-  Num batches = 17
-  Num Epochs = 589
  - Current Epoch = 471
-  Total train batch size (w. parallel, distributed & accumulation) = 1
  - Instantaneous batch size per device = 1
  - Gradient Accumulation steps = 1
-  Total optimization steps = 10000
  - Steps completed: 8000
-  Total optimization steps remaining = 2000
Epoch 478/589, Steps:  81%|████████████▏  | 8114/10000 [03:30<57:25,  1.83s/it, grad_absmax=0.00149, lr=4e-5, step_loss=1.13]

System Info

Diffusers git main

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions