Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions tools/clang/unittests/HLSLExec/LongVectorOps.def
Original file line number Diff line number Diff line change
Expand Up @@ -193,4 +193,6 @@ OP_LOAD_AND_STORE_SB(LoadAndStore_RD_SB_SRV, "RootDescriptor_SRV")
#undef OP_LOAD_AND_STORE
#undef OP_LOAD_AND_STORE_DEFINES

OP_DEFAULT(Wave, WaveActiveSum, 1, "WaveActiveSum", "")

#undef OP
4 changes: 4 additions & 0 deletions tools/clang/unittests/HLSLExec/LongVectorTestData.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@ struct HLSLHalf_t {

Val = DirectX::PackedVector::XMConvertFloatToHalf(F);
}
HLSLHalf_t(const uint32_t U) {
float F = static_cast<float>(U);
Val = DirectX::PackedVector::XMConvertFloatToHalf(F);
}

// PackedVector::HALF is a uint16. Make sure we don't ever accidentally
// convert one of these to a HLSLHalf_t by arithmetically converting it to a
Expand Down
168 changes: 147 additions & 21 deletions tools/clang/unittests/HLSLExec/LongVectors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,10 +313,10 @@ static WEX::Common::String getInputValueSetName(size_t Index) {
return ValueSetName;
}

