Skip to content

Commit a073668

Browse files
Arm backend: Add check to not partition ops with float64 input (#10106)
- Float64 placeholders are not supported in the Arm backend. They will cause a crash when processed in the process_placeholder function. This patch rejects Float64 placeholders early to prevent crashes during the partition. Signed-off-by: Yufeng Shi <[email protected]>
1 parent 0039b3e commit a073668

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

backends/arm/operator_support/tosa_supported_operators.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,7 @@ def tosa_support_factory(
112112
# Negative checks: Remove nodes from partitioning
113113
negative_checks: list[OperatorSupportBase] = [
114114
CheckInt64Inputs(exported_program, reporter),
115+
CheckFloat64Inputs(exported_program, reporter),
115116
*[
116117
reporter.wrap_check(check, f"Rejected by {check.__class__.__name__}")
117118
for check in (additional_checks if additional_checks else [])
@@ -443,3 +444,26 @@ def is_node_supported(
443444
)
444445
return False
445446
return True
447+
448+
449+
class CheckFloat64Inputs(OperatorSupportBase):
450+
451+
def __init__(
452+
self, exported_program: ExportedProgram, reporter: WhyNoPartitionReporter
453+
):
454+
self.reporter = reporter
455+
super().__init__()
456+
457+
def is_node_supported(
458+
self, submodules: typing.Mapping[str, torch.nn.Module], node: fx.Node
459+
) -> bool:
460+
461+
for input_node in node.all_input_nodes:
462+
tensor = get_first_fake_tensor(input_node)
463+
if tensor.dtype == torch.float64:
464+
self.reporter.report_reject(
465+
node,
466+
f"Had float64 input {input_node.name} that couldn't be handled.",
467+
)
468+
return False
469+
return True

0 commit comments

Comments
 (0)