Skip to content

Commit dbc412f

Browse files
committed
[mlir][gpu][spirv] Add patterns for gpu.shuffle up/down
Convert gpu.shuffle down %val, %offset, %width to spirv.GroupNonUniformRotateKHR <Subgroup> %val, %offset, cluster_size(%width) Convert gpu.shuffle up %val, %offset, %width to %down_offset = arith.subi %width, %offset spirv.GroupNonUniformRotateKHR <Subgroup> %val, %down_offset, cluster_size(%width) In addition, update the spirv.GroupNonUniformRotateKHR assembly format to be consistent with other gpu non-uniform operations.
1 parent 316a6ff commit dbc412f

File tree

4 files changed

+79
-14
lines changed

4 files changed

+79
-14
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1404,8 +1404,8 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
14041404

14051405
```mlir
14061406
%four = spirv.Constant 4 : i32
1407-
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %value, %delta : f32, i32 -> f32
1408-
%1 = spirv.GroupNonUniformRotateKHR <Workgroup>, %value, %delta,
1407+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup> %value, %delta : f32, i32 -> f32
1408+
%1 = spirv.GroupNonUniformRotateKHR <Workgroup> %value, %delta,
14091409
clustersize(%four) : f32, i32, i32 -> f32
14101410
```
14111411
}];
@@ -1429,7 +1429,7 @@ def SPIRV_GroupNonUniformRotateKHROp : SPIRV_Op<"GroupNonUniformRotateKHR", [
14291429
);
14301430

14311431
let assemblyFormat = [{
1432-
$execution_scope `,` $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
1432+
$execution_scope $value `,` $delta (`,` `cluster_size` `(` $cluster_size^ `)`)? attr-dict `:` type($value) `,` type($delta) (`,` type($cluster_size)^)? `->` type(results)
14331433
}];
14341434
}
14351435

mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,16 @@ LogicalResult GPUShuffleConversion::matchAndRewrite(
450450
result = rewriter.create<spirv::GroupNonUniformShuffleOp>(
451451
loc, scope, adaptor.getValue(), adaptor.getOffset());
452452
break;
453-
default:
454-
return rewriter.notifyMatchFailure(shuffleOp, "unimplemented shuffle mode");
453+
case gpu::ShuffleMode::DOWN:
454+
result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
455+
loc, scope, adaptor.getValue(), adaptor.getOffset(), shuffleOp.getWidth());
456+
break;
457+
case gpu::ShuffleMode::UP: {
458+
Value offsetForShuffleDown = rewriter.create<arith::SubIOp>(loc, shuffleOp.getWidth(), adaptor.getOffset());
459+
result = rewriter.create<spirv::GroupNonUniformRotateKHROp>(
460+
loc, scope, adaptor.getValue(), offsetForShuffleDown, shuffleOp.getWidth());
461+
break;
462+
}
455463
}
456464

457465
rewriter.replaceOp(shuffleOp, {result, trueVal});

mlir/test/Conversion/GPUToSPIRV/shuffle.mlir

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,60 @@ gpu.module @kernels {
7272
}
7373

7474
}
75+
76+
// -----
77+
78+
module attributes {
79+
gpu.container_module,
80+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformRotateKHR], []>,
81+
#spirv.resource_limits<subgroup_size = 16>>
82+
} {
83+
84+
gpu.module @kernels {
85+
// CHECK-LABEL: spirv.func @shuffle_down()
86+
gpu.func @shuffle_down() kernel
87+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
88+
%offset = arith.constant 4 : i32
89+
%width = arith.constant 16 : i32
90+
%val = arith.constant 42.0 : f32
91+
92+
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
93+
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
94+
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
95+
// CHECK: %{{.+}} = spirv.Constant true
96+
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
97+
%result, %valid = gpu.shuffle down %val, %offset, %width : f32
98+
gpu.return
99+
}
100+
}
101+
102+
}
103+
104+
// -----
105+
106+
module attributes {
107+
gpu.container_module,
108+
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Shader, GroupNonUniformShuffle, GroupNonUniformRotateKHR], []>,
109+
#spirv.resource_limits<subgroup_size = 16>>
110+
} {
111+
112+
gpu.module @kernels {
113+
// CHECK-LABEL: spirv.func @shuffle_up()
114+
gpu.func @shuffle_up() kernel
115+
attributes {spirv.entry_point_abi = #spirv.entry_point_abi<workgroup_size = [16, 1, 1]>} {
116+
%offset = arith.constant 4 : i32
117+
%width = arith.constant 16 : i32
118+
%val = arith.constant 42.0 : f32
119+
120+
// CHECK: %[[OFFSET:.+]] = spirv.Constant 4 : i32
121+
// CHECK: %[[WIDTH:.+]] = spirv.Constant 16 : i32
122+
// CHECK: %[[VAL:.+]] = spirv.Constant 4.200000e+01 : f32
123+
// CHECK: %{{.+}} = spirv.Constant true
124+
// CHECK: %[[DOWN_OFFSET:.+]] = spirv.Constant 12 : i32
125+
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %[[VAL]], %[[DOWN_OFFSET]], cluster_size(%[[WIDTH]]) : f32, i32, i32 -> f32
126+
%result, %valid = gpu.shuffle up %val, %offset, %width : f32
127+
gpu.return
128+
}
129+
}
130+
131+
}

