@@ -8079,6 +8079,15 @@ static OpenCLParamType getOpenCLKernelParameterType(Sema &S, QualType PT) {
8079
8079
if (PT->isRecordType())
8080
8080
return RecordKernelParam;
8081
8081
8082
+ // Look into an array argument to check if it has a forbidden type.
8083
+ if (PT->isArrayType()) {
8084
+ const Type *UnderlyingTy = PT->getPointeeOrArrayElementType();
8085
+ // Call ourself to check an underlying type of an array. Since the
8086
+ // getPointeeOrArrayElementType returns an innermost type which is not an
8087
+ // array, this recusive call only happens once.
8088
+ return getOpenCLKernelParameterType(S, QualType(UnderlyingTy, 0));
8089
+ }
8090
+
8082
8091
return ValidKernelParam;
8083
8092
}
8084
8093
@@ -8146,9 +8155,14 @@ static void checkIsValidOpenCLKernelParameter(
8146
8155
SmallVector<const FieldDecl *, 4> HistoryStack;
8147
8156
HistoryStack.push_back(nullptr);
8148
8157
8149
- const RecordDecl *PD = PT->castAs<RecordType>()->getDecl();
8150
- VisitStack.push_back(PD);
8158
+ // At this point we already handled everything except of a RecordType or
8159
+ // an ArrayType of a RecordType.
8160
+ assert((PT->isArrayType() || PT->isRecordType()) && "Unexpected type.");
8161
+ const RecordType *RecTy =
8162
+ PT->getPointeeOrArrayElementType()->getAs<RecordType>();
8163
+ const RecordDecl *OrigRecDecl = RecTy->getDecl();
8151
8164
8165
+ VisitStack.push_back(RecTy->getDecl());
8152
8166
assert(VisitStack.back() && "First decl null?");
8153
8167
8154
8168
do {
@@ -8167,7 +8181,15 @@ static void checkIsValidOpenCLKernelParameter(
8167
8181
const RecordDecl *RD;
8168
8182
if (const FieldDecl *Field = dyn_cast<FieldDecl>(Next)) {
8169
8183
HistoryStack.push_back(Field);
8170
- RD = Field->getType()->castAs<RecordType>()->getDecl();
8184
+
8185
+ QualType FieldTy = Field->getType();
8186
+ // Other field types (known to be valid or invalid) are handled while we
8187
+ // walk around RecordDecl::fields().
8188
+ assert((FieldTy->isArrayType() || FieldTy->isRecordType()) &&
8189
+ "Unexpected type.");
8190
+ const Type *FieldRecTy = FieldTy->getPointeeOrArrayElementType();
8191
+
8192
+ RD = FieldRecTy->castAs<RecordType>()->getDecl();
8171
8193
} else {
8172
8194
RD = cast<RecordDecl>(Next);
8173
8195
}
@@ -8204,8 +8226,8 @@ static void checkIsValidOpenCLKernelParameter(
8204
8226
S.Diag(Param->getLocation(), diag::err_bad_kernel_param_type) << PT;
8205
8227
}
8206
8228
8207
- S.Diag(PD ->getLocation(), diag::note_within_field_of_type)
8208
- << PD ->getDeclName();
8229
+ S.Diag(OrigRecDecl ->getLocation(), diag::note_within_field_of_type)
8230
+ << OrigRecDecl ->getDeclName();
8209
8231
8210
8232
// We have an error, now let's go back up through history and show where
8211
8233
// the offending field came from
0 commit comments