Description
Issue
I encountered an issue when using the core::arch::nvptx::_syncthreads()
while the MIR pass JumpThreading is enabled. The issue can be reproduced with a simple kernel when executed with the following parameters:
block_dim = BlockDim{x : 512, y : 1, z : 1};
grid_dim = GridDim{x : 2, y : 1, z : 1};
n = 1000;
#[no_mangle]
pub unsafe extern "ptx-kernel" fn memcpy(input: *const f32, output: *mut f32, n: u32) {
let n = n as usize;
// thread id x in grid
let gtid_x = (_block_dim_x() * _block_idx_x() + _thread_idx_x()) as usize;
// read input
let temp = if gtid_x < n {
input.add(gtid_x).read()
} else {
0f32
};
_syncthreads();
// write output
if gtid_x < n {
output.add(gtid_x).write(temp);
}
}
compute-sanitizer --tool synccheck
complains about barrier errors. The resulting .ptx
shows the reason for that. The code is transformed to the following:
...
if gtid_x < n {
let temp = input.add(gtid_x).read();
_syncthreads();
output.add(gtid_x).write(temp);
} else {
_syncthreads();
}
...
I could track down this transformation to the MIR pass JumpThreading. However, _syncthreads()
is a convergent operation. This property must be considered when doing code transformations (see this LLVM issue for reference).
Therefore turning off the MIR pass JumpThreading completely prevents this transformation from happening and the resulting code is correct (compute-sanitizer
does also not complain any longer).
PTX with JumpThreading
//
// Generated by LLVM NVPTX Back-End
//
.version 5.0
.target sm_61
.address_size 64
// .globl memcpy // -- Begin function memcpy
// @memcpy
.visible .entry memcpy(
.param .u64 memcpy_param_0,
.param .u64 memcpy_param_1,
.param .u32 memcpy_param_2
)
{
.reg .pred %p<2>;
.reg .b32 %r<5>;
.reg .f32 %f<2>;
.reg .b64 %rd<10>;
// %bb.0:
ld.param.u32 %rd6, [memcpy_param_2];
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %tid.x;
mad.lo.s32 %r4, %r1, %r2, %r3;
cvt.s64.s32 %rd3, %r4;
setp.lt.u64 %p1, %rd3, %rd6;
@%p1 bra $L__BB0_2;
bra.uni $L__BB0_1;
$L__BB0_2:
ld.param.u64 %rd4, [memcpy_param_0];
ld.param.u64 %rd5, [memcpy_param_1];
cvta.to.global.u64 %rd1, %rd5;
cvta.to.global.u64 %rd2, %rd4;
shl.b64 %rd7, %rd3, 2;
add.s64 %rd8, %rd2, %rd7;
ld.global.f32 %f1, [%rd8];
bar.sync 0;
add.s64 %rd9, %rd1, %rd7;
st.global.f32 [%rd9], %f1;
bra.uni $L__BB0_3;
$L__BB0_1:
bar.sync 0;
$L__BB0_3:
ret;
// -- End function
}
PTX without JumpThreading
//
// Generated by LLVM NVPTX Back-End
//
.version 5.0
.target sm_61
.address_size 64
// .globl memcpy // -- Begin function memcpy
// @memcpy
.visible .entry memcpy(
.param .u64 memcpy_param_0,
.param .u64 memcpy_param_1,
.param .u32 memcpy_param_2
)
{
.reg .pred %p<3>;
.reg .b32 %r<5>;
.reg .f32 %f<5>;
.reg .b64 %rd<11>;
// %bb.0:
ld.param.u32 %rd2, [memcpy_param_2];
mov.u32 %r1, %ntid.x;
mov.u32 %r2, %ctaid.x;
mov.u32 %r3, %tid.x;
mad.lo.s32 %r4, %r1, %r2, %r3;
cvt.s64.s32 %rd3, %r4;
setp.ge.u64 %p1, %rd3, %rd2;
mov.f32 %f4, 0f00000000;
@%p1 bra $L__BB0_2;
// %bb.1:
ld.param.u64 %rd6, [memcpy_param_0];
cvta.to.global.u64 %rd8, %rd6;
mul.wide.s32 %rd9, %r4, 4;
add.s64 %rd4, %rd8, %rd9;
ld.global.f32 %f4, [%rd4];
$L__BB0_2:
setp.lt.u64 %p2, %rd3, %rd2;
bar.sync 0;
@%p2 bra $L__BB0_4;
bra.uni $L__BB0_3;
$L__BB0_4:
ld.param.u64 %rd7, [memcpy_param_1];
cvta.to.global.u64 %rd1, %rd7;
shl.b64 %rd10, %rd3, 2;
add.s64 %rd5, %rd1, %rd10;
st.global.f32 [%rd5], %f4;
$L__BB0_3:
ret;
// -- End function
}
In the above example _syncthreads()
does nothing useful and can be ommited. However I encountered this issue in a more complex stencil kernel where these transformations lead to side effects and race conditions.
Compiler arguments
With JumpThreading:
cargo +nightly-2025-02-14 rustc --release -- -C target-cpu=sm_61 -Clinker-flavor=llbc -Zunstable-options
Without JumpThreading:
cargo +nightly-2025-02-14 rustc --release -- -C target-cpu=sm_61 -Clinker-flavor=llbc -Zunstable-options -Zmir-enable-passes=-JumpThreading
Background
Targets like nvptx64-nvidia-cuda, amdgpu and probably also spir-v (rust-gpu) make use of so called convergent operations (like _syncthreads()
. LLVM provides a detailed explanation for this type of operations. Special care must be taken when code that involves convergent operations is transformed.
To my knowledge rustc does not know if an operation is convergent so passes do not handle these operations correctly.