Skip to content

Commit ab8cb07

Browse files
committed
Add xavier_uniform init of MNVC hybrid attention modules. Small improvement in training stability.
1 parent 9558a7f commit ab8cb07

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

timm/layers/attention2d.py

+10
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,16 @@ def __init__(
205205

206206
self.einsum = False
207207

208+
def init_weights(self):
209+
# using xavier appeared to improve stability for mobilenetv4 hybrid w/ this layer
210+
nn.init.xavier_uniform_(self.query.proj.weight)
211+
nn.init.xavier_uniform_(self.key.proj.weight)
212+
nn.init.xavier_uniform_(self.value.proj.weight)
213+
if self.kv_stride > 1:
214+
nn.init.xavier_uniform_(self.key.down_conv.weight)
215+
nn.init.xavier_uniform_(self.value.down_conv.weight)
216+
nn.init.xavier_uniform_(self.output.proj.weight)
217+
208218
def _reshape_input(self, t: torch.Tensor):
209219
"""Reshapes a tensor to three dimensions, keeping the batch and channels."""
210220
s = t.shape

timm/models/_efficientnet_builder.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616

1717
import torch.nn as nn
1818

19-
from ._efficientnet_blocks import *
2019
from timm.layers import CondConv2d, get_condconv_initializer, get_act_layer, get_attn, make_divisible, LayerType
20+
from ._efficientnet_blocks import *
21+
from ._manipulate import named_modules
2122

2223
__all__ = ["EfficientNetBuilder", "decode_arch_def", "efficientnet_init_weights",
2324
'resolve_bn_args', 'resolve_act_layer', 'round_channels', 'BN_MOMENTUM_TF_DEFAULT', 'BN_EPS_TF_DEFAULT']
@@ -569,3 +570,7 @@ def efficientnet_init_weights(model: nn.Module, init_fn=None):
569570
for n, m in model.named_modules():
570571
init_fn(m, n)
571572

573+
# iterate and call any module.init_weights() fn, children first
574+
for n, m in named_modules(model):
575+
if hasattr(m, 'init_weights'):
576+
m.init_weights()

0 commit comments

Comments
 (0)