std::string getCompilerOptionsString(const Operation &Operation,
const DataType &OpDataType,
const DataType &OutDataType,
size_t VectorSize) {
std::string getCompilerOptionsString(
const Operation &Operation, const DataType &OpDataType,
const DataType &OutDataType, size_t VectorSize,
std::optional<std::string> AdditionalOptions = std::nullopt) {
std::stringstream CompilerOptions;

if (OpDataType.Is16Bit || OutDataType.Is16Bit)
Expand All @@ -337,6 +337,9 @@ std::string getCompilerOptionsString(const Operation &Operation,

CompilerOptions << " -DBASIC_OP_TYPE=0x" << std::hex << Operation.Arity;

if (AdditionalOptions)
CompilerOptions << " " << AdditionalOptions.value();

return CompilerOptions.str();
}

Expand Down Expand Up @@ -387,7 +390,8 @@ template <typename OUT_TYPE, typename T>
std::optional<std::vector<OUT_TYPE>>
runTest(ID3D12Device *D3DDevice, bool VerboseLogging,
const Operation &Operation, const InputSets<T> &Inputs,
size_t ExpectedOutputSize) {
size_t ExpectedOutputSize,
std::optional<std::string> AdditionalCompilerOptions) {
DXASSERT_NOMSG(Inputs.size() == Operation.Arity);

if (VerboseLogging) {
Expand All @@ -403,8 +407,9 @@ runTest(ID3D12Device *D3DDevice, bool VerboseLogging,

// We have to construct the string outside of the lambda. Otherwise it's
// cleaned up when the lambda finishes executing but before the shader runs.
std::string CompilerOptionsString = getCompilerOptionsString(
Operation, OpDataType, OutDataType, Inputs[0].size());
std::string CompilerOptionsString =
getCompilerOptionsString(Operation, OpDataType, OutDataType,
Inputs[0].size(), AdditionalCompilerOptions);

dxc::SpecificDllLoader DxilDllLoader;
CComPtr<IStream> TestXML;
Expand Down Expand Up @@ -570,13 +575,15 @@ struct ValidationConfig {
};

template <typename T, typename OUT_TYPE>
void runAndVerify(ID3D12Device *D3DDevice, bool VerboseLogging,
const Operation &Operation, const InputSets<T> &Inputs,
const std::vector<OUT_TYPE> &Expected,
const ValidationConfig &ValidationConfig) {
void runAndVerify(
ID3D12Device *D3DDevice, bool VerboseLogging, const Operation &Operation,
const InputSets<T> &Inputs, const std::vector<OUT_TYPE> &Expected,
const ValidationConfig &ValidationConfig,
std::optional<std::string> AdditionalCompilerOptions = std::nullopt) {

std::optional<std::vector<OUT_TYPE>> Actual = runTest<OUT_TYPE>(
D3DDevice, VerboseLogging, Operation, Inputs, Expected.size());
std::optional<std::vector<OUT_TYPE>> Actual =
runTest<OUT_TYPE>(D3DDevice, VerboseLogging, Operation, Inputs,
Expected.size(), AdditionalCompilerOptions);

// If the test didn't run, don't verify anything.
if (!Actual)
Expand Down Expand Up @@ -1253,6 +1260,19 @@ FLOAT_SPECIAL_OP(OpType::IsInf, (std::isinf(A)));
FLOAT_SPECIAL_OP(OpType::IsNan, (std::isnan(A)));
#undef FLOAT_SPECIAL_OP

//
// Wave Ops
//

#define WAVE_ACTIVE_OP(OP, IMPL) \
template <typename T> struct Op<OP, T, 1> : DefaultValidation<T> { \
T operator()(T A, T WaveSize) { return IMPL; } \
};

WAVE_ACTIVE_OP(OpType::WaveActiveSum, (A * WaveSize));

#undef WAVE_ACTIVE_OP

//
// dispatchTest
//
Expand Down Expand Up @@ -1296,9 +1316,25 @@ template <OpType OP, typename T> struct ExpectedBuilder {
}
};

template <OpType OP, typename T> struct WaveOpExpectedBuilder {

static auto buildExpected(Op<OP, T, 1> Op, const InputSets<T> &Inputs,
UINT WaveSize) {
DXASSERT_NOMSG(Inputs.size() == 1);
const T WaveSizeT = static_cast<T>(WaveSize);

std::vector<decltype(Op(T(), WaveSizeT))> Expected;
Expected.reserve(Inputs[0].size());

for (size_t I = 0; I < Inputs[0].size(); ++I)
Expected.push_back(Op(Inputs[0][I], WaveSizeT));

return Expected;
}
};

template <typename T, OpType OP>
void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
size_t OverrideInputSize) {
std::vector<size_t> getInputSizesToTest(size_t OverrideInputSize) {
std::vector<size_t> InputVectorSizes;
const std::array<size_t, 8> DefaultInputSizes = {3, 5, 16, 17,
35, 100, 256, 1024};
Expand All @@ -1319,8 +1355,17 @@ void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
InputVectorSizes.push_back(MaxInputSize);
}

constexpr const Operation &Operation = getOperation(OP);
return InputVectorSizes;
}

template <typename T, OpType OP>
void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
size_t OverrideInputSize) {

const std::vector<size_t> InputVectorSizes =
getInputSizesToTest<T, OP>(OverrideInputSize);

constexpr const Operation &Operation = getOperation(OP);
Op<OP, T, Operation.Arity> Op;

for (size_t VectorSize : InputVectorSizes) {
Expand All @@ -1334,6 +1379,32 @@ void dispatchTest(ID3D12Device *D3DDevice, bool VerboseLogging,
}
}

template <typename T, OpType OP>
void dispatchWaveOpTest(ID3D12Device *D3DDevice, bool VerboseLogging,
size_t OverrideInputSize, UINT WaveSize) {

const std::vector<size_t> InputVectorSizes =
getInputSizesToTest<T, OP>(OverrideInputSize);

constexpr const Operation &Operation = getOperation(OP);
Op<OP, T, Operation.Arity> Op;

const std::string AdditionalCompilerOptions =
"-DWAVE_SIZE=" + std::to_string(WaveSize) +
" -DNUMTHREADS_X=" + std::to_string(WaveSize);

for (size_t VectorSize : InputVectorSizes) {
std::vector<std::vector<T>> Inputs =
buildTestInputs<T>(VectorSize, Operation.InputSets, Operation.Arity);

auto Expected =
WaveOpExpectedBuilder<OP, T>::buildExpected(Op, Inputs, WaveSize);

runAndVerify(D3DDevice, VerboseLogging, Operation, Inputs, Expected,
Op.ValidationConfig, AdditionalCompilerOptions);
}
}

} // namespace LongVector

using namespace LongVector;
Expand All @@ -1342,6 +1413,14 @@ using namespace LongVector;
#define HLK_TEST(Op, DataType) \
TEST_METHOD(Op##_##DataType) { runTest<DataType, OpType::Op>(); }

#define HLK_WAVEOP_TEST(Op, DataType) \
TEST_METHOD(Op##_##DataType) { \
BEGIN_TEST_METHOD_PROPERTIES() \
TEST_METHOD_PROPERTY(L"Priority", L"2") \
END_TEST_METHOD_PROPERTIES() \
runWaveOpTest<DataType, OpType::Op>(); \
}

class DxilConf_SM69_Vectorized {
public:
BEGIN_TEST_CLASS(DxilConf_SM69_Vectorized)
Expand Down Expand Up @@ -1405,6 +1484,9 @@ class DxilConf_SM69_Vectorized {
WEX::TestExecution::RuntimeParameters::TryGetValue(L"InputSize",
OverrideInputSize);

WEX::TestExecution::RuntimeParameters::TryGetValue(L"WaveLaneCount",
OverrideWaveLaneCount);

bool IsRITP = false;
WEX::TestExecution::RuntimeParameters::TryGetValue(L"RITP", IsRITP);

Expand All @@ -1428,16 +1510,47 @@ class DxilConf_SM69_Vectorized {
return true;
}

template <typename T, OpType OP> void runTest() {
WEX::TestExecution::SetVerifyOutput verifySettings(
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);

TEST_METHOD_SETUP(methodSetup) {
// It's possible a previous test case caused a device removal. If it did we
// need to try and create a new device.
if (!D3DDevice || D3DDevice->GetDeviceRemovedReason() != S_OK)
if (!D3DDevice || D3DDevice->GetDeviceRemovedReason() != S_OK) {
hlsl_test::LogCommentFmt(
L"Device was lost: Attempting to create a new D3D12 device.");
VERIFY_IS_TRUE(
createDevice(&D3DDevice, ExecTestUtils::D3D_SHADER_MODEL_6_9, false));
}

return true;
}

template <typename T, OpType OP> void runWaveOpTest() {
WEX::TestExecution::SetVerifyOutput VerifySettings(
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);

UINT WaveSize = 0;

if (OverrideWaveLaneCount > 0) {
WaveSize = OverrideWaveLaneCount;
hlsl_test::LogCommentFmt(
L"Using overridden WaveLaneCount of %d for this test.", WaveSize);
} else {
D3D12_FEATURE_DATA_D3D12_OPTIONS1 WaveOpts;
VERIFY_SUCCEEDED(D3DDevice->CheckFeatureSupport(
D3D12_FEATURE_D3D12_OPTIONS1, &WaveOpts, sizeof(WaveOpts)));

WaveSize = WaveOpts.WaveLaneCountMin;
}

DXASSERT_NOMSG(WaveSize > 0);
DXASSERT((WaveSize & (WaveSize - 1)) == 0, "must be a power of 2");

dispatchWaveOpTest<T, OP>(D3DDevice, VerboseLogging, OverrideInputSize,
WaveSize);
}

template <typename T, OpType OP> void runTest() {
WEX::TestExecution::SetVerifyOutput verifySettings(
WEX::TestExecution::VerifyOutputSettings::LogOnlyFailures);
dispatchTest<T, OP>(D3DDevice, VerboseLogging, OverrideInputSize);
}

Expand Down Expand Up @@ -2052,9 +2165,22 @@ class DxilConf_SM69_Vectorized {
HLK_TEST(LoadAndStore_RD_SB_SRV, double);
HLK_TEST(LoadAndStore_RD_SB_UAV, double);

HLK_WAVEOP_TEST(WaveActiveSum, int16_t);
HLK_WAVEOP_TEST(WaveActiveSum, int32_t);
HLK_WAVEOP_TEST(WaveActiveSum, int64_t);

HLK_WAVEOP_TEST(WaveActiveSum, uint16_t);
HLK_WAVEOP_TEST(WaveActiveSum, uint32_t);
HLK_WAVEOP_TEST(WaveActiveSum, uint64_t);

HLK_WAVEOP_TEST(WaveActiveSum, HLSLHalf_t);
HLK_WAVEOP_TEST(WaveActiveSum, float);
HLK_WAVEOP_TEST(WaveActiveSum, double);

private:
bool Initialized = false;
bool VerboseLogging = false;
size_t OverrideInputSize = 0;
UINT OverrideWaveLaneCount = 0;
CComPtr<ID3D12Device> D3DDevice;
};
15 changes: 14 additions & 1 deletion tools/clang/unittests/HLSLExec/ShaderOpArith.xml
Original file line number Diff line number Diff line change
Expand Up @@ -4101,7 +4101,20 @@ void MSMain(uint GID : SV_GroupIndex,
}
#endif

[numthreads(1,1,1)]
#ifdef NUMTHREADS_X
#define NUMTHREADS_ATTR [numthreads(NUMTHREADS_X, 1, 1)]
#else
#define NUMTHREADS_ATTR [numthreads(1, 1, 1)]
#endif

#ifdef WAVE_SIZE
#define WAVE_SIZE_ATTR [WaveSize(WAVE_SIZE)]
#else
#define WAVE_SIZE_ATTR
#endif

WAVE_SIZE_ATTR
NUMTHREADS_ATTR
void main(uint GI : SV_GroupIndex) {

#ifdef FUNC_SHUFFLE_VECTOR
Expand Down