Skip to content

Commit 8648723

Browse files
committed
fix linalg op
1 parent c3fcf97 commit 8648723

File tree

2 files changed

+13
-1
lines changed

2 files changed

+13
-1
lines changed

include/gc/Dialect/Linalgx/LinalgxStructuredOps.td

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,14 @@ def Linalgx_ScaledDotProductAttentionOp
328328
Variadic<TensorOrMemref>:$inputs,
329329
Variadic<TensorOrMemref>:$outputs);
330330
let results = (outs Variadic<TensorOrMemref>:$results);
331-
let regions = (region AnyRegion:$region);
332331

333332
let hasVerifier = 1;
333+
let assemblyFormat = [{
334+
attr-dict
335+
`ins` `(` $inputs `:` type($inputs) `)`
336+
`outs` `(` $outputs `:` type($outputs) `)`
337+
(`->` type($results)^)?
338+
}];
334339
}
335340

336341
#endif // LINALGX_STRUCTURED_OPS

test/gc/Transform/flashAttention.mlir

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
// RUN: gc-opt --split-input-file --flash-attention-conversion %s
2+
3+
func.func @flash_attention(%arg0: tensor<1x16x384x64xf32>, %arg1: tensor<1x16x384x64xf32>, %arg2: tensor<1x16x384x64xf32>, %arg3: tensor<1x16x384x384xf32>) -> tensor<1x16x384x64xf32> {
4+
%0 = tensor.empty() : tensor<1x16x384x64xf32>
5+
%1 = linalgx.scaled_dot_product_attention ins(%arg0, %arg1, %arg2, %arg3: tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x64xf32>, tensor<1x16x384x384xf32>) outs(%0 : tensor<1x16x384x64xf32>) -> tensor<1x16x384x64xf32>
6+
return %1 : tensor<1x16x384x64xf32>
7+
}

0 commit comments

Comments
 (0)