Skip to content

[MLIR][ROCDL] Refactor conversion of math operations to ROCDL calls to a separate pass #98653

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 17, 2024

Conversation

jsjodin
Copy link
Contributor

@jsjodin jsjodin commented Jul 12, 2024

This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen.

…o a separaate pass

This patch refactors the conversion of math operations to ROCDL library
calls. This pass will also be used in flang to lower Fortran intrinsics/math
functions for OpenMP target offloading codgen.
@llvmbot
Copy link
Member

llvmbot commented Jul 12, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Jan Leyonberg (jsjodin)

Changes

This patch refactors the conversion of math operations to ROCDL library calls. This pass will also be used in flang to lower Fortran intrinsics/math functions for OpenMP target offloading codgen.


Patch is 31.13 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/98653.diff

9 Files Affected:

  • (added) mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h (+26)
  • (modified) mlir/include/mlir/Conversion/Passes.h (+1)
  • (modified) mlir/include/mlir/Conversion/Passes.td (+18)
  • (modified) mlir/lib/Conversion/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt (+1)
  • (modified) mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp (+2-44)
  • (added) mlir/lib/Conversion/MathToROCDL/CMakeLists.txt (+23)
  • (added) mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp (+146)
  • (added) mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir (+435)
diff --git a/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
new file mode 100644
index 0000000000000..fa7a635568c7c
--- /dev/null
+++ b/mlir/include/mlir/Conversion/MathToROCDL/MathToROCDL.h
@@ -0,0 +1,26 @@
+//===- MathToROCDL.h - Utils to convert from the complex dialect --------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+#define MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
+
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/IR/PatternMatch.h"
+#include <memory>
+
+namespace mlir {
+class Pass;
+
+#define GEN_PASS_DECL_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+
+/// Populate the given list with patterns that convert from Math to ROCDL calls.
+void populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+                                           RewritePatternSet &patterns);
+} // namespace mlir
+
+#endif // MLIR_CONVERSION_MATHTOROCDL_MATHTOROCDL_H_
diff --git a/mlir/include/mlir/Conversion/Passes.h b/mlir/include/mlir/Conversion/Passes.h
index 8c6f85d461aea..208f26489d6c3 100644
--- a/mlir/include/mlir/Conversion/Passes.h
+++ b/mlir/include/mlir/Conversion/Passes.h
@@ -46,6 +46,7 @@
 #include "mlir/Conversion/MathToFuncs/MathToFuncs.h"
 #include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
 #include "mlir/Conversion/MathToLibm/MathToLibm.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MathToSPIRV/MathToSPIRVPass.h"
 #include "mlir/Conversion/MemRefToEmitC/MemRefToEmitCPass.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
diff --git a/mlir/include/mlir/Conversion/Passes.td b/mlir/include/mlir/Conversion/Passes.td
index 560b088dbe5cd..64835b1b660b4 100644
--- a/mlir/include/mlir/Conversion/Passes.td
+++ b/mlir/include/mlir/Conversion/Passes.td
@@ -733,6 +733,24 @@ def ConvertMathToLLVMPass : Pass<"convert-math-to-llvm"> {
   ];
 }
 
+//===----------------------------------------------------------------------===//
+// MathToLibm
+//===----------------------------------------------------------------------===//
+
+def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
+  let summary = "Convert Math dialect to ROCDL library calls";
+  let description = [{
+    This pass converts supported Math ops to ROCDL library calls.
+  }];
+  let dependentDialects = [
+    "arith::ArithDialect",
+    "func::FuncDialect",
+    "math::MathDialect",
+    "ROCDL::ROCDLDialect",
+    "vector::VectorDialect",
+  ];
+}
+
 //===----------------------------------------------------------------------===//
 // MathToSPIRV
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/CMakeLists.txt b/mlir/lib/Conversion/CMakeLists.txt
index e107738a4c50c..80c8b84d9ae89 100644
--- a/mlir/lib/Conversion/CMakeLists.txt
+++ b/mlir/lib/Conversion/CMakeLists.txt
@@ -36,6 +36,7 @@ add_subdirectory(LLVMCommon)
 add_subdirectory(MathToFuncs)
 add_subdirectory(MathToLibm)
 add_subdirectory(MathToLLVM)
