1313// ===----------------------------------------------------------------------===//
1414
1515#include " CGHLSLRuntime.h"
16- #include " Address.h"
1716#include " CGDebugInfo.h"
1817#include " CodeGenFunction.h"
1918#include " CodeGenModule.h"
19+ #include " HLSLBufferLayoutBuilder.h"
2020#include " TargetInfo.h"
2121#include " clang/AST/ASTContext.h"
2222#include " clang/AST/Attrs.inc"
2626#include " clang/AST/Type.h"
2727#include " clang/Basic/TargetOptions.h"
2828#include " clang/Frontend/FrontendDiagnostic.h"
29+ #include " llvm/ADT/ScopeExit.h"
2930#include " llvm/ADT/SmallString.h"
3031#include " llvm/ADT/SmallVector.h"
3132#include " llvm/Frontend/HLSL/RootSignatureMetadata.h"
4344#include < cstdint>
4445#include < optional>
4546
47+ #define DEBUG_TYPE " cghlslruntime"
48+
4649using namespace clang ;
4750using namespace CodeGen ;
4851using namespace clang ::hlsl;
@@ -265,9 +268,9 @@ CGHLSLRuntime::convertHLSLSpecificType(const Type *T,
265268 assert (T->isHLSLSpecificType () && " Not an HLSL specific type!" );
266269
267270 // Check if the target has a specific translation for this type first.
268- if (llvm::Type *TargetTy =
271+ if (llvm::Type *LayoutTy =
269272 CGM.getTargetCodeGenInfo ().getHLSLType (CGM, T, Packoffsets))
270- return TargetTy ;
273+ return LayoutTy ;
271274
272275 llvm_unreachable (" Generic handling of HLSL types is not supported." );
273276}
@@ -284,10 +287,8 @@ void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
284287
285288 // get the layout struct from constant buffer target type
286289 llvm::Type *BufType = BufGV->getValueType ();
287- llvm::Type *BufLayoutType =
288- cast<llvm::TargetExtType>(BufType)->getTypeParameter (0 );
289290 llvm::StructType *LayoutStruct = cast<llvm::StructType>(
290- cast<llvm::TargetExtType>(BufLayoutType )->getTypeParameter (0 ));
291+ cast<llvm::TargetExtType>(BufType )->getTypeParameter (0 ));
291292
292293 // Start metadata list associating the buffer global variable with its
293294 // constatns
@@ -326,6 +327,9 @@ void CGHLSLRuntime::emitBufferGlobalsAndMetadata(const HLSLBufferDecl *BufDecl,
326327 continue ;
327328 }
328329
330+ if (CGM.getTargetCodeGenInfo ().isHLSLPadding (*ElemIt))
331+ ++ElemIt;
332+
329333 assert (ElemIt != LayoutStruct->element_end () &&
330334 " number of elements in layout struct does not match" );
331335 llvm::Type *LayoutType = *ElemIt++;
@@ -423,12 +427,11 @@ void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *BufDecl) {
423427 if (BufDecl->hasValidPackoffset ())
424428 fillPackoffsetLayout (BufDecl, Layout);
425429
426- llvm::TargetExtType *TargetTy =
427- cast<llvm::TargetExtType>(convertHLSLSpecificType (
428- ResHandleTy, BufDecl->hasValidPackoffset () ? &Layout : nullptr ));
430+ llvm::Type *LayoutTy = convertHLSLSpecificType (
431+ ResHandleTy, BufDecl->hasValidPackoffset () ? &Layout : nullptr );
429432 llvm::GlobalVariable *BufGV = new GlobalVariable (
430- TargetTy , /* isConstant*/ false ,
431- GlobalValue::LinkageTypes::ExternalLinkage, PoisonValue::get (TargetTy ),
433+ LayoutTy , /* isConstant*/ false ,
434+ GlobalValue::LinkageTypes::ExternalLinkage, PoisonValue::get (LayoutTy ),
432435 llvm::formatv (" {0}{1}" , BufDecl->getName (),
433436 BufDecl->isCBuffer () ? " .cb" : " .tb" ),
434437 GlobalValue::NotThreadLocal);
@@ -454,7 +457,7 @@ void CGHLSLRuntime::addRootSignature(
454457 SignatureDecl->getRootElements (), nullptr , M);
455458}
456459
457- llvm::TargetExtType *
460+ llvm::StructType *
458461CGHLSLRuntime::getHLSLBufferLayoutType (const RecordType *StructType) {
459462 const auto Entry = LayoutTypes.find (StructType);
460463 if (Entry != LayoutTypes.end ())
@@ -463,7 +466,7 @@ CGHLSLRuntime::getHLSLBufferLayoutType(const RecordType *StructType) {
463466}
464467
465468void CGHLSLRuntime::addHLSLBufferLayoutType (const RecordType *StructType,
466- llvm::TargetExtType *LayoutTy) {
469+ llvm::StructType *LayoutTy) {
467470 assert (getHLSLBufferLayoutType (StructType) == nullptr &&
468471 " layout type for this struct already exist" );
469472 LayoutTypes[StructType] = LayoutTy;
@@ -997,3 +1000,139 @@ std::optional<LValue> CGHLSLRuntime::emitResourceArraySubscriptExpr(
9971000 }
9981001 return CGF.MakeAddrLValue (TmpVar, ResultTy, AlignmentSource::Decl);
9991002}
1003+
1004+ namespace {
1005+ class HLSLBufferCopyEmitter {
1006+ CodeGenFunction &CGF;
1007+ Address DestPtr;
1008+ Address SrcPtr;
1009+ llvm::Type *LayoutTy = nullptr ;
1010+
1011+ SmallVector<llvm::Value *> CurStoreIndices;
1012+ SmallVector<llvm::Value *> CurLoadIndices;
1013+
1014+ void emitCopyAtIndices (llvm::Type *FieldTy, unsigned StoreIndex,
1015+ unsigned LoadIndex) {
1016+ CurStoreIndices.push_back (llvm::ConstantInt::get (CGF.SizeTy , StoreIndex));
1017+ CurLoadIndices.push_back (llvm::ConstantInt::get (CGF.SizeTy , LoadIndex));
1018+ auto RestoreIndices = llvm::make_scope_exit ([&]() {
1019+ CurStoreIndices.pop_back ();
1020+ CurLoadIndices.pop_back ();
1021+ });
1022+
1023+ if (processArray (FieldTy))
1024+ return ;
1025+ if (processBufferLayoutArray (FieldTy))
1026+ return ;
1027+ if (processStruct (FieldTy))
1028+ return ;
1029+
1030+ // We have a scalar or vector element - emit a copy.
1031+ CharUnits Align = CharUnits::fromQuantity (
1032+ CGF.CGM .getDataLayout ().getABITypeAlign (FieldTy));
1033+ Address SrcGEP = RawAddress (
1034+ CGF.Builder .CreateInBoundsGEP (LayoutTy, SrcPtr.getBasePointer (),
1035+ CurLoadIndices, " cbuf.src" ),
1036+ FieldTy, Align, SrcPtr.isKnownNonNull ());
1037+ Address DestGEP = CGF.Builder .CreateInBoundsGEP (
1038+ DestPtr, CurStoreIndices, FieldTy, Align, " cbuf.dest" );
1039+ llvm::Value *Load = CGF.Builder .CreateLoad (SrcGEP, " cbuf.load" );
1040+ CGF.Builder .CreateStore (Load, DestGEP);
1041+ }
1042+
1043+ bool processArray (llvm::Type *FieldTy) {
1044+ auto *AT = dyn_cast<llvm::ArrayType>(FieldTy);
1045+ if (!AT)
1046+ return false ;
1047+
1048+ // If we have an array then there isn't any padding
1049+ // between elements. We just need to copy each element over.
1050+ for (unsigned I = 0 , E = AT->getNumElements (); I < E; ++I)
1051+ emitCopyAtIndices (AT->getElementType (), I, I);
1052+ return true ;
1053+ }
1054+
1055+ bool processBufferLayoutArray (llvm::Type *FieldTy) {
1056+ auto *ST = dyn_cast<llvm::StructType>(FieldTy);
1057+ if (!ST || ST->getNumElements () != 2 )
1058+ return false ;
1059+
1060+ auto *PaddedEltsTy = dyn_cast<llvm::ArrayType>(ST->getElementType (0 ));
1061+ if (!PaddedEltsTy)
1062+ return false ;
1063+
1064+ auto *PaddedTy = dyn_cast<llvm::StructType>(PaddedEltsTy->getElementType ());
1065+ if (!PaddedTy || PaddedTy->getNumElements () != 2 )
1066+ return false ;
1067+
1068+ if (!CGF.CGM .getTargetCodeGenInfo ().isHLSLPadding (
1069+ PaddedTy->getElementType (1 )))
1070+ return false ;
1071+
1072+ llvm::Type *ElementTy = ST->getElementType (1 );
1073+ if (PaddedTy->getElementType (0 ) != ElementTy)
1074+ return false ;
1075+
1076+ // All but the last of the logical array elements are in the padded array.
1077+ unsigned NumElts = PaddedEltsTy->getNumElements () + 1 ;
1078+
1079+ // Add an extra indirection to the load for the struct and walk the
1080+ // array prefix.
1081+ CurLoadIndices.push_back (llvm::ConstantInt::get (CGF.SizeTy , 0 ));
1082+ for (unsigned I = 0 ; I < NumElts - 1 ; ++I) {
1083+ // We need to copy the element itself, without the padding.
1084+ CurLoadIndices.push_back (llvm::ConstantInt::get (CGF.SizeTy , I));
1085+ emitCopyAtIndices (ElementTy, I, 0 );
1086+ CurLoadIndices.pop_back ();
1087+ }
1088+ CurLoadIndices.pop_back ();
1089+
1090+ // Now copy the last element.
1091+ emitCopyAtIndices (ElementTy, NumElts - 1 , 1 );
1092+
1093+ return true ;
1094+ }
1095+
1096+ bool processStruct (llvm::Type *FieldTy) {
1097+ auto *ST = dyn_cast<llvm::StructType>(FieldTy);
1098+ if (!ST)
1099+ return false ;
1100+
1101+ unsigned Skipped = 0 ;
1102+ for (unsigned I = 0 , E = ST->getNumElements (); I < E; ++I) {
1103+ llvm::Type *ElementTy = ST->getElementType (I);
1104+ if (CGF.CGM .getTargetCodeGenInfo ().isHLSLPadding (ElementTy))
1105+ ++Skipped;
1106+ else
1107+ emitCopyAtIndices (ElementTy, I, I + Skipped);
1108+ }
1109+ return true ;
1110+ }
1111+
1112+ public:
1113+ HLSLBufferCopyEmitter (CodeGenFunction &CGF, Address DestPtr, Address SrcPtr)
1114+ : CGF(CGF), DestPtr(DestPtr), SrcPtr(SrcPtr) {}
1115+
1116+ bool emitCopy (QualType CType) {
1117+ LayoutTy = HLSLBufferLayoutBuilder (CGF.CGM ).layOutType (CType);
1118+
1119+ LLVM_DEBUG ({
1120+ dbgs () << " Emitting copy of " ;
1121+ LayoutTy->print (dbgs ());
1122+ dbgs () << " \n " ;
1123+ });
1124+
1125+ // If we don't have an aggregate, we can just fall back to normal memcpy.
1126+ if (!LayoutTy->isAggregateType ())
1127+ return false ;
1128+
1129+ emitCopyAtIndices (LayoutTy, 0 , 0 );
1130+ return true ;
1131+ }
1132+ };
1133+ } // namespace
1134+
1135+ bool CGHLSLRuntime::emitBufferCopy (CodeGenFunction &CGF, Address DestPtr,
1136+ Address SrcPtr, QualType CType) {
1137+ return HLSLBufferCopyEmitter (CGF, DestPtr, SrcPtr).emitCopy (CType);
1138+ }
0 commit comments