Skip to content

[clang] Improve diagnostics for vector builtins #125673

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 18, 2025

Conversation

frasercrmck
Copy link
Contributor

This commit improves the diagnostics for vector (elementwise) builtins in a couple of ways.

It primarily provides more precise type-checking diagnostics for builtins with specific type requirements. Previously many builtins were receiving a catch-all diagnostic suggesting types which aren't valid.

It also makes consistent the type-checking behaviour between various binary and ternary builtins. The binary builtins would check for mismatched argument types before specific type requirements, whereas ternary builtins would perform the checks in the reverse order. The binary builtins now behave as the ternary ones do.

@llvmbot llvmbot added clang Clang issues not falling into any other category backend:RISC-V clang:frontend Language frontend issues, e.g. anything involving "Sema" HLSL HLSL Language Support labels Feb 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-clang

@llvm/pr-subscribers-backend-risc-v

Author: Fraser Cormack (frasercrmck)

Changes

This commit improves the diagnostics for vector (elementwise) builtins in a couple of ways.

It primarily provides more precise type-checking diagnostics for builtins with specific type requirements. Previously many builtins were receiving a catch-all diagnostic suggesting types which aren't valid.

It also makes consistent the type-checking behaviour between various binary and ternary builtins. The binary builtins would check for mismatched argument types before specific type requirements, whereas ternary builtins would perform the checks in the reverse order. The binary builtins now behave as the ternary ones do.


Patch is 42.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125673.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2-1)
  • (modified) clang/include/clang/Sema/Sema.h (+21-6)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+72-113)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+8-4)
  • (modified) clang/test/Sema/aarch64-sve-vector-exp-ops.c (+2-2)
  • (modified) clang/test/Sema/aarch64-sve-vector-log-ops.c (+3-3)
  • (modified) clang/test/Sema/aarch64-sve-vector-trig-ops.c (+9-9)
  • (modified) clang/test/Sema/builtins-elementwise-math.c (+35-32)
  • (modified) clang/test/Sema/riscv-rvv-vector-exp-ops.c (+2-2)
  • (modified) clang/test/Sema/riscv-rvv-vector-log-ops.c (+3-3)
  • (modified) clang/test/Sema/riscv-rvv-vector-trig-ops.c (+9-9)
  • (modified) clang/test/SemaHLSL/BuiltIns/exp-errors.hlsl (+1-1)
  • (modified) clang/test/SemaHLSL/BuiltIns/reversebits-errors.hlsl (+1-1)
  • (modified) clang/test/SemaHLSL/BuiltIns/round-errors.hlsl (+1-1)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 00a94eb7a303671..e43d7293a1b45a5 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12403,7 +12403,8 @@ def err_builtin_invalid_arg_type: Error <
   "a vector of integers|"
   "an unsigned integer|"
   "an 'int'|"
-  "a vector of floating points}1 (was %2)">;
+  "a vector of floating points|"
+  "an integer or vector of integers}1 (was %2)">;
 
 def err_builtin_matrix_disabled: Error<
   "matrix types extension is disabled. Pass -fenable-matrix to enable it">;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 472a0e25adc9752..5e7a2d70df2dfb8 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2331,9 +2331,18 @@ class Sema final : public SemaBase {
   bool CheckFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,
                          const FunctionProtoType *Proto);
 
+  enum class EltwiseBuiltinArgTyRestriction {
+    None,
+    FloatTy,
+    IntegerTy,
+    SignedIntOrFloatTy,
+  };
+
   /// \param FPOnly restricts the arguments to floating-point types.
