Skip to content

[mlir][Vector] Add fastmath flags to vector.reduction #66905

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 8 additions & 26 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMIntrinsicOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -654,32 +654,14 @@ class LLVM_VecReductionI<string mnem>
// LLVM vector reduction over a single vector, with an initial value,
// and with permission to reassociate the reduction operations.
class LLVM_VecReductionAccBase<string mnem, Type element>
: LLVM_OneResultIntrOp<"vector.reduce." # mnem, [], [0],
[Pure, SameOperandsAndResultElementType]>,
Arguments<(ins element:$start_value, LLVM_VectorOf<element>:$input,
DefaultValuedAttr<BoolAttr, "false">:$reassoc)> {
let llvmBuilder = [{
llvm::Module *module = builder.GetInsertBlock()->getModule();
llvm::Function *fn = llvm::Intrinsic::getDeclaration(
module,
llvm::Intrinsic::vector_reduce_}] # mnem # [{,
{ }] # !interleave(ListIntSubst<LLVM_IntrPatterns.operand, [1]>.lst,
", ") # [{
});
auto operands = moduleTranslation.lookupValues(opInst.getOperands());
llvm::FastMathFlags origFM = builder.getFastMathFlags();
llvm::FastMathFlags tempFM = origFM;
tempFM.setAllowReassoc($reassoc);
builder.setFastMathFlags(tempFM); // set fastmath flag
$res = builder.CreateCall(fn, operands);
builder.setFastMathFlags(origFM); // restore fastmath flag
}];
let mlirBuilder = [{
bool allowReassoc = inst->getFastMathFlags().allowReassoc();
$res = $_builder.create<$_qualCppClassName>($_location,
$_resultType, $start_value, $input, allowReassoc);
}];
}
: LLVM_OneResultIntrOp</*mnem=*/"vector.reduce." # mnem,
/*overloadedResults=*/[],
/*overloadedOperands=*/[1],
/*traits=*/[Pure, SameOperandsAndResultElementType],
/*equiresFastmath=*/1>,
Arguments<(ins element:$start_value,
LLVM_VectorOf<element>:$input,
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags)>;

class LLVM_VecReductionAccF<string mnem>
: LLVM_VecReductionAccBase<mnem, AnyFloat>;
Expand Down
24 changes: 16 additions & 8 deletions mlir/include/mlir/Dialect/Vector/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
add_mlir_dialect(VectorOps vector)
add_mlir_doc(VectorOps VectorOps Dialects/ -gen-op-doc)
add_mlir_dialect(Vector vector)
add_mlir_doc(Vector Vector Dialects/ -gen-op-doc -dialect=vector)

# Add Vector operations
set(LLVM_TARGET_DEFINITIONS VectorOps.td)
mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorOpsAttrDefs.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorOpsAttrDefs.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)
mlir_tablegen(VectorOps.h.inc -gen-op-decls)
mlir_tablegen(VectorOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRVectorOpsIncGen)
add_dependencies(mlir-generic-headers MLIRVectorOpsIncGen)

# Add Vector attributes
set(LLVM_TARGET_DEFINITIONS VectorAttributes.td)
mlir_tablegen(VectorEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorEnums.cpp.inc -gen-enum-defs)
mlir_tablegen(VectorAttributes.h.inc -gen-attrdef-decls)
mlir_tablegen(VectorAttributes.cpp.inc -gen-attrdef-defs)
add_public_tablegen_target(MLIRVectorAttributesIncGen)
add_dependencies(mlir-generic-headers MLIRVectorAttributesIncGen)
31 changes: 31 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/Vector.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- Vector.td - Vector Dialect --------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the Vector dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR
#define MLIR_DIALECT_VECTOR_IR_VECTOR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this file be called VectorBase.td as other dialects do?

include "mlir/IR/OpBase.td"

def Vector_Dialect : Dialect {
let name = "vector";
let cppNamespace = "::mlir::vector";

let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
}

// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits>;

#endif // MLIR_DIALECT_VECTOR_IR_VECTOR
85 changes: 85 additions & 0 deletions mlir/include/mlir/Dialect/Vector/IR/VectorAttributes.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
//===- VectorAttributes.td - Vector Dialect ----------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file declares the attributes used in the Vector dialect.
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
#define MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES

