Skip to content

Commit 7feba5f

Browse files
authored
[CIR] Upstream extract op for VectorType (llvm#138413)
This change adds extract op for VectorType Issue llvm#136487
1 parent 7a66746 commit 7feba5f

File tree

7 files changed

+239
-3
lines changed

7 files changed

+239
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1969,4 +1969,32 @@ def VecCreateOp : CIR_Op<"vec.create", [Pure]> {
19691969
let hasVerifier = 1;
19701970
}
19711971

1972+
//===----------------------------------------------------------------------===//
1973+
// VecExtractOp
1974+
//===----------------------------------------------------------------------===//
1975+
1976+
def VecExtractOp : CIR_Op<"vec.extract", [Pure,
1977+
TypesMatchWith<"type of 'result' matches element type of 'vec'", "vec",
1978+
"result", "cast<VectorType>($_self).getElementType()">]> {
1979+
1980+
let summary = "Extract one element from a vector object";
1981+
let description = [{
1982+
The `cir.vec.extract` operation extracts the element at the given index
1983+
from a vector object.
1984+
1985+
```mlir
1986+
%tmp = cir.load %vec : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1987+
%idx = cir.const #cir.int<1> : !s32i
1988+
%element = cir.vec.extract %tmp[%idx : !s32i] : !cir.vector<4 x !s32i>
1989+
```
1990+
}];
1991+
1992+
let arguments = (ins CIR_VectorType:$vec, CIR_AnyFundamentalIntType:$index);
1993+
let results = (outs CIR_AnyType:$result);
1994+
1995+
let assemblyFormat = [{
1996+
$vec `[` $index `:` type($index) `]` attr-dict `:` qualified(type($vec))
1997+
}];
1998+
}
1999+
19722000
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,11 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
161161
mlir::Value VisitArraySubscriptExpr(ArraySubscriptExpr *e) {
162162
if (e->getBase()->getType()->isVectorType()) {
163163
assert(!cir::MissingFeatures::scalableVectors());
164-
cgf.getCIRGenModule().errorNYI("VisitArraySubscriptExpr: VectorType");
165-
return {};
164+
165+
const mlir::Location loc = cgf.getLoc(e->getSourceRange());
166+
const mlir::Value vecValue = Visit(e->getBase());
167+
const mlir::Value indexValue = Visit(e->getIdx());
168+
return cgf.builder.create<cir::VecExtractOp>(loc, vecValue, indexValue);
166169
}
167170
// Just load the lvalue formed by the subscript expression.
168171
return emitLoadOfLValue(e);

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1600,7 +1600,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
16001600
CIRToLLVMStackRestoreOpLowering,
16011601
CIRToLLVMTrapOpLowering,
16021602
CIRToLLVMUnaryOpLowering,
1603-
CIRToLLVMVecCreateOpLowering
1603+
CIRToLLVMVecCreateOpLowering,
1604+
CIRToLLVMVecExtractOpLowering
16041605
// clang-format on
16051606
>(converter, patterns.getContext());
16061607

@@ -1709,6 +1710,14 @@ mlir::LogicalResult CIRToLLVMVecCreateOpLowering::matchAndRewrite(
17091710
return mlir::success();
17101711
}
17111712

1713+
mlir::LogicalResult CIRToLLVMVecExtractOpLowering::matchAndRewrite(
1714+
cir::VecExtractOp op, OpAdaptor adaptor,
1715+
mlir::ConversionPatternRewriter &rewriter) const {
1716+
rewriter.replaceOpWithNewOp<mlir::LLVM::ExtractElementOp>(
1717+
op, adaptor.getVec(), adaptor.getIndex());
1718+
return mlir::success();
1719+
}
1720+
17121721
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
17131722
return std::make_unique<ConvertCIRToLLVMPass>();
17141723
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,16 @@ class CIRToLLVMVecCreateOpLowering
303303
mlir::ConversionPatternRewriter &) const override;
304304
};
305305