+add_subdirectory(MathToROCDL)
 add_subdirectory(MathToSPIRV)
 add_subdirectory(MemRefToEmitC)
 add_subdirectory(MemRefToLLVM)
diff --git a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
index 70707b5c3a049..945e3ccdfa87b 100644
--- a/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
+++ b/mlir/lib/Conversion/GPUToROCDL/CMakeLists.txt
@@ -13,6 +13,7 @@ add_mlir_conversion_library(MLIRGPUToROCDLTransforms
   MLIRArithToLLVM
   MLIRArithTransforms
   MLIRMathToLLVM
+  MLIRMathToROCDL
   MLIRAMDGPUToROCDL
   MLIRFuncToLLVM
   MLIRGPUDialect
diff --git a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
index 40eb15a491063..100181cdc69fe 100644
--- a/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
+++ b/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
 #include "mlir/Conversion/LLVMCommon/Pattern.h"
 #include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
 #include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
 #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -386,50 +387,7 @@ void mlir::populateGpuToROCDLConversionPatterns(
 
   patterns.add<GPUShuffleOpLowering, GPULaneIdOpToROCDL>(converter);
 
-  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
-                                   "__ocml_fabs_f64");
-  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
-                                   "__ocml_atan_f64");
-  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
-                                    "__ocml_atan2_f64");
-  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
-                                   "__ocml_cbrt_f64");
-  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
-                                   "__ocml_ceil_f64");
-  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
-                                  "__ocml_cos_f64");
-  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
-                                  "__ocml_exp_f64");
-  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
-                                   "__ocml_exp2_f64");
-  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
-                                    "__ocml_expm1_f64");
-  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
-                                    "__ocml_floor_f64");
-  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
-                                    "__ocml_fmod_f64");
-  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
-                                  "__ocml_log_f64");
-  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
-                                    "__ocml_log10_f64");
-  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
-                                    "__ocml_log1p_f64");
-  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
-                                   "__ocml_log2_f64");
-  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
-                                   "__ocml_pow_f64");
-  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
-                                    "__ocml_rsqrt_f64");
-  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
-                                  "__ocml_sin_f64");
-  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
-                                   "__ocml_sqrt_f64");
-  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
-                                   "__ocml_tanh_f64");
-  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
-                                  "__ocml_tan_f64");
-  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
-                                  "__ocml_erf_f64");
+  populateMathToROCDLConversionPatterns(converter, patterns);
 }
 
 std::unique_ptr<OperationPass<gpu::GPUModuleOp>>
