Skip to content

Commit ed7aaf8

Browse files
authored
Merge pull request #2264 from huggingface/group_size_eff
Allow group_size override for more efficientnet and mobilenetv3 based…
2 parents bef0c12 + 39e92f0 commit ed7aaf8

File tree

2 files changed

+55
-27
lines changed

2 files changed

+55
-27
lines changed

timm/models/efficientnet.py

+44-22
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
488488

489489
def _gen_mobilenet_v1(
490490
variant, channel_multiplier=1.0, depth_multiplier=1.0,
491-
fix_stem_head=False, head_conv=False, pretrained=False, **kwargs):
491+
group_size=None, fix_stem_head=False, head_conv=False, pretrained=False, **kwargs
492+
):
492493
"""
493494
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
494495
Paper: https://arxiv.org/abs/1801.04381
@@ -503,7 +504,12 @@ def _gen_mobilenet_v1(
503504
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
504505
head_features = (1024 if fix_stem_head else max(1024, round_chs_fn(1024))) if head_conv else 0
505506
model_kwargs = dict(
506-
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
507+
block_args=decode_arch_def(
508+
arch_def,
509+
depth_multiplier=depth_multiplier,
510+
fix_first_last=fix_stem_head,
511+
group_size=group_size,
512+
),
507513
num_features=head_features,
508514
stem_size=32,
509515
fix_stem=fix_stem_head,
@@ -517,7 +523,9 @@ def _gen_mobilenet_v1(
517523

518524

519525
def _gen_mobilenet_v2(
520-
variant, channel_multiplier=1.0, depth_multiplier=1.0, fix_stem_head=False, pretrained=False, **kwargs):
526+
variant, channel_multiplier=1.0, depth_multiplier=1.0,
527+
group_size=None, fix_stem_head=False, pretrained=False, **kwargs
528+
):
521529
""" Generate MobileNet-V2 network
522530
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
523531
Paper: https://arxiv.org/abs/1801.04381
@@ -533,7 +541,12 @@ def _gen_mobilenet_v2(
533541
]
534542
round_chs_fn = partial(round_channels, multiplier=channel_multiplier)
535543
model_kwargs = dict(
536-
block_args=decode_arch_def(arch_def, depth_multiplier=depth_multiplier, fix_first_last=fix_stem_head),
544+
block_args=decode_arch_def(
545+
arch_def,
546+
depth_multiplier=depth_multiplier,
547+
fix_first_last=fix_stem_head,
548+
group_size=group_size,
549+
),
537550
num_features=1280 if fix_stem_head else max(1280, round_chs_fn(1280)),
538551
stem_size=32,
539552
fix_stem=fix_stem_head,
@@ -613,7 +626,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
613626

614627
def _gen_efficientnet(
615628
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
616-
group_size=None, pretrained=False, **kwargs):
629+
group_size=None, pretrained=False, **kwargs
630+
):
617631
"""Creates an EfficientNet model.
618632
619633
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
@@ -661,7 +675,8 @@ def _gen_efficientnet(
661675

662676

663677
def _gen_efficientnet_edge(
664-
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs):
678+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
679+
):
665680
""" Creates an EfficientNet-EdgeTPU model
666681
667682
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
@@ -692,7 +707,8 @@ def _gen_efficientnet_edge(
692707

693708

694709
def _gen_efficientnet_condconv(
695-
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs):
710+
variant, channel_multiplier=1.0, depth_multiplier=1.0, experts_multiplier=1, pretrained=False, **kwargs
711+
):
696712
"""Creates an EfficientNet-CondConv model.
697713
698714
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
@@ -764,7 +780,8 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
764780

765781

766782
def _gen_efficientnetv2_base(
767-
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
783+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
784+
):
768785
""" Creates an EfficientNet-V2 base model
769786
770787
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -780,7 +797,7 @@ def _gen_efficientnetv2_base(
780797
]
781798
round_chs_fn = partial(round_channels, multiplier=channel_multiplier, round_limit=0.)
782799
model_kwargs = dict(
783-
block_args=decode_arch_def(arch_def, depth_multiplier),
800+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
784801
num_features=round_chs_fn(1280),
785802
stem_size=32,
786803
round_chs_fn=round_chs_fn,
@@ -793,7 +810,8 @@ def _gen_efficientnetv2_base(
793810

794811

795812
def _gen_efficientnetv2_s(
796-
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs):
813+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, rw=False, pretrained=False, **kwargs
814+
):
797815
""" Creates an EfficientNet-V2 Small model
798816
799817
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -831,7 +849,9 @@ def _gen_efficientnetv2_s(
831849
return model
832850

833851

834-
def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
852+
def _gen_efficientnetv2_m(
853+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
854+
):
835855
""" Creates an EfficientNet-V2 Medium model
836856
837857
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -849,7 +869,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
849869
]
850870