306+
class CIRToLLVMVecExtractOpLowering
307+
: public mlir::OpConversionPattern<cir::VecExtractOp> {
308+
public:
309+
using mlir::OpConversionPattern<cir::VecExtractOp>::OpConversionPattern;
310+
311+
mlir::LogicalResult
312+
matchAndRewrite(cir::VecExtractOp op, OpAdaptor,
313+
mlir::ConversionPatternRewriter &) const override;
314+
};
315+
306316
} // namespace direct
307317
} // namespace cir
308318

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,80 @@ void foo2(vi4 p) {}
136136

137137
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
138138
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
139+
140+
void foo3() {
141+
vi4 a = { 1, 2, 3, 4 };
142+
int e = a[1];
143+
}
144+
145+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
146+
// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
147+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
148+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
149+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
150+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
151+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
152+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
153+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
154+
// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
155+
// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
156+
// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
157+
// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
158+
159+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
160+
// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
161+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
162+
// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
163+
// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
164+
// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
165+
166+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
167+
// OGCG: %[[INIT:.*]] = alloca i32, align 4
168+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
169+
// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
170+
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
171+
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
172+
173+
void foo4() {
174+
vi4 a = { 1, 2, 3, 4 };
175+
176+
int idx = 2;
177+
int e = a[idx];
178+
}
179+
180+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
181+
// CIR: %[[IDX:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["idx", init]
182+
// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
183+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
184+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
185+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
186+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
187+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
188+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
189+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
190+
// CIR: %[[CONST_IDX:.*]] = cir.const #cir.int<2> : !s32i
191+
// CIR: cir.store %[[CONST_IDX]], %[[IDX]] : !s32i, !cir.ptr<!s32i>
192+
// CIR: %[[TMP1:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
193+
// CIR: %[[TMP2:.*]] = cir.load %[[IDX]] : !cir.ptr<!s32i>, !s32i
194+
// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP1]][%[[TMP2]] : !s32i] : !cir.vector<4 x !s32i>
195+
// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
196+
197+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
198+
// LLVM: %[[IDX:.*]] = alloca i32, i64 1, align 4
199+
// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
200+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
201+
// LLVM: store i32 2, ptr %[[IDX]], align 4
202+
// LLVM: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
203+
// LLVM: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
204+
// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
205+
// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
206+
207+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
208+
// OGCG: %[[IDX:.*]] = alloca i32, align 4
209+
// OGCG: %[[INIT:.*]] = alloca i32, align 4
210+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
211+
// OGCG: store i32 2, ptr %[[IDX]], align 4
212+
// OGCG: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
213+
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
214+
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
215+
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,3 +124,80 @@ void foo2(vi4 p) {}
124124

