Skip to content

Commit f81b094

Browse files
committed
Add 'qkv_bias_separate' flag for EVA/beit/swinv2 attn modules to allow an override for easy quantization wrappers. Fix #2098
1 parent 83c2c2f commit f81b094

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

timm/models/beit.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
dim: int,
8787
num_heads: int = 8,
8888
qkv_bias: bool = False,
89+
qkv_bias_separate: bool = False,
8990
attn_drop: float = 0.,
9091
proj_drop: float = 0.,
9192
window_size: Optional[Tuple[int, int]] = None,
@@ -99,6 +100,7 @@ def __init__(
99100
all_head_dim = head_dim * self.num_heads
100101
self.scale = head_dim ** -0.5
101102
self.fused_attn = use_fused_attn()
103+
self.qkv_bias_separate = qkv_bias_separate
102104

103105
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
104106
if qkv_bias:
@@ -136,8 +138,15 @@ def _get_rel_pos_bias(self):
136138
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
137139
B, N, C = x.shape
138140

139-
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
140-
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
141+
if self.q_bias is None:
142+
qkv = self.qkv(x)
143+
else:
144+
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
145+
if self.qkv_bias_separate:
146+
qkv = self.qkv(x)
147+
qkv += qkv_bias
148+
else:
149+
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
141150
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
142151
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
143152

timm/models/eva.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ def __init__(
5454
qkv_bias: bool = True,
5555
qkv_fused: bool = True,
5656
num_prefix_tokens: int = 1,
57+
qkv_bias_separate: bool = False,
5758
attn_drop: float = 0.,
5859
proj_drop: float = 0.,
5960
attn_head_dim: Optional[int] = None,
@@ -80,6 +81,7 @@ def __init__(
8081
self.scale = head_dim ** -0.5
8182
self.num_prefix_tokens = num_prefix_tokens
8283
self.fused_attn = use_fused_attn()
84+
self.qkv_bias_separate = qkv_bias_separate
8385

8486
if qkv_fused:
8587
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
@@ -111,8 +113,15 @@ def forward(
111113
B, N, C = x.shape
112114

113115
if self.qkv is not None:
114-
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
115-
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
116+
if self.q_bias is None:
117+
qkv = self.qkv(x)
118+
else:
119+
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
120+
if self.qkv_bias_separate:
121+
qkv = self.qkv(x)
122+
qkv += qkv_bias
123+
else:
124+
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
116125
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
117126
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
118127
else:

timm/models/swin_transformer_v2.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ def __init__(
8686
window_size: Tuple[int, int],
8787
num_heads: int,
8888
qkv_bias: bool = True,
89+
qkv_bias_separate: bool = False,
8990
attn_drop: float = 0.,
9091
proj_drop: float = 0.,
9192
pretrained_window_size: Tuple[int, int] = (0, 0),
@@ -95,6 +96,7 @@ def __init__(
9596
self.window_size = window_size # Wh, Ww
9697
self.pretrained_window_size = pretrained_window_size
9798
self.num_heads = num_heads
99+
self.qkv_bias_separate = qkv_bias_separate
98100

99101
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
100102

@@ -156,10 +158,16 @@ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch
156158
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
157159
"""
158160
B_, N, C = x.shape
159-
qkv_bias = None
160-
if self.q_bias is not None:
161+
162+
if self.q_bias is None:
163+
qkv = self.qkv(x)
164+
else:
161165
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
162-
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
166+
if self.qkv_bias_separate:
167+
qkv = self.qkv(x)
168+
qkv += qkv_bias
169+
else:
170+
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
163171
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
164172
q, k, v = qkv.unbind(0)
165173

0 commit comments

Comments
 (0)