|
31 | 31 | ToEdgeTransformAndLower,
|
32 | 32 | )
|
33 | 33 |
|
| 34 | +from torch.export.graph_signature import ExportGraphSignature, InputKind |
| 35 | + |
34 | 36 | try:
|
35 | 37 | from torchao.quantization.quant_api import (
|
36 | 38 | int8_dynamic_activation_int4_weight,
|
@@ -871,3 +873,71 @@ def test_linear_qd8_as_fp32(self):
|
871 | 873 | "dequantize_per_channel.default": 1, # 1: weight
|
872 | 874 | },
|
873 | 875 | )
|
| 876 | + |
| 877 | + def test_linear_fp32_with_force_as_mm(self): |
| 878 | + def check_signature( |
| 879 | + signature: ExportGraphSignature, |
| 880 | + force_flag: bool, |
| 881 | + use_bias: bool, |
| 882 | + legacy_mode: bool, |
| 883 | + ): |
| 884 | + num_params = 0 |
| 885 | + if force_flag: |
| 886 | + num_params = 1 # weight_param |
| 887 | + if use_bias: |
| 888 | + num_params += 1 # bias_param |
| 889 | + sign_params: int = 0 |
| 890 | + input_specs = signature.input_specs |
| 891 | + for input_spec in input_specs: |
| 892 | + if input_spec.kind == InputKind.PARAMETER: |
| 893 | + sign_params += 1 |
| 894 | + assert ( |
| 895 | + sign_params == num_params |
| 896 | + ), f"Expected {num_params} params, got {sign_params} with force_flag={force_flag}, use_bias={use_bias}, legacy_mode={legacy_mode}" |
| 897 | + |
| 898 | + for force_flag in (True, False): |
| 899 | + for use_bias in (True, False): |
| 900 | + for legacy_mode in (True, False): |
| 901 | + module = BaseLinear( |
| 902 | + in_size=8, |
| 903 | + input_channels=13, |
| 904 | + output_channels=17, |
| 905 | + use_bias=use_bias, |
| 906 | + ) |
| 907 | + inputs = module.get_inputs() |
| 908 | + tester = Tester(module, inputs).export() |
| 909 | + partitioner = XnnpackPartitioner( |
| 910 | + force_fp32_dynamic_linear=force_flag |
| 911 | + ) |
| 912 | + if legacy_mode: |
| 913 | + tester.to_edge() |
| 914 | + partitioner_stage = Partition(partitioner=partitioner) |
| 915 | + tester.partition(partition_stage=partitioner_stage) |
| 916 | + tester.check_not( |
| 917 | + [ |
| 918 | + ( |
| 919 | + "executorch_exir_dialects_edge__ops_aten_mm_default" |
| 920 | + if use_bias |
| 921 | + else "executorch_exir_dialects_edge__ops_aten_addmm_default" |
| 922 | + ) |
| 923 | + ] |
| 924 | + ) |
| 925 | + else: |
| 926 | + to_edge_and_transform_stage = ToEdgeTransformAndLower( |
| 927 | + partitioners=[partitioner] |
| 928 | + ) |
| 929 | + tester.to_edge_transform_and_lower( |
| 930 | + to_edge_and_transform_stage=to_edge_and_transform_stage |
| 931 | + ) |
| 932 | + tester.check_not( |
| 933 | + ["executorch_exir_dialects_edge__ops_aten_linear_default"] |
| 934 | + ) |
| 935 | + |
| 936 | + signature: ExportGraphSignature = ( |
| 937 | + tester.get_artifact().exported_program().graph_signature |
| 938 | + ) |
| 939 | + check_signature(signature, force_flag, use_bias, legacy_mode) |
| 940 | + |
| 941 | + tester.to_executorch() |
| 942 | + tester.serialize() |
| 943 | + tester.run_method_and_compare_outputs() |
0 commit comments