-  std::optional<QualType> BuiltinVectorMath(CallExpr *TheCall,
-                                            bool FPOnly = false);
+  std::optional<QualType>
+  BuiltinVectorMath(CallExpr *TheCall,
+                    EltwiseBuiltinArgTyRestriction ArgTyRestr =
+                        EltwiseBuiltinArgTyRestriction::None);
   bool BuiltinVectorToScalarMath(CallExpr *TheCall);
 
   void checkLifetimeCaptureBy(FunctionDecl *FDecl, bool IsMemberFunction,
@@ -2418,9 +2427,13 @@ class Sema final : public SemaBase {
                                bool *ICContext = nullptr,
                                bool IsListInit = false);
 
-  bool BuiltinElementwiseTernaryMath(CallExpr *TheCall,
-                                     bool CheckForFloatArgs = true);
-  bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
+  bool
+  BuiltinElementwiseTernaryMath(CallExpr *TheCall,
+                                EltwiseBuiltinArgTyRestriction ArgTyRestr =
+                                    EltwiseBuiltinArgTyRestriction::FloatTy);
+  bool PrepareBuiltinElementwiseMathOneArgCall(
+      CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr =
+                             EltwiseBuiltinArgTyRestriction::None);
 
 private:
   void CheckArrayAccess(const Expr *BaseExpr, const Expr *IndexExpr,
@@ -2529,7 +2542,9 @@ class Sema final : public SemaBase {
                                  AtomicExpr::AtomicOp Op);
 
   /// \param FPOnly restricts the arguments to floating-point types.
-  bool BuiltinElementwiseMath(CallExpr *TheCall, bool FPOnly = false);
+  bool BuiltinElementwiseMath(CallExpr *TheCall,
+                              EltwiseBuiltinArgTyRestriction ArgTyRestr =
+                                  EltwiseBuiltinArgTyRestriction::None);
   bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
 
   bool BuiltinNonDeterministicValue(CallExpr *TheCall);
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 61b2c8cf1cad72c..6dacce85d1719ca 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -1968,26 +1968,40 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
 // Check if \p Ty is a valid type for the elementwise math builtins. If it is
 // not a valid type, emit an error message and return true. Otherwise return
 // false.
-static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
-                                        QualType ArgTy, int ArgIndex) {
-  if (!ArgTy->getAs<VectorType>() &&
-      !ConstantMatrixType::isValidElementType(ArgTy)) {
-    return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
-  }
-
-  return false;
-}
-
-static bool checkFPMathBuiltinElementType(Sema &S, SourceLocation Loc,
-                                          QualType ArgTy, int ArgIndex) {
+static bool
+checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
+                            Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr,
+                            int ArgOrdinal) {
   QualType EltTy = ArgTy;
   if (auto *VecTy = EltTy->getAs<VectorType>())
     EltTy = VecTy->getElementType();
 
-  if (!EltTy->isRealFloatingType()) {
-    return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgIndex << /* vector or float ty*/ 5 << ArgTy;
+  switch (ArgTyRestr) {
+  case Sema::EltwiseBuiltinArgTyRestriction::None:
+    if (!ArgTy->getAs<VectorType>() &&
+        !ConstantMatrixType::isValidElementType(ArgTy)) {
+      return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* vector, integer or float ty*/ 0 << ArgTy;
+    }
+    break;
+  case Sema::EltwiseBuiltinArgTyRestriction::FloatTy:
+    if (!EltTy->isRealFloatingType()) {
+      return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* vector or float ty*/ 5 << ArgTy;
+    }
+    break;
+  case Sema::EltwiseBuiltinArgTyRestriction::IntegerTy:
+    if (!EltTy->isIntegerType()) {
+      return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* vector or int ty*/ 10 << ArgTy;
+    }
+    break;
+  case Sema::EltwiseBuiltinArgTyRestriction::SignedIntOrFloatTy:
+    if (EltTy->isUnsignedIntegerType()) {
+      return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << 1 << /* signed integer or float ty*/ 3 << ArgTy;
+    }
+    break;
   }
 
   return false;
@@ -2694,23 +2708,11 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
 
   // __builtin_elementwise_abs restricts the element type to signed integers or
   // floating point types only.
-  case Builtin::BI__builtin_elementwise_abs: {
-    if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+  case Builtin::BI__builtin_elementwise_abs:
+    if (PrepareBuiltinElementwiseMathOneArgCall(
+            TheCall, EltwiseBuiltinArgTyRestriction::SignedIntOrFloatTy))
       return ExprError();
-
-    QualType ArgTy = TheCall->getArg(0)->getType();
-    QualType EltTy = ArgTy;
-
-    if (auto *VecTy = EltTy->getAs<VectorType>())
-      EltTy = VecTy->getElementType();
-    if (EltTy->isUnsignedIntegerType()) {
-      Diag(TheCall->getArg(0)->getBeginLoc(),
-           diag::err_builtin_invalid_arg_type)
-          << 1 << /* signed integer or float ty*/ 3 << ArgTy;
-      return ExprError();
-    }
     break;
-  }
 
   // These builtins restrict the element type to floating point
   // types only.
@@ -2736,21 +2738,15 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
   case Builtin::BI__builtin_elementwise_tan:
   case Builtin::BI__builtin_elementwise_tanh:
   case Builtin::BI__builtin_elementwise_trunc:
-  case Builtin::BI__builtin_elementwise_canonicalize: {
-    if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
-      return ExprError();
-
-    QualType ArgTy = TheCall->getArg(0)->getType();
-    if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
-                                      ArgTy, 1))
+  case Builtin::BI__builtin_elementwise_canonicalize:
+    if (PrepareBuiltinElementwiseMathOneArgCall(
+            TheCall, EltwiseBuiltinArgTyRestriction::FloatTy))
       return ExprError();
     break;
-  }
-  case Builtin::BI__builtin_elementwise_fma: {
+  case Builtin::BI__builtin_elementwise_fma:
     if (BuiltinElementwiseTernaryMath(TheCall))
       return ExprError();
     break;
-  }
 
   // These builtins restrict the element type to floating point
   // types only, and take in two arguments.
@@ -2758,59 +2754,30 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
   case Builtin::BI__builtin_elementwise_maximum:
   case Builtin::BI__builtin_elementwise_atan2:
   case Builtin::BI__builtin_elementwise_fmod:
-  case Builtin::BI__builtin_elementwise_pow: {
-    if (BuiltinElementwiseMath(TheCall, /*FPOnly=*/true))
+  case Builtin::BI__builtin_elementwise_pow:
+    if (BuiltinElementwiseMath(TheCall,
+                               EltwiseBuiltinArgTyRestriction::FloatTy))
       return ExprError();
     break;
-  }
-
   // These builtins restrict the element type to integer
   // types only.
   case Builtin::BI__builtin_elementwise_add_sat:
-  case Builtin::BI__builtin_elementwise_sub_sat: {
-    if (BuiltinElementwiseMath(TheCall))
-      return ExprError();
-
-    const Expr *Arg = TheCall->getArg(0);
-    QualType ArgTy = Arg->getType();
-    QualType EltTy = ArgTy;
-
-    if (auto *VecTy = EltTy->getAs<VectorType>())
-      EltTy = VecTy->getElementType();
-
-    if (!EltTy->isIntegerType()) {
-      Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 1 << /* integer ty */ 6 << ArgTy;
+  case Builtin::BI__builtin_elementwise_sub_sat:
+    if (BuiltinElementwiseMath(TheCall,
+                               EltwiseBuiltinArgTyRestriction::IntegerTy))
       return ExprError();
-    }
     break;
-  }
-
   case Builtin::BI__builtin_elementwise_min:
   case Builtin::BI__builtin_elementwise_max:
     if (BuiltinElementwiseMath(TheCall))
       return ExprError();
     break;
   case Builtin::BI__builtin_elementwise_popcount:
-  case Builtin::BI__builtin_elementwise_bitreverse: {
-    if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
-      return ExprError();
-
-    const Expr *Arg = TheCall->getArg(0);
-    QualType ArgTy = Arg->getType();
-    QualType EltTy = ArgTy;
-
-    if (auto *VecTy = EltTy->getAs<VectorType>())
-      EltTy = VecTy->getElementType();
-
-    if (!EltTy->isIntegerType()) {
-      Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 1 << /* integer ty */ 6 << ArgTy;
+  case Builtin::BI__builtin_elementwise_bitreverse:
+    if (PrepareBuiltinElementwiseMathOneArgCall(
+            TheCall, EltwiseBuiltinArgTyRestriction::IntegerTy))
       return ExprError();
-    }
     break;
-  }
-
   case Builtin::BI__builtin_elementwise_copysign: {
     if (checkArgCount(TheCall, 2))
       return ExprError();
@@ -2822,10 +2789,12 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
 
     QualType MagnitudeTy = Magnitude.get()->getType();
     QualType SignTy = Sign.get()->getType();
-    if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
-                                      MagnitudeTy, 1) ||
-        checkFPMathBuiltinElementType(*this, TheCall->getArg(1)->getBeginLoc(),
-                                      SignTy, 2)) {
+    if (checkMathBuiltinElementType(
+            *this, TheCall->getArg(0)->getBeginLoc(), MagnitudeTy,
+            EltwiseBuiltinArgTyRestriction::FloatTy, 1) ||
+        checkMathBuiltinElementType(
+            *this, TheCall->getArg(1)->getBeginLoc(), SignTy,
+            EltwiseBuiltinArgTyRestriction::FloatTy, 2)) {
       return ExprError();
     }
 
@@ -14661,7 +14630,8 @@ static ExprResult BuiltinVectorMathConversions(Sema &S, Expr *E) {
   return S.UsualUnaryFPConversions(Res.get());
 }
 
-bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
+bool Sema::PrepareBuiltinElementwiseMathOneArgCall(
+    CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr) {
   if (checkArgCount(TheCall, 1))
     return true;
 
@@ -14672,15 +14642,17 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
   TheCall->setArg(0, A.get());
   QualType TyA = A.get()->getType();
 
-  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
+  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA,
+                                  ArgTyRestr, 1))
     return true;
 
   TheCall->setType(TyA);
   return false;
 }
 
