Skip to content

Commit a16164d

Browse files
[MLIR][ROCDL] Add dynamically legal ops to LowerGpuOpsToROCDLOpsPass (llvm#108302)
Similar to llvm#108266 After llvm#102971 It is legal to generate `LLVM::ExpOp` and `LLVM::LogOp` if the type is is a float16 or float32
1 parent 7aad873 commit a16164d

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2727
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2828
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
29+
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
2930
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
3031
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
3132
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
@@ -290,6 +291,7 @@ struct LowerGpuOpsToROCDLOpsPass
290291
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
291292
*maybeChipset);
292293
populateVectorToLLVMConversionPatterns(converter, llvmPatterns);
294+
populateMathToLLVMConversionPatterns(converter, llvmPatterns);
293295
cf::populateControlFlowToLLVMConversionPatterns(converter, llvmPatterns);
294296
populateFuncToLLVMConversionPatterns(converter, llvmPatterns);
295297
populateFinalizeMemRefToLLVMConversionPatterns(converter, llvmPatterns);
@@ -332,7 +334,11 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
332334
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
333335
LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
334336
LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();
335-
337+
// These ops are legal for f16 and f32 type.
338+
target.addDynamicallyLegalOp<LLVM::ExpOp, LLVM::LogOp>([](Operation *op) {
339+
return any_of(op->getOperandTypes(),
340+
llvm::IsaPred<Float16Type, Float32Type>);
341+
});
336342
// TODO: Remove once we support replacing non-root ops.
337343
target.addLegalOp<gpu::YieldOp, gpu::GPUModuleOp>();
338344
}

mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,68 @@ gpu.module @test_module {
131131

132132
// -----
133133

134+
gpu.module @test_module {
135+
// CHECK-LABEL: func @gpu_sqrt
136+
func.func @gpu_sqrt(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
137+
%result16 = math.sqrt %arg_f16 : f16
138+
// CHECK: llvm.intr.sqrt(%{{.*}}) : (f16) -> f16
139+
%result32 = math.sqrt %arg_f32 : f32
140+
// CHECK: llvm.intr.sqrt(%{{.*}}) : (f32) -> f32
141+
%result64 = math.sqrt %arg_f64 : f64
142+
// CHECK: llvm.intr.sqrt(%{{.*}}) : (f64) -> f64
143+
func.return %result16, %result32, %result64 : f16, f32, f64
144+
}
145+
}
146+
147+
// -----
148+
149+
gpu.module @test_module {
150+
// CHECK-LABEL: func @gpu_fabs
151+
func.func @gpu_fabs(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
152+
%result16 = math.absf %arg_f16 : f16
153+
// CHECK: llvm.intr.fabs(%{{.*}}) : (f16) -> f16
154+
%result32 = math.absf %arg_f32 : f32
155+
// CHECK: llvm.intr.fabs(%{{.*}}) : (f32) -> f32
156+
%result64 = math.absf %arg_f64 : f64
157+
// CHECK: llvm.intr.fabs(%{{.*}}) : (f64) -> f64
158+
func.return %result16, %result32, %result64 : f16, f32, f64
159+
}
160+
}
161+
162+
// -----
163+
164+
gpu.module @test_module {
165+
// CHECK: llvm.func @__ocml_exp_f64(f64) -> f64
166+
// CHECK-LABEL: func @gpu_exp
167+
func.func @gpu_exp(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
168+
%result16 = math.exp %arg_f16 : f16
169+
// CHECK: llvm.intr.exp(%{{.*}}) : (f16) -> f16
170+
%result32 = math.exp %arg_f32 : f32
171+
// CHECK: llvm.intr.exp(%{{.*}}) : (f32) -> f32
172+
%result64 = math.exp %arg_f64 : f64
173+
// CHECK: llvm.call @__ocml_exp_f64(%{{.*}}) : (f64) -> f64
174+
func.return %result16, %result32, %result64 : f16, f32, f64
175+
}
176+
}
177+
178+
// -----
179+
180+
gpu.module @test_module {
181+
// CHECK: llvm.func @__ocml_log_f64(f64) -> f64
182+
// CHECK-LABEL: func @gpu_log
183+
func.func @gpu_log(%arg_f16 : f16, %arg_f32 : f32, %arg_f64 : f64) -> (f16, f32, f64) {
184+
%result16 = math.log %arg_f16 : f16
185+
// CHECK: llvm.intr.log(%{{.*}}) : (f16) -> f16
186+
%result32 = math.log %arg_f32 : f32
187+
// CHECK: llvm.intr.log(%{{.*}}) : (f32) -> f32
188+
%result64 = math.log %arg_f64 : f64
189+
// CHECK: llvm.call @__ocml_log_f64(%{{.*}}) : (f64) -> f64
190+
func.return %result16, %result32, %result64 : f16, f32, f64
191+
}
192+
}
193+
194+
// -----
195+
134196
gpu.module @test_module {
135197
// CHECK: llvm.func @__ocml_cbrt_f32(f32) -> f32
136198
// CHECK: llvm.func @__ocml_cbrt_f64(f64) -> f64

utils/bazel/llvm-project-overlay/mlir/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6004,6 +6004,7 @@ cc_library(
60046004
":LLVMCommonConversion",
60056005
":LLVMDialect",
60066006
":MathDialect",
6007+
":MathToLLVM",
60076008
":MathToROCDL",
60086009
":MemRefDialect",
60096010
":MemRefToLLVM",

0 commit comments

Comments
 (0)