|
31 | 31 | ToEdgeTransformAndLower,
|
32 | 32 | )
|
33 | 33 |
|
| 34 | +from torch.export.graph_signature import ( |
| 35 | + ExportGraphSignature, |
| 36 | + InputKind, |
| 37 | +) |
| 38 | + |
34 | 39 | try:
|
35 | 40 | from torchao.quantization.quant_api import (
|
36 | 41 | int8_dynamic_activation_int4_weight,
|
@@ -871,3 +876,38 @@ def test_linear_qd8_as_fp32(self):
|
871 | 876 | "dequantize_per_channel.default": 1, # 1: weight
|
872 | 877 | },
|
873 | 878 | )
|
| 879 | + |
| 880 | + def test_linear_fp32_with_force_as_mm(self): |
| 881 | + def check_signature(signature:ExportGraphSignature, num_params:int=0): |
| 882 | + sign_params: int = 0 |
| 883 | + input_specs = signature.input_specs |
| 884 | + for input_spec in input_specs: |
| 885 | + if input_spec.kind == InputKind.PARAMETER: |
| 886 | + sign_params += 1 |
| 887 | + assert sign_params == num_params, f"Expected {num_params} params, got {sign_params} with force_flag={force_flag}, use_bias={use_bias}, legacy_mode={legacy_mode}" |
| 888 | + |
| 889 | + for force_flag in (True, False): |
| 890 | + partitioner = XnnpackPartitioner(force_fp32_dynamic_linear=force_flag) |
| 891 | + weight_param = force_flag == True # weight param |
| 892 | + for use_bias in (True, False): |
| 893 | + bias_param = use_bias == True and force_flag == True # bias param |
| 894 | + for legacy_mode in (True, False): |
| 895 | + module = BaseLinear(in_size=8, input_channels=13, output_channels=17, use_bias=use_bias) |
| 896 | + inputs = module.get_inputs() |
| 897 | + tester = Tester(module, inputs).export() |
| 898 | + if legacy_mode: |
| 899 | + tester.to_edge() |
| 900 | + partitioner_stage = Partition(partitioner=partitioner) |
| 901 | + tester.partition(partition_stage=partitioner_stage) |
| 902 | + tester.check_not(["executorch_exir_dialects_edge__ops_aten_mm_default" if use_bias else "executorch_exir_dialects_edge__ops_aten_addmm_default"]) |
| 903 | + else: |
| 904 | + to_edge_and_transform_stage = ToEdgeTransformAndLower(partitioners=[partitioner]) |
| 905 | + tester.to_edge_transform_and_lower(to_edge_and_transform_stage=to_edge_and_transform_stage) |
| 906 | + tester.check_not(["executorch_exir_dialects_edge__ops_aten_linear_default"]) |
| 907 | + |
| 908 | + signature: ExportGraphSignature = tester.get_artifact().exported_program().graph_signature |
| 909 | + check_signature(signature, num_params=weight_param + bias_param) |
| 910 | + |
| 911 | + tester.to_executorch() |
| 912 | + tester.serialize() |
| 913 | + tester.run_method_and_compare_outputs() |
0 commit comments