Skip to content

Commit 5a45711

Browse files
[HLSL] Implement SpirvType and SpirvOpaqueType (#134034)
This implements the design proposed by [Representing SpirvType in Clang's Type System](llvm/wg-hlsl#181). It creates `HLSLInlineSpirvType` as a new `Type` subclass, and `__hlsl_spirv_type` as a new builtin type template to create such a type. This new type is lowered to the `spirv.Type` target extension type, as described in [Target Extension Types for Inline SPIR-V and Decorated Types](https://github.com/llvm/wg-hlsl/blob/main/proposals/0017-inline-spirv-and-decorated-types.md).
1 parent 58f8053 commit 5a45711

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1034
-75
lines changed

clang/include/clang-c/Index.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3034,7 +3034,8 @@ enum CXTypeKind {
30343034

30353035
/* HLSL Types */
30363036
CXType_HLSLResource = 179,
3037-
CXType_HLSLAttributedResource = 180
3037+
CXType_HLSLAttributedResource = 180,
3038+
CXType_HLSLInlineSpirv = 181
30383039
};
30393040

30403041
/**

clang/include/clang/AST/ASTContext.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ class ASTContext : public RefCountedBase<ASTContext> {
260260
DependentBitIntTypes;
261261
mutable llvm::FoldingSet<BTFTagAttributedType> BTFTagAttributedTypes;
262262
llvm::FoldingSet<HLSLAttributedResourceType> HLSLAttributedResourceTypes;
263+
llvm::FoldingSet<HLSLInlineSpirvType> HLSLInlineSpirvTypes;
263264

264265
mutable llvm::FoldingSet<CountAttributedType> CountAttributedTypes;
265266

@@ -1808,6 +1809,10 @@ class ASTContext : public RefCountedBase<ASTContext> {
18081809
QualType Wrapped, QualType Contained,
18091810
const HLSLAttributedResourceType::Attributes &Attrs);
18101811

1812+
QualType getHLSLInlineSpirvType(uint32_t Opcode, uint32_t Size,
1813+
uint32_t Alignment,
1814+
ArrayRef<SpirvOperand> Operands);
1815+
18111816
QualType getSubstTemplateTypeParmType(QualType Replacement,
18121817
Decl *AssociatedDecl, unsigned Index,
18131818
UnsignedOrNone PackIndex,

clang/include/clang/AST/ASTNodeTraverser.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,24 @@ class ASTNodeTraverser
450450
if (!Contained.isNull())
451451
Visit(Contained);
452452
}
453+
void VisitHLSLInlineSpirvType(const HLSLInlineSpirvType *T) {
454+
for (auto &Operand : T->getOperands()) {
455+
using SpirvOperandKind = SpirvOperand::SpirvOperandKind;
456+
457+
switch (Operand.getKind()) {
458+
case SpirvOperandKind::ConstantId:
459+
case SpirvOperandKind::Literal:
460+
break;
461+
462+
case SpirvOperandKind::TypeId:
463+
Visit(Operand.getResultType());
464+
break;
465+
466+
default:
467+
llvm_unreachable("Invalid SpirvOperand kind!");
468+
}
469+
}
470+
}
453471
void VisitSubstTemplateTypeParmType(const SubstTemplateTypeParmType *) {}
454472
void
455473
VisitSubstTemplateTypeParmPackType(const SubstTemplateTypeParmPackType *T) {

clang/include/clang/AST/PropertiesBase.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def UnsignedOrNone : PropertyType;
148148
def UnaryTypeTransformKind : EnumPropertyType<"UnaryTransformType::UTTKind">;
149149
def VectorKind : EnumPropertyType<"VectorKind">;
150150
def TypeCoupledDeclRefInfo : PropertyType;
151+
def HLSLSpirvOperand : PropertyType<"SpirvOperand"> { let PassByReference = 1; }
151152

152153
def ExceptionSpecInfo : PropertyType<"FunctionProtoType::ExceptionSpecInfo"> {
153154
let BufferElementTypes = [ QualType ];

clang/include/clang/AST/RecursiveASTVisitor.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,14 @@ DEF_TRAVERSE_TYPE(BTFTagAttributedType,
11541154
DEF_TRAVERSE_TYPE(HLSLAttributedResourceType,
11551155
{ TRY_TO(TraverseType(T->getWrappedType())); })
11561156

1157+
DEF_TRAVERSE_TYPE(HLSLInlineSpirvType, {
1158+
for (auto &Operand : T->getOperands()) {
1159+
if (Operand.isConstant() || Operand.isType()) {
1160+
TRY_TO(TraverseType(Operand.getResultType()));
1161+
}
1162+
}
1163+
})
1164+
11571165
DEF_TRAVERSE_TYPE(ParenType, { TRY_TO(TraverseType(T->getInnerType())); })
11581166

11591167
DEF_TRAVERSE_TYPE(MacroQualifiedType,
@@ -1457,6 +1465,9 @@ DEF_TRAVERSE_TYPELOC(BTFTagAttributedType,
14571465
DEF_TRAVERSE_TYPELOC(HLSLAttributedResourceType,
14581466
{ TRY_TO(TraverseTypeLoc(TL.getWrappedLoc())); })
14591467

1468+
DEF_TRAVERSE_TYPELOC(HLSLInlineSpirvType,
1469+
{ TRY_TO(TraverseType(TL.getType())); })
1470+
14601471
DEF_TRAVERSE_TYPELOC(ElaboratedType, {
14611472
if (TL.getQualifierLoc()) {
14621473
TRY_TO(TraverseNestedNameSpecifierLoc(TL.getQualifierLoc()));

clang/include/clang/AST/Type.h

Lines changed: 144 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2691,6 +2691,7 @@ class alignas(TypeAlignment) Type : public ExtQualsTypeCommonBase {
26912691
bool isHLSLSpecificType() const; // Any HLSL specific type
26922692
bool isHLSLBuiltinIntangibleType() const; // Any HLSL builtin intangible type
26932693
bool isHLSLAttributedResourceType() const;
2694+
bool isHLSLInlineSpirvType() const;
26942695
bool isHLSLResourceRecord() const;
26952696
bool isHLSLIntangibleType()
26962697
const; // Any HLSL intangible type (builtin, array, class)
@@ -6364,6 +6365,143 @@ class HLSLAttributedResourceType : public Type, public llvm::FoldingSetNode {
63646365
findHandleTypeOnResource(const Type *RT);
63656366
};
63666367

6368+
/// Instances of this class represent operands to a SPIR-V type instruction.
6369+
class SpirvOperand {
6370+
public:
6371+
enum SpirvOperandKind : unsigned char {
6372+
Invalid, ///< Uninitialized.
6373+
ConstantId, ///< Integral value to represent as a SPIR-V OpConstant
6374+
///< instruction ID.
6375+
Literal, ///< Integral value to represent as an immediate literal.
6376+
TypeId, ///< Type to represent as a SPIR-V type ID.
6377+
6378+
Max,
6379+
};
6380+
6381+
private:
6382+
SpirvOperandKind Kind = Invalid;
6383+
6384+
QualType ResultType;
6385+
llvm::APInt Value; // Signedness of constants is represented by ResultType.
6386+
6387+
public:
6388+
SpirvOperand() : Kind(Invalid), ResultType(), Value() {}
6389+
6390+
SpirvOperand(SpirvOperandKind Kind, QualType ResultType, llvm::APInt Value)
6391+
: Kind(Kind), ResultType(ResultType), Value(Value) {}
6392+
6393+
SpirvOperand(const SpirvOperand &Other) { *this = Other; }
6394+
~SpirvOperand() {}
6395+
6396+
SpirvOperand &operator=(const SpirvOperand &Other) {
6397+
this->Kind = Other.Kind;
6398+
this->ResultType = Other.ResultType;
6399+
this->Value = Other.Value;
6400+
return *this;
6401+
}
6402+
6403+
bool operator==(const SpirvOperand &Other) const {
6404+
return Kind == Other.Kind && ResultType == Other.ResultType &&
6405+
Value == Other.Value;
6406+
}
6407+
6408+
bool operator!=(const SpirvOperand &Other) const { return !(*this == Other); }
6409+
6410+
SpirvOperandKind getKind() const { return Kind; }
6411+
6412+
bool isValid() const { return Kind != Invalid && Kind < Max; }
6413+
bool isConstant() const { return Kind == ConstantId; }
6414+
bool isLiteral() const { return Kind == Literal; }
6415+
bool isType() const { return Kind == TypeId; }
6416+
6417+
llvm::APInt getValue() const {
6418+
assert((isConstant() || isLiteral()) &&
6419+
"This is not an operand with a value!");
6420+
return Value;
6421+
}
6422+
6423+
QualType getResultType() const {
6424+
assert((isConstant() || isType()) &&
6425+
"This is not an operand with a result type!");
6426+
return ResultType;
6427+
}
6428+
6429+
static SpirvOperand createConstant(QualType ResultType, llvm::APInt Val) {
6430+
return SpirvOperand(ConstantId, ResultType, Val);
6431+
}
6432+
6433+
static SpirvOperand createLiteral(llvm::APInt Val) {
6434+
return SpirvOperand(Literal, QualType(), Val);
6435+
}
6436+
6437+
static SpirvOperand createType(QualType T) {
6438+
return SpirvOperand(TypeId, T, llvm::APSInt());
6439+
}
6440+
6441+
void Profile(llvm::FoldingSetNodeID &ID) const {
6442+
ID.AddInteger(Kind);
6443+
ID.AddPointer(ResultType.getAsOpaquePtr());
6444+
Value.Profile(ID);
6445+
}
6446+
};
6447+
6448+
/// Represents an arbitrary, user-specified SPIR-V type instruction.
6449+
class HLSLInlineSpirvType final
6450+
: public Type,
6451+
public llvm::FoldingSetNode,
6452+
private llvm::TrailingObjects<HLSLInlineSpirvType, SpirvOperand> {
6453+
friend class ASTContext; // ASTContext creates these
6454+
friend TrailingObjects;
6455+
6456+
private:
6457+
uint32_t Opcode;
6458+
uint32_t Size;
6459+
uint32_t Alignment;
6460+
size_t NumOperands;
6461+
6462+
HLSLInlineSpirvType(uint32_t Opcode, uint32_t Size, uint32_t Alignment,
6463+
ArrayRef<SpirvOperand> Operands)
6464+
: Type(HLSLInlineSpirv, QualType(), TypeDependence::None), Opcode(Opcode),
6465+
Size(Size), Alignment(Alignment), NumOperands(Operands.size()) {
6466+
for (size_t I = 0; I < NumOperands; I++) {
6467+
// Since Operands are stored as a trailing object, they have not been
6468+
// initialized yet. Call the constructor manually.
6469+
auto *Operand =
6470+
new (&getTrailingObjects<SpirvOperand>()[I]) SpirvOperand();
6471+
*Operand = Operands[I];
6472+
}
6473+
}
6474+
6475+
public:
6476+
uint32_t getOpcode() const { return Opcode; }
6477+
uint32_t getSize() const { return Size; }
6478+
uint32_t getAlignment() const { return Alignment; }
6479+
ArrayRef<SpirvOperand> getOperands() const {
6480+
return {getTrailingObjects<SpirvOperand>(), NumOperands};
6481+
}
6482+
6483+
bool isSugared() const { return false; }
6484+
QualType desugar() const { return QualType(this, 0); }
6485+
6486+
void Profile(llvm::FoldingSetNodeID &ID) {
6487+
Profile(ID, Opcode, Size, Alignment, getOperands());
6488+
}
6489+
6490+
static void Profile(llvm::FoldingSetNodeID &ID, uint32_t Opcode,
6491+
uint32_t Size, uint32_t Alignment,
6492+
ArrayRef<SpirvOperand> Operands) {
6493+
ID.AddInteger(Opcode);
6494+
ID.AddInteger(Size);
6495+
ID.AddInteger(Alignment);
6496+
for (auto &Operand : Operands)
6497+
Operand.Profile(ID);
6498+
}
6499+
6500+
static bool classof(const Type *T) {
6501+
return T->getTypeClass() == HLSLInlineSpirv;
6502+
}
6503+
};
6504+
63676505
class TemplateTypeParmType : public Type, public llvm::FoldingSetNode {
63686506
friend class ASTContext; // ASTContext creates these
63696507

@@ -8495,13 +8633,18 @@ inline bool Type::isHLSLBuiltinIntangibleType() const {
84958633
}
84968634

84978635
inline bool Type::isHLSLSpecificType() const {
8498-
return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType();
8636+
return isHLSLBuiltinIntangibleType() || isHLSLAttributedResourceType() ||
8637+
isHLSLInlineSpirvType();
84998638
}
85008639

85018640
inline bool Type::isHLSLAttributedResourceType() const {
85028641
return isa<HLSLAttributedResourceType>(this);
85038642
}
85048643

8644+
inline bool Type::isHLSLInlineSpirvType() const {
8645+
return isa<HLSLInlineSpirvType>(this);
8646+
}
8647+
85058648
inline bool Type::isTemplateTypeParmType() const {
85068649
return isa<TemplateTypeParmType>(CanonicalType);
85078650
}

clang/include/clang/AST/TypeLoc.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -973,6 +973,25 @@ class HLSLAttributedResourceTypeLoc
973973
}
974974
};
975975

976+
struct HLSLInlineSpirvTypeLocInfo {
977+
SourceLocation Loc;
978+
}; // Nothing.
979+
980+
class HLSLInlineSpirvTypeLoc
981+
: public ConcreteTypeLoc<UnqualTypeLoc, HLSLInlineSpirvTypeLoc,
982+
HLSLInlineSpirvType, HLSLInlineSpirvTypeLocInfo> {
983+
public:
984+
SourceLocation getSpirvTypeLoc() const { return getLocalData()->Loc; }
985+
void setSpirvTypeLoc(SourceLocation loc) const { getLocalData()->Loc = loc; }
986+
987+
SourceRange getLocalSourceRange() const {
988+
return SourceRange(getSpirvTypeLoc(), getSpirvTypeLoc());
989+
}
990+
void initializeLocal(ASTContext &Context, SourceLocation loc) {
991+
setSpirvTypeLoc(loc);
992+
}
993+
};
994+
976995
struct ObjCObjectTypeLocInfo {
977996
SourceLocation TypeArgsLAngleLoc;
978997
SourceLocation TypeArgsRAngleLoc;

clang/include/clang/AST/TypeProperties.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -701,6 +701,24 @@ let Class = HLSLAttributedResourceType in {
701701
}]>;
702702
}
703703

704+
let Class = HLSLInlineSpirvType in {
705+
def : Property<"opcode", UInt32> {
706+
let Read = [{ node->getOpcode() }];
707+
}
708+
def : Property<"size", UInt32> {
709+
let Read = [{ node->getSize() }];
710+
}
711+
def : Property<"alignment", UInt32> {
712+
let Read = [{ node->getAlignment() }];
713+
}
714+
def : Property<"operands", Array<HLSLSpirvOperand>> {
715+
let Read = [{ node->getOperands() }];
716+
}
717+
def : Creator<[{
718+
return ctx.getHLSLInlineSpirvType(opcode, size, alignment, operands);
719+
}]>;
720+
}
721+
704722
let Class = DependentAddressSpaceType in {
705723
def : Property<"pointeeType", QualType> {
706724
let Read = [{ node->getPointeeType() }];

clang/include/clang/Basic/BuiltinTemplates.td

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,37 @@ class BuiltinNTTP<string type_name> : TemplateArg<""> {
2828
}
2929

3030
def SizeT : BuiltinNTTP<"size_t"> {}
31+
def Uint32T: BuiltinNTTP<"uint32_t"> {}
3132

3233
class BuiltinTemplate<list<TemplateArg> template_head> {
3334
list<TemplateArg> TemplateHead = template_head;
3435
}
3536

37+
class CPlusPlusBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>;
38+
39+
class HLSLBuiltinTemplate<list<TemplateArg> template_head> : BuiltinTemplate<template_head>;
40+
3641
// template <template <class T, T... Ints> IntSeq, class T, T N>
37-
def __make_integer_seq : BuiltinTemplate<
42+
def __make_integer_seq : CPlusPlusBuiltinTemplate<
3843
[Template<[Class<"T">, NTTP<"T", "Ints", /*is_variadic=*/1>], "IntSeq">, Class<"T">, NTTP<"T", "N">]>;
3944