include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/IR/EnumAttr.td"

// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;

def CombiningKind : I32BitEnumAttr<
"CombiningKind",
"Kind of combining function for contractions and reductions",
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
COMBINING_KIND_OR, COMBINING_KIND_XOR,
COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"parallel", 0>,
I32EnumAttrCase<"reduction", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::vector";
}

def Vector_IteratorTypeEnum
: EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_IteratorTypeArrayAttr
: TypedArrayAttrBase<Vector_IteratorTypeEnum,
"Iterator type should be an enum.">;

def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
"Punctuation for separating vectors or vector elements", [
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
I32EnumAttrCase<"NewLine", 1, "newline">,
I32EnumAttrCase<"Comma", 2, "comma">,
I32EnumAttrCase<"Open", 3, "open">,
I32EnumAttrCase<"Close", 4, "close">
]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
let assemblyFormat = "`<` $value `>`";
}

#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_ATTRIBUTES
7 changes: 4 additions & 3 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#define MLIR_DIALECT_VECTOR_IR_VECTOROPS_H

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.h"
#include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.h"
#include "mlir/IR/AffineMap.h"
Expand All @@ -31,10 +32,10 @@
#include "llvm/ADT/StringExtras.h"

// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Vector/IR/VectorOpsEnums.h.inc"
#include "mlir/Dialect/Vector/IR/VectorEnums.h.inc"

#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/Vector/IR/VectorOpsAttrDefs.h.inc"
#include "mlir/Dialect/Vector/IR/VectorAttributes.h.inc"

namespace mlir {
class MLIRContext;
Expand Down Expand Up @@ -157,7 +158,7 @@ Value selectPassthru(OpBuilder &builder, Value mask, Value newValue,
} // namespace mlir

#define GET_OP_CLASSES
#include "mlir/Dialect/Vector/IR/VectorDialect.h.inc"
#include "mlir/Dialect/Vector/IR/VectorOps.h.inc"
#include "mlir/Dialect/Vector/IR/VectorOpsDialect.h.inc"

#endif // MLIR_DIALECT_VECTOR_IR_VECTOROPS_H
107 changes: 20 additions & 87 deletions mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,13 @@
//
//===----------------------------------------------------------------------===//

#ifndef VECTOR_OPS
#define VECTOR_OPS
#ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
#define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS

include "mlir/Dialect/Vector/IR/Vector.td"
include "mlir/Dialect/Vector/IR/VectorAttributes.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td"
include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td"
include "mlir/IR/EnumAttr.td"
Expand All @@ -23,69 +27,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/Interfaces/VectorInterfaces.td"
include "mlir/Interfaces/ViewLikeInterface.td"

def Vector_Dialect : Dialect {
let name = "vector";
let cppNamespace = "::mlir::vector";

let useDefaultAttributePrinterParser = 1;
let hasConstantMaterializer = 1;
let dependentDialects = ["arith::ArithDialect"];
}

// Base class for Vector dialect ops.
class Vector_Op<string mnemonic, list<Trait> traits = []> :
Op<Vector_Dialect, mnemonic, traits>;

// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : I32BitEnumAttrCaseBit<"ADD", 0, "add">;
def COMBINING_KIND_MUL : I32BitEnumAttrCaseBit<"MUL", 1, "mul">;
def COMBINING_KIND_MINUI : I32BitEnumAttrCaseBit<"MINUI", 2, "minui">;
def COMBINING_KIND_MINSI : I32BitEnumAttrCaseBit<"MINSI", 3, "minsi">;
def COMBINING_KIND_MINF : I32BitEnumAttrCaseBit<"MINF", 4, "minf">;
def COMBINING_KIND_MAXUI : I32BitEnumAttrCaseBit<"MAXUI", 5, "maxui">;
def COMBINING_KIND_MAXSI : I32BitEnumAttrCaseBit<"MAXSI", 6, "maxsi">;
def COMBINING_KIND_MAXF : I32BitEnumAttrCaseBit<"MAXF", 7, "maxf">;
def COMBINING_KIND_AND : I32BitEnumAttrCaseBit<"AND", 8, "and">;
def COMBINING_KIND_OR : I32BitEnumAttrCaseBit<"OR", 9, "or">;
def COMBINING_KIND_XOR : I32BitEnumAttrCaseBit<"XOR", 10, "xor">;
def COMBINING_KIND_MINIMUMF : I32BitEnumAttrCaseBit<"MINIMUMF", 11, "minimumf">;
def COMBINING_KIND_MAXIMUMF : I32BitEnumAttrCaseBit<"MAXIMUMF", 12, "maximumf">;

def CombiningKind : I32BitEnumAttr<
"CombiningKind",
"Kind of combining function for contractions and reductions",
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MINUI,
COMBINING_KIND_MINSI, COMBINING_KIND_MINF, COMBINING_KIND_MAXUI,
COMBINING_KIND_MAXSI, COMBINING_KIND_MAXF, COMBINING_KIND_AND,
COMBINING_KIND_OR, COMBINING_KIND_XOR,
COMBINING_KIND_MAXIMUMF, COMBINING_KIND_MINIMUMF]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
def Vector_CombiningKindAttr : EnumAttr<Vector_Dialect, CombiningKind, "kind"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_IteratorType : I32EnumAttr<"IteratorType", "Iterator type", [
I32EnumAttrCase<"parallel", 0>,
I32EnumAttrCase<"reduction", 1>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::vector";
}