851871
model_kwargs = dict(
852-
block_args=decode_arch_def(arch_def, depth_multiplier),
872+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
853873
num_features=1280,
854874
stem_size=24,
855875
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
@@ -861,7 +881,9 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
861881
return model
862882

863883

864-
def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
884+
def _gen_efficientnetv2_l(
885+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
886+
):
865887
""" Creates an EfficientNet-V2 Large model
866888
867889
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -879,7 +901,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
879901
]
880902

881903
model_kwargs = dict(
882-
block_args=decode_arch_def(arch_def, depth_multiplier),
904+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
883905
num_features=1280,
884906
stem_size=32,
885907
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
@@ -891,7 +913,9 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
891913
return model
892914

893915

894-
def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
916+
def _gen_efficientnetv2_xl(
917+
variant, channel_multiplier=1.0, depth_multiplier=1.0, group_size=None, pretrained=False, **kwargs
918+
):
895919
""" Creates an EfficientNet-V2 Xtra-Large model
896920
897921
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -909,7 +933,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
909933
]
910934

911935
model_kwargs = dict(
912-
block_args=decode_arch_def(arch_def, depth_multiplier),
936+
block_args=decode_arch_def(arch_def, depth_multiplier, group_size=group_size),
913937
num_features=1280,
914938
stem_size=32,
915939
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
@@ -923,7 +947,8 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
923947

924948
def _gen_efficientnet_x(
925949
variant, channel_multiplier=1.0, depth_multiplier=1.0, channel_divisor=8,
926-
group_size=None, version=1, pretrained=False, **kwargs):
950+
group_size=None, version=1, pretrained=False, **kwargs
951+
):
927952
"""Creates an EfficientNet model.
928953
929954
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
@@ -1069,9 +1094,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
10691094
return model
10701095

10711096

1072-
def _gen_tinynet(
1073-
variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs
1074-
):
1097+
def _gen_tinynet(variant, model_width=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
10751098
"""Creates a TinyNet model.
10761099
"""
10771100
arch_def = [
@@ -1183,8 +1206,7 @@ def _arch_def(chs: List[int], group_size: int):
11831206
return model
11841207

11851208

1186-
def _gen_test_efficientnet(
1187-
variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
1209+
def _gen_test_efficientnet(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrained=False, **kwargs):
11881210
""" Minimal test EfficientNet generator.
11891211
"""
11901212
arch_def = [

timm/models/mobilenetv3.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,9 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
412412
return model
413413

414414

415-
def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
415+
def _gen_mobilenet_v3_rw(
416+
variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs
417+
) -> MobileNetV3:
416418
"""Creates a MobileNet-V3 model.
417419
418420
Ref impl: ?
@@ -450,7 +452,9 @@ def _gen_mobilenet_v3_rw(variant: str, channel_multiplier: float = 1.0, pretrain
450452
return model
451453

452454

453-
def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
455+
def _gen_mobilenet_v3(
456+
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs
457+
) -> MobileNetV3:
454458
"""Creates a MobileNet-V3 model.
455459
456460
Ref impl: ?
@@ -533,7 +537,7 @@ def _gen_mobilenet_v3(variant: str, channel_multiplier: float = 1.0, pretrained:
533537
]
534538
se_layer = partial(SqueezeExcite, gate_layer='hard_sigmoid', force_act_layer=nn.ReLU, rd_round_fn=round_channels)
535539
model_kwargs = dict(
536-
block_args=decode_arch_def(arch_def),
540+
block_args=decode_arch_def(arch_def, group_size=group_size),
537541
num_features=num_features,
538542
stem_size=16,
539543
fix_stem=channel_multiplier < 0.75,
@@ -646,7 +650,9 @@ def _gen_lcnet(variant: str, channel_multiplier: float = 1.0, pretrained: bool =
646650
return model
647651

648652

649-
def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained: bool = False, **kwargs) -> MobileNetV3:
653+
def _gen_mobilenet_v4(
654+
variant: str, channel_multiplier: float = 1.0, group_size=None, pretrained: bool = False, **kwargs,
655+
) -> MobileNetV3:
650656
"""Creates a MobileNet-V4 model.
651657
652658
Ref impl: ?
@@ -877,7 +883,7 @@ def _gen_mobilenet_v4(variant: str, channel_multiplier: float = 1.0, pretrained:
877883
assert False, f'Unknown variant {variant}.'
878884

879885
model_kwargs = dict(
880-
block_args=decode_arch_def(arch_def),
886+
block_args=decode_arch_def(arch_def, group_size=group_size),
881887
head_bias=False,
882888
head_norm=True,
883889
num_features=num_features,

0 commit comments

Comments
 (0)