4045
// template <size_t, class... T>
41-
def __type_pack_element : BuiltinTemplate<
46+
def __type_pack_element : CPlusPlusBuiltinTemplate<
4247
[SizeT, Class<"T", /*is_variadic=*/1>]>;
4348

4449
// template <template <class... Args> BaseTemplate,
4550
// template <class TypeMember> HasTypeMember,
4651
// class HasNoTypeMember
4752
// class... Ts>
48-
def __builtin_common_type : BuiltinTemplate<
53+
def __builtin_common_type : CPlusPlusBuiltinTemplate<
4954
[Template<[Class<"Args", /*is_variadic=*/1>], "BaseTemplate">,
5055
Template<[Class<"TypeMember">], "HasTypeMember">,
5156
Class<"HasNoTypeMember">,
5257
Class<"Ts", /*is_variadic=*/1>]>;
58+
59+
// template <uint32_t Opcode,
60+
// uint32_t Size,
61+
// uint32_t Alignment,
62+
// typename ...Operands>
63+
def __hlsl_spirv_type : HLSLBuiltinTemplate<
64+
[Uint32T, Uint32T, Uint32T, Class<"Operands", /*is_variadic=*/1>]>;

clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12998,6 +12998,9 @@ def err_hlsl_expect_arg_const_int_one_or_neg_one: Error<
1299812998
def err_invalid_hlsl_resource_type: Error<
1299912999
"invalid __hlsl_resource_t type attributes">;
1300013000

