Skip to content

Commit 99f1f0e

Browse files
authored
[CIR] Upstream comparison ops for VectorType (#140597)
This change adds support for Cmp ops for VectorType Issue #136487
1 parent 524ef16 commit 99f1f0e

File tree

7 files changed

+806
-5
lines changed

7 files changed

+806
-5
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2105,4 +2105,33 @@ def VecExtractOp : CIR_Op<"vec.extract", [Pure,
21052105
let hasFolder = 1;
21062106
}
21072107

2108+
//===----------------------------------------------------------------------===//
2109+
// VecCmpOp
2110+
//===----------------------------------------------------------------------===//
2111+
2112+
def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {
2113+
2114+
let summary = "Compare two vectors";
2115+
let description = [{
2116+
The `cir.vec.cmp` operation does an element-wise comparison of two vectors
2117+
of the same type. The result is a vector of the same size as the operands
2118+
whose element type is the signed integral type that is the same size as the
2119+
element type of the operands. The values in the result are 0 or -1.
2120+
2121+
```mlir
2122+
%eq = cir.vec.cmp(eq, %vec_a, %vec_b) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
2123+
%lt = cir.vec.cmp(lt, %vec_a, %vec_b) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
2124+
```
2125+
}];
2126+
2127+
let arguments = (ins Arg<CmpOpKind, "cmp kind">:$kind, CIR_VectorType:$lhs,
2128+
CIR_VectorType:$rhs);
2129+
let results = (outs CIR_VectorType:$result);
2130+
2131+
let assemblyFormat = [{
2132+
`(` $kind `,` $lhs `,` $rhs `)` `:` qualified(type($lhs)) `,`
2133+
qualified(type($result)) attr-dict
2134+
}];
2135+
}
2136+
21082137
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -786,22 +786,30 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
786786
}
787787
};
788788

789+
cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode());
789790
if (lhsTy->getAs<MemberPointerType>()) {
790791
assert(!cir::MissingFeatures::dataMemberType());
791792
assert(e->getOpcode() == BO_EQ || e->getOpcode() == BO_NE);
792793
mlir::Value lhs = cgf.emitScalarExpr(e->getLHS());
793794
mlir::Value rhs = cgf.emitScalarExpr(e->getRHS());
794-
cir::CmpOpKind kind = clangCmpToCIRCmp(e->getOpcode());
795795
result = builder.createCompare(loc, kind, lhs, rhs);
796796
} else if (!lhsTy->isAnyComplexType() && !rhsTy->isAnyComplexType()) {
797797
BinOpInfo boInfo = emitBinOps(e);
798798
mlir::Value lhs = boInfo.lhs;
799799
mlir::Value rhs = boInfo.rhs;
800800

801801
if (lhsTy->isVectorType()) {
802-
assert(!cir::MissingFeatures::vectorType());
803-
cgf.cgm.errorNYI(loc, "vector comparisons");
804-
result = builder.getBool(false, loc);
802+
if (!e->getType()->isVectorType()) {
803+
// If AltiVec, the comparison results in a numeric type, so we use
804+
// intrinsics comparing vectors and giving 0 or 1 as a result
805+
cgf.cgm.errorNYI(loc, "AltiVec comparison");
806+
} else {
807+
// Other kinds of vectors. Element-wise comparison returning
808+
// a vector.
809+
result = builder.create<cir::VecCmpOp>(
810+
cgf.getLoc(boInfo.loc), cgf.convertType(boInfo.fullType), kind,
811+
boInfo.lhs, boInfo.rhs);
812+
}
805813
} else if (boInfo.isFixedPointOp()) {
806814
assert(!cir::MissingFeatures::fixedPointType());
807815
cgf.cgm.errorNYI(loc, "fixed point comparisons");

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1716,7 +1716,8 @@ void ConvertCIRToLLVMPass::runOnOperation() {
17161716
CIRToLLVMUnaryOpLowering,
17171717
CIRToLLVMVecCreateOpLowering,
17181718
CIRToLLVMVecExtractOpLowering,
1719-
CIRToLLVMVecInsertOpLowering
1719+
CIRToLLVMVecInsertOpLowering,
1720+
CIRToLLVMVecCmpOpLowering
17201721
// clang-format on
17211722
>(converter, patterns.getContext());
17221723

@@ -1841,6 +1842,35 @@ mlir::LogicalResult CIRToLLVMVecInsertOpLowering::matchAndRewrite(
18411842
return mlir::success();
18421843
}
18431844

1845+
mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
1846+
cir::VecCmpOp op, OpAdaptor adaptor,
1847+
mlir::ConversionPatternRewriter &rewriter) const {
1848+
assert(mlir::isa<cir::VectorType>(op.getType()) &&
1849+
mlir::isa<cir::VectorType>(op.getLhs().getType()) &&
1850+
mlir::isa<cir::VectorType>(op.getRhs().getType()) &&
1851+
"Vector compare with non-vector type");
1852+
mlir::Type elementType = elementTypeIfVector(op.getLhs().getType());
1853+
mlir::Value bitResult;
1854+
if (auto intType = mlir::dyn_cast<cir::IntType>(elementType)) {
1855+
bitResult = rewriter.create<mlir::LLVM::ICmpOp>(
1856+
op.getLoc(),
1857+
convertCmpKindToICmpPredicate(op.getKind(), intType.isSigned()),
1858+
adaptor.getLhs(), adaptor.getRhs());
1859+
} else if (mlir::isa<cir::CIRFPTypeInterface>(elementType)) {
1860+
bitResult = rewriter.create<mlir::LLVM::FCmpOp>(
1861+
op.getLoc(), convertCmpKindToFCmpPredicate(op.getKind()),
1862+
adaptor.getLhs(), adaptor.getRhs());
1863+
} else {
1864+
return op.emitError() << "unsupported type for VecCmpOp: " << elementType;
1865+
}
1866+
1867+
// LLVM IR vector comparison returns a vector of i1. This one-bit vector
1868+
// must be sign-extended to the correct result type.
1869+
rewriter.replaceOpWithNewOp<mlir::LLVM::SExtOp>(
1870+
op, typeConverter->convertType(op.getType()), bitResult);
1871+
return mlir::success();
1872+
}
1873+
18441874
std::unique_ptr<mlir::Pass> createConvertCIRToLLVMPass() {
18451875
return std::make_unique<ConvertCIRToLLVMPass>();
18461876
}

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,16 @@ class CIRToLLVMVecInsertOpLowering
342342
mlir::ConversionPatternRewriter &) const override;
343343
};
344344

345+
class CIRToLLVMVecCmpOpLowering
346+
: public mlir::OpConversionPattern<cir::VecCmpOp> {
347+
public:
348+
using mlir::OpConversionPattern<cir::VecCmpOp>::OpConversionPattern;
349+
350+
mlir::LogicalResult
351+
matchAndRewrite(cir::VecCmpOp op, OpAdaptor,
352+
mlir::ConversionPatternRewriter &) const override;
353+
};
354+
345355
} // namespace direct
346356
} // namespace cir
347357

0 commit comments

Comments
 (0)