Skip to content

Commit 831236e

Browse files
[MLIR][NVVM] Add support for nvvm.breakpoint Op (#107193)
This commit adds support for `nvvm.breakpoint` Op which lowers to the PTX brkpt instruction. Also, added the respective tests in `nvvmir.mlir`
1 parent 958f59d commit 831236e

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,6 +2101,23 @@ def NVVM_WgmmaMmaAsyncOp : NVVM_Op<"wgmma.mma_async",
21012101
}];
21022102
}
21032103

2104+
//===----------------------------------------------------------------------===//
2105+
// NVVM breakpoint Op
2106+
//===----------------------------------------------------------------------===//
2107+
2108+
def NVVM_Breakpoint : NVVM_Op<"breakpoint"> {
2109+
let summary = "Breakpoint Op";
2110+
let description = [{
2111+
Breakpoint suspends execution of the program for debugging.
2112+
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#miscellaneous-instructions-brkpt)
2113+
}];
2114+
string llvmBuilder = [{
2115+
createIntrinsicCall(builder, llvm::Intrinsic::debugtrap);
2116+
}];
2117+
2118+
let assemblyFormat = "attr-dict";
2119+
}
2120+
21042121
//===----------------------------------------------------------------------===//
21052122
// NVVM target attribute.
21062123
//===----------------------------------------------------------------------===//

mlir/test/Target/LLVMIR/nvvmir.mlir

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -610,4 +610,12 @@ llvm.func @nvvm_fence_proxy_tensormap_generic_acquire(%addr : !llvm.ptr) {
610610
// CHECK: call void @llvm.nvvm.fence.proxy.tensormap_generic.acquire.sys(ptr {{%[0-9]+}}, i32 128)
611611
nvvm.fence.proxy.acquire #nvvm.mem_scope<sys> %addr, %c128
612612
llvm.return
613-
}
613+
}
614+
615+
// -----
616+
// CHECK-LABEL: @nvvm_breakpoint
617+
llvm.func @nvvm_breakpoint() {
618+
// CHECK: call void @llvm.debugtrap()
619+
nvvm.breakpoint
620+
llvm.return
621+
}

0 commit comments

Comments
 (0)