13001+
def err_hlsl_spirv_only: Error<"%0 is only available for the SPIR-V target">;
13002+
def err_hlsl_vk_literal_must_contain_constant: Error<"the argument to vk::Literal must be a vk::integral_constant">;
13003+
1300113004
// Layout randomization diagnostics.
1300213005
def err_non_designated_init_used : Error<
1300313006
"a randomized struct can only be initialized with a designated initializer">;

clang/include/clang/Basic/TypeNodes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ def ElaboratedType : TypeNode<Type>, NeverCanonical;
9494
def AttributedType : TypeNode<Type>, NeverCanonical;
9595
def BTFTagAttributedType : TypeNode<Type>, NeverCanonical;
9696
def HLSLAttributedResourceType : TypeNode<Type>;
97+
def HLSLInlineSpirvType : TypeNode<Type>;
9798
def TemplateTypeParmType : TypeNode<Type>, AlwaysDependent, LeafType;
9899
def SubstTemplateTypeParmType : TypeNode<Type>, NeverCanonical;
99100
def SubstTemplateTypeParmPackType : TypeNode<Type>, AlwaysDependent;

clang/include/clang/Serialization/ASTRecordReader.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,8 @@ class ASTRecordReader
214214

215215
TypeCoupledDeclRefInfo readTypeCoupledDeclRefInfo();
216216

217+
SpirvOperand readHLSLSpirvOperand();
218+
217219
/// Read a declaration name, advancing Idx.
218220
// DeclarationName readDeclarationName(); (inherited)
219221
DeclarationNameLoc readDeclarationNameLoc(DeclarationName Name);

clang/include/clang/Serialization/ASTRecordWriter.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,20 @@ class ASTRecordWriter
151151
writeBool(Info.isDeref());
152152
}
153153

154+
void writeHLSLSpirvOperand(SpirvOperand Op) {
155+
QualType ResultType;
156+
llvm::APInt Value;
157+
158+
if (Op.isConstant() || Op.isType())
159+
ResultType = Op.getResultType();
160+
if (Op.isConstant() || Op.isLiteral())
161+
Value = Op.getValue();
162+
163+
Record->push_back(Op.getKind());
164+
writeQualType(ResultType);
165+
writeAPInt(Value);
166+
}
167+
154168
/// Emit a source range.
155169
void AddSourceRange(SourceRange Range, LocSeq *Seq = nullptr) {
156170
return Writer->AddSourceRange(Range, *Record, Seq);

0 commit comments

Comments
 (0)