-
Notifications
You must be signed in to change notification settings - Fork 5.9k
sync the max_sequence_length parameter of sd3 pipeline with official … #9921
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
you can't. the model was trained with 256 tokens at the end and limiting it to 77 prevents the model from using the end positions as registers for storing data that would otherwise leak into the prediction. |
@bghira I found that the images generated by the original repo and comfyui look better than diffusers. Therefore, the sequence length of T5 may be important. I think it's better to be consistent with other open-source implementations. |
in simpletuner we hardcode it to 256 because using 77 actually breaks the model. it is meant to be used and trained on at 256 tokens. @Dango233 can confirm |
maybe you can provide sample images |
cc @asomoza here too |
so in my mind i thought the CLIP and T5 embed were treated separately, but i was mixing it up with Flux. t5_prompt_embed = _encode_sd3_prompt_with_t5(
text_encoders[-1],
tokenizers[-1],
prompt=prompt,
num_images_per_prompt=num_images_per_prompt,
device=self.accelerator.device,
zero_padding_tokens=zero_padding_tokens,
max_sequence_length=StateTracker.get_args().tokenizer_max_length,
)
clip_prompt_embeds = torch.nn.functional.pad(
clip_prompt_embeds,
(0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]),
)
prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
return prompt_embeds, pooled_prompt_embeds for SD 3.x the CLIP embed is padded to match the dimension of the T5 input (2048 to 4096) doing inference on 77 tokens looks marginally better, but training on it still totally breaks coherence since the released weights were finetuned on 256 tokens: it's a hard decision to make for this model because clearly no ideal text input configuration works for it. if ComfyUI really isn't doing any long input handling for this model then perhaps makes sense for Diffusers not to either. |
The initial diffusers version of SD3 was with the 77 token limit for matching the original implementation, we changed it because a lot of people were requesting to remove that limit, also in the SD3 technical paper the This is a circular issue thing, if we match the limits we're going to get more issues about removing it, the only difference is that this one is the only one in several months and the 'allow more tokens' issues were a lot more and with less time in between. We still get some issues about enabling more tokens for the clip models. We need also some clear images to show that it's really worse and probably this is something the SD3 authors should address instead, if they didn't train the model with more than 77 tokens or if they don't allow more tokens for the T5 why did they write an example with more tokens? and if you look at the image all the elements are present so the prompt wasn't clipped. Finally, probably ComfyUI doesn't limit the tokens (I will look into this) because people like to write long prompts and a lot are using LLMs for this (with really long descriptive prompts). Thanks for the PR though, I will do some experiments and post the results. |
after 16k steps of training on a low LR it's recovered the 77 tokens. in comparison to training on 256 tokens which keeps degrading over time, training on 77 tokens actually works much better at a lower learning rate. the problem then becomes how many steps and how many hours it takes to learn anything. thanks for starting the discussion and apologies for repeating stuff without first testing it more thoroughly. i think it's a documentation issue from StabilityAI really, since there is no technical report on the models' training or any issues with sequence length. maybe the SD3 training script --help output can be updated to provide the option to switch between 77 and 256, with an explanation of the trade-offs in degradation or step count. I haven't checked whether the caption length is easily biased (eg. losing long prompts by training on short only) but here is the longer test result: |
Let's add more about this on the doc, no? what do you think? |
Sorry for chiming in late. Our intention with the multi-sequence T5 training stage is to enhance the model's long prompt following ability and avoid artifacts when a long prompt is provided. For the clip padding - we DID NOT pad the clip on the sequence length dimension during the training phase .This means you can have clip(77)+t5(154) (Clip truncated). The only compromise with this is it creates an unbalanced embedding that values T5 embedding more when the prompt is over 77 tokens. |
so the clip padding shouldn't be happening at all then? |
for what it's worth the encoder attn mask should be passed through to the transformer model's SDPA call so that it can avoid attending to the padding positions which is the main reason for the artifacts when long prompts are used on a 77 token trained model. |
@yiyixuxu actually i missed it, it's subtle, but the big takeaway is that the 256 covers the seq len of 77 clip and 154 t5 tokens but we have the assumption that 256 tokens is just t5. so i think all of the references to 256 should be probably updated 🥲 it doesn't add up (literally, 77+154 != 256) but this is from the lead i suppose. |
ohh thanks for explaining it! @Dango233 So I think our implementation is consistent with what you described here, i.e. the clip padding is not happening on the sequence length dimension diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py Line 437 in cd89204
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. Please note that issues that do not follow the contributing guidelines are likely to be ignored. |
not stale the seq len is still incorrect |
change the max_sequence_length default parameter of sd3 pipeline from 256 to 77, which is the same as the official code of sd3.5