@@ -488,7 +488,8 @@ def _gen_mnasnet_small(variant, channel_multiplier=1.0, pretrained=False, **kwar
488
488
489
489
def _gen_mobilenet_v1 (
490
490
variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
491
- fix_stem_head = False , head_conv = False , pretrained = False , ** kwargs ):
491
+ group_size = None , fix_stem_head = False , head_conv = False , pretrained = False , ** kwargs
492
+ ):
492
493
"""
493
494
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
494
495
Paper: https://arxiv.org/abs/1801.04381
@@ -503,7 +504,12 @@ def _gen_mobilenet_v1(
503
504
round_chs_fn = partial (round_channels , multiplier = channel_multiplier )
504
505
head_features = (1024 if fix_stem_head else max (1024 , round_chs_fn (1024 ))) if head_conv else 0
505
506
model_kwargs = dict (
506
- block_args = decode_arch_def (arch_def , depth_multiplier = depth_multiplier , fix_first_last = fix_stem_head ),
507
+ block_args = decode_arch_def (
508
+ arch_def ,
509
+ depth_multiplier = depth_multiplier ,
510
+ fix_first_last = fix_stem_head ,
511
+ group_size = group_size ,
512
+ ),
507
513
num_features = head_features ,
508
514
stem_size = 32 ,
509
515
fix_stem = fix_stem_head ,
@@ -517,7 +523,9 @@ def _gen_mobilenet_v1(
517
523
518
524
519
525
def _gen_mobilenet_v2 (
520
- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , fix_stem_head = False , pretrained = False , ** kwargs ):
526
+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 ,
527
+ group_size = None , fix_stem_head = False , pretrained = False , ** kwargs
528
+ ):
521
529
""" Generate MobileNet-V2 network
522
530
Ref impl: https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet_v2.py
523
531
Paper: https://arxiv.org/abs/1801.04381
@@ -533,7 +541,12 @@ def _gen_mobilenet_v2(
533
541
]
534
542
round_chs_fn = partial (round_channels , multiplier = channel_multiplier )
535
543
model_kwargs = dict (
536
- block_args = decode_arch_def (arch_def , depth_multiplier = depth_multiplier , fix_first_last = fix_stem_head ),
544
+ block_args = decode_arch_def (
545
+ arch_def ,
546
+ depth_multiplier = depth_multiplier ,
547
+ fix_first_last = fix_stem_head ,
548
+ group_size = group_size ,
549
+ ),
537
550
num_features = 1280 if fix_stem_head else max (1280 , round_chs_fn (1280 )),
538
551
stem_size = 32 ,
539
552
fix_stem = fix_stem_head ,
@@ -613,7 +626,8 @@ def _gen_spnasnet(variant, channel_multiplier=1.0, pretrained=False, **kwargs):
613
626
614
627
def _gen_efficientnet (
615
628
variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , channel_divisor = 8 ,
616
- group_size = None , pretrained = False , ** kwargs ):
629
+ group_size = None , pretrained = False , ** kwargs
630
+ ):
617
631
"""Creates an EfficientNet model.
618
632
619
633
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
@@ -661,7 +675,8 @@ def _gen_efficientnet(
661
675
662
676
663
677
def _gen_efficientnet_edge (
664
- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs ):
678
+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
679
+ ):
665
680
""" Creates an EfficientNet-EdgeTPU model
666
681
667
682
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/edgetpu
@@ -692,7 +707,8 @@ def _gen_efficientnet_edge(
692
707
693
708
694
709
def _gen_efficientnet_condconv (
695
- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , experts_multiplier = 1 , pretrained = False , ** kwargs ):
710
+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , experts_multiplier = 1 , pretrained = False , ** kwargs
711
+ ):
696
712
"""Creates an EfficientNet-CondConv model.
697
713
698
714
Ref impl: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet/condconv
@@ -764,7 +780,8 @@ def _gen_efficientnet_lite(variant, channel_multiplier=1.0, depth_multiplier=1.0
764
780
765
781
766
782
def _gen_efficientnetv2_base (
767
- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
783
+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
784
+ ):
768
785
""" Creates an EfficientNet-V2 base model
769
786
770
787
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -780,7 +797,7 @@ def _gen_efficientnetv2_base(
780
797
]
781
798
round_chs_fn = partial (round_channels , multiplier = channel_multiplier , round_limit = 0. )
782
799
model_kwargs = dict (
783
- block_args = decode_arch_def (arch_def , depth_multiplier ),
800
+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
784
801
num_features = round_chs_fn (1280 ),
785
802
stem_size = 32 ,
786
803
round_chs_fn = round_chs_fn ,
@@ -793,7 +810,8 @@ def _gen_efficientnetv2_base(
793
810
794
811
795
812
def _gen_efficientnetv2_s (
796
- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , rw = False , pretrained = False , ** kwargs ):
813
+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , rw = False , pretrained = False , ** kwargs
814
+ ):
797
815
""" Creates an EfficientNet-V2 Small model
798
816
799
817
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -831,7 +849,9 @@ def _gen_efficientnetv2_s(
831
849
return model
832
850
833
851
834
- def _gen_efficientnetv2_m (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
852
+ def _gen_efficientnetv2_m (
853
+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
854
+ ):
835
855
""" Creates an EfficientNet-V2 Medium model
836
856
837
857
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -849,7 +869,7 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
849
869
]
850
870
851
871
model_kwargs = dict (
852
- block_args = decode_arch_def (arch_def , depth_multiplier ),
872
+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
853
873
num_features = 1280 ,
854
874
stem_size = 24 ,
855
875
round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -861,7 +881,9 @@ def _gen_efficientnetv2_m(variant, channel_multiplier=1.0, depth_multiplier=1.0,
861
881
return model
862
882
863
883
864
- def _gen_efficientnetv2_l (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
884
+ def _gen_efficientnetv2_l (
885
+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
886
+ ):
865
887
""" Creates an EfficientNet-V2 Large model
866
888
867
889
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -879,7 +901,7 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
879
901
]
880
902
881
903
model_kwargs = dict (
882
- block_args = decode_arch_def (arch_def , depth_multiplier ),
904
+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
883
905
num_features = 1280 ,
884
906
stem_size = 32 ,
885
907
round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -891,7 +913,9 @@ def _gen_efficientnetv2_l(variant, channel_multiplier=1.0, depth_multiplier=1.0,
891
913
return model
892
914
893
915
894
- def _gen_efficientnetv2_xl (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
916
+ def _gen_efficientnetv2_xl (
917
+ variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , group_size = None , pretrained = False , ** kwargs
918
+ ):
895
919
""" Creates an EfficientNet-V2 Xtra-Large model
896
920
897
921
Ref impl: https://github.com/google/automl/tree/master/efficientnetv2
@@ -909,7 +933,7 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
909
933
]
910
934
911
935
model_kwargs = dict (
912
- block_args = decode_arch_def (arch_def , depth_multiplier ),
936
+ block_args = decode_arch_def (arch_def , depth_multiplier , group_size = group_size ),
913
937
num_features = 1280 ,
914
938
stem_size = 32 ,
915
939
round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
@@ -923,7 +947,8 @@ def _gen_efficientnetv2_xl(variant, channel_multiplier=1.0, depth_multiplier=1.0
923
947
924
948
def _gen_efficientnet_x (
925
949
variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , channel_divisor = 8 ,
926
- group_size = None , version = 1 , pretrained = False , ** kwargs ):
950
+ group_size = None , version = 1 , pretrained = False , ** kwargs
951
+ ):
927
952
"""Creates an EfficientNet model.
928
953
929
954
Ref impl: https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
@@ -1069,9 +1094,7 @@ def _gen_mixnet_m(variant, channel_multiplier=1.0, depth_multiplier=1.0, pretrai
1069
1094
return model
1070
1095
1071
1096
1072
- def _gen_tinynet (
1073
- variant , model_width = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs
1074
- ):
1097
+ def _gen_tinynet (variant , model_width = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
1075
1098
"""Creates a TinyNet model.
1076
1099
"""
1077
1100
arch_def = [
@@ -1183,8 +1206,7 @@ def _arch_def(chs: List[int], group_size: int):
1183
1206
return model
1184
1207
1185
1208
1186
- def _gen_test_efficientnet (
1187
- variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
1209
+ def _gen_test_efficientnet (variant , channel_multiplier = 1.0 , depth_multiplier = 1.0 , pretrained = False , ** kwargs ):
1188
1210
""" Minimal test EfficientNet generator.
1189
1211
"""
1190
1212
arch_def = [
0 commit comments