-
Notifications
You must be signed in to change notification settings - Fork 13.6k
[mlir][spirv] Add a generic convert-to-spirv
pass
#95942
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
//===- ConvertToSPIRVPass.h - Conversion to SPIR-V pass ---*- C++ -*-=========// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
|
||
#ifndef MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H | ||
#define MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H | ||
|
||
#include <memory> | ||
kuhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
namespace mlir { | ||
class Pass; | ||
|
||
#define GEN_PASS_DECL_CONVERTTOSPIRVPASS | ||
#include "mlir/Conversion/Passes.h.inc" | ||
|
||
} // namespace mlir | ||
|
||
#endif // MLIR_CONVERSION_CONVERTTOSPIRV_CONVERTTOSPIRVPASS_H |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,18 @@ def ConvertToLLVMPass : Pass<"convert-to-llvm"> { | |
]; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// ToSPIRV | ||
//===----------------------------------------------------------------------===// | ||
|
||
def ConvertToSPIRVPass : Pass<"convert-to-spirv"> { | ||
let summary = "Convert to SPIR-V"; | ||
let description = [{ | ||
This is a generic pass to convert to SPIR-V. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This claims to be a "generic" pass but it seems to me instead to be a "monolithic" pass. Why didn't we align this on the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @angelz913 talked about this in https://youtu.be/-qoMMrlYvGs?t=436 (starts around the 7m 15s mark). The TL;DR is we didn't know what interfaces to have, and decided to start with a monolithic multi-stage v0 implementation, with the plan to add make it interface-based when gain confidence in this design. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I don't quite understand? Why isn't it just copy the convert-to-llvm one and rename it? I have strong concerns with in-tree monolithic passes like this: what is the timeline to remove this and migrate to a pluggable one? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that in-tree monolithic passes arent great, but the change to use interfaces and go to more LLVM based approach is more involved (I dont think it is as simple as "copy the convert-to-llvm and rename it". Angel and Jakub can fill in more details here). I think timeline will depend on how much community involvement we get here. Jakub and Angel are trying to get upstream support flushed out more. So this is a strict improvement anyway. Its probably better to iterate on this in a bit. FWIW, for conversion to LLVM in IREE we just throw all the conversion patterns into one pass and run them together. I personally find the split of conversion from each dialect to LLVM kind of artificial. Everything needs to be translated to LLVM. Running multiple passes that walk the IR multiple times seems like a waste. But that is a downstream decision and not having monoliths in upstream MLIR is useful. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
In this case I would want to see this pass moved to the test folder: I'm not comfortable with monolithic passes alongside the rest of the transformations right now.
I suspect you missed the intent of upstream design: these individual passes are all made that way upstream for testing, the intention has always been that downstream projects create a monolithic pass for their own purpose and don't use the upstream passes as-is (which is why the populatePatterns method are exposed): that's exactly "work as intended". There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Why would it have to be moved to the runner? The current pattern is that the transformations are made with
I think we're talking about different things maybe? That is, from a very high-level, the specific runner tool goes away and you do instead something like Note that in this model, the test pass is available for the opt tool, the runner does not need to know about the test passes (or any pass). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification, Mehdi.
Right, that's a good way to put it. From my point of view, what I mostly care about is that we do keep these tests (as in, both the implementation of convert-to-spirv and the e2e .mlir tests) while the cleanup in this area (both gpu-to-spirv and the runner implementation). I think we all aggrege on the end state, so I just want to make sure that we can find a path that allows us to make incremental improvements to get there. From your perspective, would it be an option to make this a test pass and temporarily have it registered it in the vulkan runner (similar to how they are manually listed in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The problem is that we optionally compile the test passes I think? (With It may be easier to just keep the pass as-is until we migrate tests from mlir-vulkan-runner to mlir-cpu-runner? The only thing that makes be nervous is that unless there is a timeline with people making progress on it, we'll still be adding mlir-vulkan-runner based tests multiple years from now. What I'm trying to see is some sort of a "gradient" on the progression of all this (and some timeline), rather than a quick immediate solution. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, let me check internally if we can commit to something more concrete along this axis. Overall, I think we did make a lot of recent progress on the spirv conversion test coverage and general cleanup in this area, but I understand why you'd want to prioritize the runner cleanup next. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @joker-eph I checked and we will have @andfau-amd work on the mlir runner migration, starting from ~next week. |
||
}]; | ||
let dependentDialects = ["spirv::SPIRVDialect"]; | ||
} | ||
|
||
//===----------------------------------------------------------------------===// | ||
// AffineToStandard | ||
//===----------------------------------------------------------------------===// | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
set(LLVM_OPTIONAL_SOURCES | ||
ConvertToSPIRVPass.cpp | ||
) | ||
|
||
add_mlir_conversion_library(MLIRConvertToSPIRVPass | ||
ConvertToSPIRVPass.cpp | ||
|
||
ADDITIONAL_HEADER_DIRS | ||
${MLIR_MAIN_INCLUDE_DIR}/mlir/Conversion/ConvertToSPIRV | ||
|
||
DEPENDS | ||
MLIRConversionPassIncGen | ||
|
||
LINK_LIBS PUBLIC | ||
MLIRIR | ||
MLIRPass | ||
MLIRRewrite | ||
MLIRSPIRVConversion | ||
MLIRSPIRVDialect | ||
MLIRSupport | ||
MLIRTransformUtils | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
//===- ConvertToSPIRVPass.cpp - MLIR SPIR-V Conversion --------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===----------------------------------------------------------------------===// | ||
kuhar marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
#include "mlir/Conversion/ConvertToSPIRV/ConvertToSPIRVPass.h" | ||
#include "mlir/Conversion/ArithToSPIRV/ArithToSPIRV.h" | ||
#include "mlir/Conversion/FuncToSPIRV/FuncToSPIRV.h" | ||
#include "mlir/Conversion/IndexToSPIRV/IndexToSPIRV.h" | ||
#include "mlir/Conversion/SCFToSPIRV/SCFToSPIRV.h" | ||
#include "mlir/Conversion/UBToSPIRV/UBToSPIRV.h" | ||
#include "mlir/Conversion/VectorToSPIRV/VectorToSPIRV.h" | ||
#include "mlir/Dialect/Arith/Transforms/Passes.h" | ||
#include "mlir/Dialect/SPIRV/IR/SPIRVAttributes.h" | ||
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" | ||
#include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" | ||
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h" | ||
#include "mlir/IR/PatternMatch.h" | ||
#include "mlir/Pass/Pass.h" | ||
#include "mlir/Rewrite/FrozenRewritePatternSet.h" | ||
#include "mlir/Transforms/DialectConversion.h" | ||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
#include <memory> | ||
|
||
#define DEBUG_TYPE "convert-to-spirv" | ||
|
||
namespace mlir { | ||
#define GEN_PASS_DEF_CONVERTTOSPIRVPASS | ||
#include "mlir/Conversion/Passes.h.inc" | ||
} // namespace mlir | ||
|
||
using namespace mlir; | ||
|
||
namespace { | ||
|
||
/// A pass to perform the SPIR-V conversion. | ||
struct ConvertToSPIRVPass final | ||
: impl::ConvertToSPIRVPassBase<ConvertToSPIRVPass> { | ||
|
||
void runOnOperation() override { | ||
MLIRContext *context = &getContext(); | ||
Operation *op = getOperation(); | ||
|
||
spirv::TargetEnvAttr targetAttr = spirv::lookupTargetEnvOrDefault(op); | ||
SPIRVTypeConverter typeConverter(targetAttr); | ||
|
||
RewritePatternSet patterns(context); | ||
ScfToSPIRVContext scfToSPIRVContext; | ||
|
||
// Populate patterns. | ||
arith::populateCeilFloorDivExpandOpsPatterns(patterns); | ||
arith::populateArithToSPIRVPatterns(typeConverter, patterns); | ||
populateBuiltinFuncToSPIRVPatterns(typeConverter, patterns); | ||
populateFuncToSPIRVPatterns(typeConverter, patterns); | ||
index::populateIndexToSPIRVPatterns(typeConverter, patterns); | ||
populateVectorToSPIRVPatterns(typeConverter, patterns); | ||
populateSCFToSPIRVPatterns(typeConverter, scfToSPIRVContext, patterns); | ||
ub::populateUBToSPIRVConversionPatterns(typeConverter, patterns); | ||
|
||
std::unique_ptr<ConversionTarget> target = | ||
SPIRVConversionTarget::get(targetAttr); | ||
|
||
if (failed(applyPartialConversion(op, *target, std::move(patterns)))) | ||
return signalPassFailure(); | ||
} | ||
}; | ||
|
||
} // namespace |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
// RUN: mlir-opt -convert-to-spirv -split-input-file %s | FileCheck %s | ||
|
||
//===----------------------------------------------------------------------===// | ||
// arithmetic ops | ||
//===----------------------------------------------------------------------===// | ||
|
||
// CHECK-LABEL: @int32_scalar | ||
func.func @int32_scalar(%lhs: i32, %rhs: i32) { | ||
// CHECK: spirv.IAdd %{{.*}}, %{{.*}}: i32 | ||
%0 = arith.addi %lhs, %rhs: i32 | ||
// CHECK: spirv.ISub %{{.*}}, %{{.*}}: i32 | ||
%1 = arith.subi %lhs, %rhs: i32 | ||
// CHECK: spirv.IMul %{{.*}}, %{{.*}}: i32 | ||
%2 = arith.muli %lhs, %rhs: i32 | ||
// CHECK: spirv.SDiv %{{.*}}, %{{.*}}: i32 | ||
%3 = arith.divsi %lhs, %rhs: i32 | ||
// CHECK: spirv.UDiv %{{.*}}, %{{.*}}: i32 | ||
%4 = arith.divui %lhs, %rhs: i32 | ||
// CHECK: spirv.UMod %{{.*}}, %{{.*}}: i32 | ||
%5 = arith.remui %lhs, %rhs: i32 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @int32_scalar_srem | ||
// CHECK-SAME: (%[[LHS:.+]]: i32, %[[RHS:.+]]: i32) | ||
func.func @int32_scalar_srem(%lhs: i32, %rhs: i32) { | ||
// CHECK: %[[LABS:.+]] = spirv.GL.SAbs %[[LHS]] : i32 | ||
// CHECK: %[[RABS:.+]] = spirv.GL.SAbs %[[RHS]] : i32 | ||
// CHECK: %[[ABS:.+]] = spirv.UMod %[[LABS]], %[[RABS]] : i32 | ||
// CHECK: %[[POS:.+]] = spirv.IEqual %[[LHS]], %[[LABS]] : i32 | ||
// CHECK: %[[NEG:.+]] = spirv.SNegate %[[ABS]] : i32 | ||
// CHECK: %{{.+}} = spirv.Select %[[POS]], %[[ABS]], %[[NEG]] : i1, i32 | ||
%0 = arith.remsi %lhs, %rhs: i32 | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// | ||
// arith bit ops | ||
//===----------------------------------------------------------------------===// | ||
|
||
// CHECK-LABEL: @bitwise_scalar | ||
func.func @bitwise_scalar(%arg0 : i32, %arg1 : i32) { | ||
// CHECK: spirv.BitwiseAnd | ||
%0 = arith.andi %arg0, %arg1 : i32 | ||
// CHECK: spirv.BitwiseOr | ||
%1 = arith.ori %arg0, %arg1 : i32 | ||
// CHECK: spirv.BitwiseXor | ||
%2 = arith.xori %arg0, %arg1 : i32 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @bitwise_vector | ||
func.func @bitwise_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { | ||
// CHECK: spirv.BitwiseAnd | ||
%0 = arith.andi %arg0, %arg1 : vector<4xi32> | ||
// CHECK: spirv.BitwiseOr | ||
%1 = arith.ori %arg0, %arg1 : vector<4xi32> | ||
// CHECK: spirv.BitwiseXor | ||
%2 = arith.xori %arg0, %arg1 : vector<4xi32> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @logical_scalar | ||
func.func @logical_scalar(%arg0 : i1, %arg1 : i1) { | ||
// CHECK: spirv.LogicalAnd | ||
%0 = arith.andi %arg0, %arg1 : i1 | ||
// CHECK: spirv.LogicalOr | ||
%1 = arith.ori %arg0, %arg1 : i1 | ||
// CHECK: spirv.LogicalNotEqual | ||
%2 = arith.xori %arg0, %arg1 : i1 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @logical_vector | ||
func.func @logical_vector(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { | ||
// CHECK: spirv.LogicalAnd | ||
%0 = arith.andi %arg0, %arg1 : vector<4xi1> | ||
// CHECK: spirv.LogicalOr | ||
%1 = arith.ori %arg0, %arg1 : vector<4xi1> | ||
// CHECK: spirv.LogicalNotEqual | ||
%2 = arith.xori %arg0, %arg1 : vector<4xi1> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @shift_scalar | ||
func.func @shift_scalar(%arg0 : i32, %arg1 : i32) { | ||
// CHECK: spirv.ShiftLeftLogical | ||
%0 = arith.shli %arg0, %arg1 : i32 | ||
// CHECK: spirv.ShiftRightArithmetic | ||
%1 = arith.shrsi %arg0, %arg1 : i32 | ||
// CHECK: spirv.ShiftRightLogical | ||
%2 = arith.shrui %arg0, %arg1 : i32 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @shift_vector | ||
func.func @shift_vector(%arg0 : vector<4xi32>, %arg1 : vector<4xi32>) { | ||
// CHECK: spirv.ShiftLeftLogical | ||
%0 = arith.shli %arg0, %arg1 : vector<4xi32> | ||
// CHECK: spirv.ShiftRightArithmetic | ||
%1 = arith.shrsi %arg0, %arg1 : vector<4xi32> | ||
// CHECK: spirv.ShiftRightLogical | ||
%2 = arith.shrui %arg0, %arg1 : vector<4xi32> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// | ||
// arith.cmpf | ||
//===----------------------------------------------------------------------===// | ||
|
||
// CHECK-LABEL: @cmpf | ||
func.func @cmpf(%arg0 : f32, %arg1 : f32) { | ||
// CHECK: spirv.FOrdEqual | ||
%1 = arith.cmpf oeq, %arg0, %arg1 : f32 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @vec1cmpf | ||
func.func @vec1cmpf(%arg0 : vector<1xf32>, %arg1 : vector<1xf32>) { | ||
// CHECK: spirv.FOrdGreaterThan | ||
%0 = arith.cmpf ogt, %arg0, %arg1 : vector<1xf32> | ||
// CHECK: spirv.FUnordLessThan | ||
%1 = arith.cmpf ult, %arg0, %arg1 : vector<1xf32> | ||
return | ||
} | ||
|
||
// ----- | ||
|
||
//===----------------------------------------------------------------------===// | ||
// arith.cmpi | ||
//===----------------------------------------------------------------------===// | ||
|
||
// CHECK-LABEL: @cmpi | ||
func.func @cmpi(%arg0 : i32, %arg1 : i32) { | ||
// CHECK: spirv.IEqual | ||
%0 = arith.cmpi eq, %arg0, %arg1 : i32 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @indexcmpi | ||
func.func @indexcmpi(%arg0 : index, %arg1 : index) { | ||
// CHECK: spirv.IEqual | ||
%0 = arith.cmpi eq, %arg0, %arg1 : index | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @vec1cmpi | ||
func.func @vec1cmpi(%arg0 : vector<1xi32>, %arg1 : vector<1xi32>) { | ||
// CHECK: spirv.ULessThan | ||
%0 = arith.cmpi ult, %arg0, %arg1 : vector<1xi32> | ||
// CHECK: spirv.SGreaterThan | ||
%1 = arith.cmpi sgt, %arg0, %arg1 : vector<1xi32> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @boolcmpi_equality | ||
func.func @boolcmpi_equality(%arg0 : i1, %arg1 : i1) { | ||
// CHECK: spirv.LogicalEqual | ||
%0 = arith.cmpi eq, %arg0, %arg1 : i1 | ||
// CHECK: spirv.LogicalNotEqual | ||
%1 = arith.cmpi ne, %arg0, %arg1 : i1 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @boolcmpi_unsigned | ||
func.func @boolcmpi_unsigned(%arg0 : i1, %arg1 : i1) { | ||
// CHECK-COUNT-2: spirv.Select | ||
// CHECK: spirv.UGreaterThanEqual | ||
%0 = arith.cmpi uge, %arg0, %arg1 : i1 | ||
// CHECK-COUNT-2: spirv.Select | ||
// CHECK: spirv.ULessThan | ||
%1 = arith.cmpi ult, %arg0, %arg1 : i1 | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @vec1boolcmpi_equality | ||
func.func @vec1boolcmpi_equality(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { | ||
// CHECK: spirv.LogicalEqual | ||
%0 = arith.cmpi eq, %arg0, %arg1 : vector<1xi1> | ||
// CHECK: spirv.LogicalNotEqual | ||
%1 = arith.cmpi ne, %arg0, %arg1 : vector<1xi1> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @vec1boolcmpi_unsigned | ||
func.func @vec1boolcmpi_unsigned(%arg0 : vector<1xi1>, %arg1 : vector<1xi1>) { | ||
// CHECK-COUNT-2: spirv.Select | ||
// CHECK: spirv.UGreaterThanEqual | ||
%0 = arith.cmpi uge, %arg0, %arg1 : vector<1xi1> | ||
// CHECK-COUNT-2: spirv.Select | ||
// CHECK: spirv.ULessThan | ||
%1 = arith.cmpi ult, %arg0, %arg1 : vector<1xi1> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @vecboolcmpi_equality | ||
func.func @vecboolcmpi_equality(%arg0 : vector<4xi1>, %arg1 : vector<4xi1>) { | ||
// CHECK: spirv.LogicalEqual | ||
%0 = arith.cmpi eq, %arg0, %arg1 : vector<4xi1> | ||
// CHECK: spirv.LogicalNotEqual | ||
%1 = arith.cmpi ne, %arg0, %arg1 : vector<4xi1> | ||
return | ||
} | ||
|
||
// CHECK-LABEL: @vecboolcmpi_unsigned | ||
func.func @vecboolcmpi_unsigned(%arg0 : vector<3xi1>, %arg1 : vector<3xi1>) { | ||
// CHECK-COUNT-2: spirv.Select | ||
// CHECK: spirv.UGreaterThanEqual | ||
%0 = arith.cmpi uge, %arg0, %arg1 : vector<3xi1> | ||
// CHECK-COUNT-2: spirv.Select | ||
// CHECK: spirv.ULessThan | ||
%1 = arith.cmpi ult, %arg0, %arg1 : vector<3xi1> | ||
return | ||
} |
Uh oh!
There was an error while loading. Please reload this page.