@@ -328,6 +328,19 @@ class CGMSHLSLRuntime : public CGHLSLRuntime {
328328};
329329} // namespace
330330
331+ static uint32_t GetIntConstAttrArg (ASTContext &astContext, const Expr *expr,
332+ uint32_t defaultVal = 0 ) {
333+ if (expr) {
334+ llvm::APSInt apsInt;
335+ APValue apValue;
336+ if (expr->isIntegerConstantExpr (apsInt, astContext))
337+ return (uint32_t )apsInt.getSExtValue ();
338+ if (expr->isVulkanSpecConstantExpr (astContext, &apValue) && apValue.isInt ())
339+ return (uint32_t )apValue.getInt ().getSExtValue ();
340+ }
341+ return defaultVal;
342+ }
343+
331344// ------------------------------------------------------------------------------
332345//
333346// CGMSHLSLRuntime methods.
@@ -1422,6 +1435,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
14221435 }
14231436
14241437 DiagnosticsEngine &Diags = CGM.getDiags ();
1438+ ASTContext &astContext = CGM.getTypes ().getContext ();
14251439
14261440 std::unique_ptr<DxilFunctionProps> funcProps =
14271441 llvm::make_unique<DxilFunctionProps>();
@@ -1632,10 +1646,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
16321646
16331647 // Populate numThreads
16341648 if (const HLSLNumThreadsAttr *Attr = FD->getAttr <HLSLNumThreadsAttr>()) {
1635-
1636- funcProps->numThreads [0 ] = Attr->getX ();
1637- funcProps->numThreads [1 ] = Attr->getY ();
1638- funcProps->numThreads [2 ] = Attr->getZ ();
1649+ funcProps->numThreads [0 ] = GetIntConstAttrArg (astContext, Attr->getX (), 1 );
1650+ funcProps->numThreads [1 ] = GetIntConstAttrArg (astContext, Attr->getY (), 1 );
1651+ funcProps->numThreads [2 ] = GetIntConstAttrArg (astContext, Attr->getZ (), 1 );
16391652
16401653 if (isEntry && !SM->IsCS () && !SM->IsMS () && !SM->IsAS ()) {
16411654 unsigned DiagID = Diags.getCustomDiagID (
@@ -1808,7 +1821,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18081821
18091822 if (const auto *pAttr = FD->getAttr <HLSLNodeIdAttr>()) {
18101823 funcProps->NodeShaderID .Name = pAttr->getName ().str ();
1811- funcProps->NodeShaderID .Index = pAttr->getArrayIndex ();
1824+ funcProps->NodeShaderID .Index =
1825+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
18121826 } else {
18131827 funcProps->NodeShaderID .Name = FD->getName ().str ();
18141828 funcProps->NodeShaderID .Index = 0 ;
@@ -1819,20 +1833,28 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
18191833 }
18201834 if (const auto *pAttr = FD->getAttr <HLSLNodeShareInputOfAttr>()) {
18211835 funcProps->NodeShaderSharedInput .Name = pAttr->getName ().str ();
1822- funcProps->NodeShaderSharedInput .Index = pAttr->getArrayIndex ();
1836+ funcProps->NodeShaderSharedInput .Index =
1837+ GetIntConstAttrArg (astContext, pAttr->getArrayIndex (), 0 );
18231838 }
18241839 if (const auto *pAttr = FD->getAttr <HLSLNodeDispatchGridAttr>()) {
1825- funcProps->Node .DispatchGrid [0 ] = pAttr->getX ();
1826- funcProps->Node .DispatchGrid [1 ] = pAttr->getY ();
1827- funcProps->Node .DispatchGrid [2 ] = pAttr->getZ ();
1840+ funcProps->Node .DispatchGrid [0 ] =
1841+ GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1842+ funcProps->Node .DispatchGrid [1 ] =
1843+ GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1844+ funcProps->Node .DispatchGrid [2 ] =
1845+ GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
18281846 }
18291847 if (const auto *pAttr = FD->getAttr <HLSLNodeMaxDispatchGridAttr>()) {
1830- funcProps->Node .MaxDispatchGrid [0 ] = pAttr->getX ();
1831- funcProps->Node .MaxDispatchGrid [1 ] = pAttr->getY ();
1832- funcProps->Node .MaxDispatchGrid [2 ] = pAttr->getZ ();
1848+ funcProps->Node .MaxDispatchGrid [0 ] =
1849+ GetIntConstAttrArg (astContext, pAttr->getX (), 1 );
1850+ funcProps->Node .MaxDispatchGrid [1 ] =
1851+ GetIntConstAttrArg (astContext, pAttr->getY (), 1 );
1852+ funcProps->Node .MaxDispatchGrid [2 ] =
1853+ GetIntConstAttrArg (astContext, pAttr->getZ (), 1 );
18331854 }
18341855 if (const auto *pAttr = FD->getAttr <HLSLNodeMaxRecursionDepthAttr>()) {
1835- funcProps->Node .MaxRecursionDepth = pAttr->getCount ();
1856+ funcProps->Node .MaxRecursionDepth =
1857+ GetIntConstAttrArg (astContext, pAttr->getCount (), 0 );
18361858 }
18371859 if (!FD->getAttr <HLSLNumThreadsAttr>()) {
18381860 // NumThreads wasn't specified.
@@ -2346,8 +2368,9 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23462368 NodeInputRecordParams[ArgIt].MetadataIdx = NodeInputParamIdx++;
23472369
23482370 if (parmDecl->hasAttr <HLSLMaxRecordsAttr>()) {
2349- node.MaxRecords =
2350- parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount ();
2371+ node.MaxRecords = GetIntConstAttrArg (
2372+ astContext,
2373+ parmDecl->getAttr <HLSLMaxRecordsAttr>()->getMaxCount (), 1 );
23512374 }
23522375 if (parmDecl->hasAttr <HLSLGloballyCoherentAttr>())
23532376 node.Flags .SetGloballyCoherent ();
@@ -2378,7 +2401,8 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
23782401 // OutputID from attribute
23792402 if (const auto *Attr = parmDecl->getAttr <HLSLNodeIdAttr>()) {
23802403 node.OutputID .Name = Attr->getName ().str ();
2381- node.OutputID .Index = Attr->getArrayIndex ();
2404+ node.OutputID .Index =
2405+ GetIntConstAttrArg (astContext, Attr->getArrayIndex (), 0 );
23822406 } else {
23832407 node.OutputID .Name = parmDecl->getName ().str ();
23842408 node.OutputID .Index = 0 ;
@@ -2437,7 +2461,7 @@ void CGMSHLSLRuntime::AddHLSLFunctionInfo(Function *F, const FunctionDecl *FD) {
24372461 node.MaxRecordsSharedWith = ix;
24382462 }
24392463 if (const auto *Attr = parmDecl->getAttr <HLSLMaxRecordsAttr>())
2440- node.MaxRecords = Attr->getMaxCount ();
2464+ node.MaxRecords = GetIntConstAttrArg (astContext, Attr->getMaxCount (), 0 );
24412465 }
24422466
24432467 if (inputPatchCount > 1 ) {
0 commit comments