@@ -620,12 +620,28 @@ def save_bpte_program(exec_prog, original_model: torch.nn.Module, output_name: s
620
620
save_bundled_program (exec_prog , method_test_suites , output_name )
621
621
622
622
623
+ def quantize_model (
624
+ exported_program , args , model : torch .nn .Module , example_inputs , compile_spec
625
+ ):
626
+ model_int8 = quantize (
627
+ model ,
628
+ args .model_name ,
629
+ compile_spec ,
630
+ example_inputs ,
631
+ args .evaluate ,
632
+ args .evaluate_config ,
633
+ )
634
+ # Wrap quantized model back into an exported_program
635
+ exported_program = torch .export .export_for_training (
636
+ model_int8 , example_inputs , strict = True
637
+ )
638
+
639
+ return model_int8 , exported_program
640
+
641
+
623
642
def to_edge_TOSA_delegate (
624
- exported_program ,
625
- args ,
626
- model : torch .nn .Module ,
643
+ exported_program , args , model : torch .nn .Module , example_inputs
627
644
):
628
- model_int8 = None
629
645
# As we can target multiple output encodings, one must
630
646
# be specified.
631
647
compile_spec = get_compile_spec (
@@ -634,23 +650,13 @@ def to_edge_TOSA_delegate(
634
650
args .system_config ,
635
651
args .memory_mode ,
636
652
)
653
+
654
+ model_int8 = None
637
655
if args .quantize :
638
- model = quantize (
639
- model ,
640
- args .model_name ,
641
- compile_spec ,
642
- example_inputs ,
643
- args .evaluate ,
644
- args .evaluate_config ,
656
+ model_int8 , exported_program = quantize_model (
657
+ exported_program , args , model , example_inputs , compile_spec
645
658
)
646
- model_int8 = model
647
- # Wrap quantized model back into an exported_program
648
- exported_program = torch .export .export_for_training (
649
- model , example_inputs , strict = True
650
- )
651
-
652
- if args .intermediates :
653
- os .makedirs (args .intermediates , exist_ok = True )
659
+ model = model_int8
654
660
655
661
if is_ethosu (compile_spec ):
656
662
partitioner = EthosUPartitioner (compile_spec )
@@ -669,6 +675,31 @@ def to_edge_TOSA_delegate(
669
675
return model_int8 , edge
670
676
671
677
678
+ def to_edge_no_delegate (exported_program , args , model : torch .nn .Module , example_inputs ):
679
+ model_int8 = None
680
+ if args .quantize :
681
+ # As we can target multiple output encodings, one must
682
+ # be specified.
683
+ compile_spec = get_compile_spec (
684
+ args .target ,
685
+ args .intermediates ,
686
+ args .system_config ,
687
+ args .memory_mode ,
688
+ )
689
+ model , exported_program = quantize_model (
690
+ exported_program , args , model , example_inputs , compile_spec
691
+ )
692
+ model_int8 = model
693
+
694
+ edge = to_edge_transform_and_lower (
695
+ exported_program ,
696
+ compile_config = EdgeCompileConfig (
697
+ _check_ir_validity = False ,
698
+ ),
699
+ )
700
+ return model_int8 , edge
701
+
702
+
672
703
if __name__ == "__main__" : # noqa: C901
673
704
args = get_args ()
674
705
@@ -686,16 +717,18 @@ def to_edge_TOSA_delegate(
686
717
model = exported_program .module ()
687
718
model_fp32 = model
688
719
720
+ if args .intermediates :
721
+ os .makedirs (args .intermediates , exist_ok = True )
722
+
689
723
# Quantize if required
690
724
model_int8 = None
691
725
if args .delegate :
692
- model_int8 , edge = to_edge_TOSA_delegate (exported_program , args , model )
726
+ model_int8 , edge = to_edge_TOSA_delegate (
727
+ exported_program , args , model , example_inputs
728
+ )
693
729
else :
694
- edge = to_edge_transform_and_lower (
695
- exported_program ,
696
- compile_config = EdgeCompileConfig (
697
- _check_ir_validity = False ,
698
- ),
730
+ model_int8 , edge = to_edge_no_delegate (
731
+ exported_program , args , model , example_inputs
699
732
)
700
733
701
734
dump_delegation_info (edge , args .intermediates )
0 commit comments