Skip to content

Commit f741fbd

Browse files
committed
support flash attention
1 parent de0376f commit f741fbd

File tree

7 files changed

+534
-0
lines changed

7 files changed

+534
-0
lines changed

include/gc/Dialect/Linalgx/LinalgxOps.td

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,33 @@
1111

1212
include "LinalgxDialect.td"
1313

14+
include "mlir/Dialect/Linalg/IR/LinalgBase.td"
15+
include "mlir/Dialect/Linalg/IR/LinalgInterfaces.td"
16+
1417
// Base class for Linalg dialect ops that do not correspond to library calls.
1518
class Linalgx_Op<string mnemonic, list<Trait> traits = []> :
1619
Op<LinalgxDialect, mnemonic, traits>;
1720

21+
def Linalgx_ScaledDotProductAttentionOp
22+
: Linalgx_Op<"scaled_dot_product_attention",
23+
[AttrSizedOperandSegments,
24+
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>]> {
25+
let summary = "Attention structure.";
26+
let description = [{
27+
Q, K, V, attention_mask.
28+
Output = SoftMax(Q @ K.transpose(-2, -1) + attention_mask) @ V.
29+
}];
30+
let arguments = (ins
31+
Variadic<AnyRankedTensor>:$inputs,
32+
Variadic<AnyRankedTensor>:$outputs);
33+
let results = (outs Variadic<AnyRankedTensor>:$results);
34+
35+
let hasVerifier = 1;
36+
let assemblyFormat = [{
37+
attr-dict
38+
`ins` `(` $inputs `:` type($inputs) `)`
39+
`outs` `(` $outputs `:` type($outputs) `)`
40+
(`->` type($results)^)?
41+
}];
42+
}
1843
#endif // LINALGX_OPS

include/gc/Transforms/Passes.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,17 @@ def ConvertOneDNNGraphToLinalg : Pass<"convert-onednn-graph-to-linalg"> {
3232
];
3333
}
3434

35+
def FlashAttentionConversion
36+
: Pass<"flash-attention-conversion", "func::FuncOp"> {
37+
let summary = "Flash Attention Conversion";
38+
let description =
39+
[{The pass converts MHA to flash attention implementation.}];
40+
let dependentDialects = [
41+
"func::FuncDialect", "linalg::LinalgDialect", "scf::SCFDialect",
42+
"tensor::TensorDialect"
43+
];
44+
}
45+
3546
#ifdef GC_USE_GPU
3647
def LinalgToXeGPU : Pass<"linalg-to-xegpu", "func::FuncOp"> {
3748
let summary = "Convert linalg dialect to XeGPU dialect.";

lib/gc/Dialect/Linalgx/LinalgxOps.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
#include "gc/Dialect/Linalgx/LinalgxOps.h"
1010
#include "gc/Dialect/Linalgx/LinalgxDialect.h"
1111
#include "mlir/IR/OpImplementation.h"
12+
#include <utility>
1213

1314
//===----------------------------------------------------------------------===//
1415
// Builder helper from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -608,6 +609,80 @@ void MultiBatchMatmulOp::getEffects(
608609
getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
609610
}
610611

612+
//===----------------------------------------------------------------------===//
613+
// ScaledDotProductAttentionOp
614+
//===----------------------------------------------------------------------===//
615+
616+
LogicalResult ScaledDotProductAttentionOp::verify() { return success(); }
617+
618+
/// This method converts ScaledDotProductAttention into the following
619+
/// sequence of operations:
620+
/// output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2]
621+
FailureOr<SmallVector<Value>>
622+
ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) {
623+
OpBuilder::InsertionGuard guard(b);
624+
b.setInsertionPoint(*this);
625+
Location loc = getLoc();
626+
Value query = getInputs()[0], key = getInputs()[1], value = getInputs()[2],
627+
mask = getInputs()[3];
628+
auto dtype = cast<RankedTensorType>(query.getType()).getElementType();
629+
auto shape = cast<RankedTensorType>(query.getType()).getShape();
630+
float rsqrt_head = 1 / sqrt(shape[3]);
631+
632+
SmallVector<int64_t> permutation{0, 1, 3, 2};
633+
SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]};
634+
auto transposeOut = b.create<tensor::EmptyOp>(loc, transposeShape, dtype);
635+
auto transpose = b.create<linalg::TransposeOp>(
636+
/*location=*/loc,
637+
/*inputs=*/key,
638+
/*outputs=*/transposeOut,
639+
/*permutation=*/permutation);
640+
641+
SmallVector<int64_t> matmulQKShape{shape[0], shape[1], shape[2], shape[2]};
642+
auto matmulQKOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
643+
auto matmulQK = b.create<linalgx::MultiBatchMatmulOp>(
644+
/*location=*/loc, matmulQKOut.getResult().getType(),
645+
/*inputs=*/ValueRange{query, transpose->getResult(0)},
646+
/*outputs=*/ValueRange{matmulQKOut.getResult()});
647+
648+
auto mulOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
649+
// Broadcast the initial value to the output tensor before convolving.
650+
SmallVector<AffineMap, 4> indexingMaps;
651+
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
652+
indexingMaps.push_back(b.getMultiDimIdentityMap(4));
653+
auto mul = b.create<linalg::GenericOp>(
654+
/*location=*/loc, matmulQKOut.getResult().getType(),
655+
/*inputs=*/ValueRange{matmulQK->getResult(0)},
656+
/*outputs=*/ValueRange{mulOut.getResult()}, indexingMaps,
657+
SmallVector<utils::IteratorType>(4, utils::IteratorType::parallel),
658+
[&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) {
659+
Value constant = b.create<arith::ConstantOp>(
660+
loc, nestedBuilder.getFloatAttr(dtype, rsqrt_head));
661+
Value added =
662+
nestedBuilder.create<arith::MulFOp>(loc, args[0], constant);
663+
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
664+
});
665+
666+
auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
667+
auto add = b.create<linalg::AddOp>(
668+
/*location=*/loc, addOut.getResult().getType(),
669+
/*inputs=*/ValueRange{mul->getResult(0), mask},
670+
/*outputs=*/ValueRange{addOut.getResult()});
671+
672+
auto softmaxOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype);
673+
auto softmax = b.create<linalg::SoftmaxOp>(
674+
/*location=*/loc, softmaxOut.getResult().getType(),
675+
/*inputs=*/add->getResult(0),
676+
/*outputs=*/softmaxOut.getResult(), 3);
677+
678+
auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype);
679+
auto matmulV = b.create<linalgx::MultiBatchMatmulOp>(
680+
/*location=*/loc, matmulVOut.getResult().getType(),
681+
/*inputs=*/ValueRange{softmax->getResult(0), value},
682+
/*outputs=*/ValueRange{matmulVOut.getResult()});
683+
return SmallVector<Value>{matmulV.getResults()[0]};
684+
}
685+
611686
/////// Operations corresponding to library calls defined with Tablegen ////////
612687

613688
#define GET_OP_CLASSES

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_library(GCPasses
1313
OneDNNGraphToLinalg.cpp
1414
Pipeline.cpp
1515
TileNamed.cpp
16+
FlashAttentionConversion.cpp
1617

1718
ADDITIONAL_HEADER_DIRS
1819
${PROJECT_SOURCE_DIR}/include

0 commit comments

Comments
 (0)