@@ -328,6 +328,25 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
328328};
329329} // namespace
330330
331+ static uint32_t
332+ getIntConstAttrArg (ASTContext &astContext, const Expr *expr,
333+ llvm::Optional<uint32_t > defaultVal = llvm::None) {
334+ if (expr) {
335+ llvm::APSInt apsInt;
336+ APValue apValue;
337+ if (expr->isIntegerConstantExpr (apsInt, astContext))
338+ return (uint32_t )apsInt.getSExtValue ();
339+ if (expr->isVulkanSpecConstantExpr (astContext, &apValue) && apValue.isInt ())
340+ return (uint32_t )apValue.getInt ().getSExtValue ();
341+ llvm_unreachable (
342+ " Expression must be a constant expression or spec constant" );
343+ }
344+ if (!defaultVal.hasValue ()) {
345+ DXASSERT (defaultVal.hasValue (), " missing attribute parameter" );
346+ }
347+ return defaultVal.getValue ();
348+ }
349+
331350// ------------------------------------------------------------------------------
332351//
333352// CGMSHLSLRuntime methods.
@@ -1422,6 +1441,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14221441 }
14231442
14241443 DiagnosticsEngine &Diags = CGM.getDiags ();
1444+ ASTContext &astContext = CGM.getTypes ().getContext ();
14251445
14261446 std::unique_ptr<DxilFunctionProps> funcProps =
14271447 llvm::make_unique<DxilFunctionProps>();
@@ -1632,10 +1652,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16321652
16331653 // Populate numThreads
16341654 if (const HLSLNumThreadsAttr *Attr = FD->getAttr <HLSLNumThreadsAttr>()) {
1635-
1636- funcProps->numThreads [0 ] = Attr->getX ();
1637- funcProps->numThreads [1 ] = Attr->getY ();
1638- funcProps->numThreads [2 ] = Attr->getZ ();
1655+ funcProps->numThreads [0 ] = getIntConstAttrArg (astContext, Attr->getX ());
1656+ funcProps->numThreads [1 ] = getIntConstAttrArg (astContext, Attr->getY ());
1657+ funcProps->numThreads [2 ] = getIntConstAttrArg (astContext, Attr->getZ ());
16391658
16401659 if (isEntry && !SM->IsCS () && !SM->IsMS () && !SM->IsAS ()) {
16411660 unsigned DiagID = Diags.getCustomDiagID (
@@ -1808,7 +1827,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18081827
18091828 if (const auto *pAttr = FD->getAttr <HLSLNodeIdAttr>()) {
18101829 funcProps->NodeShaderID .Name = pAttr->getName ().str ();
1811- funcProps->NodeShaderID .Index = pAttr->getArrayIndex ();
1830+ funcProps->NodeShaderID .Index =
1831+ getIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
18121832 } else {
18131833 funcProps->NodeShaderID .Name = FD->getName ().str ();
18141834 funcProps->NodeShaderID .Index = 0 ;
@@ -1819,20 +1839,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18191839 }
18201840 if (const auto *pAttr = FD->getAttr <HLSLNodeShareInputOfAttr>()) {
18211841 funcProps->NodeShaderSharedInput .Name = pAttr->getName ().str ();
1822- funcProps->NodeShaderSharedInput .Index = pAttr->getArrayIndex ();
1842+ funcProps->NodeShaderSharedInput .Index =
1843+ getIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
18231844 }
18241845 if (const auto *pAttr = FD->getAttr <HLSLNodeDispatchGridAttr>()) {
1825- funcProps->Node .DispatchGrid [0 ] = pAttr->getX ();
1826- funcProps->Node .DispatchGrid [1 ] = pAttr->getY ();
1827- funcProps->Node .DispatchGrid [2 ] = pAttr->getZ ();
1846+ funcProps->Node .DispatchGrid [0 ] =
1847+ getIntConstAttrArg (astContext, pAttr->getX ());
1848+ funcProps->Node .DispatchGrid [1 ] =
1849+ getIntConstAttrArg (astContext, pAttr->getY ());
1850+ funcProps->Node .DispatchGrid [2 ] =
1851+ getIntConstAttrArg (astContext, pAttr->getZ ());
18281852 }
18291853 if (const auto *pAttr = FD->getAttr <HLSLNodeMaxDispatchGridAttr>()) {
1830- funcProps->Node .MaxDispatchGrid [0 ] = pAttr->getX ();
1831- funcProps->Node .MaxDispatchGrid [1 ] = pAttr->getY ();
1832- funcProps->Node .MaxDispatchGrid [2 ] = pAttr->getZ ();
1854+ funcProps->Node .MaxDispatchGrid [0 ] =
1855+ getIntConstAttrArg (astContext, pAttr->getX ());
1856+ funcProps->Node .MaxDispatchGrid [1 ] =
1857+ getIntConstAttrArg (astContext, pAttr->getY ());
1858+ funcProps->Node .MaxDispatchGrid [2 ] =
1859+ getIntConstAttrArg (astContext, pAttr->getZ ());
18331860 }
18341861 if (const auto *pAttr = FD->getAttr <HLSLNodeMaxRecursionDepthAttr>()) {
1835- funcProps->Node .MaxRecursionDepth = pAttr->getCount ();
1862+ funcProps->Node .MaxRecursionDepth =
1863+ getIntConstAttrArg (astContext, pAttr->getCount ());
18361864 }
18371865 if (!FD->getAttr <HLSLNumThreadsAttr>()) {
18381866 // NumThreads wasn't specified.
@@ -2346,8 +2374,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23462374 NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23472375
23482376 if (parmDecl->hasAttr <HLSLMaxRecordsAttr>()) {
2349- node.MaxRecords =
2350- parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount ();
2377+ node.MaxRecords = getIntConstAttrArg (
2378+ astContext,
2379+ parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount ());
23512380 }
23522381 if (parmDecl->hasAttr <HLSLGloballyCoherentAttr>())
23532382 node.Flags .SetGloballyCoherent ();
@@ -2378,7 +2407,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23782407 // OutputID from attribute
23792408 if (const auto *Attr = parmDecl->getAttr <HLSLNodeIdAttr>()) {
23802409 node.OutputID .Name = Attr->getName ().str ();
2381- node.OutputID .Index = Attr->getArrayIndex ();
2410+ node.OutputID .Index =
2411+ getIntConstAttrArg (astContext, Attr->getArrayIndex (), 0 );
23822412 } else {
23832413 node.OutputID .Name = parmDecl->getName ().str ();
23842414 node.OutputID .Index = 0 ;
@@ -2437,7 +2467,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24372467 node.MaxRecordsSharedWith = ix;
24382468 }
24392469 if (const auto *Attr = parmDecl->getAttr <HLSLMaxRecordsAttr>())
2440- node.MaxRecords = Attr->getMaxCount ();
2470+ node.MaxRecords = getIntConstAttrArg (astContext, Attr->getMaxCount () );
24412471 }
24422472
24432473 if (inputPatchCount > 1 ) {
0 commit comments