Skip to content

[Clang] support vector subscript expressions in constant evaluator (WIP) #76379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 57 additions & 4 deletions clang/lib/AST/ExprConstant.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ namespace {
ArraySize = 2;
MostDerivedLength = I + 1;
IsArray = true;
} else if (Type->isVectorType()) {
const VectorType *CT = Type->castAs<VectorType>();
Type = CT->getElementType();
ArraySize = CT->getNumElements();
MostDerivedLength = I + 1;
IsArray = true;
} else if (const FieldDecl *FD = getAsField(Path[I])) {
Type = FD->getType();
ArraySize = 0;
Expand Down Expand Up @@ -437,6 +443,15 @@ namespace {
MostDerivedArraySize = 2;
MostDerivedPathLength = Entries.size();
}
/// Update this designator to refer to the given vector component.
void addVectorUnchecked(const VectorType *VecTy) {
Entries.push_back(PathEntry::ArrayIndex(0));

MostDerivedType = VecTy->getElementType();
MostDerivedIsArrayElement = true;
MostDerivedArraySize = VecTy->getNumElements();
MostDerivedPathLength = Entries.size();
}
void diagnoseUnsizedArrayPointerArithmetic(EvalInfo &Info, const Expr *E);
void diagnosePointerArithmetic(EvalInfo &Info, const Expr *E,
const APSInt &N);
Expand Down Expand Up @@ -1732,6 +1747,10 @@ namespace {
if (checkSubobject(Info, E, Imag ? CSK_Imag : CSK_Real))
Designator.addComplexUnchecked(EltTy, Imag);
}
void addVector(EvalInfo &Info, const Expr *E, const VectorType *VecTy) {
if (checkSubobject(Info, E, CSK_ArrayIndex))
Designator.addVectorUnchecked(VecTy);
}
void clearIsNullPointer() {
IsNullPtr = false;
}
Expand Down Expand Up @@ -1890,6 +1909,8 @@ static bool EvaluateFixedPointOrInteger(const Expr *E, APFixedPoint &Result,
static bool EvaluateFixedPoint(const Expr *E, APFixedPoint &Result,
EvalInfo &Info);

static bool EvaluateVector(const Expr *E, APValue &Result, EvalInfo &Info);

//===----------------------------------------------------------------------===//
// Misc utilities
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -3278,6 +3299,19 @@ static bool HandleLValueComplexElement(EvalInfo &Info, const Expr *E,
return true;
}

static bool HandeLValueVectorComponent(EvalInfo &Info, const Expr *E,
LValue &LVal, const VectorType *VecTy,
APSInt &Adjustment) {
LVal.addVector(Info, E, VecTy);

CharUnits SizeOfComponent;
if (!HandleSizeof(Info, E->getExprLoc(), VecTy->getElementType(),
SizeOfComponent))
return false;
LVal.adjustOffsetAndIndex(Info, E, Adjustment, SizeOfComponent);
return true;
}

/// Try to evaluate the initializer for a variable declaration.
///
/// \param Info Information about the ongoing evaluation.
Expand Down Expand Up @@ -3718,7 +3752,8 @@ findSubobject(EvalInfo &Info, const Expr *E, const CompleteObject &Obj,
}

// If this is our last pass, check that the final object type is OK.
if (I == N || (I == N - 1 && ObjType->isAnyComplexType())) {
if (I == N || (I == N - 1 &&
(ObjType->isAnyComplexType() || ObjType->isVectorType()))) {
// Accesses to volatile objects are prohibited.
if (ObjType.isVolatileQualified() && isFormalAccess(handler.AccessKind)) {
if (Info.getLangOpts().CPlusPlus) {
Expand Down Expand Up @@ -3823,6 +3858,10 @@ findSubobject(EvalInfo &Info, const Expr *E, const CompleteObject &Obj,
return handler.found(Index ? O->getComplexFloatImag()
: O->getComplexFloatReal(), ObjType);
}
} else if (ObjType->isVectorType()) {
// Next Subobject is a vector element
uint64_t Index = Sub.Entries[I].getAsArrayIndex();
O = &O->getVectorElt(Index);
} else if (const FieldDecl *Field = getAsField(Sub.Entries[I])) {
if (Field->isMutable() &&
!Obj.mayAccessMutableMembers(Info, handler.AccessKind)) {
Expand Down Expand Up @@ -8756,14 +8795,28 @@ bool LValueExprEvaluator::VisitMemberExpr(const MemberExpr *E) {
}

bool LValueExprEvaluator::VisitArraySubscriptExpr(const ArraySubscriptExpr *E) {
// FIXME: Deal with vectors as array subscript bases.
if (E->getBase()->getType()->isVectorType() ||
E->getBase()->getType()->isSveVLSBuiltinType())

if (E->getBase()->getType()->isSveVLSBuiltinType())
return Error(E);

APSInt Index;
bool Success = true;

if (E->getBase()->getType()->isVectorType()) {
for (const Expr *SubExpr : {E->getLHS(), E->getRHS()}) {
Success = (SubExpr == E->getBase())
? EvaluateLValue(SubExpr, Result, Info, true)
: EvaluateInteger(SubExpr, Index, Info);
}
if (Success) {
Success = HandeLValueVectorComponent(
Info, E, Result, E->getBase()->getType()->castAs<VectorType>(),
Index);
return Success;
}
return false;
}

// C++17's rules require us to evaluate the LHS first, regardless of which
// side is the base.
for (const Expr *SubExpr : {E->getLHS(), E->getRHS()}) {
Expand Down
12 changes: 4 additions & 8 deletions clang/test/CodeGenCXX/temporaries.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -673,16 +673,12 @@ namespace Vector {
vi4a v;
vi4b w;
};
// CHECK: alloca
// CHECK: extractelement
// CHECK: store i32 {{.*}}, ptr @_ZGRN6Vector1rE_
// CHECK: store ptr @_ZGRN6Vector1rE_, ptr @_ZN6Vector1rE,
// @_ZGRN6Vector1rE_ = internal global i32 0, align 4
// @_ZN6Vector1rE = constant ptr @_ZGRN6Vector1rE_, align 8
int &&r = S().v[1];

// CHECK: alloca
// CHECK: extractelement
// CHECK: store i32 {{.*}}, ptr @_ZGRN6Vector1sE_
// CHECK: store ptr @_ZGRN6Vector1sE_, ptr @_ZN6Vector1sE,
// @_ZGRN6Vector1sE_ = internal global i32 0, align 4
// @_ZN6Vector1sE = constant ptr @_ZGRN6Vector1sE_, align 8
int &&s = S().w[1];
// FIXME PR16204: The following code leads to an assertion in Sema.
//int &&s = S().w.y;
Expand Down
Loading