-bool Sema::BuiltinElementwiseMath(CallExpr *TheCall, bool FPOnly) {
-  if (auto Res = BuiltinVectorMath(TheCall, FPOnly); Res.has_value()) {
+bool Sema::BuiltinElementwiseMath(CallExpr *TheCall,
+                                  EltwiseBuiltinArgTyRestriction ArgTyRestr) {
+  if (auto Res = BuiltinVectorMath(TheCall, ArgTyRestr); Res.has_value()) {
     TheCall->setType(*Res);
     return false;
   }
@@ -14713,8 +14685,9 @@ static bool checkBuiltinVectorMathMixedEnums(Sema &S, Expr *LHS, Expr *RHS,
   return false;
 }
 
-std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
-                                                bool FPOnly) {
+std::optional<QualType>
+Sema::BuiltinVectorMath(CallExpr *TheCall,
+                        EltwiseBuiltinArgTyRestriction ArgTyRestr) {
   if (checkArgCount(TheCall, 2))
     return std::nullopt;
 
@@ -14735,26 +14708,21 @@ std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
   QualType TyA = Args[0]->getType();
   QualType TyB = Args[1]->getType();
 
+  if (checkMathBuiltinElementType(*this, LocA, TyA, ArgTyRestr, 1))
+    return std::nullopt;
+
   if (TyA.getCanonicalType() != TyB.getCanonicalType()) {
     Diag(LocA, diag::err_typecheck_call_different_arg_types) << TyA << TyB;
     return std::nullopt;
   }
 
-  if (FPOnly) {
-    if (checkFPMathBuiltinElementType(*this, LocA, TyA, 1))
-      return std::nullopt;
-  } else {
-    if (checkMathBuiltinElementType(*this, LocA, TyA, 1))
-      return std::nullopt;
-  }
-
   TheCall->setArg(0, Args[0]);
   TheCall->setArg(1, Args[1]);
   return TyA;
 }
 
-bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
-                                         bool CheckForFloatArgs) {
+bool Sema::BuiltinElementwiseTernaryMath(
+    CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr) {
   if (checkArgCount(TheCall, 3))
     return true;
 
@@ -14774,20 +14742,11 @@ bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
     Args[I] = Converted.get();
   }
 
-  if (CheckForFloatArgs) {
-    int ArgOrdinal = 1;
-    for (Expr *Arg : Args) {
-      if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
-                                        Arg->getType(), ArgOrdinal++))
-        return true;
-    }
-  } else {
-    int ArgOrdinal = 1;
-    for (Expr *Arg : Args) {
-      if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
-                                      ArgOrdinal++))
-        return true;
-    }
+  int ArgOrdinal = 1;
+  for (Expr *Arg : Args) {
+    if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
+                                    ArgTyRestr, ArgOrdinal++))
+      return true;
   }
 
   for (int I = 1; I < 3; ++I) {
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index d748c10455289b9..b45879314727048 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2259,8 +2259,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
       return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
-            TheCall, /*CheckForFloatArgs*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()))
+            TheCall, /*ArgTyRestr*/
+            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
+                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
+                : Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
@@ -2393,8 +2395,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
       return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
-            TheCall, /*CheckForFloatArgs*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()))
+            TheCall, /*ArgTyRestr*/
+            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
+                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
+                : Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
diff --git a/clang/test/Sema/aarch64-sve-vector-exp-ops.c b/clang/test/Sema/aarch64-sve-vector-exp-ops.c
index f2bba8c7eeb196d..4b411babbc3471e 100644
--- a/clang/test/Sema/aarch64-sve-vector-exp-ops.c
+++ b/clang/test/Sema/aarch64-sve-vector-exp-ops.c
@@ -7,11 +7,11 @@
 svfloat32_t test_exp_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_exp(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_exp2_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_exp2(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
diff --git a/clang/test/Sema/aarch64-sve-vector-log-ops.c b/clang/test/Sema/aarch64-sve-vector-log-ops.c
index ef16e8581844d7f..bc81323b560c9c4 100644
--- a/clang/test/Sema/aarch64-sve-vector-log-ops.c
+++ b/clang/test/Sema/aarch64-sve-vector-log-ops.c
@@ -7,17 +7,17 @@
 svfloat32_t test_log_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_log(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_log10_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_log10(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_log2_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_log2(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
diff --git a/clang/test/Sema/aarch64-sve-vector-trig-ops.c b/clang/test/Sema/aarch64-sve-vector-trig-ops.c
index 3fe6834be2e0b7f..46df63cbba42bf7 100644
--- a/clang/test/Sema/aarch64-sve-vector-trig-ops.c
+++ b/clang/test/Sema/aarch64-sve-vector-trig-ops.c
@@ -7,19 +7,19 @@
 svfloat32_t test_asin_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_asin(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_acos_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_acos(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_atan_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_atan(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_atan2_vv_i8mf8(svfloat32_t v) {
@@ -31,35 +31,35 @@ svfloat32_t test_atan2_vv_i8mf8(svfloat32_t v) {
 svfloat32_t test_sin_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_sin(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_cos_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_cos(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_tan_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_tan(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_sinh_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_sinh(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_cosh_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_cosh(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Feb 4, 2025

@llvm/pr-subscribers-hlsl

Author: Fraser Cormack (frasercrmck)

Changes

This commit improves the diagnostics for vector (elementwise) builtins in a couple of ways.

It primarily provides more precise type-checking diagnostics for builtins with specific type requirements. Previously many builtins were receiving a catch-all diagnostic suggesting types which aren't valid.

It also makes consistent the type-checking behaviour between various binary and ternary builtins. The binary builtins would check for mismatched argument types before specific type requirements, whereas ternary builtins would perform the checks in the reverse order. The binary builtins now behave as the ternary ones do.


Patch is 42.59 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/125673.diff

14 Files Affected:

  • (modified) clang/include/clang/Basic/DiagnosticSemaKinds.td (+2-1)
  • (modified) clang/include/clang/Sema/Sema.h (+21-6)
  • (modified) clang/lib/Sema/SemaChecking.cpp (+72-113)
  • (modified) clang/lib/Sema/SemaHLSL.cpp (+8-4)
  • (modified) clang/test/Sema/aarch64-sve-vector-exp-ops.c (+2-2)
  • (modified) clang/test/Sema/aarch64-sve-vector-log-ops.c (+3-3)
  • (modified) clang/test/Sema/aarch64-sve-vector-trig-ops.c (+9-9)
  • (modified) clang/test/Sema/builtins-elementwise-math.c (+35-32)
  • (modified) clang/test/Sema/riscv-rvv-vector-exp-ops.c (+2-2)
  • (modified) clang/test/Sema/riscv-rvv-vector-log-ops.c (+3-3)
  • (modified) clang/test/Sema/riscv-rvv-vector-trig-ops.c (+9-9)
  • (modified) clang/test/SemaHLSL/BuiltIns/exp-errors.hlsl (+1-1)
  • (modified) clang/test/SemaHLSL/BuiltIns/reversebits-errors.hlsl (+1-1)
  • (modified) clang/test/SemaHLSL/BuiltIns/round-errors.hlsl (+1-1)
diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td
index 00a94eb7a303671..e43d7293a1b45a5 100644
--- a/clang/include/clang/Basic/DiagnosticSemaKinds.td
+++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td
@@ -12403,7 +12403,8 @@ def err_builtin_invalid_arg_type: Error <
   "a vector of integers|"
   "an unsigned integer|"
   "an 'int'|"
-  "a vector of floating points}1 (was %2)">;
+  "a vector of floating points|"
+  "an integer or vector of integers}1 (was %2)">;
 
 def err_builtin_matrix_disabled: Error<
   "matrix types extension is disabled. Pass -fenable-matrix to enable it">;
diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h
index 472a0e25adc9752..5e7a2d70df2dfb8 100644
--- a/clang/include/clang/Sema/Sema.h
+++ b/clang/include/clang/Sema/Sema.h
@@ -2331,9 +2331,18 @@ class Sema final : public SemaBase {
   bool CheckFunctionCall(FunctionDecl *FDecl, CallExpr *TheCall,
                          const FunctionProtoType *Proto);
 
+  enum class EltwiseBuiltinArgTyRestriction {
+    None,
+    FloatTy,
+    IntegerTy,
+    SignedIntOrFloatTy,
+  };
+
   /// \param FPOnly restricts the arguments to floating-point types.
-  std::optional<QualType> BuiltinVectorMath(CallExpr *TheCall,
-                                            bool FPOnly = false);
+  std::optional<QualType>
+  BuiltinVectorMath(CallExpr *TheCall,
+                    EltwiseBuiltinArgTyRestriction ArgTyRestr =
+                        EltwiseBuiltinArgTyRestriction::None);
   bool BuiltinVectorToScalarMath(CallExpr *TheCall);
 
   void checkLifetimeCaptureBy(FunctionDecl *FDecl, bool IsMemberFunction,
@@ -2418,9 +2427,13 @@ class Sema final : public SemaBase {
                                bool *ICContext = nullptr,
                                bool IsListInit = false);
 
-  bool BuiltinElementwiseTernaryMath(CallExpr *TheCall,
-                                     bool CheckForFloatArgs = true);
-  bool PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall);
+  bool
+  BuiltinElementwiseTernaryMath(CallExpr *TheCall,
+                                EltwiseBuiltinArgTyRestriction ArgTyRestr =
+                                    EltwiseBuiltinArgTyRestriction::FloatTy);
+  bool PrepareBuiltinElementwiseMathOneArgCall(
+      CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr =
+                             EltwiseBuiltinArgTyRestriction::None);
 
 private:
   void CheckArrayAccess(const Expr *BaseExpr, const Expr *IndexExpr,
@@ -2529,7 +2542,9 @@ class Sema final : public SemaBase {
                                  AtomicExpr::AtomicOp Op);
 
   /// \param FPOnly restricts the arguments to floating-point types.
-  bool BuiltinElementwiseMath(CallExpr *TheCall, bool FPOnly = false);
+  bool BuiltinElementwiseMath(CallExpr *TheCall,
+                              EltwiseBuiltinArgTyRestriction ArgTyRestr =
+                                  EltwiseBuiltinArgTyRestriction::None);
   bool PrepareBuiltinReduceMathOneArgCall(CallExpr *TheCall);
 
   bool BuiltinNonDeterministicValue(CallExpr *TheCall);
diff --git a/clang/lib/Sema/SemaChecking.cpp b/clang/lib/Sema/SemaChecking.cpp
index 61b2c8cf1cad72c..6dacce85d1719ca 100644
--- a/clang/lib/Sema/SemaChecking.cpp
+++ b/clang/lib/Sema/SemaChecking.cpp
@@ -1968,26 +1968,40 @@ bool Sema::CheckTSBuiltinFunctionCall(const TargetInfo &TI, unsigned BuiltinID,
 // Check if \p Ty is a valid type for the elementwise math builtins. If it is
 // not a valid type, emit an error message and return true. Otherwise return
 // false.
-static bool checkMathBuiltinElementType(Sema &S, SourceLocation Loc,
-                                        QualType ArgTy, int ArgIndex) {
-  if (!ArgTy->getAs<VectorType>() &&
-      !ConstantMatrixType::isValidElementType(ArgTy)) {
-    return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgIndex << /* vector, integer or float ty*/ 0 << ArgTy;
-  }
-
-  return false;
-}
-
-static bool checkFPMathBuiltinElementType(Sema &S, SourceLocation Loc,
-                                          QualType ArgTy, int ArgIndex) {
+static bool
+checkMathBuiltinElementType(Sema &S, SourceLocation Loc, QualType ArgTy,
+                            Sema::EltwiseBuiltinArgTyRestriction ArgTyRestr,
+                            int ArgOrdinal) {
   QualType EltTy = ArgTy;
   if (auto *VecTy = EltTy->getAs<VectorType>())
     EltTy = VecTy->getElementType();
 
-  if (!EltTy->isRealFloatingType()) {
-    return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
-           << ArgIndex << /* vector or float ty*/ 5 << ArgTy;
+  switch (ArgTyRestr) {
+  case Sema::EltwiseBuiltinArgTyRestriction::None:
+    if (!ArgTy->getAs<VectorType>() &&
+        !ConstantMatrixType::isValidElementType(ArgTy)) {
+      return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* vector, integer or float ty*/ 0 << ArgTy;
+    }
+    break;
+  case Sema::EltwiseBuiltinArgTyRestriction::FloatTy:
+    if (!EltTy->isRealFloatingType()) {
+      return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* vector or float ty*/ 5 << ArgTy;
+    }
+    break;
+  case Sema::EltwiseBuiltinArgTyRestriction::IntegerTy:
+    if (!EltTy->isIntegerType()) {
+      return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << ArgOrdinal << /* vector or int ty*/ 10 << ArgTy;
+    }
+    break;
+  case Sema::EltwiseBuiltinArgTyRestriction::SignedIntOrFloatTy:
+    if (EltTy->isUnsignedIntegerType()) {
+      return S.Diag(Loc, diag::err_builtin_invalid_arg_type)
+             << 1 << /* signed integer or float ty*/ 3 << ArgTy;
+    }
+    break;
   }
 
   return false;
@@ -2694,23 +2708,11 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
 
   // __builtin_elementwise_abs restricts the element type to signed integers or
   // floating point types only.
-  case Builtin::BI__builtin_elementwise_abs: {
-    if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
+  case Builtin::BI__builtin_elementwise_abs:
+    if (PrepareBuiltinElementwiseMathOneArgCall(
+            TheCall, EltwiseBuiltinArgTyRestriction::SignedIntOrFloatTy))
       return ExprError();
-
-    QualType ArgTy = TheCall->getArg(0)->getType();
-    QualType EltTy = ArgTy;
-
-    if (auto *VecTy = EltTy->getAs<VectorType>())
-      EltTy = VecTy->getElementType();
-    if (EltTy->isUnsignedIntegerType()) {
-      Diag(TheCall->getArg(0)->getBeginLoc(),
-           diag::err_builtin_invalid_arg_type)
-          << 1 << /* signed integer or float ty*/ 3 << ArgTy;
-      return ExprError();
-    }
     break;
-  }
 
   // These builtins restrict the element type to floating point
   // types only.
@@ -2736,21 +2738,15 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
   case Builtin::BI__builtin_elementwise_tan:
   case Builtin::BI__builtin_elementwise_tanh:
   case Builtin::BI__builtin_elementwise_trunc:
-  case Builtin::BI__builtin_elementwise_canonicalize: {
-    if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
-      return ExprError();
-
-    QualType ArgTy = TheCall->getArg(0)->getType();
-    if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
-                                      ArgTy, 1))
+  case Builtin::BI__builtin_elementwise_canonicalize:
+    if (PrepareBuiltinElementwiseMathOneArgCall(
+            TheCall, EltwiseBuiltinArgTyRestriction::FloatTy))
       return ExprError();
     break;
-  }
-  case Builtin::BI__builtin_elementwise_fma: {
+  case Builtin::BI__builtin_elementwise_fma:
     if (BuiltinElementwiseTernaryMath(TheCall))
       return ExprError();
     break;
-  }
 
   // These builtins restrict the element type to floating point
   // types only, and take in two arguments.
@@ -2758,59 +2754,30 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
   case Builtin::BI__builtin_elementwise_maximum:
   case Builtin::BI__builtin_elementwise_atan2:
   case Builtin::BI__builtin_elementwise_fmod:
-  case Builtin::BI__builtin_elementwise_pow: {
-    if (BuiltinElementwiseMath(TheCall, /*FPOnly=*/true))
+  case Builtin::BI__builtin_elementwise_pow:
+    if (BuiltinElementwiseMath(TheCall,
+                               EltwiseBuiltinArgTyRestriction::FloatTy))
       return ExprError();
     break;
-  }
-
   // These builtins restrict the element type to integer
   // types only.
   case Builtin::BI__builtin_elementwise_add_sat:
-  case Builtin::BI__builtin_elementwise_sub_sat: {
-    if (BuiltinElementwiseMath(TheCall))
-      return ExprError();
-
-    const Expr *Arg = TheCall->getArg(0);
-    QualType ArgTy = Arg->getType();
-    QualType EltTy = ArgTy;
-
-    if (auto *VecTy = EltTy->getAs<VectorType>())
-      EltTy = VecTy->getElementType();
-
-    if (!EltTy->isIntegerType()) {
-      Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 1 << /* integer ty */ 6 << ArgTy;
+  case Builtin::BI__builtin_elementwise_sub_sat:
+    if (BuiltinElementwiseMath(TheCall,
+                               EltwiseBuiltinArgTyRestriction::IntegerTy))
       return ExprError();
-    }
     break;
-  }
-
   case Builtin::BI__builtin_elementwise_min:
   case Builtin::BI__builtin_elementwise_max:
     if (BuiltinElementwiseMath(TheCall))
       return ExprError();
     break;
   case Builtin::BI__builtin_elementwise_popcount:
-  case Builtin::BI__builtin_elementwise_bitreverse: {
-    if (PrepareBuiltinElementwiseMathOneArgCall(TheCall))
-      return ExprError();
-
-    const Expr *Arg = TheCall->getArg(0);
-    QualType ArgTy = Arg->getType();
-    QualType EltTy = ArgTy;
-
-    if (auto *VecTy = EltTy->getAs<VectorType>())
-      EltTy = VecTy->getElementType();
-
-    if (!EltTy->isIntegerType()) {
-      Diag(Arg->getBeginLoc(), diag::err_builtin_invalid_arg_type)
-          << 1 << /* integer ty */ 6 << ArgTy;
+  case Builtin::BI__builtin_elementwise_bitreverse:
+    if (PrepareBuiltinElementwiseMathOneArgCall(
+            TheCall, EltwiseBuiltinArgTyRestriction::IntegerTy))
       return ExprError();
-    }
     break;
-  }
-
   case Builtin::BI__builtin_elementwise_copysign: {
     if (checkArgCount(TheCall, 2))
       return ExprError();
@@ -2822,10 +2789,12 @@ Sema::CheckBuiltinFunctionCall(FunctionDecl *FDecl, unsigned BuiltinID,
 
     QualType MagnitudeTy = Magnitude.get()->getType();
     QualType SignTy = Sign.get()->getType();
-    if (checkFPMathBuiltinElementType(*this, TheCall->getArg(0)->getBeginLoc(),
-                                      MagnitudeTy, 1) ||
-        checkFPMathBuiltinElementType(*this, TheCall->getArg(1)->getBeginLoc(),
-                                      SignTy, 2)) {
+    if (checkMathBuiltinElementType(
+            *this, TheCall->getArg(0)->getBeginLoc(), MagnitudeTy,
+            EltwiseBuiltinArgTyRestriction::FloatTy, 1) ||
+        checkMathBuiltinElementType(
+            *this, TheCall->getArg(1)->getBeginLoc(), SignTy,
+            EltwiseBuiltinArgTyRestriction::FloatTy, 2)) {
       return ExprError();
     }
 
@@ -14661,7 +14630,8 @@ static ExprResult BuiltinVectorMathConversions(Sema &S, Expr *E) {
   return S.UsualUnaryFPConversions(Res.get());
 }
 
-bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
+bool Sema::PrepareBuiltinElementwiseMathOneArgCall(
+    CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr) {
   if (checkArgCount(TheCall, 1))
     return true;
 
@@ -14672,15 +14642,17 @@ bool Sema::PrepareBuiltinElementwiseMathOneArgCall(CallExpr *TheCall) {
   TheCall->setArg(0, A.get());
   QualType TyA = A.get()->getType();
 
-  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA, 1))
+  if (checkMathBuiltinElementType(*this, A.get()->getBeginLoc(), TyA,
+                                  ArgTyRestr, 1))
     return true;
 
   TheCall->setType(TyA);
   return false;
 }
 
-bool Sema::BuiltinElementwiseMath(CallExpr *TheCall, bool FPOnly) {
-  if (auto Res = BuiltinVectorMath(TheCall, FPOnly); Res.has_value()) {
+bool Sema::BuiltinElementwiseMath(CallExpr *TheCall,
+                                  EltwiseBuiltinArgTyRestriction ArgTyRestr) {
+  if (auto Res = BuiltinVectorMath(TheCall, ArgTyRestr); Res.has_value()) {
     TheCall->setType(*Res);
     return false;
   }
@@ -14713,8 +14685,9 @@ static bool checkBuiltinVectorMathMixedEnums(Sema &S, Expr *LHS, Expr *RHS,
   return false;
 }
 
-std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
-                                                bool FPOnly) {
+std::optional<QualType>
+Sema::BuiltinVectorMath(CallExpr *TheCall,
+                        EltwiseBuiltinArgTyRestriction ArgTyRestr) {
   if (checkArgCount(TheCall, 2))
     return std::nullopt;
 
@@ -14735,26 +14708,21 @@ std::optional<QualType> Sema::BuiltinVectorMath(CallExpr *TheCall,
   QualType TyA = Args[0]->getType();
   QualType TyB = Args[1]->getType();
 
+  if (checkMathBuiltinElementType(*this, LocA, TyA, ArgTyRestr, 1))
+    return std::nullopt;
+
   if (TyA.getCanonicalType() != TyB.getCanonicalType()) {
     Diag(LocA, diag::err_typecheck_call_different_arg_types) << TyA << TyB;
     return std::nullopt;
   }
 
-  if (FPOnly) {
-    if (checkFPMathBuiltinElementType(*this, LocA, TyA, 1))
-      return std::nullopt;
-  } else {
-    if (checkMathBuiltinElementType(*this, LocA, TyA, 1))
-      return std::nullopt;
-  }
-
   TheCall->setArg(0, Args[0]);
   TheCall->setArg(1, Args[1]);
   return TyA;
 }
 
-bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
-                                         bool CheckForFloatArgs) {
+bool Sema::BuiltinElementwiseTernaryMath(
+    CallExpr *TheCall, EltwiseBuiltinArgTyRestriction ArgTyRestr) {
   if (checkArgCount(TheCall, 3))
     return true;
 
@@ -14774,20 +14742,11 @@ bool Sema::BuiltinElementwiseTernaryMath(CallExpr *TheCall,
     Args[I] = Converted.get();
   }
 
-  if (CheckForFloatArgs) {
-    int ArgOrdinal = 1;
-    for (Expr *Arg : Args) {
-      if (checkFPMathBuiltinElementType(*this, Arg->getBeginLoc(),
-                                        Arg->getType(), ArgOrdinal++))
-        return true;
-    }
-  } else {
-    int ArgOrdinal = 1;
-    for (Expr *Arg : Args) {
-      if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
-                                      ArgOrdinal++))
-        return true;
-    }
+  int ArgOrdinal = 1;
+  for (Expr *Arg : Args) {
+    if (checkMathBuiltinElementType(*this, Arg->getBeginLoc(), Arg->getType(),
+                                    ArgTyRestr, ArgOrdinal++))
+      return true;
   }
 
   for (int I = 1; I < 3; ++I) {
diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp
index d748c10455289b9..b45879314727048 100644
--- a/clang/lib/Sema/SemaHLSL.cpp
+++ b/clang/lib/Sema/SemaHLSL.cpp
@@ -2259,8 +2259,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
       return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
-            TheCall, /*CheckForFloatArgs*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()))
+            TheCall, /*ArgTyRestr*/
+            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
+                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
+                : Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
@@ -2393,8 +2395,10 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) {
     if (CheckVectorElementCallArgs(&SemaRef, TheCall))
       return true;
     if (SemaRef.BuiltinElementwiseTernaryMath(
-            TheCall, /*CheckForFloatArgs*/
-            TheCall->getArg(0)->getType()->hasFloatingRepresentation()))
+            TheCall, /*ArgTyRestr*/
+            TheCall->getArg(0)->getType()->hasFloatingRepresentation()
+                ? Sema::EltwiseBuiltinArgTyRestriction::FloatTy
+                : Sema::EltwiseBuiltinArgTyRestriction::None))
       return true;
     break;
   }
diff --git a/clang/test/Sema/aarch64-sve-vector-exp-ops.c b/clang/test/Sema/aarch64-sve-vector-exp-ops.c
index f2bba8c7eeb196d..4b411babbc3471e 100644
--- a/clang/test/Sema/aarch64-sve-vector-exp-ops.c
+++ b/clang/test/Sema/aarch64-sve-vector-exp-ops.c
@@ -7,11 +7,11 @@
 svfloat32_t test_exp_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_exp(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_exp2_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_exp2(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
diff --git a/clang/test/Sema/aarch64-sve-vector-log-ops.c b/clang/test/Sema/aarch64-sve-vector-log-ops.c
index ef16e8581844d7f..bc81323b560c9c4 100644
--- a/clang/test/Sema/aarch64-sve-vector-log-ops.c
+++ b/clang/test/Sema/aarch64-sve-vector-log-ops.c
@@ -7,17 +7,17 @@
 svfloat32_t test_log_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_log(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_log10_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_log10(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_log2_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_log2(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
diff --git a/clang/test/Sema/aarch64-sve-vector-trig-ops.c b/clang/test/Sema/aarch64-sve-vector-trig-ops.c
index 3fe6834be2e0b7f..46df63cbba42bf7 100644
--- a/clang/test/Sema/aarch64-sve-vector-trig-ops.c
+++ b/clang/test/Sema/aarch64-sve-vector-trig-ops.c
@@ -7,19 +7,19 @@
 svfloat32_t test_asin_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_asin(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_acos_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_acos(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_atan_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_atan(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_atan2_vv_i8mf8(svfloat32_t v) {
@@ -31,35 +31,35 @@ svfloat32_t test_atan2_vv_i8mf8(svfloat32_t v) {
 svfloat32_t test_sin_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_sin(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_cos_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_cos(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_tan_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_tan(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_sinh_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_sinh(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat32_t test_cosh_vv_i8mf8(svfloat32_t v) {
 
   return __builtin_elementwise_cosh(v);
-  // expected-error@-1 {{1st argument must be a vector, integer or floating point type}}
+  // expected-error@-1 {{1st argument must be a floating point type}}
 }
 
 svfloat...
[truncated]

@frasercrmck
Copy link
Contributor Author

There are a couple of things I thought I'd leave to the PR discussion, like:

  • Should we standardise the diagnostics a bit better? Specifically with regards to whether the possibility of "vector" needs to be made explicit when we say a floating point type, an unsigned integer? That's why I added an integer or vector of integers but now it stands out.
  • The ternary builtins check each individual type for validity, then check whether they're all the same. The binary builtins only check the first argument, then check whether the second matches. Should these be made consistent?

@farzonl
Copy link
Member

farzonl commented Feb 4, 2025

This PR has me thinking about a related problem.

In HLSL we have elementwise builtins like __builtin_hlsl_elementwise_frac. We do not add the CustomTypeChecking attribute to these builtins. This puts us down the variadic type check rules for scalar inputs to these builtins. Which mean for floating point cases the type defaults to double. in -O0 our builtin is wrapped in
%8 = fpext float %7 to double
%9 = <call to builitn>(%8)
fptrunc double %9 to float

Does it make more sense to switch all these to CustomTypeChecking Or should we consider type constaining like what you did here?

@frasercrmck
Copy link
Contributor Author

This PR has me thinking about a related problem.

In HLSL we have elementwise builtins like __builtin_hlsl_elementwise_frac. We do not add the CustomTypeChecking attribute to these builtins. This puts us down the variadic type check rules for scalar inputs to these builtins. Which mean for floating point cases the type defaults to double. in -O0 our builtin is wrapped in %8 = fpext float %7 to double %9 = <call to builitn>(%8) fptrunc double %9 to float

Does it make more sense to switch all these to CustomTypeChecking Or should we consider type constaining like what you did here?

I wouldn't say I'm knowledgeable enough in this area or in HLSL to answer this question, sorry.

@frasercrmck
Copy link
Contributor Author

There are a couple of things I thought I'd leave to the PR discussion, like:

* Should we standardise the diagnostics a bit better? Specifically with regards to whether the possibility of "vector" needs to be made explicit when we say `a floating point type`, `an unsigned integer`? That's why I added `an integer or vector of integers` but now it stands out.

The latest version of this PR does just that. I've tried to split up the error diagnostic so that it's more modular. Downstream I needed a way of accepting a scalar or vector of 'int' which felt like the straw that broke the camel's back.

@frasercrmck
Copy link
Contributor Author

ping, thanks

Copy link
Member

@farzonl farzonl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as the HLSL changes go. This looks fine to me. Approved! Please wait for other stakholders before merging.

@frasercrmck
Copy link
Contributor Author

ping, thanks

Thie commit improves the diagnostics for vector (elementwise) builtins
in a couple of ways.

It primarily provides more precise type-checking diagnostics for
builtins with specific type requirements. Previously many builtins were
receiving a catch-all diagnostic suggesting types which aren't valid.

It also makes consistent the type-checking behaviour between various
binary and ternary builtins. The binary builtins would check for
mismatched argument types before specific type requirements, whereas
ternary builtins would perform the checks in the reverse order. The
binary builtins now behave as the ternary ones do.
@frasercrmck frasercrmck force-pushed the clang-builtin-errors branch from c22dc5b to 6a687d7 Compare March 18, 2025 10:17
@frasercrmck
Copy link
Contributor Author

ping. This PR is blocking the addition of some new elementwise builtins we want for libclc.

Copy link
Collaborator

@AaronBallman AaronBallman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM aside from a nit

@frasercrmck frasercrmck merged commit 8cc9a48 into llvm:main Mar 18, 2025
12 checks passed
@frasercrmck frasercrmck deleted the clang-builtin-errors branch March 18, 2025 18:11
@damyanp damyanp moved this to Closed in HLSL Support Apr 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:RISC-V clang:frontend Language frontend issues, e.g. anything involving "Sema" clang Clang issues not falling into any other category HLSL HLSL Language Support
Projects
Status: Closed
Development

Successfully merging this pull request may close these issues.

6 participants