Skip to content

Commit ca920aa

Browse files
jsjodinyuxuanchen1997
authored andcommitted
[MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass (#98653)
Summary: This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250922
1 parent 241c24f commit ca920aa

File tree

9 files changed

+652
-44
lines changed

9 files changed

+652
-44
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
9+
#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
10+
11+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12+
#include "mlir/IR/PatternMatch.h"
13+
#include <memory>
14+
15+
namespace mlir {
16+
class Pass;
17+
18+
#define GEN_PASS_DECL_CONVERTMATHTOROCDL
19+
#include "mlir/Conversion/Passes.h.inc"
20+
21+
/// Populate the given list with patterns that convert from Math to ROCDL calls.
22+
void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
23+
RewritePatternSet &patterns);
24+
} // namespace mlir
25+
26+
#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_

mlir/include/mlir/Conversion/Passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
4747
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
4848
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
49+
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
4950
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
5051
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
5152
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"

mlir/include/mlir/Conversion/Passes.td

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,23 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
733733
];
734734
}
735735

736+
//===----------------------------------------------------------------------===//
737+
// MathToLibm
738+
//===----------------------------------------------------------------------===//
739+
740+
def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
741+
let summary = "Convert Math dialect to ROCDL library calls";
742+
let description = [{
743+
This pass converts supported Math ops to ROCDL library calls.
744+
}];
745+
let dependentDialects = [
746+
"arith::ArithDialect",
747+
"func::FuncDialect",
748+
"ROCDL::ROCDLDialect",
749+
"vector::VectorDialect",
750+
];
751+
}
752+
736753
//===----------------------------------------------------------------------===//
737754
// MathToSPIRV
738755
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon)
3636
add_subdirectory(MathToFuncs)
3737
add_subdirectory(MathToLibm)
3838
add_subdirectory(MathToLLVM)
39+
add_subdirectory(MathToROCDL)
3940
add_subdirectory(MathToSPIRV)
4041
add_subdirectory(MemRefToEmitC)
4142
add_subdirectory(MemRefToLLVM)

mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
1313
MLIRArithToLLVM
1414
MLIRArithTransforms
1515
MLIRMathToLLVM
16+
MLIRMathToROCDL
1617
MLIRAMDGPUToROCDL
1718
MLIRFuncToLLVM
1819
MLIRGPUDialect

mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 2 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
2727
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2828
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
29+
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
2930
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
3031
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
3132
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
386387

387388
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
388389

389-
populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
390-
"__ocml_fabs_f64");
391-
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
392-
"__ocml_atan_f64");
393-
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
394-
"__ocml_atan2_f64");
395-
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
396-
"__ocml_cbrt_f64");
397-
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
398-
"__ocml_ceil_f64");
399-
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
400-
"__ocml_cos_f64");
401-
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
402-
"__ocml_exp_f64");
403-
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
404-
"__ocml_exp2_f64");
405-
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
406-
"__ocml_expm1_f64");
407-
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
408-
"__ocml_floor_f64");
409-
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
410-
"__ocml_fmod_f64");
411-
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
412-
"__ocml_log_f64");
413-
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
414-
"__ocml_log10_f64");
415-
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
416-
"__ocml_log1p_f64");
417-
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
418-
"__ocml_log2_f64");
419-
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
420-
"__ocml_pow_f64");
421-
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
422-
"__ocml_rsqrt_f64");
423-
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
424-
"__ocml_sin_f64");
425-
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
426-
"__ocml_sqrt_f64");
427-
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
428-
"__ocml_tanh_f64");
429-
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
430-
"__ocml_tan_f64");
431-
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
432-
"__ocml_erf_f64");
390+
populateMathToROCDLConversionPatterns(converter, patterns);
433391
}
434392