mlir/test/Dialect/SPIRV/IR/non-uniform-ops.mlir

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -613,18 +613,18 @@ func.func @group_non_uniform_logical_xor(%val: i32) -> i32 {
613613

614614
// CHECK-LABEL: @group_non_uniform_rotate_khr
615615
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
616-
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup>, %{{.+}} : f32, i32 -> f32
617-
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta : f32, i32 -> f32
616+
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Subgroup> %{{.+}} : f32, i32 -> f32
617+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta : f32, i32 -> f32
618618
return %0: f32
619619
}
620620

621621
// -----
622622

623623
// CHECK-LABEL: @group_non_uniform_rotate_khr
624624
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
625-
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup>, %{{.+}} : f32, i32, i32 -> f32
625+
// CHECK: %{{.+}} = spirv.GroupNonUniformRotateKHR <Workgroup> %{{.+}} : f32, i32, i32 -> f32
626626
%four = spirv.Constant 4 : i32
627-
%0 = spirv.GroupNonUniformRotateKHR <Workgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
627+
%0 = spirv.GroupNonUniformRotateKHR <Workgroup> %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
628628
return %0: f32
629629
}
630630

@@ -633,7 +633,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
633633
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
634634
%four = spirv.Constant 4 : i32
635635
// expected-error @+1 {{execution scope must be 'Workgroup' or 'Subgroup'}}
636-
%0 = spirv.GroupNonUniformRotateKHR <Device>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
636+
%0 = spirv.GroupNonUniformRotateKHR <Device> %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
637637
return %0: f32
638638
}
639639

@@ -642,7 +642,7 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
642642
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
643643
%four = spirv.Constant 4 : i32
644644
// expected-error @+1 {{op operand #1 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
645-
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
645+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%four) : f32, si32, i32 -> f32
646646
return %0: f32
647647
}
648648

@@ -651,15 +651,15 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: si32) -> f32 {
651651
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
652652
%four = spirv.Constant 4 : si32
653653
// expected-error @+1 {{op operand #2 must be 8/16/32/64-bit signless/unsigned integer, but got 'si32'}}
654-
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
654+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%four) : f32, i32, si32 -> f32
655655
return %0: f32
656656
}
657657

658658
// -----
659659

660660
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f32 {
661661
// expected-error @+1 {{cluster size operand must come from a constant op}}
662-
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
662+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%four) : f32, i32, i32 -> f32
663663
return %0: f32
664664
}
665665

@@ -668,6 +668,6 @@ func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32, %four: i32) -> f
668668
func.func @group_non_uniform_rotate_khr(%val: f32, %delta: i32) -> f32 {
669669
%five = spirv.Constant 5 : i32
670670
// expected-error @+1 {{cluster size operand must be a power of two}}
671-
%0 = spirv.GroupNonUniformRotateKHR <Subgroup>, %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
671+
%0 = spirv.GroupNonUniformRotateKHR <Subgroup> %val, %delta, cluster_size(%five) : f32, i32, i32 -> f32
672672
return %0: f32
673673
}

0 commit comments

Comments
 (0)