|
11 | 11 | #include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
|
12 | 12 | #include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
|
13 | 13 | #include "mlir/Conversion/LLVMCommon/Pattern.h"
|
| 14 | +#include "mlir/Dialect/Arith/IR/Arith.h" |
14 | 15 | #include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
15 | 16 | #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
16 | 17 | #include "mlir/Dialect/LLVMIR/LLVMTypes.h"
|
@@ -394,8 +395,8 @@ struct ConvertNVGPUToNVVMPass
|
394 | 395 | using Base::Base;
|
395 | 396 |
|
396 | 397 | void getDependentDialects(DialectRegistry ®istry) const override {
|
397 |
| - registry |
398 |
| - .insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect>(); |
| 398 | + registry.insert<memref::MemRefDialect, LLVM::LLVMDialect, NVVM::NVVMDialect, |
| 399 | + arith::ArithDialect>(); |
399 | 400 | }
|
400 | 401 |
|
401 | 402 | void runOnOperation() override {
|
@@ -436,6 +437,7 @@ struct ConvertNVGPUToNVVMPass
|
436 | 437 | populateNVGPUToNVVMConversionPatterns(converter, patterns);
|
437 | 438 | LLVMConversionTarget target(getContext());
|
438 | 439 | target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
|
| 440 | + target.addLegalDialect<::mlir::arith::ArithDialect>(); |
439 | 441 | target.addLegalDialect<::mlir::memref::MemRefDialect>();
|
440 | 442 | target.addLegalDialect<::mlir::NVVM::NVVMDialect>();
|
441 | 443 | mlir::scf::populateSCFStructuralTypeConversionsAndLegality(
|
@@ -1434,11 +1436,88 @@ struct NVGPUWarpgroupMmaOpLowering
|
1434 | 1436 | }
|
1435 | 1437 | };
|
1436 | 1438 |
|
| 1439 | +struct NVGPUWarpgroupMmaStoreOpLowering |
| 1440 | + : public ConvertOpToLLVMPattern<nvgpu::WarpgroupMmaStoreOp> { |
| 1441 | + using ConvertOpToLLVMPattern< |
| 1442 | + nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern; |
| 1443 | + |
| 1444 | + void storeFragmentedMatrix(Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op, |
| 1445 | + OpAdaptor adaptor, |
| 1446 | + ConversionPatternRewriter &rewriter, |
| 1447 | + int offset) const { |
| 1448 | + Location loc = op->getLoc(); |
| 1449 | + Type i32 = rewriter.getI32Type(); |
| 1450 | + |
| 1451 | + auto makeConst = [&](int32_t index) -> Value { |
| 1452 | + return rewriter.create<LLVM::ConstantOp>( |
| 1453 | + loc, i32, rewriter.getI32IntegerAttr(index)); |
| 1454 | + }; |
| 1455 | + Value c4 = makeConst(4); |
| 1456 | + Value c32 = makeConst(kWarpSize); |
| 1457 | + Value c8 = makeConst(8); |
| 1458 | + Value c2 = makeConst(2); |
| 1459 | + Value c1 = makeConst(1); |
| 1460 | + Value c16 = makeConst(16); |
| 1461 | + |
| 1462 | + auto makeMul = [&](Value lhs, Value rhs) -> Value { |
| 1463 | + return rewriter.create<LLVM::MulOp>(loc, lhs.getType(), lhs, rhs); |
| 1464 | + }; |
| 1465 | + auto makeAdd = [&](Value lhs, Value rhs) -> Value { |
| 1466 | + return rewriter.create<LLVM::AddOp>(loc, lhs.getType(), lhs, rhs); |
| 1467 | + }; |
| 1468 | + |
| 1469 | + Value tidx = rewriter.create<NVVM::ThreadIdXOp>(loc, i32); |
| 1470 | + Value laneId = rewriter.create<LLVM::URemOp>(loc, i32, tidx, c32); |
| 1471 | + Value warpId = rewriter.create<LLVM::UDivOp>(loc, i32, tidx, c32); |
| 1472 | + Value lane4Id = rewriter.create<LLVM::UDivOp>(loc, i32, laneId, c4); |
| 1473 | + Value lane4modId = rewriter.create<LLVM::URemOp>(loc, i32, laneId, c4); |
| 1474 | + |
| 1475 | + auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y, |
| 1476 | + TypedValue<::mlir::MemRefType> memref) { |
| 1477 | + Type it = rewriter.getIndexType(); |
| 1478 | + Value idx = rewriter.create<arith::IndexCastOp>(loc, it, x); |
| 1479 | + Value idy0 = rewriter.create<arith::IndexCastOp>(loc, it, y); |
| 1480 | + Value idy1 = rewriter.create<arith::IndexCastOp>(loc, it, makeAdd(y, c1)); |
| 1481 | + Value d0 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i); |
| 1482 | + Value d1 = rewriter.create<LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1); |
| 1483 | + rewriter.create<memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0}); |
| 1484 | + rewriter.create<memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1}); |
| 1485 | + }; |
| 1486 | + |
| 1487 | + Value tj = makeMul(lane4modId, c2); |
| 1488 | + Value ti = makeAdd(lane4Id, makeMul(warpId, c16)); |
| 1489 | + if (offset) |
| 1490 | + ti = makeAdd(ti, makeConst(offset)); |
| 1491 | + for (int i = 0; i < 2; ++i) { |
| 1492 | + Value idx = makeAdd(ti, makeMul(makeConst(i), c8)); |
| 1493 | + for (int j = 0; j < 16; ++j) { |
| 1494 | + Value idy = makeAdd(tj, makeMul(makeConst(j), c8)); |
| 1495 | + int sIndex = i * 2 + j * 4; |
| 1496 | + makeExtractAndStore(sIndex, wgmmaResult, idx, idy, op.getDstMemref()); |
| 1497 | + } |
| 1498 | + } |
| 1499 | + } |
| 1500 | + |
| 1501 | + LogicalResult |
| 1502 | + matchAndRewrite(nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor, |
| 1503 | + ConversionPatternRewriter &rewriter) const override { |
| 1504 | + int offset = 0; |
| 1505 | + for (auto result : adaptor.getMatrixD()) { |
| 1506 | + auto stype = result.getType().cast<LLVM::LLVMStructType>(); |
| 1507 | + storeFragmentedMatrix(result, op, adaptor, rewriter, offset); |
| 1508 | + offset += stype.getBody().size(); |
| 1509 | + } |
| 1510 | + rewriter.eraseOp(op); |
| 1511 | + return success(); |
| 1512 | + } |
| 1513 | +}; |
| 1514 | + |
1437 | 1515 | } // namespace
|
1438 | 1516 |
|
1439 | 1517 | void mlir::populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
1440 | 1518 | RewritePatternSet &patterns) {
|
1441 | 1519 | patterns.add<
|
| 1520 | + NVGPUWarpgroupMmaStoreOpLowering, // nvgpu.warpgroup.mma.store` |
1442 | 1521 | NVGPUMBarrierCreateLowering, // nvgpu.mbarrier.create
|
1443 | 1522 | NVGPUMBarrierInitLowering, // nvgpu.mbarrier.init
|
1444 | 1523 | NVGPUMBarrierArriveLowering, // nvgpu.mbarrier.arrive
|
|
0 commit comments