Skip to content

Commit e13ec10

Browse files
committed
fix sdpa op definition
1 parent 3567769 commit e13ec10

File tree

2 files changed

+25
-26
lines changed

2 files changed

+25
-26
lines changed

include/gc/Dialect/Linalgx/LinalgxOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,33 @@
1111

1212
include "LinalgxDialect.td"
1313

14+
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
15+
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
16+
1417
// Base class for Linalg dialect ops that do not correspond to library calls.
1518
class Linalgx_Op<string mnemonic, list<Trait> traits = []> :
1619
Op<LinalgxDialect, mnemonic, traits>;
1720

21+
def Linalgx_ScaledDotProductAttentionOp
22+
: Linalgx_Op<"scaled_dot_product_attention",
23+
[AttrSizedOperandSegments,
24+
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>]> {
25+
let summary = "Attention structure.";
26+
let description = [{
27+
Q, K, V, attention_mask.
28+
Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V.
29+
}];
30+
let arguments = (ins
31+
Variadic<AnyRankedTensor>:$inputs,
32+
Variadic<AnyRankedTensor>:$outputs);
33+
let results = (outs Variadic<AnyRankedTensor>:$results);
34+
35+
let hasVerifier = 1;
36+
let assemblyFormat = [{
37+
attr-dict
38+
`ins` `(` $inputs `:` type($inputs) `)`
39+
`outs` `(` $outputs `:` type($outputs) `)`
40+
(`->` type($results)^)?
41+
}];
42+
}
1843
#endif // LINALGX_OPS

include/gc/Dialect/Linalgx/LinalgxStructuredOps.td

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
2323
include "mlir/Interfaces/SideEffectInterfaces.td"
2424
include "mlir/IR/OpAsmInterface.td"
2525

26-
class Linalgx_Op<string mnemonic, list<Trait> traits = []> :
27-
Op<LinalgxDialect, mnemonic, traits>;
28-
2926
// Base Tablegen class for Linalg ops.
3027
// Linalg ops that correspond to library calls operate on ShapedType as their
3128
// first operands. These may be optionally followed by non-view operands
@@ -315,27 +312,4 @@ def Linalgx_MultiBatchMatmulOp : LinalgxStructuredBase_Op<"multi_batch_matmul",
315312
}];
316313
}
317314

318-
def Linalgx_ScaledDotProductAttentionOp
319-
: Linalgx_Op<"scaled_dot_product_attention",
320-
[AttrSizedOperandSegments,
321-
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>]> {
322-
let summary = "Attention structure.";
323-
let description = [{
324-
Q, K, V, attention_mask.
325-
Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V.
326-
}];
327-
let arguments = (ins
328-
Variadic<TensorOrMemref>:$inputs,
329-
Variadic<TensorOrMemref>:$outputs);
330-
let results = (outs Variadic<TensorOrMemref>:$results);
331-
332-
let hasVerifier = 1;
333-
let assemblyFormat = [{
334-
attr-dict
335-
`ins` `(` $inputs `:` type($inputs) `)`
336-
`outs` `(` $outputs `:` type($outputs) `)`
337-
(`->` type($results)^)?
338-
}];
339-
}
340-
341315
#endif // LINALGX_STRUCTURED_OPS

0 commit comments

Comments
 (0)