435393
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
add_mlir_conversion_library(MLIRMathToROCDL
2+
MathToROCDL.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL
6+
7+
DEPENDS
8+
MLIRConversionPassIncGen
9+
10+
LINK_COMPONENTS
11+
Core
12+
13+
LINK_LIBS PUBLIC
14+
MLIRDialectUtils
15+
MLIRFuncDialect
16+
MLIRGPUToGPURuntimeTransforms
17+
MLIRMathDialect
18+
MLIRLLVMCommonConversion
19+
MLIRPass
20+
MLIRTransformUtils
21+
MLIRVectorDialect
22+
MLIRVectorUtils
23+
)
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
10+
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
11+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
12+
#include "mlir/Dialect/Func/IR/FuncOps.h"
13+
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
14+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
15+
#include "mlir/Dialect/Math/IR/Math.h"
16+
#include "mlir/Dialect/Utils/IndexingUtils.h"
17+
#include "mlir/Dialect/Vector/IR/VectorOps.h"
18+
#include "mlir/IR/BuiltinDialect.h"
19+
#include "mlir/IR/PatternMatch.h"
20+
#include "mlir/Pass/Pass.h"
21+
#include "mlir/Transforms/DialectConversion.h"
22+
23+
#include "../GPUCommon/GPUOpsLowering.h"
24+
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
25+
#include "../GPUCommon/OpToFuncCallLowering.h"
26+
#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
27+
28+
namespace mlir {
29+
#define GEN_PASS_DEF_CONVERTMATHTOROCDL
30+
#include "mlir/Conversion/Passes.h.inc"
31+
} // namespace mlir
32+
33+
using namespace mlir;
34+
35+
#define DEBUG_TYPE "math-to-rocdl"
36+
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
37+
38+
template <typename OpTy>
39+
static void populateOpPatterns(LLVMTypeConverter &converter,
40+
RewritePatternSet &patterns, StringRef f32Func,
41+
StringRef f64Func) {
42+
patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
43+
patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
44+
}
45+
46+
void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
47+
RewritePatternSet &patterns) {
48+
// Handled by mathToLLVM: math::AbsIOp
49+
// Handled by mathToLLVM: math::CopySignOp
50+
// Handled by mathToLLVM: math::CountLeadingZerosOp
51+
// Handled by mathToLLVM: math::CountTrailingZerosOp
52+
// Handled by mathToLLVM: math::CgPopOp
53+
// Handled by mathToLLVM: math::FmaOp
54+
// FIXME: math::IPowIOp
55+
// FIXME: math::FPowIOp
56+
// Handled by mathToLLVM: math::RoundEvenOp
57+
// Handled by mathToLLVM: math::RoundOp
58+
// Handled by mathToLLVM: math::TruncOp
59+
populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
60+
"__ocml_fabs_f64");
61+
populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
62+
"__ocml_acos_f64");
63+
populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
64+
"__ocml_acosh_f64");
65+
populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
66+
"__ocml_asin_f64");
67+
populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
68+
"__ocml_asinh_f64");
69+
populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
70+
"__ocml_atan_f64");
71+
populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
72+
"__ocml_atanh_f64");
73+
populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
74+
"__ocml_atan2_f64");
75+
populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
76+
"__ocml_cbrt_f64");
77+
populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
78+
"__ocml_ceil_f64");
79+
populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
80+
"__ocml_cos_f64");
81+
populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
82+
"__ocml_cosh_f64");
83+
populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
84+
"__ocml_sinh_f64");
85+
populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
86+
"__ocml_exp_f64");
87+
populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
88+
"__ocml_exp2_f64");
89+
populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
90+
"__ocml_expm1_f64");
91+
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
92+
"__ocml_floor_f64");
93+
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
94+
"__ocml_log_f64");
95+
populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
96+
"__ocml_log10_f64");
97+
populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
98+
"__ocml_log1p_f64");
99+
populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
100+
"__ocml_log2_f64");
101+
populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
102+
"__ocml_pow_f64");
103+
populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
104+
"__ocml_rsqrt_f64");
105+
populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
106+
"__ocml_sin_f64");
107+
populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
108+
"__ocml_sqrt_f64");
109+
populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
110+
"__ocml_tanh_f64");
111+
populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
112+
"__ocml_tan_f64");
113+
populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
114+
"__ocml_erf_f64");
115+
// Single arith pattern that needs a ROCDL call, probably not
116+
// worth creating a separate pass for it.
117+
populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
118+
"__ocml_fmod_f64");
119+
}
120+
121+
namespace {
122+
struct ConvertMathToROCDLPass
123+
: public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
124+
ConvertMathToROCDLPass() = default;
125+
void runOnOperation() override;
126+
};
127+
} // namespace
128+
129+
void ConvertMathToROCDLPass::runOnOperation() {
130+
auto m = getOperation();
131+
MLIRContext *ctx = m.getContext();
132+
133+
RewritePatternSet patterns(&getContext());
134+
LowerToLLVMOptions options(ctx, DataLayout(m));
135+
LLVMTypeConverter converter(ctx, options);
136+
populateMathToROCDLConversionPatterns(converter, patterns);
137+
ConversionTarget target(getContext());
138+
target.addLegalDialect<BuiltinDialect, func::FuncDialect,
139+
vector::VectorDialect, LLVM::LLVMDialect>();
140+
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
141+
LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
142+
LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
143+
LLVM::SqrtOp>();
144+
if (failed(applyPartialConversion(m, target, std::move(patterns))))
145+
signalPassFailure();
146+
}

0 commit comments

Comments
 (0)