Skip to content

[HLSL][DXIL] Implement refract intrinsic #136026

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/BuiltinsSPIRV.td
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ def SPIRVReflect : Builtin {
let Prototype = "void(...)";
}

def SPIRVRefract : Builtin {
let Spellings = ["__builtin_spirv_refract"];
let Attributes = [NoThrow, Const];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to add CustomTypeChecking.

Suggested change
let Attributes = [NoThrow, Const];
let Attributes = [NoThrow, Const, CustomTypeChecking];

let Prototype = "void(...)";
}

def SPIRVSmoothStep : Builtin {
let Spellings = ["__builtin_spirv_smoothstep"];
let Attributes = [NoThrow, Const, CustomTypeChecking];
Expand Down
15 changes: 15 additions & 0 deletions clang/lib/CodeGen/TargetBuiltins/SPIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ Value *CodeGenFunction::EmitSPIRVBuiltinExpr(unsigned BuiltinID,
/*ReturnType=*/I->getType(), Intrinsic::spv_reflect,
ArrayRef<Value *>{I, N}, nullptr, "spv.reflect");
}
case SPIRV::BI__builtin_spirv_refract: {
Value *I = EmitScalarExpr(E->getArg(0));
Value *N = EmitScalarExpr(E->getArg(1));
Value *eta = EmitScalarExpr(E->getArg(2));
assert(E->getArg(0)->getType()->hasFloatingRepresentation() &&
E->getArg(1)->getType()->hasFloatingRepresentation() &&
E->getArg(2)->getType()->hasFloatingRepresentation() &&
"refract operands must have a float representation");
assert(E->getArg(0)->getType()->isVectorType() &&
E->getArg(1)->getType()->isVectorType() &&
"refract I and N operands must be a vector");
return Builder.CreateIntrinsic(
/*ReturnType=*/I->getType(), Intrinsic::spv_refract,
ArrayRef<Value *>{I, N, eta}, nullptr, "spv.refract");
}
case SPIRV::BI__builtin_spirv_smoothstep: {
Value *Min = EmitScalarExpr(E->getArg(0));
Value *Max = EmitScalarExpr(E->getArg(1));
Expand Down
25 changes: 23 additions & 2 deletions clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#define _HLSL_HLSL_INTRINSIC_HELPERS_H_

namespace hlsl {
namespace __detail {
namespace __dETAil {

constexpr vector<uint, 4> d3d_color_to_ubyte4_impl(vector<float, 4> V) {
// Use the same scaling factor used by FXC, and DXC for DXIL
Expand Down Expand Up @@ -71,6 +71,27 @@ constexpr vector<T, L> reflect_vec_impl(vector<T, L> I, vector<T, L> N) {
#endif
}

template <typename T> constexpr T refract_impl(T I, T N, T Eta) {
T K = 1 - Eta * Eta * (1 - (N * I * N * I));
if (K < 0)
return 0;
else
return (Eta * I - (Eta * N * I + sqrt(K)) * N);
}

template <typename T, int L>
constexpr vector<T, L> refract_vec_impl(vector<T, L> I, vector<T, L> N, T Eta) {
#if (__has_builtin(__builtin_spirv_refract))
return __builtin_spirv_refract(I, N, Eta);
#else
vector<T, L> K = 1 - Eta * Eta * (1 - dot(N, I) * dot(N, I));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should really be storing dot(N, I) somewhere so that we aren't repeating the computation 3 times. That should simplify the codegen for the -O0 cases.

if (K < 0)
return 0;
else
return (Eta * I - (Eta * dot(N, I) + sqrt(K)) * N);
#endif
}

template <typename T> constexpr T fmod_impl(T X, T Y) {
#if !defined(__DIRECTX__)
return __builtin_elementwise_fmod(X, Y);
Expand Down Expand Up @@ -126,7 +147,7 @@ template <typename T> constexpr vector<T, 4> lit_impl(T NDotL, T NDotH, T M) {
return Result;
}

} // namespace __detail
} // namespace __dETAil
} // namespace hlsl

#endif // _HLSL_HLSL_INTRINSIC_HELPERS_H_
59 changes: 59 additions & 0 deletions clang/lib/Headers/hlsl/hlsl_intrinsics.h
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,65 @@ reflect(__detail::HLSL_FIXED_VECTOR<float, L> I,
return __detail::reflect_vec_impl(I, N);
}

//===----------------------------------------------------------------------===//
// refract builtin
//===----------------------------------------------------------------------===//

/// \fn T refract(T I, T N, T eta)
/// \brief Returns a refraction using an entering ray, \a I, a surface
/// normal, \a N and refraction index \a eta
/// \param I The entering ray.
/// \param N The surface normal.
/// \param eta The refraction index.
///
/// The return value is a floating-point vector that represents the refraction
/// using the refraction index, \a eta, for the direction of the entering ray,
/// \a I, off a surface with the normal \a N.
///
/// This function calculates the refraction vector using the following formulas:
/// k = 1.0 - eta * eta * (1.0 - dot(N, I) * dot(N, I))
/// if k < 0.0 the result is 0.0
/// otherwise, the result is eta * I - (eta * dot(N, I) + sqrt(k)) * N
///
/// I and N must already be normalized in order to achieve the desired result.
///
/// I and N must be a scalar or vector whose component type is
/// floating-point.
///
/// eta must be a 16-bit or 32-bit floating-point scalar.
///
/// Result type, the type of I, and the type of N must all be the same type.

template <typename T>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::enable_if_t<__detail::is_arithmetic<T>::Value &&
__detail::is_same<half, T>::value,
T> refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <typename T>
const inline __detail::enable_if_t<
__detail::is_arithmetic<T>::Value && __detail::is_same<float, T>::value, T>
refract(T I, T N, T eta) {
return __detail::refract_impl(I, N, eta);
}

template <int L>
_HLSL_16BIT_AVAILABILITY(shadermodel, 6.2)
const inline __detail::HLSL_FIXED_VECTOR<half, L> refract(
__detail::HLSL_FIXED_VECTOR<half, L> I,
__detail::HLSL_FIXED_VECTOR<half, L> N, half eta) {
return __detail::refract_vec_impl(I, N, eta);
}

template <int L>
const inline __detail::HLSL_FIXED_VECTOR<float, L>
refract(__detail::HLSL_FIXED_VECTOR<float, L> I,
__detail::HLSL_FIXED_VECTOR<float, L> N, float eta) {
return __detail::refract_vec_impl(I, N, eta);
}

//===----------------------------------------------------------------------===//
// smoothstep builtin
//===----------------------------------------------------------------------===//
Expand Down
45 changes: 43 additions & 2 deletions clang/lib/Sema/SemaSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
Expand Down Expand Up @@ -69,6 +69,47 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_refract: {
if (SemaRef.checkArgCount(TheCall, 3))
return true;

ExprResult A = TheCall->getArg(0);
QualType ArgTyA = A.get()->getType();
auto *VTyA = ArgTyA->getAs<VectorType>();
if (VTyA == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyA
<< SemaRef.Context.getVectorType(ArgTyA, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

ExprResult B = TheCall->getArg(1);
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
<< 0 << 0;
return true;
}

ExprResult C = TheCall->getArg(2);
QualType ArgTyC = C.get()->getType();
if (!ArgTyC->hasFloatingRepresentation()) {
SemaRef.Diag(C.get()->getBeginLoc(), diag::err_builtin_invalid_arg_type)
<< 3 << /* scalar or vector */ 5 << /* no int */ 0 << /* fp */ 1
<< ArgTyC;
return true;
}

QualType RetTy = ArgTyA;
TheCall->setType(RetTy);
break;
}
case SPIRV::BI__builtin_spirv_reflect: {
if (SemaRef.checkArgCount(TheCall, 2))
return true;
Expand All @@ -89,7 +130,7 @@ bool SemaSPIRV::CheckSPIRVBuiltinFunctionCall(unsigned BuiltinID,
QualType ArgTyB = B.get()->getType();
auto *VTyB = ArgTyB->getAs<VectorType>();
if (VTyB == nullptr) {
SemaRef.Diag(A.get()->getBeginLoc(),
SemaRef.Diag(B.get()->getBeginLoc(),
diag::err_typecheck_convert_incompatible)
<< ArgTyB
<< SemaRef.Context.getVectorType(ArgTyB, 2, VectorKind::Generic) << 1
Expand Down
Loading