Skip to content

Commit c15539c

Browse files
authored
[mlir][x86vector] Improve intrinsic operands creation (llvm#138666)
Refactors intrinsic op interface to delegate initial operands mapping to the dialect converter and allow intrinsic operands getters to only perform last mile post-processing.
1 parent aa9f859 commit c15539c

File tree

4 files changed

+72
-52
lines changed

4 files changed

+72
-52
lines changed

mlir/include/mlir/Dialect/X86Vector/X86Vector.td

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [Pure,
8383
}
8484
}];
8585
let extraClassDeclaration = [{
86-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
86+
SmallVector<Value> getIntrinsicOperands(
87+
::mlir::ArrayRef<Value> operands,
88+
const ::mlir::LLVMTypeConverter &typeConverter,
89+
::mlir::RewriterBase &rewriter);
8790
}];
8891
}
8992

@@ -404,7 +407,10 @@ def DotOp : AVX_LowOp<"dot", [Pure,
404407
}
405408
}];
406409
let extraClassDeclaration = [{
407-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
410+
SmallVector<Value> getIntrinsicOperands(
411+
::mlir::ArrayRef<Value> operands,
412+
const ::mlir::LLVMTypeConverter &typeConverter,
413+
::mlir::RewriterBase &rewriter);
408414
}];
409415
}
410416

@@ -452,7 +458,10 @@ def BcstToPackedF32Op : AVX_Op<"bcst_to_f32.packed", [MemoryEffects<[MemRead]>,
452458
}];
453459

454460
let extraClassDeclaration = [{
455-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
461+
SmallVector<Value> getIntrinsicOperands(
462+
::mlir::ArrayRef<Value> operands,
463+
const ::mlir::LLVMTypeConverter &typeConverter,
464+
::mlir::RewriterBase &rewriter);
456465
}];
457466

458467
}
@@ -500,7 +509,10 @@ def CvtPackedEvenIndexedToF32Op : AVX_Op<"cvt.packed.even.indexed_to_f32", [Memo
500509
}];
501510

502511
let extraClassDeclaration = [{
503-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
512+
SmallVector<Value> getIntrinsicOperands(
513+
::mlir::ArrayRef<Value> operands,
514+
const ::mlir::LLVMTypeConverter &typeConverter,
515+
::mlir::RewriterBase &rewriter);
504516
}];
505517
}
506518

@@ -543,7 +555,10 @@ def CvtPackedOddIndexedToF32Op : AVX_Op<"cvt.packed.odd.indexed_to_f32", [Memory
543555
}];
544556

545557
let extraClassDeclaration = [{
546-
SmallVector<Value> getIntrinsicOperands(::mlir::RewriterBase&, const LLVMTypeConverter&);
558+
SmallVector<Value> getIntrinsicOperands(
559+
::mlir::ArrayRef<Value> operands,
560+
const ::mlir::LLVMTypeConverter &typeConverter,
561+
::mlir::RewriterBase &rewriter);
547562
}];
548563
}
549564
#endif // X86VECTOR_OPS

mlir/include/mlir/Dialect/X86Vector/X86VectorInterfaces.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,11 @@ def OneToOneIntrinsicOpInterface : OpInterface<"OneToOneIntrinsicOp"> {
5858
}],
5959
/*retType=*/"SmallVector<Value>",
6060
/*methodName=*/"getIntrinsicOperands",
61-
/*args=*/(ins "::mlir::RewriterBase &":$rewriter, "const LLVMTypeConverter &":$typeConverter),
61+
/*args=*/(ins "::mlir::ArrayRef<Value>":$operands,
62+
"const ::mlir::LLVMTypeConverter &":$typeConverter,
63+
"::mlir::RewriterBase &":$rewriter),
6264
/*methodBody=*/"",
63-
/*defaultImplementation=*/"return SmallVector<Value>($_op->getOperands());"
65+
/*defaultImplementation=*/"return SmallVector<Value>(operands);"
6466
>,
6567
];
6668
}

mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -31,24 +31,11 @@ void x86vector::X86VectorDialect::initialize() {
3131
>();
3232
}
3333

34-
static SmallVector<Value>
35-
getMemrefBuffPtr(Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
36-
RewriterBase &rewriter,
37-
const LLVMTypeConverter &typeConverter) {
38-
SmallVector<Value> operands;
39-
auto opType = memrefVal.getType();
40-
41-
Type llvmStructType = typeConverter.convertType(opType);
42-
Value llvmStruct =
43-
rewriter
44-
.create<UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
45-
.getResult(0);
46-
MemRefDescriptor memRefDescriptor(llvmStruct);
47-
48-
Value ptr = memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, opType);
49-
operands.push_back(ptr);
50-
51-
return operands;
34+
static Value getMemrefBuffPtr(Location loc, MemRefType type, Value buffer,
35+
const LLVMTypeConverter &typeConverter,
36+
RewriterBase &rewriter) {
37+
MemRefDescriptor memRefDescriptor(buffer);
38+
return memRefDescriptor.bufferPtr(rewriter, loc, typeConverter, type);
5239
}
5340