diff --git a/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
new file mode 100644
index 0000000000000..2771955aa9493
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/CMakeLists.txt
@@ -0,0 +1,23 @@
+add_mlir_conversion_library(MLIRMathToROCDL
+  MathToROCDL.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/MathToROCDL
+
+  DEPENDS
+  MLIRConversionPassIncGen
+
+  LINK_COMPONENTS
+  Core
+
+  LINK_LIBS PUBLIC
+  MLIRDialectUtils
+  MLIRFuncDialect
+  MLIRGPUToGPURuntimeTransforms
+  MLIRMathDialect
+  MLIRLLVMCommonConversion
+  MLIRPass
+  MLIRTransformUtils
+  MLIRVectorDialect
+  MLIRVectorUtils
+  )
diff --git a/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
new file mode 100644
index 0000000000000..03c7ce5dac0d1
--- /dev/null
+++ b/mlir/lib/Conversion/MathToROCDL/MathToROCDL.cpp
@@ -0,0 +1,146 @@
+//===-- MathToROCDL.cpp - conversion from Math to rocdl calls -------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
+#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
+#include "mlir/IR/BuiltinDialect.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include "../GPUCommon/GPUOpsLowering.h"
+#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
+#include "../GPUCommon/OpToFuncCallLowering.h"
+#include "mlir/Conversion/GPUCommon/GPUCommonPass.h"
+
+namespace mlir {
+#define GEN_PASS_DEF_CONVERTMATHTOROCDL
+#include "mlir/Conversion/Passes.h.inc"
+} // namespace mlir
+
+using namespace mlir;
+
+#define DEBUG_TYPE "math-to-rocdl"
+#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
+
+template <typename OpTy>
+static void populateOpPatterns(LLVMTypeConverter &converter,
+                               RewritePatternSet &patterns, StringRef f32Func,
+                               StringRef f64Func) {
+  patterns.add<ScalarizeVectorOpLowering<OpTy>>(converter);
+  patterns.add<OpToFuncCallLowering<OpTy>>(converter, f32Func, f64Func);
+}
+
+void mlir::populateMathToROCDLConversionPatterns(LLVMTypeConverter &converter,
+                                                 RewritePatternSet &patterns) {
+  // Handled by mathToLLVM: math::AbsIOp
+  // Handled by mathToLLVM: math::CopySignOp
+  // Handled by mathToLLVM: math::CountLeadingZerosOp
+  // Handled by mathToLLVM: math::CountTrailingZerosOp
+  // Handled by mathToLLVM: math::CgPopOp
+  // Handled by mathToLLVM: math::FmaOp
+  // FIXME: math::IPowIOp
+  // FIXME: math::FPowIOp
+  // Handled by mathToLLVM: math::RoundEvenOp
+  // Handled by mathToLLVM: math::RoundOp
+  // Handled by mathToLLVM: math::TruncOp
+  populateOpPatterns<math::AbsFOp>(converter, patterns, "__ocml_fabs_f32",
+                                   "__ocml_fabs_f64");
+  populateOpPatterns<math::AcosOp>(converter, patterns, "__ocml_acos_f32",
+                                   "__ocml_acos_f64");
+  populateOpPatterns<math::AcoshOp>(converter, patterns, "__ocml_acosh_f32",
+                                    "__ocml_acosh_f64");
+  populateOpPatterns<math::AsinOp>(converter, patterns, "__ocml_asin_f32",
+                                   "__ocml_asin_f64");
+  populateOpPatterns<math::AsinhOp>(converter, patterns, "__ocml_asinh_f32",
+                                    "__ocml_asinh_f64");
+  populateOpPatterns<math::AtanOp>(converter, patterns, "__ocml_atan_f32",
+                                   "__ocml_atan_f64");
+  populateOpPatterns<math::AtanhOp>(converter, patterns, "__ocml_atanh_f32",
+                                    "__ocml_atanh_f64");
+  populateOpPatterns<math::Atan2Op>(converter, patterns, "__ocml_atan2_f32",
+                                    "__ocml_atan2_f64");
+  populateOpPatterns<math::CbrtOp>(converter, patterns, "__ocml_cbrt_f32",
+                                   "__ocml_cbrt_f64");
+  populateOpPatterns<math::CeilOp>(converter, patterns, "__ocml_ceil_f32",
+                                   "__ocml_ceil_f64");
+  populateOpPatterns<math::CosOp>(converter, patterns, "__ocml_cos_f32",
+                                  "__ocml_cos_f64");
+  populateOpPatterns<math::CoshOp>(converter, patterns, "__ocml_cosh_f32",
+                                   "__ocml_cosh_f64");
+  populateOpPatterns<math::SinhOp>(converter, patterns, "__ocml_sinh_f32",
+                                   "__ocml_sinh_f64");
+  populateOpPatterns<math::ExpOp>(converter, patterns, "__ocml_exp_f32",
+                                  "__ocml_exp_f64");
+  populateOpPatterns<math::Exp2Op>(converter, patterns, "__ocml_exp2_f32",
+                                   "__ocml_exp2_f64");
+  populateOpPatterns<math::ExpM1Op>(converter, patterns, "__ocml_expm1_f32",
+                                    "__ocml_expm1_f64");
+  populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
+                                    "__ocml_floor_f64");
+  populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
+                                  "__ocml_log_f64");
+  populateOpPatterns<math::Log10Op>(converter, patterns, "__ocml_log10_f32",
+                                    "__ocml_log10_f64");
+  populateOpPatterns<math::Log1pOp>(converter, patterns, "__ocml_log1p_f32",
+                                    "__ocml_log1p_f64");
+  populateOpPatterns<math::Log2Op>(converter, patterns, "__ocml_log2_f32",
+                                   "__ocml_log2_f64");
+  populateOpPatterns<math::PowFOp>(converter, patterns, "__ocml_pow_f32",
+                                   "__ocml_pow_f64");
+  populateOpPatterns<math::RsqrtOp>(converter, patterns, "__ocml_rsqrt_f32",
+                                    "__ocml_rsqrt_f64");
+  populateOpPatterns<math::SinOp>(converter, patterns, "__ocml_sin_f32",
+                                  "__ocml_sin_f64");
+  populateOpPatterns<math::SqrtOp>(converter, patterns, "__ocml_sqrt_f32",
+                                   "__ocml_sqrt_f64");
+  populateOpPatterns<math::TanhOp>(converter, patterns, "__ocml_tanh_f32",
+                                   "__ocml_tanh_f64");
+  populateOpPatterns<math::TanOp>(converter, patterns, "__ocml_tan_f32",
+                                  "__ocml_tan_f64");
+  populateOpPatterns<math::ErfOp>(converter, patterns, "__ocml_erf_f32",
+                                  "__ocml_erf_f64");
+  // Single arith pattern that needs a ROCDL call, probably not
+  // worth creating a separate pass for it.
+  populateOpPatterns<arith::RemFOp>(converter, patterns, "__ocml_fmod_f32",
+                                    "__ocml_fmod_f64");
+}
+
+namespace {
+struct ConvertMathToROCDLPass
+    : public impl::ConvertMathToROCDLBase<ConvertMathToROCDLPass> {
+  ConvertMathToROCDLPass() = default;
+  void runOnOperation() override;
+};
+} // namespace
+
+void ConvertMathToROCDLPass::runOnOperation() {
+  auto m = getOperation();
+  MLIRContext *ctx = m.getContext();
+
+  RewritePatternSet patterns(&getContext());
+  LowerToLLVMOptions options(ctx, DataLayout(m));
+  LLVMTypeConverter converter(ctx, options);
+  populateMathToROCDLConversionPatterns(converter, patterns);
+  ConversionTarget target(getContext());
+  target.addLegalDialect<BuiltinDialect, func::FuncDialect,
+                         vector::VectorDialect, LLVM::LLVMDialect>();
+  target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FAbsOp,
+                      LLVM::FCeilOp, LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp,
+                      LLVM::Log10Op, LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp,
+                      LLVM::SqrtOp>();
+  if (failed(applyPartialConversion(m, target, std::move(patterns))))
+    signalPassFailure();
+}
diff --git a/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
new file mode 100644
index 0000000000000..a406ec45a7f10
--- /dev/null
+++ b/mlir/test/Conversion/MathToROCDL/math-to-rocdl.mlir
@@ -0,0 +1,435 @@
+// RUN: mlir-opt %s -convert-math-to-rocdl -split-input-file | FileCheck %s
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_fmod_f32(f32, f32) -> f32
+  // CHECK: llvm.func @__ocml_fmod_f64(f64, f64) -> f64
+  // CHECK-LABEL: func @arith_remf
+  func.func @arith_remf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = arith.remf %arg_f32, %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_fmod_f32(%{{.*}}, %{{.*}}) : (f32, f32) -> f32
+    %result64 = arith.remf %arg_f64, %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_fmod_f64(%{{.*}}, %{{.*}}) : (f64, f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_fabs_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_fabs_f64(f64) -> f64
+  // CHECK-LABEL: func @math_absf
+  func.func @math_absf(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.absf %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_fabs_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.absf %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_fabs_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_acos_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_acos_f64(f64) -> f64
+  // CHECK-LABEL: func @math_acos
+  func.func @math_acos(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acos %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_acos_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.acos %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_acos_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_acosh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_acosh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_acosh
+  func.func @math_acosh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.acosh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_acosh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.acosh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_acosh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_asin_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_asin_f64(f64) -> f64
+  // CHECK-LABEL: func @math_asin
+  func.func @math_asin(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asin %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_asin_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.asin %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_asin_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_asinh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_asinh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_asinh
+  func.func @math_asinh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.asinh %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_asinh_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.asinh %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_asinh_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_atan_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_atan_f64(f64) -> f64
+  // CHECK-LABEL: func @math_atan
+  func.func @math_atan(%arg_f32 : f32, %arg_f64 : f64) -> (f32, f64) {
+    %result32 = math.atan %arg_f32 : f32
+    // CHECK: llvm.call @__ocml_atan_f32(%{{.*}}) : (f32) -> f32
+    %result64 = math.atan %arg_f64 : f64
+    // CHECK: llvm.call @__ocml_atan_f64(%{{.*}}) : (f64) -> f64
+    func.return %result32, %result64 : f32, f64
+  }
+}
+
+// -----
+
+module @test_module {
+  // CHECK: llvm.func @__ocml_atanh_f32(f32) -> f32
+  // CHECK: llvm.func @__ocml_atanh_f64(f64) -> f64
+  // CHECK-LABEL: func @math_atanh
+  func.func @math_atanh(%arg_f32 : f32, %arg_f64 : f64) -> (f32, ...
[truncated]

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly nitpicking and generality, the refactor seems fine

let dependentDialects = [
"arith::ArithDialect",
"func::FuncDialect",
"math::MathDialect",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need to depend on math if there are cases where you create a math op without there being one in the input.

(The point of the dependencies field is to let you access stuff from dialects that wouldn't "naturally" be loaded)

"__ocml_expm1_f64");
populateOpPatterns<math::FloorOp>(converter, patterns, "__ocml_floor_f32",
"__ocml_floor_f64");
populateOpPatterns<math::LogOp>(converter, patterns, "__ocml_log_f32",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While you're here (tm), the OCML math libraries now have native f16 implementations of these functions - could you add that case?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(IIRC "these" is all of them, I just commented here arbitrarily)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... ah, we're using the GPU pattern, might not be worth the bother

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably add those in a separate patch since this is only a refactoring change.

// MathToLibm
//===----------------------------------------------------------------------===//

def ConvertMathToROCDL : Pass<"convert-math-to-rocdl", "ModuleOp"> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pass doesn't need to operate on a ModuleOp specifically - you probably want this to work on anything that has a symbol table

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It needs to have a DataLayoutOpInterface as well. Should this be checked dynamically, so that the pass can be used on any operation but should only be triggered on ops with a symbol table and data layout interface?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, yeah, good point

(I'm basically trying to make sure this works on gpu.module and builtin.module)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I was thinking about making it work for both as well, but opted to use the "add patterns" function in the GPUToROCDL conversion since the options and converter are more complicated, and I didn't want to put that in this simpler pass, so it seemed okay for it to operate on the ModuleOp. I can try to see if I can make it work with both if you need it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

... Good point, this pass is for testing anyway, let's keep it at ModuleOp

Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jsjodin jsjodin merged commit 3fae555 into llvm:main Jul 17, 2024
7 checks passed
keith added a commit to keith/llvm-project that referenced this pull request Jul 17, 2024
keith added a commit that referenced this pull request Jul 17, 2024
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
…o a separate pass (#98653)

Summary:
This patch refactors the conversion of math operations to ROCDL library
calls. This pass will also be used in flang to lower Fortran
intrinsics/math functions for OpenMP target offloading codgen.

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60250922
yuxuanchen1997 pushed a commit that referenced this pull request Jul 25, 2024
Summary: 

Test Plan: 

Reviewers: 

Subscribers: 

Tasks: 

Tags: 


Differential Revision: https://phabricator.intern.facebook.com/D60251611
@jsjodin jsjodin deleted the jleyonberg/mathtorocdl branch October 1, 2024 14:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants