Skip to content

Commit ea8d59f

Browse files
ThomasJannaudfacebook-github-bot
authored andcommitted
RMSNorm support - Executorch (pytorch#9844)
Summary: This follows D72014553 which adds support for RMSNorm (cpu backend) This is a separate diff for Executorch / Github Reviewed By: Vysarat Differential Revision: D72258890
1 parent 3600d4f commit ea8d59f

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@
139139
"int in_zero_point, bool channel_last=False) -> (Tensor out)"
140140
)
141141
lib.define("linalg_vector_norm(Tensor X) -> (Tensor Y)")
142+
lib.define("rms_norm(Tensor X, float eps, Tensor W) -> (Tensor Y)")
142143
lib.define(
143144
"transposed_im2row(Tensor input, int[2] kernel_size, int[2] dilation, int[2] padding, int[2] stride, "
144145
"int[2] output_padding, Tensor in_zero_point, bool channel_last=False) -> (Tensor out)"
@@ -210,6 +211,7 @@
210211
"fully_connected.out(Tensor input, Tensor weight, Tensor? bias=None, *, Tensor(a!) out) -> Tensor(a!)"
211212
)
212213
lib.define("linalg_vector_norm.out(Tensor X, *, Tensor(a!) out) -> Tensor(a!)")
214+
lib.define("rms_norm.out(Tensor X, float eps, Tensor W, *, Tensor(a!) out) -> Tensor(a!)")
213215
lib.define(
214216
"quantized_fully_connected.out(Tensor src, Tensor weight, Tensor bias, int src_zero_point, "
215217
"Tensor weight_zero_point, Tensor out_multiplier, Tensor out_shift, int out_zero_point, Tensor? offset, *, Tensor(a!) out) -> Tensor(a!)"
@@ -614,6 +616,13 @@ def linalg_vector_norm_meta(
614616
# Output of norm is a scalar, so we return a [] tensor
615617
return X.new_empty([], dtype=X.dtype)
616618

619+
@register_fake("cadence::rms_norm")
620+
def rms_norm_meta(
621+
X: torch.Tensor,
622+
eps: float,
623+
weight: torch.Tensor,
624+
) -> torch.Tensor:
625+
return X.new_empty(X.shape, dtype=X.dtype)
617626

618627
@register_fake("cadence::requantize")
619628
def requantize_meta(

0 commit comments

Comments
 (0)