5441
LogicalResult x86vector::MaskCompressOp::verify() {
@@ -66,48 +53,61 @@ LogicalResult x86vector::MaskCompressOp::verify() {
6653
}
6754

6855
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands(
69-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
56+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
57+
RewriterBase &rewriter) {
7058
auto loc = getLoc();
59+
Adaptor adaptor(operands, *this);
7160

72-
auto opType = getA().getType();
61+
auto opType = adaptor.getA().getType();
7362
Value src;
74-
if (getSrc()) {
75-
src = getSrc();
76-
} else if (getConstantSrc()) {
77-
src = rewriter.create<LLVM::ConstantOp>(loc, opType, getConstantSrcAttr());
63+
if (adaptor.getSrc()) {
64+
src = adaptor.getSrc();
65+
} else if (adaptor.getConstantSrc()) {
66+
src = rewriter.create<LLVM::ConstantOp>(loc, opType,
67+
adaptor.getConstantSrcAttr());
7868
} else {
7969
auto zeroAttr = rewriter.getZeroAttr(opType);
8070
src = rewriter.create<LLVM::ConstantOp>(loc, opType, zeroAttr);
8171
}
8272

83-
return SmallVector<Value>{getA(), src, getK()};
73+
return SmallVector<Value>{adaptor.getA(), src, adaptor.getK()};
8474
}
8575

8676
SmallVector<Value>
87-
x86vector::DotOp::getIntrinsicOperands(RewriterBase &rewriter,
88-
const LLVMTypeConverter &typeConverter) {
89-
SmallVector<Value> operands(getOperands());
77+
x86vector::DotOp::getIntrinsicOperands(ArrayRef<Value> operands,
78+
const LLVMTypeConverter &typeConverter,
79+
RewriterBase &rewriter) {
80+
SmallVector<Value> intrinsicOperands(operands);
9081
// Dot product of all elements, broadcasted to all elements.
9182
Value scale =
9283
rewriter.create<LLVM::ConstantOp>(getLoc(), rewriter.getI8Type(), 0xff);
93-
operands.push_back(scale);
84+
intrinsicOperands.push_back(scale);
9485

95-
return operands;
86+
return intrinsicOperands;
9687
}
9788

9889
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands(
99-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
100-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
90+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
91+
RewriterBase &rewriter) {
92+
Adaptor adaptor(operands, *this);
93+
return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
94+
typeConverter, rewriter)};
10195
}
10296

10397
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands(
104-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
105-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
98+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
99+
RewriterBase &rewriter) {
100+
Adaptor adaptor(operands, *this);
101+
return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
102+
typeConverter, rewriter)};
106103
}
107104

108105
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands(
109-
RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
110-
return getMemrefBuffPtr(getLoc(), getA(), rewriter, typeConverter);
106+
ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
107+
RewriterBase &rewriter) {
108+
Adaptor adaptor(operands, *this);
109+
return {getMemrefBuffPtr(getLoc(), getA().getType(), adaptor.getA(),
110+
typeConverter, rewriter)};
111111
}
112112

113113
#define GET_OP_CLASSES

mlir/lib/Dialect/X86Vector/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -84,20 +84,23 @@ LogicalResult intrinsicRewrite(Operation *op, StringAttr intrinsic,
8484
/// Generic one-to-one conversion of simply mappable operations into calls
8585
/// to their respective LLVM intrinsics.
8686
struct OneToOneIntrinsicOpConversion
87-
: public OpInterfaceRewritePattern<x86vector::OneToOneIntrinsicOp> {
88-
using OpInterfaceRewritePattern<
89-
x86vector::OneToOneIntrinsicOp>::OpInterfaceRewritePattern;
87+
: public OpInterfaceConversionPattern<x86vector::OneToOneIntrinsicOp> {
88+
using OpInterfaceConversionPattern<
89+
x86vector::OneToOneIntrinsicOp>::OpInterfaceConversionPattern;
9090

9191
OneToOneIntrinsicOpConversion(const LLVMTypeConverter &typeConverter,
9292
PatternBenefit benefit = 1)
93-
: OpInterfaceRewritePattern(&typeConverter.getContext(), benefit),
93+
: OpInterfaceConversionPattern(typeConverter, &typeConverter.getContext(),
94+
benefit),
9495
typeConverter(typeConverter) {}
9596

96-
LogicalResult matchAndRewrite(x86vector::OneToOneIntrinsicOp op,
97-
PatternRewriter &rewriter) const override {
98-
return intrinsicRewrite(op, rewriter.getStringAttr(op.getIntrinsicName()),
99-
op.getIntrinsicOperands(rewriter, typeConverter),
100-
typeConverter, rewriter);
97+
LogicalResult
98+
matchAndRewrite(x86vector::OneToOneIntrinsicOp op, ArrayRef<Value> operands,
99+
ConversionPatternRewriter &rewriter) const override {
100+
return intrinsicRewrite(
101+
op, rewriter.getStringAttr(op.getIntrinsicName()),
102+
op.getIntrinsicOperands(operands, typeConverter, rewriter),
103+
typeConverter, rewriter);
101104
}
102105

103106
private:

0 commit comments

Comments
 (0)