@@ -54,6 +54,26 @@ static Value truncToI32(ImplicitLocOpBuilder &b, Value value) {
54
54
return b.create <LLVM::TruncOp>(b.getI32Type (), value);
55
55
}
56
56
57
+ // / Returns warp-size as a value.
58
+ static Value getWarpSizeValue (ImplicitLocOpBuilder &b) {
59
+ static std::optional<Value> warpSize = std::nullopt;
60
+ if (!warpSize.has_value ()) {
61
+ warpSize = b.create <LLVM::ConstantOp>(IntegerType::get (b.getContext (), 32 ),
62
+ b.getI32IntegerAttr (kWarpSize ));
63
+ }
64
+ return warpSize.value ();
65
+ }
66
+
67
+ // / Returns warp-size as a value.
68
+ static Value getWarpSizeValue (ImplicitLocOpBuilder &b) {
69
+ static std::optional<Value> warpSize = std::nullopt;
70
+ if (!warpSize.has_value ()) {
71
+ warpSize = b.create <LLVM::ConstantOp>(IntegerType::get (b.getContext (), 32 ),
72
+ b.getI32IntegerAttr (kWarpSize ));
73
+ }
74
+ return warpSize.value ();
75
+ }
76
+
57
77
// / Returns the type for the intrinsic given the vectorResultType of the
58
78
// / `gpu.mma.sync` operation.
59
79
static Type inferIntrinsicResultType (Type vectorResultType) {
@@ -1441,47 +1461,80 @@ struct NVGPUWarpgroupMmaStoreOpLowering
1441
1461
using ConvertOpToLLVMPattern<
1442
1462
nvgpu::WarpgroupMmaStoreOp>::ConvertOpToLLVMPattern;
1443
1463
1444
- void storeFragmentedMatrix (Value wgmmaResult, nvgpu::WarpgroupMmaStoreOp op,
1445
- OpAdaptor adaptor,
1446
- ConversionPatternRewriter &rewriter,
1464
+ // / This function stores a fragmented register matrix owned by a warp group
1465
+ // / (128 threads) into a memref. Each thread has 64 registers, each the size
1466
+ // / of a struct.
1467
+ // / Here is what each threads (T) holds, each `d` is struct value with a
1468
+ // / number.
1469
+ // /
1470
+ // / Threads in warp-group (128 threads) and what they owns in the matriD:
1471
+ // / 0-31 Warp-0 -> MatrixD[0:15 ][0:N]
1472
+ // / 32-63 Warp-1 -> MatrixD[16:31][0:N]
1473
+ // / 64-95 Warp-2 -> MatrixD[32:47][0:N]
1474
+ // / 96-127 Warp-3 -> MatrixD[48:64][0:N]
1475
+ // /
1476
+ // / Matrix-D:
1477
+ // / +______________________________________________________________________+
1478
+ // / | 0-1 | 2-3 | 4-5 | 6-7 | 8-9 | 10-11|..|N-8,N-7 |
1479
+ // / 0 | T0:d0-d1 |T1:d0-d1 |T2:d0-d1 |T3:d0-d1 |T0:d4-d5| T1:d4-d5..|T0:dX-dY|
1480
+ // / 1 | T4:d0-d1 |T5:d0-d1 |T6:d0-d1 |T7:d0-d1 |T4:d4-d5| T5:d4-d5..|T4:dX-dY|
1481
+ // / ..| .........|.........|.........|.........|........|...........|........|
1482
+ // / 8 | T0:d2-d3 |T1:d2-d3 |T2:d2-d3 |T3:d2-d3 |T0:d6-d7|T1:d6-d7,..|T0:dZ-dW|
1483
+ // / 9 | T4:d2-d3 |T5:d2-d3 |T6:d2-d3 |T7:d2-d3 |T4:d6-d7| T5:d6-d7..|T4:dZ-dW|
1484
+ // / ..| .........|.........|.........|.........|........|...........|........|
1485
+ // / 15| T28:d2-d3|T29:d2-d3|T30:d2-d3|T31:d2-d3|........|...........|........|
1486
+ // / 16| T32:d2-d3|T33:d2-d3|T34:d2-d3|T35:d2-d3|........|...........|........|
1487
+ // / ..| .........|.........|.........|.........|........|...........|........|
1488
+ // / 32| T64:d2-d3|T65:d2-d3|T66:d2-d3|T67:d2-d3|........|...........|........|
1489
+ // / ..| .........|.........|.........|.........|........|...........|........|
1490
+ // / 48| T96:d2-d3|T97:d2-d3|T98:d2-d3|T99:d2-d3|........|...........|........|
1491
+ // / ..| .........|.........|.........|.........|........|...........|........|
1492
+ // / +______________________________________________________________________+
1493
+ // /
1494
+ // / \param rewriter: The pattern rewriter.
1495
+ // / \param matrixD: Result of the warp-group MMA operation (fragmented
1496
+ // / matrix). It is holded by a thread and a struct with 64 elements.
1497
+ // / \param dstMemref: The memref where the registers will be stored.
1498
+ // / \param offset: the offset within the memref where the registers will be
1499
+ // / stored.
1500
+ void storeFragmentedMatrix (ImplicitLocOpBuilder &b, Value matrixD,
1501
+ TypedValue<MemRefType> dstMemref,
1447
1502
int offset) const {
1448
- Location loc = op->getLoc ();
1449
- Type i32 = rewriter.getI32Type ();
1503
+ Type i32 = b.getI32Type ();
1450
1504
1451
1505
auto makeConst = [&](int32_t index ) -> Value {
1452
- return rewriter.create <LLVM::ConstantOp>(
1453
- loc, i32, rewriter.getI32IntegerAttr (index ));
1506
+ return b.create <LLVM::ConstantOp>(i32, b.getI32IntegerAttr (index ));
1454
1507
};
1508
+ Value c1 = makeConst (1 );
1509
+ Value c2 = makeConst (2 );
1455
1510
Value c4 = makeConst (4 );
1456
- Value c32 = makeConst (kWarpSize );
1457
1511
Value c8 = makeConst (8 );
1458
- Value c2 = makeConst (2 );
1459
- Value c1 = makeConst (1 );
1460
1512
Value c16 = makeConst (16 );
1513
+ Value warpSize = getWarpSizeValue (b);
1461
1514
1462
1515
auto makeMul = [&](Value lhs, Value rhs) -> Value {
1463
- return rewriter .create <LLVM::MulOp>(loc, lhs.getType (), lhs, rhs);
1516
+ return b .create <LLVM::MulOp>(lhs.getType (), lhs, rhs);
1464
1517
};
1465
1518
auto makeAdd = [&](Value lhs, Value rhs) -> Value {
1466
- return rewriter .create <LLVM::AddOp>(loc, lhs.getType (), lhs, rhs);
1519
+ return b .create <LLVM::AddOp>(lhs.getType (), lhs, rhs);
1467
1520
};
1468
1521
1469
- Value tidx = rewriter .create <NVVM::ThreadIdXOp>(loc, i32);
1470
- Value laneId = rewriter .create <LLVM::URemOp>(loc, i32, tidx, c32 );
1471
- Value warpId = rewriter .create <LLVM::UDivOp>(loc, i32, tidx, c32 );
1472
- Value lane4Id = rewriter .create <LLVM::UDivOp>(loc, i32, laneId, c4);
1473
- Value lane4modId = rewriter .create <LLVM::URemOp>(loc, i32, laneId, c4);
1522
+ Value tidx = b .create <NVVM::ThreadIdXOp>(i32);
1523
+ Value laneId = b .create <LLVM::URemOp>(i32, tidx, warpSize );
1524
+ Value warpId = b .create <LLVM::UDivOp>(i32, tidx, warpSize );
1525
+ Value lane4Id = b .create <LLVM::UDivOp>(i32, laneId, c4);
1526
+ Value lane4modId = b .create <LLVM::URemOp>(i32, laneId, c4);
1474
1527
1475
1528
auto makeExtractAndStore = [&](int i, Value wgmmaResult, Value x, Value y,
1476
1529
TypedValue<::mlir::MemRefType> memref) {
1477
- Type it = rewriter .getIndexType ();
1478
- Value idx = rewriter .create <arith::IndexCastOp>(loc, it, x);
1479
- Value idy0 = rewriter .create <arith::IndexCastOp>(loc, it, y);
1480
- Value idy1 = rewriter .create <arith::IndexCastOp>(loc, it, makeAdd (y, c1));
1481
- Value d0 = rewriter .create <LLVM::ExtractValueOp>(loc, wgmmaResult, i);
1482
- Value d1 = rewriter .create <LLVM::ExtractValueOp>(loc, wgmmaResult, i + 1 );
1483
- rewriter .create <memref::StoreOp>(loc, d0, memref, ValueRange{idx, idy0});
1484
- rewriter .create <memref::StoreOp>(loc, d1, memref, ValueRange{idx, idy1});
1530
+ Type it = b .getIndexType ();
1531
+ Value idx = b .create <arith::IndexCastOp>(it, x);
1532
+ Value idy0 = b .create <arith::IndexCastOp>(it, y);
1533
+ Value idy1 = b .create <arith::IndexCastOp>(it, makeAdd (y, c1));
1534
+ Value d0 = b .create <LLVM::ExtractValueOp>(wgmmaResult, i);
1535
+ Value d1 = b .create <LLVM::ExtractValueOp>(wgmmaResult, i + 1 );
1536
+ b .create <memref::StoreOp>(d0, memref, ValueRange{idx, idy0});
1537
+ b .create <memref::StoreOp>(d1, memref, ValueRange{idx, idy1});
1485
1538
};
1486
1539
1487
1540
Value tj = makeMul (lane4modId, c2);
@@ -1493,7 +1546,7 @@ struct NVGPUWarpgroupMmaStoreOpLowering
1493
1546
for (int j = 0 ; j < 16 ; ++j) {
1494
1547
Value idy = makeAdd (tj, makeMul (makeConst (j), c8));
1495
1548
int sIndex = i * 2 + j * 4 ;
1496
- makeExtractAndStore (sIndex , wgmmaResult , idx, idy, op. getDstMemref () );
1549
+ makeExtractAndStore (sIndex , matrixD , idx, idy, dstMemref );
1497
1550
}
1498
1551
}
1499
1552
}
@@ -1502,10 +1555,11 @@ struct NVGPUWarpgroupMmaStoreOpLowering
1502
1555
matchAndRewrite (nvgpu::WarpgroupMmaStoreOp op, OpAdaptor adaptor,
1503
1556
ConversionPatternRewriter &rewriter) const override {
1504
1557
int offset = 0 ;
1505
- for (auto result : adaptor.getMatrixD ()) {
1506
- auto stype = result.getType ().cast <LLVM::LLVMStructType>();
1507
- storeFragmentedMatrix (result, op, adaptor, rewriter, offset);
1508
- offset += stype.getBody ().size ();
1558
+ ImplicitLocOpBuilder lb (op->getLoc (), rewriter);
1559
+ for (Value matrixD : adaptor.getMatrixD ()) {
1560
+ auto structType = matrixD.getType ().cast <LLVM::LLVMStructType>();
1561
+ storeFragmentedMatrix (lb, matrixD, op.getDstMemref (), offset);
1562
+ offset += structType.getBody ().size ();
1509
1563
}
1510
1564
rewriter.eraseOp (op);
1511
1565
return success ();
0 commit comments