Skip to content

Commit 0300563

Browse files
committed
Simplify the UTransformer2DModel / UniDiffuserModel implementation and fix some more bugs.
1 parent a492e0c commit 0300563

File tree

4 files changed

+104
-162
lines changed

4 files changed

+104
-162
lines changed

src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ def __init__(
1717
self,
1818
prefix_length: int,
1919
prefix_hidden_dim: Optional[int] = None,
20-
n_positions: int = 1024, # Start of GPT2 config args
20+
vocab_size: int = 50257, # Start of GPT2 config args
21+
n_positions: int = 1024,
2122
n_embd: int = 768,
2223
n_layer: int = 12,
2324
n_head: int = 12,
@@ -28,6 +29,10 @@ def __init__(
2829
attn_pdrop: float = 0.1,
2930
layer_norm_epsilon: float = 1e-5,
3031
initializer_range: float = 0.02,
32+
scale_attn_weights: bool = True,
33+
use_cache: bool = True,
34+
scale_attn_by_inverse_layer_idx: bool = False,
35+
reorder_and_upcast_attn: bool = False,
3136
):
3237
"""
3338
Text decoder model for a image-text [UniDiffuser](https://arxiv.org/pdf/2303.06555.pdf) model. This is used to
@@ -52,6 +57,7 @@ def __init__(
5257
)
5358

5459
gpt_config = GPT2Config(
60+
vocab_size=vocab_size,
5561
n_positions=n_positions,
5662
n_embd=n_embd,
5763
n_layer=n_layer,
@@ -63,6 +69,10 @@ def __init__(
6369
attn_pdrop=attn_pdrop,
6470
layer_norm_epsilon=layer_norm_epsilon,
6571
initializer_range=initializer_range,
72+
scale_attn_weights=scale_attn_weights,
73+
use_cache=use_cache,
74+
scale_attn_by_inverse_layer_idx=scale_attn_by_inverse_layer_idx,
75+
reorder_and_upcast_attn=reorder_and_upcast_attn,
6676
)
6777
self.transformer = GPT2LMHeadModel(gpt_config)
6878

@@ -143,8 +153,7 @@ def generate_beam(
143153
TODO: args
144154
"""
145155
# Generates text until stop_token is reached using beam search with the desired beam size.
146-
# TODO: get the stop token index directly from tokenizer rather than manually specifying the EOS token?
147-
stop_token_index = tokenizer.encode(stop_token)[0]
156+
stop_token_index = tokenizer.eos_token_id
148157
tokens = None
149158
scores = None
150159
seq_lengths = torch.ones(beam_size, device=device)
@@ -159,7 +168,7 @@ def generate_beam(
159168
generated = self.transformer.transformer.wte(tokens)
160169

161170
for i in range(entry_length):
162-
outputs = self.transformer(input_embeds=generated)
171+
outputs = self.transformer(inputs_embeds=generated)
163172
logits = outputs.logits
164173
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
165174
logits = logits.softmax(-1).log()
@@ -198,8 +207,12 @@ def generate_beam(
198207

199208
scores = scores / seq_lengths
200209
output_list = tokens.cpu().numpy()
210+
# print(f"Output list: {output_list}")
211+
# print(f"Output list length: {len(output_list)}")
212+
# print(f"Seq lengths: {seq_lengths}")
213+
# print(f"Seq lengths length: {len(seq_lengths)}")
201214
output_texts = [
202-
self.tokenizer.decode(output[: int(length)], skip_special_tokens=True)
215+
tokenizer.decode(output[: int(length)], skip_special_tokens=True)
203216
for output, length in zip(output_list, seq_lengths)
204217
]
205218
order = scores.argsort(descending=True)

0 commit comments

Comments
 (0)