From 42c73c867fb59ff6eb1811246c313d0d519cef58 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Thu, 19 Sep 2024 10:33:20 -0400 Subject: [PATCH 01/13] [Validator] Check size of PSV. Check size of PSV part matches the PSVVersion. Updated DxilPipelineStateValidation::ReadOrWrite to read based on initInfo.PSVVersion. And return fail when size mismatch in RWMode::Read. Fixes #6817 --- docs/ReleaseNotes.md | 1 + .../DxilPipelineStateValidation.h | 24 +- .../DxilContainerValidation.cpp | 17 +- tools/clang/unittests/HLSL/ValidationTest.cpp | 315 ++++++++++++++++++ 4 files changed, 341 insertions(+), 16 deletions(-) diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md index e9021e46b1..d2ee393261 100644 --- a/docs/ReleaseNotes.md +++ b/docs/ReleaseNotes.md @@ -23,6 +23,7 @@ The included licenses apply to the following files: Place release notes for the upcoming release below this line and remove this line upon naming this release. - The incomplete WaveMatrix implementation has been removed. +- The DXIL validator supports the string table and semantic index table in any order. ### Version 1.8.2407 diff --git a/include/dxc/DxilContainer/DxilPipelineStateValidation.h b/include/dxc/DxilContainer/DxilPipelineStateValidation.h index ed482b39a7..031cf6dc2a 100644 --- a/include/dxc/DxilContainer/DxilPipelineStateValidation.h +++ b/include/dxc/DxilContainer/DxilPipelineStateValidation.h @@ -554,6 +554,7 @@ class DxilPipelineStateValidation { uint32_t GetSize() { return Size; } RWMode GetMode() { return Mode; } + uint32_t GetOffset() { return Offset; } // Return true if size fits in remaing buffer. bool CheckBounds(size_t size); @@ -595,8 +596,9 @@ class DxilPipelineStateValidation { const PSVInitInfo &initInfo = PSVInitInfo(MAX_PSV_VERSION)); public: - bool InitFromPSV0(const void *pBits, uint32_t size) { - return ReadOrWrite(pBits, &size, RWMode::Read); + bool InitFromPSV0(const void *pBits, uint32_t size, + uint32_t PSVVersion = MAX_PSV_VERSION) { + return ReadOrWrite(pBits, &size, RWMode::Read, PSVInitInfo(PSVVersion)); } // Initialize a new buffer @@ -937,12 +939,15 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize, PSV_RETB(rw.MapValue(&m_uPSVRuntimeInfoSize, initInfo.RuntimeInfoSize())); PSV_RETB(rw.MapArray(&m_pPSVRuntimeInfo0, 1, m_uPSVRuntimeInfoSize)); - AssignDerived(&m_pPSVRuntimeInfo1, m_pPSVRuntimeInfo0, - m_uPSVRuntimeInfoSize); // failure ok - AssignDerived(&m_pPSVRuntimeInfo2, m_pPSVRuntimeInfo0, - m_uPSVRuntimeInfoSize); // failure ok - AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0, - m_uPSVRuntimeInfoSize); // failure ok + if (initInfo.PSVVersion > 0) + AssignDerived(&m_pPSVRuntimeInfo1, m_pPSVRuntimeInfo0, + m_uPSVRuntimeInfoSize); // failure ok + if (initInfo.PSVVersion > 1) + AssignDerived(&m_pPSVRuntimeInfo2, m_pPSVRuntimeInfo0, + m_uPSVRuntimeInfoSize); // failure ok + if (initInfo.PSVVersion > 2) + AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0, + m_uPSVRuntimeInfoSize); // failure ok // In RWMode::CalcSize, use temp runtime info to hold needed values from // initInfo @@ -1078,6 +1083,9 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize, if (mode == RWMode::CalcSize) { *pSize = rw.GetSize(); m_pPSVRuntimeInfo1 = nullptr; // clear ptr to tempRuntimeInfo + } else if (mode == RWMode::Read) { + if (rw.GetSize() != rw.GetOffset()) + return false; } return true; } diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index a5e68eb5e1..f653745608 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -89,7 +89,7 @@ class PSVContentVerifier { PSVContentVerifier(DxilPipelineStateValidation &PSV, DxilModule &DM, ValidationContext &ValCtx) : DM(DM), PSV(PSV), ValCtx(ValCtx) {} - void Verify(); + void Verify(unsigned ValMajor, unsigned ValMinor, unsigned PSVVersion); private: void VerifySignatures(unsigned ValMajor, unsigned ValMinor); @@ -368,11 +368,8 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM, } } -void PSVContentVerifier::Verify() { - unsigned ValMajor, ValMinor; - DM.GetValidatorVersion(ValMajor, ValMinor); - unsigned PSVVersion = hlsl::GetPSVVersion(ValMajor, ValMinor); - +void PSVContentVerifier::Verify(unsigned ValMajor, unsigned ValMinor, + unsigned PSVVersion) { PSVInitInfo PSVInfo(PSVVersion); if (PSV.GetRuntimeInfoSize() != PSVInfo.RuntimeInfoSize()) { EmitMismatchError("PSVRuntimeInfoSize", @@ -509,14 +506,18 @@ bool VerifySignatureMatches(llvm::Module *pModule, DXIL::SignatureKind SigKind, static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, uint32_t PSVSize) { + unsigned ValMajor, ValMinor; + ValCtx.DxilMod.GetValidatorVersion(ValMajor, ValMinor); + unsigned PSVVersion = hlsl::GetPSVVersion(ValMajor, ValMinor); DxilPipelineStateValidation PSV; - if (!PSV.InitFromPSV0(pPSVData, PSVSize)) { + if (!PSV.InitFromPSV0(pPSVData, PSVSize, PSVVersion)) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); return; } + PSVContentVerifier Verifier(PSV, ValCtx.DxilMod, ValCtx); - Verifier.Verify(); + Verifier.Verify(ValMajor, ValMinor, PSVVersion); } static void VerifyFeatureInfoMatches(ValidationContext &ValCtx, diff --git a/tools/clang/unittests/HLSL/ValidationTest.cpp b/tools/clang/unittests/HLSL/ValidationTest.cpp index 339ff1fbb3..0e7f65ec2a 100644 --- a/tools/clang/unittests/HLSL/ValidationTest.cpp +++ b/tools/clang/unittests/HLSL/ValidationTest.cpp @@ -321,6 +321,8 @@ class ValidationTest : public ::testing::Test { TEST_METHOD(PSVContentValidationCS) TEST_METHOD(PSVContentValidationMS) TEST_METHOD(PSVContentValidationAS) + TEST_METHOD(WrongPSVSize) + TEST_METHOD(WrongPSVVersion) dxc::DxcDllSupport m_dllSupport; VersionSupportInfo m_ver; @@ -427,6 +429,18 @@ class ValidationTest : public ::testing::Test { pResultBlob); } + bool CompileFile(LPCWSTR fileName, LPCSTR pShaderModel, LPCWSTR *pArguments, + UINT32 argCount, IDxcBlob **pResultBlob) { + std::wstring fullPath = hlsl_test::GetPathToHlslDataFile(fileName); + CComPtr pLibrary; + CComPtr pSource; + VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary)); + VERIFY_SUCCEEDED( + pLibrary->CreateBlobFromFile(fullPath.c_str(), nullptr, &pSource)); + return CompileSource(pSource, pShaderModel, pArguments, argCount, nullptr, + 0, pResultBlob); + } + bool CompileSource(IDxcBlobEncoding *pSource, LPCSTR pShaderModel, IDxcBlob **pResultBlob) { return CompileSource(pSource, pShaderModel, nullptr, 0, nullptr, 0, @@ -6034,3 +6048,304 @@ TEST_F(ValidationTest, PSVContentValidationAS) { "Validation failed."}, /*maySucceedAnyway*/ false, /*bRegex*/ false); } + +struct SimpleContainer { + hlsl::DxilContainerHeader *Header; + std::vector PartOffsets; + std::vector PartHeaders; + SimpleContainer(void *Ptr) { + Header = (hlsl::DxilContainerHeader *)Ptr; + hlsl::DxilPartIterator pPartIter(nullptr, 0); + pPartIter = hlsl::begin(Header); + for (unsigned i = 0; i < Header->PartCount; ++i) { + VERIFY_IS_TRUE(pPartIter != hlsl::end(Header)); + PartOffsets.push_back( + (uint32_t)((const char *)(*pPartIter) - (const char *)Header)); + PartHeaders.push_back((const DxilPartHeader *)(*pPartIter)); + ++pPartIter; + } + VERIFY_ARE_EQUAL(pPartIter, hlsl::end(Header)); + } +}; + +TEST_F(ValidationTest, WrongPSVSize) { + if (!m_ver.m_InternalValidator) + if (m_ver.SkipDxilVersion(1, 8)) + return; + + CComPtr pProgram; + CompileFile(L"..\\DXC\\dumpPSV_AS.hlsl", "as_6_8", &pProgram); + + CComPtr pValidator; + CComPtr pResult; + unsigned Flags = 0; + VERIFY_SUCCEEDED( + m_dllSupport.CreateInstance(CLSID_DxcValidator, &pValidator)); + VERIFY_SUCCEEDED(pValidator->Validate(pProgram, Flags, &pResult)); + // Make sure the validation was successful. + HRESULT status; + VERIFY_IS_NOT_NULL(pResult); + VERIFY_SUCCEEDED(pResult->GetStatus(&status)); + VERIFY_SUCCEEDED(status); + + hlsl::DxilContainerHeader *pHeader; + hlsl::DxilPartIterator pPartIter(nullptr, 0); + pHeader = (hlsl::DxilContainerHeader *)pProgram->GetBufferPointer(); + // Make sure the PSV part exists. + pPartIter = + std::find_if(hlsl::begin(pHeader), hlsl::end(pHeader), + hlsl::DxilPartIsType(hlsl::DFCC_PipelineStateValidation)); + VERIFY_ARE_NOT_EQUAL(hlsl::end(pHeader), pPartIter); + + // Create a new Blob which is 16 bytes larger than the original one. + std::vector pProgram2Data(pProgram->GetBufferSize() + 16, 0); + // Copy data from the original blob part by part. + // Copy all parts to program2. + SimpleContainer Container(pProgram->GetBufferPointer()); + uint32_t PartOffsetsSize = pHeader->PartCount * sizeof(uint32_t); + uint32_t Offset = sizeof(hlsl::DxilContainerHeader) + PartOffsetsSize; + std::vector PartOffsets; + const uint32_t ExtraSize = 16; + // copy all parts to program2. + for (unsigned i = 0; i < pHeader->PartCount; ++i) { + PartOffsets.emplace_back(Offset); + const DxilPartHeader *pPartHeader = Container.PartHeaders[i]; + + // Copy part header. + memcpy(pProgram2Data.data() + Offset, pPartHeader, sizeof(DxilPartHeader)); + Offset += sizeof(DxilPartHeader); + + // Copy part content. + uint32_t *PartPtr = + const_cast((const uint32_t *)GetDxilPartData(pPartHeader)); + memcpy(pProgram2Data.data() + Offset, PartPtr, pPartHeader->PartSize); + + Offset += pPartHeader->PartSize; + + if (pPartHeader->PartFourCC == hlsl::DFCC_PipelineStateValidation) { + // Update the size of PSV part. + DxilPartHeader *pPSVPartHeader = + (DxilPartHeader *)(pProgram2Data.data() + PartOffsets.back()); + pPSVPartHeader->PartSize += ExtraSize; + Offset += ExtraSize; + } + } + // Copy header. + pHeader->ContainerSizeInBytes += ExtraSize; + memcpy(pProgram2Data.data(), pHeader, sizeof(hlsl::DxilContainerHeader)); + // Copy partOffsets. + memcpy(pProgram2Data.data() + sizeof(hlsl::DxilContainerHeader), + PartOffsets.data(), PartOffsetsSize); + + // Create a new Blob from pProgram2Data. + CComPtr pProgram2; + CComPtr pLibrary; + VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary)); + VERIFY_SUCCEEDED(pLibrary->CreateBlobWithEncodingFromPinned( + pProgram2Data.data(), pProgram2Data.size(), CP_UTF8, &pProgram2)); + + // Run validation on updated container. + CComPtr pUpdatedResult; + VERIFY_SUCCEEDED(pValidator->Validate(pProgram2, Flags, &pUpdatedResult)); + // Make sure the validation was fail. + VERIFY_IS_NOT_NULL(pUpdatedResult); + VERIFY_SUCCEEDED(pUpdatedResult->GetStatus(&status)); + VERIFY_FAILED(status); + + CheckOperationResultMsgs(pUpdatedResult, + {"Container part 'Pipeline State Validation' does " + "not match expected for module."}, + /*maySucceedAnyway*/ false, /*bRegex*/ false); +} + +TEST_F(ValidationTest, WrongPSVVersion) { + if (!m_ver.m_InternalValidator) + if (m_ver.SkipDxilVersion(1, 8)) + return; + + CComPtr pProgram60; + std::vector args; + args.emplace_back(L"-validator-version"); + args.emplace_back(L"1.0"); + CompileFile(L"..\\DXC\\dumpPSV_CS.hlsl", "cs_6_0", args.data(), args.size(), + &pProgram60); + + CComPtr pValidator; + CComPtr pResult; + unsigned Flags = DxcValidatorFlags_InPlaceEdit; + VERIFY_SUCCEEDED( + m_dllSupport.CreateInstance(CLSID_DxcValidator, &pValidator)); + VERIFY_SUCCEEDED(pValidator->Validate(pProgram60, Flags, &pResult)); + // Make sure the validation was successful. + HRESULT status; + VERIFY_IS_NOT_NULL(pResult); + VERIFY_SUCCEEDED(pResult->GetStatus(&status)); + VERIFY_SUCCEEDED(status); + + hlsl::DxilContainerHeader *pHeader60; + hlsl::DxilPartIterator pPartIter(nullptr, 0); + pHeader60 = (hlsl::DxilContainerHeader *)pProgram60->GetBufferPointer(); + // Make sure the PSV part exists. + pPartIter = + std::find_if(hlsl::begin(pHeader60), hlsl::end(pHeader60), + hlsl::DxilPartIsType(hlsl::DFCC_PipelineStateValidation)); + VERIFY_ARE_NOT_EQUAL(hlsl::end(pHeader60), pPartIter); + + CComPtr pProgram68; + + CompileFile(L"..\\DXC\\dumpPSV_CS.hlsl", "cs_6_8", &pProgram68); + CComPtr pResult2; + VERIFY_SUCCEEDED(pValidator->Validate(pProgram68, Flags, &pResult2)); + // Make sure the validation was successful. + VERIFY_IS_NOT_NULL(pResult); + VERIFY_SUCCEEDED(pResult->GetStatus(&status)); + VERIFY_SUCCEEDED(status); + + hlsl::DxilContainerHeader *pHeader68; + pHeader68 = (hlsl::DxilContainerHeader *)pProgram68->GetBufferPointer(); + // Make sure the PSV part exists. + pPartIter = + std::find_if(hlsl::begin(pHeader68), hlsl::end(pHeader68), + hlsl::DxilPartIsType(hlsl::DFCC_PipelineStateValidation)); + VERIFY_ARE_NOT_EQUAL(hlsl::end(pHeader68), pPartIter); + + // Switch the PSV part between 6.0 to 6.8. + SimpleContainer Container60(pProgram60->GetBufferPointer()); + SimpleContainer Container68(pProgram68->GetBufferPointer()); + uint32_t Container60WithPSV68Size = sizeof(hlsl::DxilContainerHeader) + + pHeader60->PartCount * sizeof(uint32_t); + uint32_t Container68WithPSV60Size = sizeof(hlsl::DxilContainerHeader) + + pHeader60->PartCount * sizeof(uint32_t); + unsigned Container68PartSkipped = 0; + for (unsigned i = 0; i < pHeader60->PartCount; ++i) { + const DxilPartHeader *pPartHeader60 = Container60.PartHeaders[i]; + const DxilPartHeader *pPartHeader68 = + Container68.PartHeaders[i + Container68PartSkipped]; + if (pPartHeader68->PartFourCC == hlsl::DFCC_ShaderHash) { + Container68PartSkipped++; + pPartHeader68 = Container68.PartHeaders[i + Container68PartSkipped]; + } + + VERIFY_ARE_EQUAL(pPartHeader60->PartFourCC, pPartHeader68->PartFourCC); + Container60WithPSV68Size += sizeof(DxilPartHeader); + Container68WithPSV60Size += sizeof(DxilPartHeader); + if (pPartHeader60->PartFourCC == hlsl::DFCC_PipelineStateValidation) { + Container60WithPSV68Size += pPartHeader68->PartSize; + Container68WithPSV60Size += pPartHeader60->PartSize; + } else { + Container60WithPSV68Size += pPartHeader60->PartSize; + Container68WithPSV60Size += pPartHeader68->PartSize; + } + } + + // Create mixed container. + std::vector pProgram60WithPSV68Data(Container60WithPSV68Size, 0); + std::vector pProgram68WithPSV60Data(Container68WithPSV60Size, 0); + + uint32_t PartOffsetsSize = pHeader60->PartCount * sizeof(uint32_t); + uint32_t Offset60 = sizeof(hlsl::DxilContainerHeader) + PartOffsetsSize; + std::vector PartOffsets60; + uint32_t Offset68 = sizeof(hlsl::DxilContainerHeader) + PartOffsetsSize; + std::vector PartOffsets68; + Container68PartSkipped = 0; + for (unsigned i = 0; i < pHeader60->PartCount; ++i) { + PartOffsets60.emplace_back(Offset60); + PartOffsets68.emplace_back(Offset68); + + const DxilPartHeader *pPartHeader60 = Container60.PartHeaders[i]; + const DxilPartHeader *pPartHeader68 = + Container68.PartHeaders[i + Container68PartSkipped]; + if (pPartHeader68->PartFourCC == hlsl::DFCC_ShaderHash) { + Container68PartSkipped++; + pPartHeader68 = Container68.PartHeaders[i + Container68PartSkipped]; + } + + if (pPartHeader60->PartFourCC == hlsl::DFCC_PipelineStateValidation) { + // Copy PSV part from 6.8 to 6.0. + memcpy(pProgram60WithPSV68Data.data() + Offset60, pPartHeader68, + sizeof(DxilPartHeader)); + Offset60 += sizeof(DxilPartHeader); + memcpy(pProgram60WithPSV68Data.data() + Offset60, + GetDxilPartData(pPartHeader68), pPartHeader68->PartSize); + Offset60 += pPartHeader68->PartSize; + // Copy PSV part from 6.0 to 6.8. + memcpy(pProgram68WithPSV60Data.data() + Offset68, pPartHeader60, + sizeof(DxilPartHeader)); + Offset68 += sizeof(DxilPartHeader); + memcpy(pProgram68WithPSV60Data.data() + Offset68, + GetDxilPartData(pPartHeader60), pPartHeader60->PartSize); + + Offset68 += pPartHeader60->PartSize; + } else { + // Copy PSV part from 6.0 to 6.0. + memcpy(pProgram60WithPSV68Data.data() + Offset60, pPartHeader60, + sizeof(DxilPartHeader)); + Offset60 += sizeof(DxilPartHeader); + memcpy(pProgram60WithPSV68Data.data() + Offset60, + GetDxilPartData(pPartHeader60), pPartHeader60->PartSize); + Offset60 += pPartHeader60->PartSize; + // Copy PSV part from 6.8 to 6.8. + memcpy(pProgram68WithPSV60Data.data() + Offset68, pPartHeader68, + sizeof(DxilPartHeader)); + Offset68 += sizeof(DxilPartHeader); + memcpy(pProgram68WithPSV60Data.data() + Offset68, + GetDxilPartData(pPartHeader68), pPartHeader68->PartSize); + Offset68 += pPartHeader68->PartSize; + } + } + + // Copy header. + VERIFY_ARE_EQUAL(Container60WithPSV68Size, Offset60); + pHeader60->ContainerSizeInBytes = Container60WithPSV68Size; + memcpy(pProgram60WithPSV68Data.data(), pHeader60, + sizeof(hlsl::DxilContainerHeader)); + VERIFY_ARE_EQUAL(Container68WithPSV60Size, Offset68); + pHeader68->ContainerSizeInBytes = Container68WithPSV60Size; + pHeader68->PartCount -= Container68PartSkipped; + memcpy(pProgram68WithPSV60Data.data(), pHeader68, + sizeof(hlsl::DxilContainerHeader)); + // Copy partOffsets. + memcpy(pProgram60WithPSV68Data.data() + sizeof(hlsl::DxilContainerHeader), + PartOffsets60.data(), PartOffsetsSize); + memcpy(pProgram68WithPSV60Data.data() + sizeof(hlsl::DxilContainerHeader), + PartOffsets68.data(), PartOffsetsSize); + + // Create a new Blob. + CComPtr pProgram60WithPSV68; + CComPtr pLibrary; + VERIFY_SUCCEEDED(m_dllSupport.CreateInstance(CLSID_DxcLibrary, &pLibrary)); + VERIFY_SUCCEEDED(pLibrary->CreateBlobWithEncodingFromPinned( + pProgram60WithPSV68Data.data(), pProgram60WithPSV68Data.size(), CP_UTF8, + &pProgram60WithPSV68)); + + // Run validation on new containers. + CComPtr p60WithPSV68Result; + VERIFY_SUCCEEDED( + pValidator->Validate(pProgram60WithPSV68, Flags, &p60WithPSV68Result)); + // Make sure the validation was fail. + VERIFY_IS_NOT_NULL(p60WithPSV68Result); + VERIFY_SUCCEEDED(p60WithPSV68Result->GetStatus(&status)); + VERIFY_FAILED(status); + CheckOperationResultMsgs(p60WithPSV68Result, + {"Container part 'Pipeline State Validation' does " + "not match expected for module."}, + /*maySucceedAnyway*/ false, /*bRegex*/ false); + + // Create a new Blob. + CComPtr pProgram68WithPSV60; + VERIFY_SUCCEEDED(pLibrary->CreateBlobWithEncodingFromPinned( + pProgram68WithPSV60Data.data(), pProgram68WithPSV60Data.size(), CP_UTF8, + &pProgram68WithPSV60)); + CComPtr p68WithPSV60Result; + VERIFY_SUCCEEDED( + pValidator->Validate(pProgram68WithPSV60, Flags, &p68WithPSV60Result)); + // Make sure the validation was fail. + VERIFY_IS_NOT_NULL(p68WithPSV60Result); + VERIFY_SUCCEEDED(p68WithPSV60Result->GetStatus(&status)); + VERIFY_FAILED(status); + CheckOperationResultMsgs( + p68WithPSV60Result, + {"DXIL container mismatch for 'PSVRuntimeInfoSize' between 'PSV0' " + "part:('24') and DXIL module:('52')"}, + /*maySucceedAnyway*/ false, /*bRegex*/ false); +} From 6ec774c1a13a45730ccbf2cefb4c645358f513ae Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Thu, 19 Sep 2024 13:27:48 -0400 Subject: [PATCH 02/13] Return the size read for PSV for validation. --- .../dxc/DxilContainer/DxilPipelineStateValidation.h | 12 +++++++----- lib/DxilValidation/DxilContainerValidation.cpp | 8 +++++++- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/include/dxc/DxilContainer/DxilPipelineStateValidation.h b/include/dxc/DxilContainer/DxilPipelineStateValidation.h index 031cf6dc2a..fb341449d4 100644 --- a/include/dxc/DxilContainer/DxilPipelineStateValidation.h +++ b/include/dxc/DxilContainer/DxilPipelineStateValidation.h @@ -596,9 +596,12 @@ class DxilPipelineStateValidation { const PSVInitInfo &initInfo = PSVInitInfo(MAX_PSV_VERSION)); public: - bool InitFromPSV0(const void *pBits, uint32_t size, - uint32_t PSVVersion = MAX_PSV_VERSION) { - return ReadOrWrite(pBits, &size, RWMode::Read, PSVInitInfo(PSVVersion)); + bool InitFromPSV0(const void *pBits, uint32_t size) { + return ReadOrWrite(pBits, &size, RWMode::Read); + } + + bool InitFromPSV0(const void *pBits, uint32_t *pSize, uint32_t PSVVersion) { + return ReadOrWrite(pBits, pSize, RWMode::Read, PSVInitInfo(PSVVersion)); } // Initialize a new buffer @@ -1084,8 +1087,7 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize, *pSize = rw.GetSize(); m_pPSVRuntimeInfo1 = nullptr; // clear ptr to tempRuntimeInfo } else if (mode == RWMode::Read) { - if (rw.GetSize() != rw.GetOffset()) - return false; + *pSize = rw.GetOffset(); } return true; } diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index f653745608..f0489445d0 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -510,7 +510,13 @@ static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, ValCtx.DxilMod.GetValidatorVersion(ValMajor, ValMinor); unsigned PSVVersion = hlsl::GetPSVVersion(ValMajor, ValMinor); DxilPipelineStateValidation PSV; - if (!PSV.InitFromPSV0(pPSVData, PSVSize, PSVVersion)) { + uint32_t PSVSizeRead = PSVSize; + if (!PSV.InitFromPSV0(pPSVData, &PSVSizeRead, PSVVersion)) { + ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, + {"Pipeline State Validation"}); + return; + } + if (PSVSizeRead != PSVSize) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); return; From b7c81b44a605cd1d231fab4254a8b312a81e9fc2 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Thu, 19 Sep 2024 18:27:43 -0400 Subject: [PATCH 03/13] Check PSVRuntimeInfoSize early. Then use PSVInitInfo to calculate correct size of PSV. --- .../DxilPipelineStateValidation.h | 20 +++++----- lib/DxilContainer/DxilContainerAssembler.cpp | 34 +---------------- .../DxilPipelineStateValidation.cpp | 37 +++++++++++++++++++ .../DxilContainerValidation.cpp | 35 ++++++++++++++++-- tools/clang/unittests/HLSL/ValidationTest.cpp | 9 +++-- 5 files changed, 85 insertions(+), 50 deletions(-) diff --git a/include/dxc/DxilContainer/DxilPipelineStateValidation.h b/include/dxc/DxilContainer/DxilPipelineStateValidation.h index fb341449d4..5aee7bbfda 100644 --- a/include/dxc/DxilContainer/DxilPipelineStateValidation.h +++ b/include/dxc/DxilContainer/DxilPipelineStateValidation.h @@ -223,7 +223,8 @@ struct PSVStringTable { PSVStringTable() : Table(nullptr), Size(0) {} PSVStringTable(const char *table, uint32_t size) : Table(table), Size(size) {} const char *Get(uint32_t offset) const { - assert(offset < Size && Table && Table[Size - 1] == '\0'); + if (!(offset < Size && Table && Table[Size - 1] == '\0')) + return nullptr; return Table + offset; } }; @@ -341,7 +342,8 @@ struct PSVSemanticIndexTable { PSVSemanticIndexTable(const uint32_t *table, uint32_t entries) : Table(table), Entries(entries) {} const uint32_t *Get(uint32_t offset) const { - assert(offset < Entries && Table); + if (!(offset < Entries && Table)) + return nullptr; return Table + offset; } }; @@ -554,7 +556,6 @@ class DxilPipelineStateValidation { uint32_t GetSize() { return Size; } RWMode GetMode() { return Mode; } - uint32_t GetOffset() { return Offset; } // Return true if size fits in remaing buffer. bool CheckBounds(size_t size); @@ -600,10 +601,6 @@ class DxilPipelineStateValidation { return ReadOrWrite(pBits, &size, RWMode::Read); } - bool InitFromPSV0(const void *pBits, uint32_t *pSize, uint32_t PSVVersion) { - return ReadOrWrite(pBits, pSize, RWMode::Read, PSVInitInfo(PSVVersion)); - } - // Initialize a new buffer // call with null pBuffer to get required size @@ -640,7 +637,8 @@ class DxilPipelineStateValidation { _T *GetRecord(void *pRecords, uint32_t recordSize, uint32_t numRecords, uint32_t index) const { if (pRecords && index < numRecords && sizeof(_T) <= recordSize) { - assert((size_t)index * (size_t)recordSize <= UINT_MAX); + if (!((size_t)index * (size_t)recordSize <= UINT_MAX)) + return nullptr; return reinterpret_cast<_T *>(reinterpret_cast(pRecords) + (index * recordSize)); } @@ -1086,8 +1084,6 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize, if (mode == RWMode::CalcSize) { *pSize = rw.GetSize(); m_pPSVRuntimeInfo1 = nullptr; // clear ptr to tempRuntimeInfo - } else if (mode == RWMode::Read) { - *pSize = rw.GetOffset(); } return true; } @@ -1133,6 +1129,10 @@ void InitPSVSignatureElement(PSVSignatureElement0 &E, const DxilSignatureElement &SE, bool i1ToUnknownCompat); +// Setup PSVInitInfo with DxilModule. +// Note that the StringTable and PSVSemanticIndexTable are not done. +void SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM); + // Setup shader properties for PSVRuntimeInfo* with DxilModule. void SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM); void SetShaderProps(PSVRuntimeInfo1 *pInfo1, const DxilModule &DM); diff --git a/lib/DxilContainer/DxilContainerAssembler.cpp b/lib/DxilContainer/DxilContainerAssembler.cpp index d7d006bd4f..0b7f5dd467 100644 --- a/lib/DxilContainer/DxilContainerAssembler.cpp +++ b/lib/DxilContainer/DxilContainerAssembler.cpp @@ -738,30 +738,14 @@ class DxilPSVWriter : public DxilPartWriter { DxilPSVWriter(const DxilModule &mod, uint32_t PSVVersion = UINT_MAX) : m_Module(mod), m_PSVInitInfo(PSVVersion) { m_Module.GetValidatorVersion(m_ValMajor, m_ValMinor); - // Constraint PSVVersion based on validator version - uint32_t PSVVersionConstraint = hlsl::GetPSVVersion(m_ValMajor, m_ValMinor); - if (PSVVersion > PSVVersionConstraint) - m_PSVInitInfo.PSVVersion = PSVVersionConstraint; - - const ShaderModel *SM = m_Module.GetShaderModel(); - UINT uCBuffers = m_Module.GetCBuffers().size(); - UINT uSamplers = m_Module.GetSamplers().size(); - UINT uSRVs = m_Module.GetSRVs().size(); - UINT uUAVs = m_Module.GetUAVs().size(); - m_PSVInitInfo.ResourceCount = uCBuffers + uSamplers + uSRVs + uUAVs; + hlsl::SetupPSVInitInfo(m_PSVInitInfo, m_Module); + // TODO: for >= 6.2 version, create more efficient structure if (m_PSVInitInfo.PSVVersion > 0) { - m_PSVInitInfo.ShaderStage = (PSVShaderKind)SM->GetKind(); // Copy Dxil Signatures m_StringBuffer.push_back('\0'); // For empty semantic name (system value) - m_PSVInitInfo.SigInputElements = - m_Module.GetInputSignature().GetElements().size(); m_SigInputElements.resize(m_PSVInitInfo.SigInputElements); - m_PSVInitInfo.SigOutputElements = - m_Module.GetOutputSignature().GetElements().size(); m_SigOutputElements.resize(m_PSVInitInfo.SigOutputElements); - m_PSVInitInfo.SigPatchConstOrPrimElements = - m_Module.GetPatchConstOrPrimSignature().GetElements().size(); m_SigPatchConstOrPrimElements.resize( m_PSVInitInfo.SigPatchConstOrPrimElements); uint32_t i = 0; @@ -791,20 +775,6 @@ class DxilPSVWriter : public DxilPartWriter { m_PSVInitInfo.StringTable.Size = m_StringBuffer.size(); m_PSVInitInfo.SemanticIndexTable.Table = m_SemanticIndexBuffer.data(); m_PSVInitInfo.SemanticIndexTable.Entries = m_SemanticIndexBuffer.size(); - // Set up ViewID and signature dependency info - m_PSVInitInfo.UsesViewID = - m_Module.m_ShaderFlags.GetViewID() ? true : false; - m_PSVInitInfo.SigInputVectors = - m_Module.GetInputSignature().NumVectorsUsed(0); - for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) { - m_PSVInitInfo.SigOutputVectors[streamIndex] = - m_Module.GetOutputSignature().NumVectorsUsed(streamIndex); - } - m_PSVInitInfo.SigPatchConstOrPrimVectors = 0; - if (SM->IsHS() || SM->IsDS() || SM->IsMS()) { - m_PSVInitInfo.SigPatchConstOrPrimVectors = - m_Module.GetPatchConstOrPrimSignature().NumVectorsUsed(0); - } } if (!m_PSV.InitNew(m_PSVInitInfo, nullptr, &m_PSVBufferSize)) { DXASSERT(false, "PSV InitNew failed computing size!"); diff --git a/lib/DxilContainer/DxilPipelineStateValidation.cpp b/lib/DxilContainer/DxilPipelineStateValidation.cpp index 9f04c09943..66186549f2 100644 --- a/lib/DxilContainer/DxilPipelineStateValidation.cpp +++ b/lib/DxilContainer/DxilPipelineStateValidation.cpp @@ -110,6 +110,43 @@ void hlsl::InitPSVSignatureElement(PSVSignatureElement0 &E, E.DynamicMaskAndStream |= (SE.GetDynIdxCompMask()) & 0xF; } +void hlsl::SetupPSVInitInfo(PSVInitInfo &InitInfo, const DxilModule &DM) { + // Constraint PSVVersion based on validator version + unsigned ValMajor, ValMinor; + DM.GetValidatorVersion(ValMajor, ValMinor); + unsigned PSVVersionConstraint = hlsl::GetPSVVersion(ValMajor, ValMinor); + if (InitInfo.PSVVersion > PSVVersionConstraint) + InitInfo.PSVVersion = PSVVersionConstraint; + + const ShaderModel *SM = DM.GetShaderModel(); + uint32_t uCBuffers = DM.GetCBuffers().size(); + uint32_t uSamplers = DM.GetSamplers().size(); + uint32_t uSRVs = DM.GetSRVs().size(); + uint32_t uUAVs = DM.GetUAVs().size(); + InitInfo.ResourceCount = uCBuffers + uSamplers + uSRVs + uUAVs; + + if (InitInfo.PSVVersion > 0) { + InitInfo.ShaderStage = (PSVShaderKind)SM->GetKind(); + InitInfo.SigInputElements = DM.GetInputSignature().GetElements().size(); + InitInfo.SigPatchConstOrPrimElements = + DM.GetPatchConstOrPrimSignature().GetElements().size(); + InitInfo.SigOutputElements = DM.GetOutputSignature().GetElements().size(); + + // Set up ViewID and signature dependency info + InitInfo.UsesViewID = DM.m_ShaderFlags.GetViewID() ? true : false; + InitInfo.SigInputVectors = DM.GetInputSignature().NumVectorsUsed(0); + for (unsigned streamIndex = 0; streamIndex < 4; streamIndex++) { + InitInfo.SigOutputVectors[streamIndex] = + DM.GetOutputSignature().NumVectorsUsed(streamIndex); + } + InitInfo.SigPatchConstOrPrimVectors = 0; + if (SM->IsHS() || SM->IsDS() || SM->IsMS()) { + InitInfo.SigPatchConstOrPrimVectors = + DM.GetPatchConstOrPrimSignature().NumVectorsUsed(0); + } + } +} + void hlsl::SetShaderProps(PSVRuntimeInfo0 *pInfo, const DxilModule &DM) { const ShaderModel *SM = DM.GetShaderModel(); pInfo->MinimumExpectedWaveLaneCount = 0; diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index f0489445d0..43a7a7d3fd 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -84,7 +84,6 @@ class PSVContentVerifier { DxilPipelineStateValidation &PSV; ValidationContext &ValCtx; bool PSVContentValid = true; - public: PSVContentVerifier(DxilPipelineStateValidation &PSV, DxilModule &DM, ValidationContext &ValCtx) @@ -506,17 +505,45 @@ bool VerifySignatureMatches(llvm::Module *pModule, DXIL::SignatureKind SigKind, static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, uint32_t PSVSize) { + if (PSVSize < sizeof(uint32_t)) { + ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, + {"Pipeline State Validation"}); + return; + } unsigned ValMajor, ValMinor; ValCtx.DxilMod.GetValidatorVersion(ValMajor, ValMinor); unsigned PSVVersion = hlsl::GetPSVVersion(ValMajor, ValMinor); + + PSVInitInfo PSVInfo(PSVVersion); + hlsl::SetupPSVInitInfo(PSVInfo, ValCtx.DxilMod); + uint32_t PSVRuntimeInfoSize = *(uint32_t *)pPSVData; + + if (PSVRuntimeInfoSize != PSVInfo.RuntimeInfoSize()) { + ValCtx.EmitFormatError(ValidationRule::ContainerContentMatches, + {"PSVRuntimeInfoSize", "PSV0", + std::to_string(PSVRuntimeInfoSize), + std::to_string(PSVInfo.RuntimeInfoSize())}); + return; + } + DxilPipelineStateValidation PSV; - uint32_t PSVSizeRead = PSVSize; - if (!PSV.InitFromPSV0(pPSVData, &PSVSizeRead, PSVVersion)) { + if (!PSV.InitFromPSV0(pPSVData, PSVSize)) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); return; } - if (PSVSizeRead != PSVSize) { + + PSVInfo.StringTable = PSV.GetStringTable(); + PSVInfo.SemanticIndexTable = PSV.GetSemanticIndexTable(); + uint32_t ExpectedSize = 0; + DxilPipelineStateValidation SizePSV; + if (!SizePSV.InitNew(PSVInfo, nullptr, &ExpectedSize)) { + ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, + {"Pipeline State Validation"}); + return; + } + + if (ExpectedSize != PSVSize) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); return; diff --git a/tools/clang/unittests/HLSL/ValidationTest.cpp b/tools/clang/unittests/HLSL/ValidationTest.cpp index 0e7f65ec2a..0a8f9cb2ac 100644 --- a/tools/clang/unittests/HLSL/ValidationTest.cpp +++ b/tools/clang/unittests/HLSL/ValidationTest.cpp @@ -6326,10 +6326,11 @@ TEST_F(ValidationTest, WrongPSVVersion) { VERIFY_IS_NOT_NULL(p60WithPSV68Result); VERIFY_SUCCEEDED(p60WithPSV68Result->GetStatus(&status)); VERIFY_FAILED(status); - CheckOperationResultMsgs(p60WithPSV68Result, - {"Container part 'Pipeline State Validation' does " - "not match expected for module."}, - /*maySucceedAnyway*/ false, /*bRegex*/ false); + CheckOperationResultMsgs( + p60WithPSV68Result, + {"DXIL container mismatch for 'PSVRuntimeInfoSize' between 'PSV0' " + "part:('52') and DXIL module:('24')"}, + /*maySucceedAnyway*/ false, /*bRegex*/ false); // Create a new Blob. CComPtr pProgram68WithPSV60; From 69a78efd72f8a63268c55aa4da56458c1cda7625 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Thu, 19 Sep 2024 21:13:41 -0400 Subject: [PATCH 04/13] Make sure data in PSV part follows format before load. --- .../DxilPipelineStateValidation.h | 15 +- .../DxilContainerValidation.cpp | 209 ++++++++++++++++-- 2 files changed, 201 insertions(+), 23 deletions(-) diff --git a/include/dxc/DxilContainer/DxilPipelineStateValidation.h b/include/dxc/DxilContainer/DxilPipelineStateValidation.h index 5aee7bbfda..9fdaf715e5 100644 --- a/include/dxc/DxilContainer/DxilPipelineStateValidation.h +++ b/include/dxc/DxilContainer/DxilPipelineStateValidation.h @@ -940,15 +940,12 @@ DxilPipelineStateValidation::ReadOrWrite(const void *pBits, uint32_t *pSize, PSV_RETB(rw.MapValue(&m_uPSVRuntimeInfoSize, initInfo.RuntimeInfoSize())); PSV_RETB(rw.MapArray(&m_pPSVRuntimeInfo0, 1, m_uPSVRuntimeInfoSize)); - if (initInfo.PSVVersion > 0) - AssignDerived(&m_pPSVRuntimeInfo1, m_pPSVRuntimeInfo0, - m_uPSVRuntimeInfoSize); // failure ok - if (initInfo.PSVVersion > 1) - AssignDerived(&m_pPSVRuntimeInfo2, m_pPSVRuntimeInfo0, - m_uPSVRuntimeInfoSize); // failure ok - if (initInfo.PSVVersion > 2) - AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0, - m_uPSVRuntimeInfoSize); // failure ok + AssignDerived(&m_pPSVRuntimeInfo1, m_pPSVRuntimeInfo0, + m_uPSVRuntimeInfoSize); // failure ok + AssignDerived(&m_pPSVRuntimeInfo2, m_pPSVRuntimeInfo0, + m_uPSVRuntimeInfoSize); // failure ok + AssignDerived(&m_pPSVRuntimeInfo3, m_pPSVRuntimeInfo0, + m_uPSVRuntimeInfoSize); // failure ok // In RWMode::CalcSize, use temp runtime info to hold needed values from // initInfo diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 43a7a7d3fd..074fbc1ebf 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -84,6 +84,7 @@ class PSVContentVerifier { DxilPipelineStateValidation &PSV; ValidationContext &ValCtx; bool PSVContentValid = true; + public: PSVContentVerifier(DxilPipelineStateValidation &PSV, DxilModule &DM, ValidationContext &ValCtx) @@ -370,12 +371,6 @@ void PSVContentVerifier::VerifyEntryProperties(const ShaderModel *SM, void PSVContentVerifier::Verify(unsigned ValMajor, unsigned ValMinor, unsigned PSVVersion) { PSVInitInfo PSVInfo(PSVVersion); - if (PSV.GetRuntimeInfoSize() != PSVInfo.RuntimeInfoSize()) { - EmitMismatchError("PSVRuntimeInfoSize", - std::to_string(PSV.GetRuntimeInfoSize()), - std::to_string(PSVInfo.RuntimeInfoSize())); - return; - } if (PSV.GetBindCount() > 0 && PSV.GetResourceBindInfoSize() != PSVInfo.ResourceBindInfoSize()) { @@ -503,28 +498,214 @@ bool VerifySignatureMatches(llvm::Module *pModule, DXIL::SignatureKind SigKind, return !ValCtx.Failed; } +struct SimplePSV { + uint32_t PSVRuntimeInfoSize = 0; + uint32_t PSVNumResources = 0; + uint32_t PSVResourceBindInfoSize = 0; + uint32_t StringTableSize = 0; + char *StringTable = nullptr; + uint32_t SemanticIndexTableEntries = 0; + uint32_t *SemanticIndexTable = nullptr; + uint32_t PSVSignatureElementSize = 0; + PSVRuntimeInfo1 *RuntimeInfo1 = nullptr; + bool IsValid = true; + SimplePSV(const void *pPSVData, uint32_t PSVSize) { + uint32_t Offset = 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + PSVRuntimeInfoSize = *(uint32_t *)pPSVData; + if (PSVRuntimeInfoSize >= sizeof(PSVRuntimeInfo1)) + RuntimeInfo1 = (PSVRuntimeInfo1 *)((char *)pPSVData + Offset); + Offset += PSVRuntimeInfoSize; + if (PSVSize < Offset) { + IsValid = false; + return; + } + + PSVNumResources = *(uint32_t *)((char *)pPSVData + Offset); + Offset += 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + if (PSVNumResources > 0) { + PSVResourceBindInfoSize = *(uint32_t *)((char *)pPSVData + Offset); + Offset += 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + Offset += PSVNumResources * PSVResourceBindInfoSize; + if (PSVSize < Offset) { + IsValid = false; + return; + } + } + if (RuntimeInfo1) { + StringTableSize = *(uint32_t *)((char *)pPSVData + Offset); + Offset += 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + // Make sure StringTableSize is aligned to 4 bytes. + if ((StringTableSize & 3) != 0) { + IsValid = false; + return; + } + if (StringTableSize) { + StringTable = (char *)pPSVData + Offset; + Offset += StringTableSize; + } + SemanticIndexTableEntries = *(uint32_t *)((char *)pPSVData + Offset); + Offset += 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + if (SemanticIndexTableEntries) { + SemanticIndexTable = (uint32_t *)((char *)pPSVData + Offset); + Offset += SemanticIndexTableEntries * 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + } + if (RuntimeInfo1->SigInputElements || RuntimeInfo1->SigOutputElements || + RuntimeInfo1->SigPatchConstOrPrimElements) { + PSVSignatureElementSize = *(uint32_t *)((char *)pPSVData + Offset); + Offset += 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + uint32_t PSVNumSignatures = RuntimeInfo1->SigInputElements + + RuntimeInfo1->SigOutputElements + + RuntimeInfo1->SigPatchConstOrPrimElements; + Offset += PSVNumSignatures * PSVSignatureElementSize; + if (PSVSize < Offset) { + IsValid = false; + return; + } + } + if (RuntimeInfo1->UsesViewID) { + for (unsigned i = 0; i < DXIL::kNumOutputStreams; i++) { + uint32_t SigOutputVectors = RuntimeInfo1->SigOutputVectors[i]; + if (SigOutputVectors == 0) + continue; + Offset += sizeof(uint32_t) * + llvm::RoundUpToAlignment(SigOutputVectors, 8) / 8; + if (PSVSize < Offset) { + IsValid = false; + return; + } + } + if ((RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Hull || + RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Mesh) && + RuntimeInfo1->SigPatchConstOrPrimVectors) { + Offset += sizeof(uint32_t) * + llvm::RoundUpToAlignment( + RuntimeInfo1->SigPatchConstOrPrimVectors, 8) / + 8; + if (PSVSize < Offset) { + IsValid = false; + return; + } + } + } + + for (unsigned i = 0; i < DXIL::kNumOutputStreams; i++) { + uint32_t SigOutputVectors = RuntimeInfo1->SigOutputVectors[i]; + if (SigOutputVectors == 0) + continue; + Offset += sizeof(uint32_t) * + llvm::RoundUpToAlignment(SigOutputVectors, 8) / 8 * + RuntimeInfo1->SigInputVectors * 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + } + + if ((RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Hull || + RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Mesh) && + RuntimeInfo1->SigPatchConstOrPrimVectors && + RuntimeInfo1->SigInputVectors) { + Offset += sizeof(uint32_t) * + llvm::RoundUpToAlignment( + RuntimeInfo1->SigPatchConstOrPrimVectors, 8) / + 8 * RuntimeInfo1->SigInputVectors * 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + } + + if (RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Domain && + RuntimeInfo1->SigOutputVectors[0] && + RuntimeInfo1->SigPatchConstOrPrimVectors) { + Offset += + sizeof(uint32_t) * + llvm::RoundUpToAlignment(RuntimeInfo1->SigOutputVectors[0], 8) / 8 * + RuntimeInfo1->SigPatchConstOrPrimVectors * 4; + if (PSVSize < Offset) { + IsValid = false; + return; + } + } + } + IsValid = PSVSize == Offset; + } + bool ValidatePSVInit(PSVInitInfo PSVInfo, ValidationContext &ValCtx) { + if (PSVRuntimeInfoSize != PSVInfo.RuntimeInfoSize()) { + ValCtx.EmitFormatError(ValidationRule::ContainerContentMatches, + {"PSVRuntimeInfoSize", "PSV0", + std::to_string(PSVRuntimeInfoSize), + std::to_string(PSVInfo.RuntimeInfoSize())}); + return false; + } + if (PSVNumResources && + PSVResourceBindInfoSize != PSVInfo.ResourceBindInfoSize()) { + ValCtx.EmitFormatError(ValidationRule::ContainerContentMatches, + {"PSVResourceBindInfoSize", "PSV0", + std::to_string(PSVResourceBindInfoSize), + std::to_string(PSVInfo.ResourceBindInfoSize())}); + return false; + } + if (RuntimeInfo1 && + (RuntimeInfo1->SigInputElements || RuntimeInfo1->SigOutputElements || + RuntimeInfo1->SigPatchConstOrPrimElements) && + PSVSignatureElementSize != PSVInfo.SignatureElementSize()) { + ValCtx.EmitFormatError(ValidationRule::ContainerContentMatches, + {"PSVSignatureElementSize", "PSV0", + std::to_string(PSVSignatureElementSize), + std::to_string(PSVInfo.SignatureElementSize())}); + return false; + } + return true; + } +}; + static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, uint32_t PSVSize) { - if (PSVSize < sizeof(uint32_t)) { + SimplePSV SimplePSV(pPSVData, PSVSize); + if (!SimplePSV.IsValid) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); return; } + unsigned ValMajor, ValMinor; ValCtx.DxilMod.GetValidatorVersion(ValMajor, ValMinor); unsigned PSVVersion = hlsl::GetPSVVersion(ValMajor, ValMinor); PSVInitInfo PSVInfo(PSVVersion); hlsl::SetupPSVInitInfo(PSVInfo, ValCtx.DxilMod); - uint32_t PSVRuntimeInfoSize = *(uint32_t *)pPSVData; - if (PSVRuntimeInfoSize != PSVInfo.RuntimeInfoSize()) { - ValCtx.EmitFormatError(ValidationRule::ContainerContentMatches, - {"PSVRuntimeInfoSize", "PSV0", - std::to_string(PSVRuntimeInfoSize), - std::to_string(PSVInfo.RuntimeInfoSize())}); + if (!SimplePSV.ValidatePSVInit(PSVInfo, ValCtx)) return; - } DxilPipelineStateValidation PSV; if (!PSV.InitFromPSV0(pPSVData, PSVSize)) { From fe34412e940de6ea02d0860baa7f760882ca8583 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Thu, 19 Sep 2024 21:23:34 -0400 Subject: [PATCH 05/13] Load PSV after check size. --- lib/DxilValidation/DxilContainerValidation.cpp | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 074fbc1ebf..8c3f2c1fe1 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -707,24 +707,26 @@ static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, if (!SimplePSV.ValidatePSVInit(PSVInfo, ValCtx)) return; - DxilPipelineStateValidation PSV; - if (!PSV.InitFromPSV0(pPSVData, PSVSize)) { + PSVInfo.StringTable = + PSVStringTable(SimplePSV.StringTable, SimplePSV.StringTableSize); + PSVInfo.SemanticIndexTable = PSVSemanticIndexTable( + SimplePSV.SemanticIndexTable, SimplePSV.SemanticIndexTableEntries); + uint32_t ExpectedSize = 0; + DxilPipelineStateValidation SizePSV; + if (!SizePSV.InitNew(PSVInfo, nullptr, &ExpectedSize)) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); return; } - PSVInfo.StringTable = PSV.GetStringTable(); - PSVInfo.SemanticIndexTable = PSV.GetSemanticIndexTable(); - uint32_t ExpectedSize = 0; - DxilPipelineStateValidation SizePSV; - if (!SizePSV.InitNew(PSVInfo, nullptr, &ExpectedSize)) { + if (ExpectedSize != PSVSize) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); return; } - if (ExpectedSize != PSVSize) { + DxilPipelineStateValidation PSV; + if (!PSV.InitFromPSV0(pPSVData, PSVSize)) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); return; From cafe1d927200a3aaed8f69b91d4920342b99b359 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Fri, 20 Sep 2024 00:45:27 -0400 Subject: [PATCH 06/13] Add check for StringTable and SemanticIndexTable. --- .../DxilContainerValidation.cpp | 116 +++++++++++++++++- tools/clang/unittests/HLSL/ValidationTest.cpp | 55 ++++++++- utils/hct/hctdb.py | 5 + 3 files changed, 169 insertions(+), 7 deletions(-) diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 8c3f2c1fe1..130c6a4d97 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -33,6 +33,7 @@ #include "DxilValidationUtils.h" #include +#include #include using namespace llvm; @@ -79,16 +80,94 @@ static void emitDxilDiag(LLVMContext &Ctx, const char *str) { hlsl::dxilutil::EmitErrorOnContext(Ctx, str); } +class StringTableVerifier { + MapVector OffsetToUseCountMap; + const PSVStringTable &Table; + +public: + StringTableVerifier(const PSVStringTable &Table) : Table(Table) { + unsigned Start = 0; + for (unsigned i = 0; i < Table.Size; ++i) { + char ch = Table.Table[i]; + if (ch == '\0') { + OffsetToUseCountMap[Start] = 0; + Start = i + 1; + } + } + if (Table.Size >= 4) { + // Remove the '\0's at the end of the table added for padding. + for (unsigned i = Table.Size - 1; i > Table.Size - 4; --i) { + if (Table.Table[i] != '\0') + break; + OffsetToUseCountMap.erase(i); + } + } + } + bool MarkUse(unsigned Offset) { + auto it = OffsetToUseCountMap.find(Offset); + if (it != OffsetToUseCountMap.end()) + it->second++; + return Offset < Table.Size; + } + void Verify(ValidationContext &ValCtx) { + for (auto [Offset, UseCount] : OffsetToUseCountMap) { + if (UseCount != 0) + continue; + // DXC will always add a null-terminated string at the beginning of the + // StringTable. It is OK if it is not used. + if (Offset == 0 && Table.Table[0] == '\0') + continue; + + ValCtx.EmitFormatError(ValidationRule::ContainerUnusedItemInTable, + {"StringTable", Table.Get(Offset)}); + } + } +}; + +class SemanticIndexTableVerifier { + std::vector UseMask; + const PSVSemanticIndexTable &Table; + +public: + SemanticIndexTableVerifier(const PSVSemanticIndexTable &Table) + : Table(Table), UseMask(Table.Entries, false) {} + bool MarkUse(unsigned Offset, unsigned Size) { + if (Table.Table == nullptr) + return false; + if (Offset > Table.Entries) + return false; + if ((Offset + Size) > Table.Entries) + return false; + for (unsigned i = Offset; i < (Offset + Size); ++i) { + UseMask[i] = true; + } + return true; + } + void Verify(ValidationContext &ValCtx) { + for (unsigned i = 0; i < Table.Entries; i++) { + if (UseMask[i]) + continue; + + ValCtx.EmitFormatError(ValidationRule::ContainerUnusedItemInTable, + {"SemanticIndexTable", std::to_string(i)}); + } + } +}; + class PSVContentVerifier { DxilModule &DM; DxilPipelineStateValidation &PSV; ValidationContext &ValCtx; bool PSVContentValid = true; + StringTableVerifier StrTableVerifier; + SemanticIndexTableVerifier IndexTableVerifier; public: PSVContentVerifier(DxilPipelineStateValidation &PSV, DxilModule &DM, ValidationContext &ValCtx) - : DM(DM), PSV(PSV), ValCtx(ValCtx) {} + : DM(DM), PSV(PSV), ValCtx(ValCtx), + StrTableVerifier(PSV.GetStringTable()), + IndexTableVerifier(PSV.GetSemanticIndexTable()) {} void Verify(unsigned ValMajor, unsigned ValMinor, unsigned PSVVersion); private: @@ -113,7 +192,8 @@ class PSVContentVerifier { PSVContentValid = false; } void EmitInvalidError(StringRef Name) { - ValCtx.EmitFormatError(ValidationRule::ContainerContentInvalid, {Name}); + ValCtx.EmitFormatError(ValidationRule::ContainerContentInvalid, + {"PSV0 part", Name}); PSVContentValid = false; } template static std::string GetDump(const Ty &T) { @@ -226,6 +306,17 @@ void PSVContentVerifier::VerifySignatureElement( const DxilSignatureElement &SE, PSVSignatureElement0 *PSVSE0, const PSVStringTable &StrTab, const PSVSemanticIndexTable &IndexTab, std::string Name, bool i1ToUnknownCompat) { + bool InvalidTableAccess = false; + if (!StrTableVerifier.MarkUse(PSVSE0->SemanticName)) { + EmitInvalidError("SemanticName"); + InvalidTableAccess = true; + } + if (!IndexTableVerifier.MarkUse(PSVSE0->SemanticIndexes, PSVSE0->Rows)) { + EmitInvalidError("SemanticIndex"); + InvalidTableAccess = true; + } + if (InvalidTableAccess) + return; // Find the signature element in the set. PSVSignatureElement0 ModulePSVSE0; InitPSVSignatureElement(ModulePSVSE0, SE, i1ToUnknownCompat); @@ -412,11 +503,19 @@ void PSVContentVerifier::Verify(unsigned ValMajor, unsigned ValMinor, } // PSV2 only added NumThreadsX/Y/Z which verified in VerifyEntryProperties. if (PSVVersion > 2) { - if (DM.GetEntryFunctionName() != PSV.GetEntryFunctionName()) - EmitMismatchError("EntryFunctionName", PSV.GetEntryFunctionName(), - DM.GetEntryFunctionName()); + PSVRuntimeInfo3 *PSV3 = PSV.GetPSVRuntimeInfo3(); + if (!StrTableVerifier.MarkUse(PSV3->EntryFunctionName)) { + EmitInvalidError("EntryFunctionName"); + } else { + if (DM.GetEntryFunctionName() != PSV.GetEntryFunctionName()) + EmitMismatchError("EntryFunctionName", PSV.GetEntryFunctionName(), + DM.GetEntryFunctionName()); + } } + StrTableVerifier.Verify(ValCtx); + IndexTableVerifier.Verify(ValCtx); + if (!PSVContentValid) ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, {"Pipeline State Validation"}); @@ -707,6 +806,13 @@ static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, if (!SimplePSV.ValidatePSVInit(PSVInfo, ValCtx)) return; + if (SimplePSV.StringTable && + SimplePSV.StringTable[SimplePSV.StringTableSize - 1] != '\0') { + ValCtx.EmitFormatError(ValidationRule::ContainerContentInvalid, + {"PSV part StringTable"}); + return; + } + PSVInfo.StringTable = PSVStringTable(SimplePSV.StringTable, SimplePSV.StringTableSize); PSVInfo.SemanticIndexTable = PSVSemanticIndexTable( diff --git a/tools/clang/unittests/HLSL/ValidationTest.cpp b/tools/clang/unittests/HLSL/ValidationTest.cpp index 0a8f9cb2ac..71b137cb02 100644 --- a/tools/clang/unittests/HLSL/ValidationTest.cpp +++ b/tools/clang/unittests/HLSL/ValidationTest.cpp @@ -4650,7 +4650,7 @@ TEST_F(ValidationTest, PSVStringTableReorder) { " ComponentType: 3", " DynamicIndexMask: 0", "') and DXIL module:('PSVSignatureElement:", - " SemanticName: ", + " SemanticName: A", " SemanticIndex: 0 ", " IsAllocated: 1", " StartRow: 0", @@ -4678,7 +4678,7 @@ TEST_F(ValidationTest, PSVStringTableReorder) { " ComponentType: 3", " DynamicIndexMask: 0", "') and DXIL module:('PSVSignatureElement:", - " SemanticName: ", + " SemanticName: B", " SemanticIndex: 0 ", " IsAllocated: 1", " StartRow: 0", @@ -4693,6 +4693,8 @@ TEST_F(ValidationTest, PSVStringTableReorder) { "')", "error: DXIL container mismatch for 'EntryFunctionName' between 'PSV0' " "part:('ain') and DXIL module:('main')", + "error: In 'StringTable', 'A' is not used", + "error: In 'StringTable', 'main' is not used", "error: Container part 'Pipeline State Validation' does not match " "expected for module.", "Validation failed."}, @@ -4711,6 +4713,26 @@ TEST_F(ValidationTest, PSVStringTableReorder) { VERIFY_IS_NOT_NULL(pUpdatedResult); VERIFY_SUCCEEDED(pUpdatedResult->GetStatus(&status)); VERIFY_SUCCEEDED(status); + + // Create unused name in String table. + PSVInfo->EntryFunctionName = 2; + + // Run validation again. + CComPtr pUpdatedTableResult2; + VERIFY_SUCCEEDED( + pValidator->Validate(pProgram, Flags, &pUpdatedTableResult2)); + // Make sure the validation was fail. + VERIFY_IS_NOT_NULL(pUpdatedTableResult2); + VERIFY_SUCCEEDED(pUpdatedTableResult2->GetStatus(&status)); + VERIFY_FAILED(status); + CheckOperationResultMsgs( + pUpdatedTableResult2, + { + "error: DXIL container mismatch for 'EntryFunctionName' between " + "'PSV0' part:('A') and DXIL module:('main')", + "error: In 'StringTable', 'main' is not used", + }, + /*maySucceedAnyway*/ false, /*bRegex*/ false); } class SemanticIndexRotator { @@ -4729,6 +4751,10 @@ class SemanticIndexRotator { SignatureElements[i].SemanticIndexes = SignatureElements[i].SemanticIndexes - 1; } + void Clear(unsigned Index) { + for (unsigned i = 0; i < SignatureElements.size(); ++i) + SignatureElements[i].SemanticIndexes = Index; + } }; TEST_F(ValidationTest, PSVSemanticIndexTableReorder) { @@ -5057,6 +5083,31 @@ TEST_F(ValidationTest, PSVSemanticIndexTableReorder) { VERIFY_IS_NOT_NULL(pUpdatedResult); VERIFY_SUCCEEDED(pUpdatedResult->GetStatus(&status)); VERIFY_SUCCEEDED(status); + + // Clear SemanticIndexes. + InputRotator.Clear(UINT32_MAX); + OutputRotator.Clear(UINT32_MAX); + PatchConstOrPrimRotator.Clear(UINT32_MAX); + + // Run validation again. + CComPtr pUpdatedResult2; + VERIFY_SUCCEEDED(pValidator->Validate(pProgram, Flags, &pUpdatedResult2)); + // Make sure the validation was successful. + VERIFY_IS_NOT_NULL(pUpdatedResult2); + VERIFY_SUCCEEDED(pUpdatedResult2->GetStatus(&status)); + VERIFY_FAILED(status); + + CheckOperationResultMsgs( + pUpdatedResult2, + {"error: In 'PSV0 part', 'SemanticIndex' is not well-formed", + "error: In 'SemanticIndexTable', '0' is not used", + "error: In 'SemanticIndexTable', '2' is not used", + "error: In 'SemanticIndexTable', '3' is not used", + "error: In 'SemanticIndexTable', '4' is not used", + "error: Container part 'Pipeline State Validation' " + "does not match expected for module.", + "Validation failed."}, + /*maySucceedAnyway*/ false, /*bRegex*/ false); } struct SimplePSV { diff --git a/utils/hct/hctdb.py b/utils/hct/hctdb.py index 69d1e4293f..19220d6d1a 100644 --- a/utils/hct/hctdb.py +++ b/utils/hct/hctdb.py @@ -6933,6 +6933,11 @@ def build_valrules(self): "DXIL Container Content is well-formed", "In '%0', '%1' is not well-formed", ) + self.add_valrule_msg( + "Container.UnusedItemInTable", + "Items in Table must be used", + "In '%0', '%1' is not used", + ) self.add_valrule("Meta.Required", "Required metadata missing.") self.add_valrule_msg( "Meta.ComputeWithNode", From 9c4b132b1888833c47f18394cf794936f08ce63b Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Fri, 20 Sep 2024 01:01:57 -0400 Subject: [PATCH 07/13] Test out of bound access for StringTable. --- lib/DxilValidation/DxilContainerValidation.cpp | 2 +- tools/clang/unittests/HLSL/ValidationTest.cpp | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 130c6a4d97..233b216b68 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -81,7 +81,7 @@ static void emitDxilDiag(LLVMContext &Ctx, const char *str) { } class StringTableVerifier { - MapVector OffsetToUseCountMap; + std::unordered_map OffsetToUseCountMap; const PSVStringTable &Table; public: diff --git a/tools/clang/unittests/HLSL/ValidationTest.cpp b/tools/clang/unittests/HLSL/ValidationTest.cpp index 71b137cb02..c789451ea1 100644 --- a/tools/clang/unittests/HLSL/ValidationTest.cpp +++ b/tools/clang/unittests/HLSL/ValidationTest.cpp @@ -4715,7 +4715,7 @@ TEST_F(ValidationTest, PSVStringTableReorder) { VERIFY_SUCCEEDED(status); // Create unused name in String table. - PSVInfo->EntryFunctionName = 2; + PSVInfo->EntryFunctionName = UINT32_MAX; // Run validation again. CComPtr pUpdatedTableResult2; @@ -4728,8 +4728,7 @@ TEST_F(ValidationTest, PSVStringTableReorder) { CheckOperationResultMsgs( pUpdatedTableResult2, { - "error: DXIL container mismatch for 'EntryFunctionName' between " - "'PSV0' part:('A') and DXIL module:('main')", + "In 'PSV0 part', 'EntryFunctionName' is not well-formed", "error: In 'StringTable', 'main' is not used", }, /*maySucceedAnyway*/ false, /*bRegex*/ false); From 69d12d4b1d31bd6dfa266d9c38d8477e3123d1ec Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Fri, 20 Sep 2024 12:11:10 -0400 Subject: [PATCH 08/13] Fix clang build issues. --- .../DxilContainerValidation.cpp | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 233b216b68..08eff2b7ce 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -125,8 +125,8 @@ class StringTableVerifier { }; class SemanticIndexTableVerifier { - std::vector UseMask; const PSVSemanticIndexTable &Table; + std::vector UseMask; public: SemanticIndexTableVerifier(const PSVSemanticIndexTable &Table) @@ -602,11 +602,11 @@ struct SimplePSV { uint32_t PSVNumResources = 0; uint32_t PSVResourceBindInfoSize = 0; uint32_t StringTableSize = 0; - char *StringTable = nullptr; + const char *StringTable = nullptr; uint32_t SemanticIndexTableEntries = 0; - uint32_t *SemanticIndexTable = nullptr; + const uint32_t *SemanticIndexTable = nullptr; uint32_t PSVSignatureElementSize = 0; - PSVRuntimeInfo1 *RuntimeInfo1 = nullptr; + const PSVRuntimeInfo1 *RuntimeInfo1 = nullptr; bool IsValid = true; SimplePSV(const void *pPSVData, uint32_t PSVSize) { uint32_t Offset = 4; @@ -614,23 +614,24 @@ struct SimplePSV { IsValid = false; return; } - PSVRuntimeInfoSize = *(uint32_t *)pPSVData; + PSVRuntimeInfoSize = *(const uint32_t *)pPSVData; if (PSVRuntimeInfoSize >= sizeof(PSVRuntimeInfo1)) - RuntimeInfo1 = (PSVRuntimeInfo1 *)((char *)pPSVData + Offset); + RuntimeInfo1 = (const PSVRuntimeInfo1 *)((const char *)pPSVData + Offset); Offset += PSVRuntimeInfoSize; if (PSVSize < Offset) { IsValid = false; return; } - PSVNumResources = *(uint32_t *)((char *)pPSVData + Offset); + PSVNumResources = *(const uint32_t *)((const char *)pPSVData + Offset); Offset += 4; if (PSVSize < Offset) { IsValid = false; return; } if (PSVNumResources > 0) { - PSVResourceBindInfoSize = *(uint32_t *)((char *)pPSVData + Offset); + PSVResourceBindInfoSize = + *(const uint32_t *)((const char *)pPSVData + Offset); Offset += 4; if (PSVSize < Offset) { IsValid = false; @@ -643,7 +644,7 @@ struct SimplePSV { } } if (RuntimeInfo1) { - StringTableSize = *(uint32_t *)((char *)pPSVData + Offset); + StringTableSize = *(const uint32_t *)((const char *)pPSVData + Offset); Offset += 4; if (PSVSize < Offset) { IsValid = false; @@ -655,17 +656,19 @@ struct SimplePSV { return; } if (StringTableSize) { - StringTable = (char *)pPSVData + Offset; + StringTable = (const char *)pPSVData + Offset; Offset += StringTableSize; } - SemanticIndexTableEntries = *(uint32_t *)((char *)pPSVData + Offset); + SemanticIndexTableEntries = + *(const uint32_t *)((const char *)pPSVData + Offset); Offset += 4; if (PSVSize < Offset) { IsValid = false; return; } if (SemanticIndexTableEntries) { - SemanticIndexTable = (uint32_t *)((char *)pPSVData + Offset); + SemanticIndexTable = + (const uint32_t *)((const char *)pPSVData + Offset); Offset += SemanticIndexTableEntries * 4; if (PSVSize < Offset) { IsValid = false; @@ -674,7 +677,8 @@ struct SimplePSV { } if (RuntimeInfo1->SigInputElements || RuntimeInfo1->SigOutputElements || RuntimeInfo1->SigPatchConstOrPrimElements) { - PSVSignatureElementSize = *(uint32_t *)((char *)pPSVData + Offset); + PSVSignatureElementSize = + *(const uint32_t *)((const char *)pPSVData + Offset); Offset += 4; if (PSVSize < Offset) { IsValid = false; From 4fb11b45455f4c76f7092d49d57f49a7ce903438 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Sat, 21 Sep 2024 16:02:54 -0400 Subject: [PATCH 09/13] Code cleanup. --- .../DxilContainerValidation.cpp | 164 +++++++----------- 1 file changed, 64 insertions(+), 100 deletions(-) diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 08eff2b7ce..d1f77551bf 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -609,113 +609,76 @@ struct SimplePSV { const PSVRuntimeInfo1 *RuntimeInfo1 = nullptr; bool IsValid = true; SimplePSV(const void *pPSVData, uint32_t PSVSize) { - uint32_t Offset = 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } - PSVRuntimeInfoSize = *(const uint32_t *)pPSVData; + +#define INCREMENT_POS(Size) \ + Offset += Size; \ + if (Offset > PSVSize) { \ + IsValid = false; \ + return; \ + } + + uint32_t Offset = 0; + PSVRuntimeInfoSize = GetUint32AtOffset(pPSVData, 0); + INCREMENT_POS(4); if (PSVRuntimeInfoSize >= sizeof(PSVRuntimeInfo1)) - RuntimeInfo1 = (const PSVRuntimeInfo1 *)((const char *)pPSVData + Offset); - Offset += PSVRuntimeInfoSize; - if (PSVSize < Offset) { - IsValid = false; - return; - } + RuntimeInfo1 = + (const PSVRuntimeInfo1 *)(GetPtrAtOffset(pPSVData, Offset)); + INCREMENT_POS(PSVRuntimeInfoSize); - PSVNumResources = *(const uint32_t *)((const char *)pPSVData + Offset); - Offset += 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } + PSVNumResources = GetUint32AtOffset(pPSVData, Offset); + INCREMENT_POS(4); if (PSVNumResources > 0) { - PSVResourceBindInfoSize = - *(const uint32_t *)((const char *)pPSVData + Offset); - Offset += 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } - Offset += PSVNumResources * PSVResourceBindInfoSize; - if (PSVSize < Offset) { - IsValid = false; - return; - } + PSVResourceBindInfoSize = GetUint32AtOffset(pPSVData, Offset); + // Increase the offset for the resource bind info size. + INCREMENT_POS(4); + // Increase the offset for the resource bind info. + INCREMENT_POS(PSVNumResources * PSVResourceBindInfoSize); } if (RuntimeInfo1) { - StringTableSize = *(const uint32_t *)((const char *)pPSVData + Offset); - Offset += 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } + StringTableSize = GetUint32AtOffset(pPSVData, Offset); + INCREMENT_POS(4); // Make sure StringTableSize is aligned to 4 bytes. if ((StringTableSize & 3) != 0) { IsValid = false; return; } if (StringTableSize) { - StringTable = (const char *)pPSVData + Offset; - Offset += StringTableSize; - } - SemanticIndexTableEntries = - *(const uint32_t *)((const char *)pPSVData + Offset); - Offset += 4; - if (PSVSize < Offset) { - IsValid = false; - return; + StringTable = GetPtrAtOffset(pPSVData, Offset); + INCREMENT_POS(StringTableSize); } + SemanticIndexTableEntries = GetUint32AtOffset(pPSVData, Offset); + INCREMENT_POS(4); if (SemanticIndexTableEntries) { SemanticIndexTable = - (const uint32_t *)((const char *)pPSVData + Offset); - Offset += SemanticIndexTableEntries * 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } + (const uint32_t *)(GetPtrAtOffset(pPSVData, Offset)); + INCREMENT_POS(SemanticIndexTableEntries * 4); } if (RuntimeInfo1->SigInputElements || RuntimeInfo1->SigOutputElements || RuntimeInfo1->SigPatchConstOrPrimElements) { - PSVSignatureElementSize = - *(const uint32_t *)((const char *)pPSVData + Offset); - Offset += 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } + PSVSignatureElementSize = GetUint32AtOffset(pPSVData, Offset); + INCREMENT_POS(4); uint32_t PSVNumSignatures = RuntimeInfo1->SigInputElements + RuntimeInfo1->SigOutputElements + RuntimeInfo1->SigPatchConstOrPrimElements; - Offset += PSVNumSignatures * PSVSignatureElementSize; - if (PSVSize < Offset) { - IsValid = false; - return; - } + INCREMENT_POS(PSVNumSignatures * PSVSignatureElementSize); } if (RuntimeInfo1->UsesViewID) { for (unsigned i = 0; i < DXIL::kNumOutputStreams; i++) { uint32_t SigOutputVectors = RuntimeInfo1->SigOutputVectors[i]; if (SigOutputVectors == 0) continue; - Offset += sizeof(uint32_t) * - llvm::RoundUpToAlignment(SigOutputVectors, 8) / 8; - if (PSVSize < Offset) { - IsValid = false; - return; - } + uint32_t MaskSizeInBytes = + sizeof(uint32_t) * + PSVComputeMaskDwordsFromVectors(SigOutputVectors); + INCREMENT_POS(MaskSizeInBytes); } if ((RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Hull || RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Mesh) && RuntimeInfo1->SigPatchConstOrPrimVectors) { - Offset += sizeof(uint32_t) * - llvm::RoundUpToAlignment( - RuntimeInfo1->SigPatchConstOrPrimVectors, 8) / - 8; - if (PSVSize < Offset) { - IsValid = false; - return; - } + uint32_t MaskSizeInBytes = + sizeof(uint32_t) * PSVComputeMaskDwordsFromVectors( + RuntimeInfo1->SigPatchConstOrPrimVectors); + INCREMENT_POS(MaskSizeInBytes); } } @@ -723,43 +686,36 @@ struct SimplePSV { uint32_t SigOutputVectors = RuntimeInfo1->SigOutputVectors[i]; if (SigOutputVectors == 0) continue; - Offset += sizeof(uint32_t) * - llvm::RoundUpToAlignment(SigOutputVectors, 8) / 8 * - RuntimeInfo1->SigInputVectors * 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } + uint32_t TableSizeInBytes = + sizeof(uint32_t) * + PSVComputeInputOutputTableDwords(RuntimeInfo1->SigInputVectors, + SigOutputVectors); + INCREMENT_POS(TableSizeInBytes); } if ((RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Hull || RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Mesh) && RuntimeInfo1->SigPatchConstOrPrimVectors && RuntimeInfo1->SigInputVectors) { - Offset += sizeof(uint32_t) * - llvm::RoundUpToAlignment( - RuntimeInfo1->SigPatchConstOrPrimVectors, 8) / - 8 * RuntimeInfo1->SigInputVectors * 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } + uint32_t TableSizeInBytes = + sizeof(uint32_t) * PSVComputeInputOutputTableDwords( + RuntimeInfo1->SigInputVectors, + RuntimeInfo1->SigPatchConstOrPrimVectors); + INCREMENT_POS(TableSizeInBytes); } if (RuntimeInfo1->ShaderStage == (unsigned)DXIL::ShaderKind::Domain && RuntimeInfo1->SigOutputVectors[0] && RuntimeInfo1->SigPatchConstOrPrimVectors) { - Offset += - sizeof(uint32_t) * - llvm::RoundUpToAlignment(RuntimeInfo1->SigOutputVectors[0], 8) / 8 * - RuntimeInfo1->SigPatchConstOrPrimVectors * 4; - if (PSVSize < Offset) { - IsValid = false; - return; - } + uint32_t TableSizeInBytes = + sizeof(uint32_t) * PSVComputeInputOutputTableDwords( + RuntimeInfo1->SigPatchConstOrPrimVectors, + RuntimeInfo1->SigOutputVectors[0]); + INCREMENT_POS(TableSizeInBytes); } } IsValid = PSVSize == Offset; +#undef INCREMENT_POS } bool ValidatePSVInit(PSVInitInfo PSVInfo, ValidationContext &ValCtx) { if (PSVRuntimeInfoSize != PSVInfo.RuntimeInfoSize()) { @@ -789,6 +745,14 @@ struct SimplePSV { } return true; } + +private: + const char *GetPtrAtOffset(const void *BasePtr, uint32_t Offset) const { + return (const char *)BasePtr + Offset; + } + uint32_t GetUint32AtOffset(const void *BasePtr, uint32_t Offset) const { + return *(const uint32_t *)GetPtrAtOffset(BasePtr, Offset); + } }; static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, From 822fa98ae1d7fed382c0722250ae31ac46d53879 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Mon, 23 Sep 2024 14:12:56 -0400 Subject: [PATCH 10/13] Update release note per comment. --- docs/ReleaseNotes.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/ReleaseNotes.md b/docs/ReleaseNotes.md index d2ee393261..ebc30ba159 100644 --- a/docs/ReleaseNotes.md +++ b/docs/ReleaseNotes.md @@ -23,7 +23,7 @@ The included licenses apply to the following files: Place release notes for the upcoming release below this line and remove this line upon naming this release. - The incomplete WaveMatrix implementation has been removed. -- The DXIL validator supports the string table and semantic index table in any order. +- DXIL container validation for PSV0 part allows any content ordering inside string and semantic index tables. ### Version 1.8.2407 From 686c562dfe6c8ca711823029fbaf2be4e5add0c6 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Mon, 23 Sep 2024 14:43:42 -0400 Subject: [PATCH 11/13] Use llvm::BitVector. --- lib/DxilValidation/DxilContainerValidation.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index d1f77551bf..12eccb7504 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -24,6 +24,7 @@ #include "dxc/DXIL/DxilUtil.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/BitVector.h" #include "llvm/Bitcode/ReaderWriter.h" #include "llvm/IR/DiagnosticPrinter.h" #include "llvm/IR/Module.h" @@ -126,7 +127,7 @@ class StringTableVerifier { class SemanticIndexTableVerifier { const PSVSemanticIndexTable &Table; - std::vector UseMask; + llvm::BitVector UseMask; public: SemanticIndexTableVerifier(const PSVSemanticIndexTable &Table) From d500bc1ec815d0d7893d4c59e2aa59367fb98a02 Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Mon, 23 Sep 2024 15:02:36 -0400 Subject: [PATCH 12/13] Switch to ContainerContentInvalid for invalid PSV0 part. --- lib/DxilValidation/DxilContainerValidation.cpp | 4 ++-- tools/clang/unittests/HLSL/ValidationTest.cpp | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 12eccb7504..1719d8edf6 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -760,8 +760,8 @@ static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, uint32_t PSVSize) { SimplePSV SimplePSV(pPSVData, PSVSize); if (!SimplePSV.IsValid) { - ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, - {"Pipeline State Validation"}); + ValCtx.EmitFormatError(ValidationRule::ContainerContentInvalid, + {"DxilContainer", "PSV0 part"}); return; } diff --git a/tools/clang/unittests/HLSL/ValidationTest.cpp b/tools/clang/unittests/HLSL/ValidationTest.cpp index c789451ea1..ec83f4e260 100644 --- a/tools/clang/unittests/HLSL/ValidationTest.cpp +++ b/tools/clang/unittests/HLSL/ValidationTest.cpp @@ -6202,10 +6202,9 @@ TEST_F(ValidationTest, WrongPSVSize) { VERIFY_SUCCEEDED(pUpdatedResult->GetStatus(&status)); VERIFY_FAILED(status); - CheckOperationResultMsgs(pUpdatedResult, - {"Container part 'Pipeline State Validation' does " - "not match expected for module."}, - /*maySucceedAnyway*/ false, /*bRegex*/ false); + CheckOperationResultMsgs( + pUpdatedResult, {"In 'DxilContainer', 'PSV0 part' is not well-formed"}, + /*maySucceedAnyway*/ false, /*bRegex*/ false); } TEST_F(ValidationTest, WrongPSVVersion) { From d8a04b094cde5a1d3524f02935c01a05ef5a441a Mon Sep 17 00:00:00 2001 From: Xiang Li Date: Mon, 23 Sep 2024 18:14:42 -0400 Subject: [PATCH 13/13] Remove useless size compare and add comment from review. --- .../DxilContainerValidation.cpp | 31 ++++++------------- 1 file changed, 9 insertions(+), 22 deletions(-) diff --git a/lib/DxilValidation/DxilContainerValidation.cpp b/lib/DxilValidation/DxilContainerValidation.cpp index 1719d8edf6..2276b0d3de 100644 --- a/lib/DxilValidation/DxilContainerValidation.cpp +++ b/lib/DxilValidation/DxilContainerValidation.cpp @@ -758,23 +758,28 @@ struct SimplePSV { static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, uint32_t PSVSize) { + // SimplePSV.IsValid indicates whether the part is well-formed so that we may + // proceed with more detailed validation. SimplePSV SimplePSV(pPSVData, PSVSize); if (!SimplePSV.IsValid) { ValCtx.EmitFormatError(ValidationRule::ContainerContentInvalid, {"DxilContainer", "PSV0 part"}); return; } - + // The PSVVersion determines the size of record structures that should be + // used when writing PSV0 data, and is based on the validator version in the + // module. unsigned ValMajor, ValMinor; ValCtx.DxilMod.GetValidatorVersion(ValMajor, ValMinor); unsigned PSVVersion = hlsl::GetPSVVersion(ValMajor, ValMinor); - + // PSVInfo is used to compute the expected record size of the PSV0 part of the + // container. It uses facts from the module. PSVInitInfo PSVInfo(PSVVersion); hlsl::SetupPSVInitInfo(PSVInfo, ValCtx.DxilMod); - + // ValidatePSVInit checks that record sizes match expected for PSVVersion. if (!SimplePSV.ValidatePSVInit(PSVInfo, ValCtx)) return; - + // Ensure that the string table data is null-terminated. if (SimplePSV.StringTable && SimplePSV.StringTable[SimplePSV.StringTableSize - 1] != '\0') { ValCtx.EmitFormatError(ValidationRule::ContainerContentInvalid, @@ -782,24 +787,6 @@ static void VerifyPSVMatches(ValidationContext &ValCtx, const void *pPSVData, return; } - PSVInfo.StringTable = - PSVStringTable(SimplePSV.StringTable, SimplePSV.StringTableSize); - PSVInfo.SemanticIndexTable = PSVSemanticIndexTable( - SimplePSV.SemanticIndexTable, SimplePSV.SemanticIndexTableEntries); - uint32_t ExpectedSize = 0; - DxilPipelineStateValidation SizePSV; - if (!SizePSV.InitNew(PSVInfo, nullptr, &ExpectedSize)) { - ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, - {"Pipeline State Validation"}); - return; - } - - if (ExpectedSize != PSVSize) { - ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches, - {"Pipeline State Validation"}); - return; - } - DxilPipelineStateValidation PSV; if (!PSV.InitFromPSV0(pPSVData, PSVSize)) { ValCtx.EmitFormatError(ValidationRule::ContainerPartMatches,