@@ -31,24 +31,11 @@ void x86vector::X86VectorDialect::initialize() {
31
31
>();
32
32
}
33
33
34
- static SmallVector<Value>
35
- getMemrefBuffPtr (Location loc, ::mlir::TypedValue<::mlir::MemRefType> memrefVal,
36
- RewriterBase &rewriter,
37
- const LLVMTypeConverter &typeConverter) {
38
- SmallVector<Value> operands;
39
- auto opType = memrefVal.getType ();
40
-
41
- Type llvmStructType = typeConverter.convertType (opType);
42
- Value llvmStruct =
43
- rewriter
44
- .create <UnrealizedConversionCastOp>(loc, llvmStructType, memrefVal)
45
- .getResult (0 );
46
- MemRefDescriptor memRefDescriptor (llvmStruct);
47
-
48
- Value ptr = memRefDescriptor.bufferPtr (rewriter, loc, typeConverter, opType);
49
- operands.push_back (ptr);
50
-
51
- return operands;
34
+ static Value getMemrefBuffPtr (Location loc, MemRefType type, Value buffer,
35
+ const LLVMTypeConverter &typeConverter,
36
+ RewriterBase &rewriter) {
37
+ MemRefDescriptor memRefDescriptor (buffer);
38
+ return memRefDescriptor.bufferPtr (rewriter, loc, typeConverter, type);
52
39
}
53
40
54
41
LogicalResult x86vector::MaskCompressOp::verify () {
@@ -66,48 +53,61 @@ LogicalResult x86vector::MaskCompressOp::verify() {
66
53
}
67
54
68
55
SmallVector<Value> x86vector::MaskCompressOp::getIntrinsicOperands (
69
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
56
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
57
+ RewriterBase &rewriter) {
70
58
auto loc = getLoc ();
59
+ Adaptor adaptor (operands, *this );
71
60
72
- auto opType = getA ().getType ();
61
+ auto opType = adaptor. getA ().getType ();
73
62
Value src;
74
- if (getSrc ()) {
75
- src = getSrc ();
76
- } else if (getConstantSrc ()) {
77
- src = rewriter.create <LLVM::ConstantOp>(loc, opType, getConstantSrcAttr ());
63
+ if (adaptor.getSrc ()) {
64
+ src = adaptor.getSrc ();
65
+ } else if (adaptor.getConstantSrc ()) {
66
+ src = rewriter.create <LLVM::ConstantOp>(loc, opType,
67
+ adaptor.getConstantSrcAttr ());
78
68
} else {
79
69
auto zeroAttr = rewriter.getZeroAttr (opType);
80
70
src = rewriter.create <LLVM::ConstantOp>(loc, opType, zeroAttr);
81
71
}
82
72
83
- return SmallVector<Value>{getA (), src, getK ()};
73
+ return SmallVector<Value>{adaptor. getA (), src, adaptor. getK ()};
84
74
}
85
75
86
76
SmallVector<Value>
87
- x86vector::DotOp::getIntrinsicOperands (RewriterBase &rewriter,
88
- const LLVMTypeConverter &typeConverter) {
89
- SmallVector<Value> operands (getOperands ());
77
+ x86vector::DotOp::getIntrinsicOperands (ArrayRef<Value> operands,
78
+ const LLVMTypeConverter &typeConverter,
79
+ RewriterBase &rewriter) {
80
+ SmallVector<Value> intrinsicOperands (operands);
90
81
// Dot product of all elements, broadcasted to all elements.
91
82
Value scale =
92
83
rewriter.create <LLVM::ConstantOp>(getLoc (), rewriter.getI8Type (), 0xff );
93
- operands .push_back (scale);
84
+ intrinsicOperands .push_back (scale);
94
85
95
- return operands ;
86
+ return intrinsicOperands ;
96
87
}
97
88
98
89
SmallVector<Value> x86vector::BcstToPackedF32Op::getIntrinsicOperands (
99
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
100
- return getMemrefBuffPtr (getLoc (), getA (), rewriter, typeConverter);
90
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
91
+ RewriterBase &rewriter) {
92
+ Adaptor adaptor (operands, *this );
93
+ return {getMemrefBuffPtr (getLoc (), getA ().getType (), adaptor.getA (),
94
+ typeConverter, rewriter)};
101
95
}
102
96
103
97
SmallVector<Value> x86vector::CvtPackedEvenIndexedToF32Op::getIntrinsicOperands (
104
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
105
- return getMemrefBuffPtr (getLoc (), getA (), rewriter, typeConverter);
98
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
99
+ RewriterBase &rewriter) {
100
+ Adaptor adaptor (operands, *this );
101
+ return {getMemrefBuffPtr (getLoc (), getA ().getType (), adaptor.getA (),
102
+ typeConverter, rewriter)};
106
103
}
107
104
108
105
SmallVector<Value> x86vector::CvtPackedOddIndexedToF32Op::getIntrinsicOperands (
109
- RewriterBase &rewriter, const LLVMTypeConverter &typeConverter) {
110
- return getMemrefBuffPtr (getLoc (), getA (), rewriter, typeConverter);
106
+ ArrayRef<Value> operands, const LLVMTypeConverter &typeConverter,
107
+ RewriterBase &rewriter) {
108
+ Adaptor adaptor (operands, *this );
109
+ return {getMemrefBuffPtr (getLoc (), getA ().getType (), adaptor.getA (),
110
+ typeConverter, rewriter)};
111
111
}
112
112
113
113
#define GET_OP_CLASSES
0 commit comments