def Vector_IteratorTypeEnum
: EnumAttr<Vector_Dialect, Vector_IteratorType, "iterator_type"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_IteratorTypeArrayAttr
: TypedArrayAttrBase<Vector_IteratorTypeEnum,
"Iterator type should be an enum.">;

// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
Expand Down Expand Up @@ -274,12 +215,16 @@ def Vector_ReductionOp :
Vector_Op<"reduction", [Pure,
PredOpTrait<"source operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
DeclareOpInterfaceMethods<ArithFastMathInterface>,
DeclareOpInterfaceMethods<MaskableOpInterface>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface,
["getShapeForUnroll"]>]>,
DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>
]>,
Arguments<(ins Vector_CombiningKindAttr:$kind,
AnyVectorOfAnyRank:$vector,
Optional<AnyType>:$acc)>,
Optional<AnyType>:$acc,
DefaultValuedAttr<
Arith_FastMathAttr,
"::mlir::arith::FastMathFlags::none">:$fastmath)>,
Results<(outs AnyType:$dest)> {
let summary = "reduction operation";
let description = [{
Expand Down Expand Up @@ -309,9 +254,13 @@ def Vector_ReductionOp :
}];
let builders = [
// Builder that infers the type of `dest`.
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc)>,
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector, "Value":$acc,
CArg<"::mlir::arith::FastMathFlags",
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>,
// Builder that infers the type of `dest` and has no accumulator.
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector)>
OpBuilder<(ins "CombiningKind":$kind, "Value":$vector,
CArg<"::mlir::arith::FastMathFlags",
"::mlir::arith::FastMathFlags::none">:$fastMathFlags)>
];

// TODO: Migrate to assemblyFormat once `AllTypesMatch` supports optional
Expand Down Expand Up @@ -2466,22 +2415,6 @@ def Vector_TransposeOp :
let hasVerifier = 1;
}

def PrintPunctuation : I32EnumAttr<"PrintPunctuation",
"Punctuation for separating vectors or vector elements", [
I32EnumAttrCase<"NoPunctuation", 0, "no_punctuation">,
I32EnumAttrCase<"NewLine", 1, "newline">,
I32EnumAttrCase<"Comma", 2, "comma">,
I32EnumAttrCase<"Open", 3, "open">,
I32EnumAttrCase<"Close", 4, "close">
]> {
let cppNamespace = "::mlir::vector";
let genSpecializedAttr = 0;
}

def Vector_PrintPunctuation : EnumAttr<Vector_Dialect, PrintPunctuation, "punctuation"> {
let assemblyFormat = "`<` $value `>`";
}

def Vector_PrintOp :
Vector_Op<"print", []>,
Arguments<(ins Optional<Type<Or<[
Expand Down Expand Up @@ -2936,4 +2869,4 @@ def Vector_WarpExecuteOnLane0Op : Vector_Op<"warp_execute_on_lane_0",
}];
}

#endif // VECTOR_OPS
#endif // MLIR_DIALECT_VECTOR_IR_VECTOR_OPS
Loading