|
9 | 9 | #include "gc/Dialect/Linalgx/LinalgxOps.h"
|
10 | 10 | #include "gc/Dialect/Linalgx/LinalgxDialect.h"
|
11 | 11 | #include "mlir/IR/OpImplementation.h"
|
| 12 | +#include <utility> |
12 | 13 |
|
13 | 14 | //===----------------------------------------------------------------------===//
|
14 | 15 | // Builder helper from mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
|
@@ -608,6 +609,80 @@ void MultiBatchMatmulOp::getEffects(
|
608 | 609 | getGenericEffectsImpl(effects, cast<LinalgOp>(getOperation()));
|
609 | 610 | }
|
610 | 611 |
|
| 612 | +//===----------------------------------------------------------------------===// |
| 613 | +// ScaledDotProductAttentionOp |
| 614 | +//===----------------------------------------------------------------------===// |
| 615 | + |
| 616 | +LogicalResult ScaledDotProductAttentionOp::verify() { return success(); } |
| 617 | + |
| 618 | +/// This method converts ScaledDotProductAttention into the following |
| 619 | +/// sequence of operations: |
| 620 | +/// output = softmax(ins[0] @ transpose(ins[1]) * scale + ins[3]) @ ins[2] |
| 621 | +FailureOr<SmallVector<Value>> |
| 622 | +ScaledDotProductAttentionOp::decomposeOperation(OpBuilder &b) { |
| 623 | + OpBuilder::InsertionGuard guard(b); |
| 624 | + b.setInsertionPoint(*this); |
| 625 | + Location loc = getLoc(); |
| 626 | + Value query = getInputs()[0], key = getInputs()[1], value = getInputs()[2], |
| 627 | + mask = getInputs()[3]; |
| 628 | + auto dtype = cast<RankedTensorType>(query.getType()).getElementType(); |
| 629 | + auto shape = cast<RankedTensorType>(query.getType()).getShape(); |
| 630 | + float rsqrt_head = 1 / sqrt(shape[3]); |
| 631 | + |
| 632 | + SmallVector<int64_t> permutation{0, 1, 3, 2}; |
| 633 | + SmallVector<int64_t> transposeShape{shape[0], shape[1], shape[3], shape[2]}; |
| 634 | + auto transposeOut = b.create<tensor::EmptyOp>(loc, transposeShape, dtype); |
| 635 | + auto transpose = b.create<linalg::TransposeOp>( |
| 636 | + /*location=*/loc, |
| 637 | + /*inputs=*/key, |
| 638 | + /*outputs=*/transposeOut, |
| 639 | + /*permutation=*/permutation); |
| 640 | + |
| 641 | + SmallVector<int64_t> matmulQKShape{shape[0], shape[1], shape[2], shape[2]}; |
| 642 | + auto matmulQKOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype); |
| 643 | + auto matmulQK = b.create<linalgx::MultiBatchMatmulOp>( |
| 644 | + /*location=*/loc, matmulQKOut.getResult().getType(), |
| 645 | + /*inputs=*/ValueRange{query, transpose->getResult(0)}, |
| 646 | + /*outputs=*/ValueRange{matmulQKOut.getResult()}); |
| 647 | + |
| 648 | + auto mulOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype); |
| 649 | + // Broadcast the initial value to the output tensor before convolving. |
| 650 | + SmallVector<AffineMap, 4> indexingMaps; |
| 651 | + indexingMaps.push_back(b.getMultiDimIdentityMap(4)); |
| 652 | + indexingMaps.push_back(b.getMultiDimIdentityMap(4)); |
| 653 | + auto mul = b.create<linalg::GenericOp>( |
| 654 | + /*location=*/loc, matmulQKOut.getResult().getType(), |
| 655 | + /*inputs=*/ValueRange{matmulQK->getResult(0)}, |
| 656 | + /*outputs=*/ValueRange{mulOut.getResult()}, indexingMaps, |
| 657 | + SmallVector<utils::IteratorType>(4, utils::IteratorType::parallel), |
| 658 | + [&](OpBuilder &nestedBuilder, Location nestedLoc, ValueRange args) { |
| 659 | + Value constant = b.create<arith::ConstantOp>( |
| 660 | + loc, nestedBuilder.getFloatAttr(dtype, rsqrt_head)); |
| 661 | + Value added = |
| 662 | + nestedBuilder.create<arith::MulFOp>(loc, args[0], constant); |
| 663 | + nestedBuilder.create<linalg::YieldOp>(nestedLoc, added); |
| 664 | + }); |
| 665 | + |
| 666 | + auto addOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype); |
| 667 | + auto add = b.create<linalg::AddOp>( |
| 668 | + /*location=*/loc, addOut.getResult().getType(), |
| 669 | + /*inputs=*/ValueRange{mul->getResult(0), mask}, |
| 670 | + /*outputs=*/ValueRange{addOut.getResult()}); |
| 671 | + |
| 672 | + auto softmaxOut = b.create<tensor::EmptyOp>(loc, matmulQKShape, dtype); |
| 673 | + auto softmax = b.create<linalg::SoftmaxOp>( |
| 674 | + /*location=*/loc, softmaxOut.getResult().getType(), |
| 675 | + /*inputs=*/add->getResult(0), |
| 676 | + /*outputs=*/softmaxOut.getResult(), 3); |
| 677 | + |
| 678 | + auto matmulVOut = b.create<tensor::EmptyOp>(loc, shape, dtype); |
| 679 | + auto matmulV = b.create<linalgx::MultiBatchMatmulOp>( |
| 680 | + /*location=*/loc, matmulVOut.getResult().getType(), |
| 681 | + /*inputs=*/ValueRange{softmax->getResult(0), value}, |
| 682 | + /*outputs=*/ValueRange{matmulVOut.getResult()}); |
| 683 | + return SmallVector<Value>{matmulV.getResults()[0]}; |
| 684 | +} |
| 685 | + |
611 | 686 | /////// Operations corresponding to library calls defined with Tablegen ////////
|
612 | 687 |
|
613 | 688 | #define GET_OP_CLASSES
|
|
0 commit comments