-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass #98653
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…o a separaate pass This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen.
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Jan Leyonberg (jsjodin) ChangesThis patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen. Patch is 31.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98653.diff 9 Files Affected:
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
new file mode 100644
index 0000000000000..fa7a635568c7c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -0,0 +1,26 @@
+//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to ROCDL calls.
+void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8c6f85d461aea..208f26489d6c3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -46,6 +46,7 @@
#include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
#include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
#include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..64835b1b660b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -733,6 +733,24 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
];
}
+//===----------------------------------------------------------------------===//
+// MathToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
+ let summary = "Convert Math dialect to ROCDL library calls";
+ let description = [{
+ This pass converts supported Math ops to ROCDL library calls.
+ }];
+ let dependentDialects = [
+ "arith::ArithDialect",
+ "func::FuncDialect",
+ "math::MathDialect",
+ "ROCDL::ROCDLDialect",
+ "vector::VectorDialect",
+ ];
+}
+
//===----------------------------------------------------------------------===//
// MathToSPIRV
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e107738a4c50c..80c8b84d9ae89 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon)
add_subdirectory(MathToFuncs)
add_subdirectory(MathToLibm)
add_subdirectory(MathToLLVM)
+add_subdirectory(MathToROCDL)
add_subdirectory(MathToSPIRV)
add_subdirectory(MemRefToEmitC)
add_subdirectory(MemRefToLLVM)
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 70707b5c3a049..945e3ccdfa87b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
MLIRArithToLLVM
MLIRArithTransforms
MLIRMathToLLVM
+ MLIRMathToROCDL
MLIRAMDGPUToROCDL
MLIRFuncToLLVM
MLIRGPUDialect
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 40eb15a491063..100181cdc69fe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -26,6 +26,7 @@
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
- populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
- "__ocml_fabs_f64");
- populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
- "__ocml_atan_f64");
- populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
- "__ocml_atan2_f64");
- populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
- "__ocml_cbrt_f64");
- populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
- "__ocml_ceil_f64");
- populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
- "__ocml_cos_f64");
- populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
- "__ocml_exp_f64");
- populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
- "__ocml_exp2_f64");
- populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
- "__ocml_expm1_f64");
- populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
- "__ocml_floor_f64");
- populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
- "__ocml_fmod_f64");
- populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
- "__ocml_log_f64");
- populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
- "__ocml_log10_f64");
- populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
- "__ocml_log1p_f64");
- populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
- "__ocml_log2_f64");
- populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
- "__ocml_pow_f64");
- populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
- "__ocml_rsqrt_f64");
- populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
- "__ocml_sin_f64");
- populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
- "__ocml_sqrt_f64");
- populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
- "__ocml_tanh_f64");
- populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
- "__ocml_tan_f64");
- populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
- "__ocml_erf_f64");
+ populateMathToROCDLConversionPatterns(converter, patterns);
}
std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..2771955aa9493
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_conversion_library(MLIRMathToROCDL
+ MathToROCDL.cpp
+
+ ADDITIONAL_HEADER_DIRS
+ ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL
+
+ DEPENDS
+ MLIRConversionPassIncGen
+
+ LINK_COMPONENTS
+ Core
+
+ LINK_LIBS PUBLIC
+ MLIRDialectUtils
+ MLIRFuncDialect
+ MLIRGPUToGPURuntimeTransforms
+ MLIRMathDialect
+ MLIRLLVMCommonConversion
+ MLIRPass
+ MLIRTransformUtils
+ MLIRVectorDialect
+ MLIRVectorUtils
+ )
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
new file mode 100644
index 0000000000000..03c7ce5dac0d1
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -0,0 +1,146 @@
+//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-rocdl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns, StringRef f32Func,
+ StringRef f64Func) {
+ patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+ patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
+void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
+ // Handled by mathToLLVM: math::AbsIOp
+ // Handled by mathToLLVM: math::CopySignOp
+ // Handled by mathToLLVM: math::CountLeadingZerosOp
+ // Handled by mathToLLVM: math::CountTrailingZerosOp
+ // Handled by mathToLLVM: math::CgPopOp
+ // Handled by mathToLLVM: math::FmaOp
+ // FIXME: math::IPowIOp
+ // FIXME: math::FPowIOp
+ // Handled by mathToLLVM: math::RoundEvenOp
+ // Handled by mathToLLVM: math::RoundOp
+ // Handled by mathToLLVM: math::TruncOp
+ populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
+ "__ocml_fabs_f64");
+ populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
+ "__ocml_acos_f64");
+ populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
+ "__ocml_acosh_f64");
+ populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
+ "__ocml_asin_f64");
+ populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
+ "__ocml_asinh_f64");
+ populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
+ "__ocml_atan_f64");
+ populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
+ "__ocml_atanh_f64");
+ populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
+ "__ocml_atan2_f64");
+ populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
+ "__ocml_cbrt_f64");
+ populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
+ "__ocml_ceil_f64");
+ populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
+ "__ocml_cos_f64");
+ populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
+ "__ocml_cosh_f64");
+ populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
+ "__ocml_sinh_f64");
+ populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
+ "__ocml_exp_f64");
+ populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
+ "__ocml_exp2_f64");
+ populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
+ "__ocml_expm1_f64");
+ populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
+ "__ocml_floor_f64");
+ populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
+ "__ocml_log_f64");
+ populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
+ "__ocml_log10_f64");
+ populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
+ "__ocml_log1p_f64");
+ populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
+ "__ocml_log2_f64");
+ populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
+ "__ocml_pow_f64");
+ populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
+ "__ocml_rsqrt_f64");
+ populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
+ "__ocml_sin_f64");
+ populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
+ "__ocml_sqrt_f64");
+ populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
+ "__ocml_tanh_f64");
+ populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
+ "__ocml_tan_f64");
+ populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
+ "__ocml_erf_f64");
+ // Single arith pattern that needs a ROCDL call, probably not
+ // worth creating a separate pass for it.
+ populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
+ "__ocml_fmod_f64");
+}
+
+namespace {
+struct ConvertMathToROCDLPass
+ : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+ ConvertMathToROCDLPass() = default;
+ void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToROCDLPass::runOnOperation() {
+ auto m = getOperation();
+ MLIRContext *ctx = m.getContext();
+
+ RewritePatternSet patterns(&getContext());
+ LowerToLLVMOptions options(ctx, DataLayout(m));
+ LLVMTypeConverter converter(ctx, options);
+ populateMathToROCDLConversionPatterns(converter, patterns);
+ ConversionTarget target(getContext());
+ target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+ vector::VectorDialect, LLVM::LLVMDialect>();
+ target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+ LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
+ LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
+ LLVM::SqrtOp>();
+ if (failed(applyPartialConversion(m, target, std::move(patterns))))
+ signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
new file mode 100644
index 0000000000000..a406ec45a7f10
--- /dev/null
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -0,0 +1,435 @@
+// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
+ // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
+ // CHECK-LABEL: func @arith_remf
+ func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = arith.remf %arg_f32, %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+ %result64 = arith.remf %arg_f64, %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
+ // CHECK-LABEL: func @math_absf
+ func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.absf %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.absf %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
+ // CHECK-LABEL: func @math_acos
+ func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acos %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.acos %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_acosh
+ func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.acosh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.acosh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
+ // CHECK-LABEL: func @math_asin
+ func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asin %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.asin %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_asinh
+ func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.asinh %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.asinh %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
+ // CHECK-LABEL: func @math_atan
+ func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+ %result32 = math.atan %arg_f32 : f32
+ // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
+ %result64 = math.atan %arg_f64 : f64
+ // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
+ func.return %result32, %result64 : f32, f64
+ }
+}
+
+// -----
+
+module @test_module {
+ // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
+ // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
+ // CHECK-LABEL: func @math_atanh
+ func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, ...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mainly nitpicking and generality, the refactor seems fine
let dependentDialects = [ | ||
"arith::ArithDialect", | ||
"func::FuncDialect", | ||
"math::MathDialect", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to depend on math
if there are cases where you create a math
op without there being one in the input.
(The point of the dependencies field is to let you access stuff from dialects that wouldn't "naturally" be loaded)
"__ocml_expm1_f64"); | ||
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32", | ||
"__ocml_floor_f64"); | ||
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While you're here (tm), the OCML math libraries now have native f16 implementations of these functions - could you add that case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(IIRC "these" is all of them, I just commented here arbitrarily)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... ah, we're using the GPU pattern, might not be worth the bother
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd probably add those in a separate patch since this is only a refactoring change.
// MathToLibm | ||
//===----------------------------------------------------------------------===// | ||
|
||
def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass doesn't need to operate on a ModuleOp
specifically - you probably want this to work on anything that has a symbol table
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It needs to have a DataLayoutOpInterface as well. Should this be checked dynamically, so that the pass can be used on any operation but should only be triggered on ops with a symbol table and data layout interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, yeah, good point
(I'm basically trying to make sure this works on gpu.module
and builtin.module
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I was thinking about making it work for both as well, but opted to use the "add patterns" function in the GPUToROCDL conversion since the options and converter are more complicated, and I didn't want to put that in this simpler pass, so it seemed okay for it to operate on the ModuleOp. I can try to see if I can make it work with both if you need it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
... Good point, this pass is for testing anyway, let's keep it at ModuleOp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…o a separate pass (#98653) Summary: This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen. Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60250922
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: Differential Revision: https://phabricator.intern.facebook.com/D60251611
This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen.