Skip to content

Commit a7f4044

Browse files
authored
[clang][SME] Emit error for OpenMP captured regions in SME functions (#124590)
Currently, these generate incorrect code, as streaming/SME attributes are not propagated to the outlined function. As we've yet to work on mixing OpenMP and streaming functions (and determine how they should interact with OpenMP's runtime), we think it is best to disallow this for now.
1 parent 4310245 commit a7f4044

File tree

6 files changed

+125
-14
lines changed

6 files changed

+125
-14
lines changed

clang/include/clang/AST/Decl.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5139,6 +5139,12 @@ static constexpr StringRef getOpenMPVariantManglingSeparatorStr() {
51395139
bool IsArmStreamingFunction(const FunctionDecl *FD,
51405140
bool IncludeLocallyStreaming);
51415141

5142+
/// Returns whether the given FunctionDecl has Arm ZA state.
5143+
bool hasArmZAState(const FunctionDecl *FD);
5144+
5145+
/// Returns whether the given FunctionDecl has Arm ZT0 state.
5146+
bool hasArmZT0State(const FunctionDecl *FD);
5147+
51425148
} // namespace clang
51435149

51445150
#endif // LLVM_CLANG_AST_DECL_H

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3870,6 +3870,9 @@ def err_sme_definition_using_za_in_non_sme_target : Error<
38703870
"function using ZA state requires 'sme'">;
38713871
def err_sme_definition_using_zt0_in_non_sme2_target : Error<
38723872
"function using ZT0 state requires 'sme2'">;
3873+
def err_sme_openmp_captured_region : Error<
3874+
"OpenMP captured regions are not yet supported in "
3875+
"%select{streaming functions|functions with ZA state|functions with ZT0 state}0">;
38733876
def warn_sme_streaming_pass_return_vl_to_non_streaming : Warning<
38743877
"%select{returning|passing}0 a VL-dependent argument %select{from|to}0 a function with a different"
38753878
" streaming-mode is undefined behaviour when the streaming and non-streaming vector lengths are different at runtime">,

clang/lib/AST/Decl.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5845,3 +5845,17 @@ bool clang::IsArmStreamingFunction(const FunctionDecl *FD,
58455845

58465846
return false;
58475847
}
5848+
5849+
bool clang::hasArmZAState(const FunctionDecl *FD) {
5850+
const auto *T = FD->getType()->getAs<FunctionProtoType>();
5851+
return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
5852+
FunctionType::ARM_None) ||
5853+
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
5854+
}
5855+
5856+
bool clang::hasArmZT0State(const FunctionDecl *FD) {
5857+
const auto *T = FD->getType()->getAs<FunctionProtoType>();
5858+
return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
5859+
FunctionType::ARM_None) ||
5860+
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
5861+
}

clang/lib/Sema/SemaARM.cpp

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -624,20 +624,6 @@ static bool checkArmStreamingBuiltin(Sema &S, CallExpr *TheCall,
624624
return true;
625625
}
626626

627-
static bool hasArmZAState(const FunctionDecl *FD) {
628-
const auto *T = FD->getType()->getAs<FunctionProtoType>();
629-
return (T && FunctionType::getArmZAState(T->getAArch64SMEAttributes()) !=
630-
FunctionType::ARM_None) ||
631-
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZA());
632-
}
633-
634-
static bool hasArmZT0State(const FunctionDecl *FD) {
635-
const auto *T = FD->getType()->getAs<FunctionProtoType>();
636-
return (T && FunctionType::getArmZT0State(T->getAArch64SMEAttributes()) !=
637-
FunctionType::ARM_None) ||
638-
(FD->hasAttr<ArmNewAttr>() && FD->getAttr<ArmNewAttr>()->isNewZT0());
639-
}
640-
641627
static ArmSMEState getSMEState(unsigned BuiltinID) {
642628
switch (BuiltinID) {
643629
default:

clang/lib/Sema/SemaStmt.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4568,9 +4568,27 @@ buildCapturedStmtCaptureList(Sema &S, CapturedRegionScopeInfo *RSI,
45684568
return false;
45694569
}
45704570

4571+
static std::optional<int>
4572+
isOpenMPCapturedRegionInArmSMEFunction(Sema const &S, CapturedRegionKind Kind) {
4573+
if (!S.getLangOpts().OpenMP || Kind != CR_OpenMP)
4574+
return {};
4575+
if (const FunctionDecl *FD = S.getCurFunctionDecl(/*AllowLambda=*/true)) {
4576+
if (IsArmStreamingFunction(FD, /*IncludeLocallyStreaming=*/true))
4577+
return /* in streaming functions */ 0;
4578+
if (hasArmZAState(FD))
4579+
return /* in functions with ZA state */ 1;
4580+
if (hasArmZT0State(FD))
4581+
return /* in fuctions with ZT0 state */ 2;
4582+
}
4583+
return {};
4584+
}
4585+
45714586
void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
45724587
CapturedRegionKind Kind,
45734588
unsigned NumParams) {
4589+
if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
4590+
Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;
4591+
45744592
CapturedDecl *CD = nullptr;
45754593
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, NumParams);
45764594

@@ -4602,6 +4620,9 @@ void Sema::ActOnCapturedRegionStart(SourceLocation Loc, Scope *CurScope,
46024620
CapturedRegionKind Kind,
46034621
ArrayRef<CapturedParamNameType> Params,
46044622
unsigned OpenMPCaptureLevel) {
4623+
if (auto ErrorIndex = isOpenMPCapturedRegionInArmSMEFunction(*this, Kind))
4624+
Diag(Loc, diag::err_sme_openmp_captured_region) << *ErrorIndex;
4625+
46054626
CapturedDecl *CD = nullptr;
46064627
RecordDecl *RD = CreateCapturedStmtRecordDecl(CD, Loc, Params.size());
46074628

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify %s
2+
// RUN: %clang_cc1 -triple aarch64-none-linux-gnu -target-feature +sme -target-feature +sme2 -fopenmp -fsyntax-only -verify=expected-cpp -x c++ %s
3+
4+
int compute(int);
5+
6+
void streaming_openmp_captured_region(int * out) __arm_streaming {
7+
// expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}}
8+
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}}
9+
#pragma omp parallel for num_threads(32)
10+
for (int ci = 0; ci < 8; ci++) {
11+
out[ci] = compute(ci);
12+
}
13+
}
14+
15+
__arm_locally_streaming void locally_streaming_openmp_captured_region(int * out) {
16+
// expected-error@+2 {{OpenMP captured regions are not yet supported in streaming functions}}
17+
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in streaming functions}}
18+
#pragma omp parallel for num_threads(32)
19+
for (int ci = 0; ci < 8; ci++) {
20+
out[ci] = compute(ci);
21+
}
22+
}
23+
24+
void za_state_captured_region(int * out) __arm_inout("za") {
25+
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
26+
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
27+
#pragma omp parallel for num_threads(32)
28+
for (int ci = 0; ci < 8; ci++) {
29+
out[ci] = compute(ci);
30+
}
31+
}
32+
33+
__arm_new("za") void new_za_state_captured_region(int * out) {
34+
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZA state}}
35+
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZA state}}
36+
#pragma omp parallel for num_threads(32)
37+
for (int ci = 0; ci < 8; ci++) {
38+
out[ci] = compute(ci);
39+
}
40+
}
41+
42+
void zt0_state_openmp_captured_region(int * out) __arm_inout("zt0") {
43+
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
44+
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
45+
#pragma omp parallel for num_threads(32)
46+
for (int ci = 0; ci < 8; ci++) {
47+
out[ci] = compute(ci);
48+
}
49+
}
50+
51+
__arm_new("zt0") void new_zt0_state_openmp_captured_region(int * out) {
52+
// expected-error@+2 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
53+
// expected-cpp-error@+1 {{OpenMP captured regions are not yet supported in functions with ZT0 state}}
54+
#pragma omp parallel for num_threads(32)
55+
for (int ci = 0; ci < 8; ci++) {
56+
out[ci] = compute(ci);
57+
}
58+
}
59+
60+
/// OpenMP directives that don't create a captured region are okay:
61+
62+
void streaming_function_openmp(int * out) __arm_streaming __arm_inout("za", "zt0") {
63+
#pragma omp unroll full
64+
for (int ci = 0; ci < 8; ci++) {
65+
out[ci] = compute(ci);
66+
}
67+
}
68+
69+
__arm_locally_streaming void locally_streaming_openmp(int * out) __arm_inout("za", "zt0") {
70+
#pragma omp unroll full
71+
for (int ci = 0; ci < 8; ci++) {
72+
out[ci] = compute(ci);
73+
}
74+
}
75+
76+
__arm_new("za", "zt0") void arm_new_openmp(int * out) {
77+
#pragma omp unroll full
78+
for (int ci = 0; ci < 8; ci++) {
79+
out[ci] = compute(ci);
80+
}
81+
}

0 commit comments

Comments
 (0)