Skip to content

Commit 21fdad3

Browse files
committed
[spirv] Allows spec constants as attribute arguments (for selected attributes).
Fixes #3092.
1 parent facd05a commit 21fdad3

File tree

20 files changed

+682
-214
lines changed

20 files changed

+682
-214
lines changed

tools/clang/include/clang/AST/Expr.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,9 @@ class Expr : public Stmt {
531531
bool isConstantInitializer(ASTContext &Ctx, bool ForRef,
532532
const Expr **Culprit = nullptr) const;
533533

534+
bool isVulkanSpecConstantExpr(const ASTContext &Ctx,
535+
APValue *Result = nullptr) const;
536+
534537
/// EvalStatus is a struct with detailed info about an evaluation in progress.
535538
struct EvalStatus {
536539
/// HasSideEffects - Whether the evaluated expression has side effects.

tools/clang/include/clang/Basic/Attr.td

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -668,7 +668,7 @@ def HLSLMaxTessFactor: InheritableAttr {
668668
}
669669
def HLSLNumThreads: InheritableAttr {
670670
let Spellings = [CXX11<"", "numthreads", 2015>];
671-
let Args = [IntArgument<"X">, IntArgument<"Y">, IntArgument<"Z">];
671+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
672672
let Documentation = [Undocumented];
673673
}
674674
def HLSLRootSignature: InheritableAttr {
@@ -1016,7 +1016,7 @@ def HLSLNodeIsProgramEntry : InheritableAttr {
10161016

10171017
def HLSLNodeId : InheritableAttr {
10181018
let Spellings = [CXX11<"", "nodeid", 2017>];
1019-
let Args = [StringArgument<"Name">,DefaultIntArgument<"ArrayIndex", 0>];
1019+
let Args = [StringArgument<"Name">, ExprArgument<"ArrayIndex", 1>];
10201020
let Documentation = [Undocumented];
10211021
}
10221022

@@ -1028,25 +1028,25 @@ def HLSLNodeLocalRootArgumentsTableIndex : InheritableAttr {
10281028

10291029
def HLSLNodeShareInputOf : InheritableAttr {
10301030
let Spellings = [CXX11<"", "nodeshareinputof", 2017>];
1031-
let Args = [StringArgument<"Name">,UnsignedArgument<"ArrayIndex", 1>];
1031+
let Args = [StringArgument<"Name">,ExprArgument<"ArrayIndex", 1>];
10321032
let Documentation = [Undocumented];
10331033
}
10341034

10351035
def HLSLNodeDispatchGrid: InheritableAttr {
10361036
let Spellings = [CXX11<"", "nodedispatchgrid", 2015>];
1037-
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
1037+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
10381038
let Documentation = [Undocumented];
10391039
}
10401040

10411041
def HLSLNodeMaxDispatchGrid: InheritableAttr {
10421042
let Spellings = [CXX11<"", "nodemaxdispatchgrid", 2015>];
1043-
let Args = [UnsignedArgument<"X">, UnsignedArgument<"Y">, UnsignedArgument<"Z">];
1043+
let Args = [ExprArgument<"X">, ExprArgument<"Y">, ExprArgument<"Z">];
10441044
let Documentation = [Undocumented];
10451045
}
10461046

10471047
def HLSLNodeMaxRecursionDepth : InheritableAttr {
10481048
let Spellings = [CXX11<"", "nodemaxrecursiondepth", 2017>];
1049-
let Args = [UnsignedArgument<"Count">];
1049+
let Args = [ExprArgument<"Count">];
10501050
let Documentation = [Undocumented];
10511051
}
10521052

@@ -1194,7 +1194,7 @@ def HLSLHitObject : InheritableAttr {
11941194

11951195
def HLSLMaxRecords : InheritableAttr {
11961196
let Spellings = [CXX11<"", "MaxRecords", 2015>];
1197-
let Args = [IntArgument<"maxCount">];
1197+
let Args = [ExprArgument<"maxCount">];
11981198
let Documentation = [Undocumented];
11991199
}
12001200

tools/clang/include/clang/Basic/DiagnosticSemaKinds.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7744,6 +7744,8 @@ def warn_hlsl_entry_attribute_without_shader_attribute : Warning<
77447744
InGroup<HLSLEntryAttributeWithoutShaderAttrType>;
77457745
def err_hlsl_attribute_expects_float_literal : Error<
77467746
"attribute %0 must have a float literal argument">;
7747+
def err_hlsl_attribute_expects_integer_const_expr : Error<
7748+
"attribute %0 argument %1 must be integer constant expression">;
77477749
def warn_hlsl_comma_in_init : Warning<
77487750
"comma expression used where a constructor list may have been intended">,
77497751
InGroup<HLSLCommaInInit>;

tools/clang/include/clang/SPIRV/SpirvContext.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,15 @@ class SpirvContext {
456456
instructionsWithLoweredType.end();
457457
}
458458

459+
SpirvInstruction *getSpecConstant(const VarDecl *decl) {
460+
return specConstants[decl];
461+
}
462+
463+
void registerSpecConstant(const VarDecl *decl,
464+
SpirvInstruction *specConstant) {
465+
specConstants[decl] = specConstant;
466+
}
467+
459468
void registerDispatchGridIndex(const RecordDecl *decl, unsigned index) {
460469
auto iter = dispatchGridIndices.find(decl);
461470
if (iter == dispatchGridIndices.end()) {
@@ -536,6 +545,7 @@ class SpirvContext {
536545
llvm::DenseSet<FunctionType *, FunctionTypeMapInfo> functionTypes;
537546
llvm::DenseMap<unsigned, SpirvIntrinsicType *> spirvIntrinsicTypesById;
538547
llvm::SmallVector<const SpirvIntrinsicType *, 8> spirvIntrinsicTypes;
548+
llvm::MapVector<const VarDecl *, SpirvInstruction *> specConstants;
539549
const AccelerationStructureTypeNV *accelerationStructureTypeNV;
540550
const RayQueryTypeKHR *rayQueryTypeKHR;
541551

tools/clang/include/clang/Sema/SemaHLSL.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,6 @@ unsigned CaculateInitListArraySizeForHLSL(clang::Sema *sema,
160160
const clang::InitListExpr *InitList,
161161
const clang::QualType EltTy);
162162

163-
bool ContainsLongVector(clang::QualType);
164-
165163
bool IsConversionToLessOrEqualElements(clang::Sema *self,
166164
const clang::ExprResult &sourceExpr,
167165
const clang::QualType &targetType,

tools/clang/lib/AST/ExprConstant.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9450,6 +9450,19 @@ bool Expr::isIntegerConstantExpr(llvm::APSInt &Value, const ASTContext &Ctx,
94509450
return true;
94519451
}
94529452

9453+
bool Expr::isVulkanSpecConstantExpr(const ASTContext &Ctx,
9454+
APValue *Result) const {
9455+
auto *D = dyn_cast<DeclRefExpr>(this);
9456+
if (!D)
9457+
return false;
9458+
auto *V = dyn_cast<VarDecl>(D->getDecl());
9459+
if (!V || !V->hasAttr<VKConstantIdAttr>())
9460+
return false;
9461+
if (const Expr *I = V->getAnyInitializer())
9462+
return I->IgnoreParenCasts()->isCXX11ConstantExpr(Ctx, Result);
9463+
return true;
9464+
}
9465+
94539466
bool Expr::isCXX98IntegralConstantExpr(const ASTContext &Ctx) const {
94549467
return CheckICE(this, Ctx).Kind == IK_ICE;
94559468
}

tools/clang/lib/CodeGen/CGHLSLMS.cpp

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,25 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
328328
};
329329
} // namespace
330330

331+
static uint32_t
332+
getIntConstAttrArg(ASTContext &astContext, const Expr *expr,
333+
llvm::Optional<uint32_t> defaultVal = llvm::None) {
334+
if (expr) {
335+
llvm::APSInt apsInt;
336+
APValue apValue;
337+
if (expr->isIntegerConstantExpr(apsInt, astContext))
338+
return (uint32_t)apsInt.getSExtValue();
339+
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) && apValue.isInt())
340+
return (uint32_t)apValue.getInt().getSExtValue();
341+
llvm_unreachable(
342+
"Expression must be a constant expression or spec constant");
343+
}
344+
if (!defaultVal.hasValue()) {
345+
DXASSERT(defaultVal.hasValue(), "missing attribute parameter");
346+
}
347+
return defaultVal.getValue();
348+
}
349+
331350
//------------------------------------------------------------------------------
332351
//
333352
// CGMSHLSLRuntime methods.
@@ -1422,6 +1441,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14221441
}
14231442

14241443
DiagnosticsEngine &Diags = CGM.getDiags();
1444+
ASTContext &astContext = CGM.getTypes().getContext();
14251445

14261446
std::unique_ptr<DxilFunctionProps> funcProps =
14271447
llvm::make_unique<DxilFunctionProps>();
@@ -1632,10 +1652,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16321652

16331653
// Populate numThreads
16341654
if (const HLSLNumThreadsAttr *Attr = FD->getAttr<HLSLNumThreadsAttr>()) {
1635-
1636-
funcProps->numThreads[0] = Attr->getX();
1637-
funcProps->numThreads[1] = Attr->getY();
1638-
funcProps->numThreads[2] = Attr->getZ();
1655+
funcProps->numThreads[0] = getIntConstAttrArg(astContext, Attr->getX());
1656+
funcProps->numThreads[1] = getIntConstAttrArg(astContext, Attr->getY());
1657+
funcProps->numThreads[2] = getIntConstAttrArg(astContext, Attr->getZ());
16391658

16401659
if (isEntry && !SM->IsCS() && !SM->IsMS() && !SM->IsAS()) {
16411660
unsigned DiagID = Diags.getCustomDiagID(
@@ -1808,7 +1827,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18081827

18091828
if (const auto *pAttr = FD->getAttr<HLSLNodeIdAttr>()) {
18101829
funcProps->NodeShaderID.Name = pAttr->getName().str();
1811-
funcProps->NodeShaderID.Index = pAttr->getArrayIndex();
1830+
funcProps->NodeShaderID.Index =
1831+
getIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
18121832
} else {
18131833
funcProps->NodeShaderID.Name = FD->getName().str();
18141834
funcProps->NodeShaderID.Index = 0;
@@ -1819,20 +1839,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18191839
}
18201840
if (const auto *pAttr = FD->getAttr<HLSLNodeShareInputOfAttr>()) {
18211841
funcProps->NodeShaderSharedInput.Name = pAttr->getName().str();
1822-
funcProps->NodeShaderSharedInput.Index = pAttr->getArrayIndex();
1842+
funcProps->NodeShaderSharedInput.Index =
1843+
getIntConstAttrArg(astContext, pAttr->getArrayIndex(), 0);
18231844
}
18241845
if (const auto *pAttr = FD->getAttr<HLSLNodeDispatchGridAttr>()) {
1825-
funcProps->Node.DispatchGrid[0] = pAttr->getX();
1826-
funcProps->Node.DispatchGrid[1] = pAttr->getY();
1827-
funcProps->Node.DispatchGrid[2] = pAttr->getZ();
1846+
funcProps->Node.DispatchGrid[0] =
1847+
getIntConstAttrArg(astContext, pAttr->getX());
1848+
funcProps->Node.DispatchGrid[1] =
1849+
getIntConstAttrArg(astContext, pAttr->getY());
1850+
funcProps->Node.DispatchGrid[2] =
1851+
getIntConstAttrArg(astContext, pAttr->getZ());
18281852
}
18291853
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxDispatchGridAttr>()) {
1830-
funcProps->Node.MaxDispatchGrid[0] = pAttr->getX();
1831-
funcProps->Node.MaxDispatchGrid[1] = pAttr->getY();
1832-
funcProps->Node.MaxDispatchGrid[2] = pAttr->getZ();
1854+
funcProps->Node.MaxDispatchGrid[0] =
1855+
getIntConstAttrArg(astContext, pAttr->getX());
1856+
funcProps->Node.MaxDispatchGrid[1] =
1857+
getIntConstAttrArg(astContext, pAttr->getY());
1858+
funcProps->Node.MaxDispatchGrid[2] =
1859+
getIntConstAttrArg(astContext, pAttr->getZ());
18331860
}
18341861
if (const auto *pAttr = FD->getAttr<HLSLNodeMaxRecursionDepthAttr>()) {
1835-
funcProps->Node.MaxRecursionDepth = pAttr->getCount();
1862+
funcProps->Node.MaxRecursionDepth =
1863+
getIntConstAttrArg(astContext, pAttr->getCount());
18361864
}
18371865
if (!FD->getAttr<HLSLNumThreadsAttr>()) {
18381866
// NumThreads wasn't specified.
@@ -2346,8 +2374,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23462374
NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23472375

23482376
if (parmDecl->hasAttr<HLSLMaxRecordsAttr>()) {
2349-
node.MaxRecords =
2350-
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount();
2377+
node.MaxRecords = getIntConstAttrArg(
2378+
astContext,
2379+
parmDecl->getAttr<HLSLMaxRecordsAttr>()->getMaxCount());
23512380
}
23522381
if (parmDecl->hasAttr<HLSLGloballyCoherentAttr>())
23532382
node.Flags.SetGloballyCoherent();
@@ -2378,7 +2407,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23782407
// OutputID from attribute
23792408
if (const auto *Attr = parmDecl->getAttr<HLSLNodeIdAttr>()) {
23802409
node.OutputID.Name = Attr->getName().str();
2381-
node.OutputID.Index = Attr->getArrayIndex();
2410+
node.OutputID.Index =
2411+
getIntConstAttrArg(astContext, Attr->getArrayIndex(), 0);
23822412
} else {
23832413
node.OutputID.Name = parmDecl->getName().str();
23842414
node.OutputID.Index = 0;
@@ -2437,7 +2467,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24372467
node.MaxRecordsSharedWith = ix;
24382468
}
24392469
if (const auto *Attr = parmDecl->getAttr<HLSLMaxRecordsAttr>())
2440-
node.MaxRecords = Attr->getMaxCount();
2470+
node.MaxRecords = getIntConstAttrArg(astContext, Attr->getMaxCount());
24412471
}
24422472

24432473
if (inputPatchCount > 1) {

tools/clang/lib/SPIRV/DeclResultIdMapper.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,7 @@ DeclResultIdMapper::getCounterVarFields(const DeclaratorDecl *decl) {
18211821
void DeclResultIdMapper::registerSpecConstant(const VarDecl *decl,
18221822
SpirvInstruction *specConstant) {
18231823
specConstant->setRValue();
1824+
spvContext.registerSpecConstant(decl, specConstant);
18241825
registerVariableForDecl(decl, createDeclSpirvInfo(specConstant));
18251826
}
18261827

tools/clang/lib/SPIRV/EmitVisitor.cpp

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2522,6 +2522,24 @@ isFieldMergeWithPrevious(const StructType::FieldInfo &previous,
25222522
return previous.fieldIndex == field.fieldIndex;
25232523
}
25242524

2525+
uint32_t EmitTypeHandler::getAttrArgInstr(ASTContext &astContext,
2526+
const Expr *expr,
2527+
uint32_t defaultVal) {
2528+
if (expr) {
2529+
llvm::APSInt apsInt;
2530+
APValue apValue;
2531+
if (expr->isIntegerConstantExpr(apsInt, astContext))
2532+
return getOrCreateConstantInt(apsInt, context.getUIntType(32), false);
2533+
if (expr->isVulkanSpecConstantExpr(astContext, &apValue) &&
2534+
apValue.isInt()) {
2535+
auto *declRefExpr = dyn_cast<DeclRefExpr>(expr);
2536+
auto *decl = dyn_cast<const VarDecl>(declRefExpr->getDecl());
2537+
return getOrAssignResultId(context.getSpecConstant(decl));
2538+
}
2539+
}
2540+
return defaultVal;
2541+
}
2542+
25252543
uint32_t EmitTypeHandler::emitType(const SpirvType *type) {
25262544
// First get the decorations that would apply to this type.
25272545
bool alreadyExists = false;
@@ -2938,29 +2956,29 @@ void EmitTypeHandler::emitDecorationsForNodePayloadArrayTypes(
29382956
// Emit decorations
29392957
const ParmVarDecl *nodeDecl = npaType->getNodeDecl();
29402958
if (hlsl::IsHLSLNodeOutputType(nodeDecl->getType())) {
2941-
StringRef name = nodeDecl->getName();
2942-
unsigned index = 0;
2943-
if (auto nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
2959+
StringRef name;
2960+
llvm::Optional<unsigned> index;
2961+
if (auto *nodeID = nodeDecl->getAttr<HLSLNodeIdAttr>()) {
29442962
name = nodeID->getName();
2945-
index = nodeID->getArrayIndex();
2963+
index = getAttrArgInstr(astContext, nodeID->getArrayIndex());
2964+
} else {
2965+
name = nodeDecl->getName();
2966+
index = llvm::None;
29462967
}
29472968

29482969
auto *str = new (context) SpirvConstantString(name);
29492970
uint32_t nodeName = getOrCreateConstantString(str);
29502971
emitDecoration(id, spv::Decoration::PayloadNodeNameAMDX, {nodeName},
29512972
llvm::None, true);
2952-
if (index) {
2953-
uint32_t baseIndex = getOrCreateConstantInt(
2954-
llvm::APInt(32, index), context.getUIntType(32), false);
2955-
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX, {baseIndex},
2956-
llvm::None, true);
2973+
if (index.hasValue()) {
2974+
emitDecoration(id, spv::Decoration::PayloadNodeBaseIndexAMDX,
2975+
{index.getValue()}, llvm::None, true);
29572976
}
29582977
}
29592978

29602979
uint32_t maxRecords;
29612980
if (const auto *attr = nodeDecl->getAttr<HLSLMaxRecordsAttr>()) {
2962-
maxRecords = getOrCreateConstantInt(llvm::APInt(32, attr->getMaxCount()),
2963-
context.getUIntType(32), false);
2981+
maxRecords = getAttrArgInstr(astContext, attr->getMaxCount(), 1);
29642982
} else {
29652983
maxRecords = getOrCreateConstantInt(llvm::APInt(32, 1),
29662984
context.getUIntType(32), false);

tools/clang/lib/SPIRV/EmitVisitor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,9 @@ class EmitTypeHandler {
6767
EmitTypeHandler(const EmitTypeHandler &) = delete;
6868
EmitTypeHandler &operator=(const EmitTypeHandler &) = delete;
6969

70+
uint32_t getAttrArgInstr(ASTContext &astContext, const Expr *expr,
71+
uint32_t defaultVal = 0);
72+
7073
// Emits the instruction for the given type into the typeConstantBinary and
7174
// returns the result-id for the type. If the type has already been emitted,
7275
// it only returns its result-id.

0 commit comments

Comments
 (0)