125125
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
126126
// OGCG: store <4 x i32> %{{.*}}, ptr %[[VEC_A]], align 16
127+
128+
void foo3() {
129+
vi4 a = { 1, 2, 3, 4 };
130+
int e = a[1];
131+
}
132+
133+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
134+
// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
135+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
136+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
137+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
138+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
139+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
140+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
141+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
142+
// CIR: %[[TMP:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
143+
// CIR: %[[IDX:.*]] = cir.const #cir.int<1> : !s32i
144+
// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP]][%[[IDX]] : !s32i] : !cir.vector<4 x !s32i>
145+
// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
146+
147+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
148+
// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
149+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
150+
// LLVM: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
151+
// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
152+
// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
153+
154+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
155+
// OGCG: %[[INIT:.*]] = alloca i32, align 4
156+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
157+
// OGCG: %[[TMP:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
158+
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP]], i32 1
159+
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4
160+
161+
void foo4() {
162+
vi4 a = { 1, 2, 3, 4 };
163+
164+
int idx = 2;
165+
int e = a[idx];
166+
}
167+
168+
// CIR: %[[VEC:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a", init]
169+
// CIR: %[[IDX:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["idx", init]
170+
// CIR: %[[INIT:.*]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
171+
// CIR: %[[CONST_1:.*]] = cir.const #cir.int<1> : !s32i
172+
// CIR: %[[CONST_2:.*]] = cir.const #cir.int<2> : !s32i
173+
// CIR: %[[CONST_3:.*]] = cir.const #cir.int<3> : !s32i
174+
// CIR: %[[CONST_4:.*]] = cir.const #cir.int<4> : !s32i
175+
// CIR: %[[VEC_VAL:.*]] = cir.vec.create(%[[CONST_1]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]] :
176+
// CIR-SAME: !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
177+
// CIR: cir.store %[[VEC_VAL]], %[[VEC]] : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
178+
// CIR: %[[CONST_IDX:.*]] = cir.const #cir.int<2> : !s32i
179+
// CIR: cir.store %[[CONST_IDX]], %[[IDX]] : !s32i, !cir.ptr<!s32i>
180+
// CIR: %[[TMP1:.*]] = cir.load %[[VEC]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
181+
// CIR: %[[TMP2:.*]] = cir.load %[[IDX]] : !cir.ptr<!s32i>, !s32i
182+
// CIR: %[[ELE:.*]] = cir.vec.extract %[[TMP1]][%[[TMP2]] : !s32i] : !cir.vector<4 x !s32i>
183+
// CIR: cir.store %[[ELE]], %[[INIT]] : !s32i, !cir.ptr<!s32i>
184+
185+
// LLVM: %[[VEC:.*]] = alloca <4 x i32>, i64 1, align 16
186+
// LLVM: %[[IDX:.*]] = alloca i32, i64 1, align 4
187+
// LLVM: %[[INIT:.*]] = alloca i32, i64 1, align 4
188+
// LLVM: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
189+
// LLVM: store i32 2, ptr %[[IDX]], align 4
190+
// LLVM: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
191+
// LLVM: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
192+
// LLVM: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
193+
// LLVM: store i32 %[[ELE]], ptr %[[INIT]], align 4
194+
195+
// OGCG: %[[VEC:.*]] = alloca <4 x i32>, align 16
196+
// OGCG: %[[IDX:.*]] = alloca i32, align 4
197+
// OGCG: %[[INIT:.*]] = alloca i32, align 4
198+
// OGCG: store <4 x i32> <i32 1, i32 2, i32 3, i32 4>, ptr %[[VEC]], align 16
199+
// OGCG: store i32 2, ptr %[[IDX]], align 4
200+
// OGCG: %[[TMP1:.*]] = load <4 x i32>, ptr %[[VEC]], align 16
201+
// OGCG: %[[TMP2:.*]] = load i32, ptr %[[IDX]], align 4
202+
// OGCG: %[[ELE:.*]] = extractelement <4 x i32> %[[TMP1]], i32 %[[TMP2]]
203+
// OGCG: store i32 %[[ELE]], ptr %[[INIT]], align 4

clang/test/CIR/IR/vector.cir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,4 +65,36 @@ cir.func @local_vector_create_test() {
6565
// CHECK: cir.return
6666
// CHECK: }
6767

68+
cir.func @vector_extract_element_test() {
69+
%0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
70+
%1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
71+
%2 = cir.const #cir.int<1> : !s32i
72+
%3 = cir.const #cir.int<2> : !s32i
73+
%4 = cir.const #cir.int<3> : !s32i
74+
%5 = cir.const #cir.int<4> : !s32i
75+
%6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
76+
cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
77+
%7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
78+
%8 = cir.const #cir.int<1> : !s32i
79+
%9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
80+
cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
81+
cir.return
82+
}
83+
84+
// CHECK: cir.func @vector_extract_element_test() {
85+
// CHECK: %0 = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["d", init]
86+
// CHECK: %1 = cir.alloca !s32i, !cir.ptr<!s32i>, ["e", init]
87+
// CHECK: %2 = cir.const #cir.int<1> : !s32i
88+
// CHECK: %3 = cir.const #cir.int<2> : !s32i
89+
// CHECK: %4 = cir.const #cir.int<3> : !s32i
90+
// CHECK: %5 = cir.const #cir.int<4> : !s32i
91+
// CHECK: %6 = cir.vec.create(%2, %3, %4, %5 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
92+
// CHECK: cir.store %6, %0 : !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>
93+
// CHECK: %7 = cir.load %0 : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
94+
// CHECK: %8 = cir.const #cir.int<1> : !s32i
95+
// CHECK: %9 = cir.vec.extract %7[%8 : !s32i] : !cir.vector<4 x !s32i>
96+
// CHECK: cir.store %9, %1 : !s32i, !cir.ptr<!s32i>
97+
// CHECK: cir.return
98+
// CHECK: }
99+
68100
}

0 commit comments

Comments
 (0)