From 1e03d653b71e7a2fadfdab636512224fdead4cea Mon Sep 17 00:00:00 2001 From: Natalie Chouinard Date: Fri, 12 Apr 2024 18:08:21 +0000 Subject: [PATCH] [SPIR-V] Fix mesh payload for VK_EXT_mesh_shaders The existing logic from VK_NV_mesh_shader was incorrectly adapted for the VK_EXT_mesh_shader implementation when it comes to the handling of mesh payloads as in/out variables. Because TaskPayloadWorkgroupEXT must be applied to a single global OpVariable for each Task/Mesh shader, the struct should not be flattened. Further, as far as I can tell, Location assignment is not necessary for these input and output variables, so the usual reason for flattening structs does not apply. This change now removes the inner struct member global variables and ensures the parent payload is decorated with TaskPayloadWorkgroupEXT. Note that for amplification/task shaders, the payload variable is created with the groupshared decl, and then it's storage class needs to be updated when that variable is used as a parameter to the DispatchMesh call, as described in: https://docs.vulkan.org/spec/latest/proposals/proposals/VK_EXT_mesh_shader.html#_hlsl_changes Tested with updated spirv-val from: https://github.com/KhronosGroup/SPIRV-Tools/pull/5640 Fixes #5981 --- tools/clang/lib/SPIRV/DeclResultIdMapper.cpp | 75 +++++++++++++------ .../meshshading.ext.amplification.hlsl | 12 +-- .../meshshading.ext.triangle.mesh.hlsl | 5 +- 3 files changed, 57 insertions(+), 35 deletions(-) diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index 9482a0a213..361b3b5b56 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -3582,32 +3582,61 @@ bool DeclResultIdMapper::createPayloadStageVars( } const auto loc = decl->getLocation(); - if (!type->isStructureType()) { - StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type, - getLocationAndComponentCount(astContext, type)); - const auto name = namePrefix.str() + "." + decl->getNameAsString(); - SpirvVariable *varInstr = spvBuilder.addStageIOVar( - type, sc, name, /*isPrecise=*/false, /*isNointerp=*/false, loc); - - if (!varInstr) - return false; - // Even though these as user defined IO stage variables, set them as SPIR-V - // builtins in order to bypass any semantic string checks and location - // assignment. - stageVar.setIsSpirvBuiltin(); - stageVar.setSpirvInstr(varInstr); - if (stageVar.getStorageClass() == spv::StorageClass::Input || - stageVar.getStorageClass() == spv::StorageClass::Output) { - stageVar.setEntryPoint(entryFunction); + // Most struct type stage vars must be flattened, but for EXT_mesh_shaders the + // mesh payload struct should be decorated with TaskPayloadWorkgroupEXT and + // used directly as the OpEntryPoint variable. + if (!type->isStructureType() || + featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) { + + SpirvVariable *varInstr = nullptr; + + // Check whether a mesh payload module variable has already been added, as + // is the case for the groupshared payload variable parameter of + // DispatchMesh. In this case, change the storage class from Workgroup to + // TaskPayloadWorkgroupEXT. + if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) { + for (SpirvVariable *moduleVar : spvBuilder.getModule()->getVariables()) { + if (moduleVar->getAstResultType() == type) { + moduleVar->setStorageClass( + spv::StorageClass::TaskPayloadWorkgroupEXT); + varInstr = moduleVar; + } + } } - stageVars.push_back(stageVar); - if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) { - // Decorate with PerTaskNV for mesh/amplification shader payload - // variables. - spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset, - varInstr->getSourceLocation()); + // If necessary, create new stage variable for mesh payload. + if (!varInstr) { + LocationAndComponent locationAndComponentCount = + type->isStructureType() + ? LocationAndComponent({0, 0, false}) + : getLocationAndComponentCount(astContext, type); + StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, + type, locationAndComponentCount); + const auto name = namePrefix.str() + "." + decl->getNameAsString(); + varInstr = spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false, + /*isNointerp=*/false, loc); + + if (!varInstr) + return false; + + // Even though these as user defined IO stage variables, set them as + // SPIR-V builtins in order to bypass any semantic string checks and + // location assignment. + stageVar.setIsSpirvBuiltin(); + stageVar.setSpirvInstr(varInstr); + if (stageVar.getStorageClass() == spv::StorageClass::Input || + stageVar.getStorageClass() == spv::StorageClass::Output) { + stageVar.setEntryPoint(entryFunction); + } + stageVars.push_back(stageVar); + + if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) { + // Decorate with PerTaskNV for mesh/amplification shader payload + // variables. + spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset, + varInstr->getSourceLocation()); + } } if (asInput) { diff --git a/tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl b/tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl index e7061aff41..bd2ccd7723 100644 --- a/tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl +++ b/tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl @@ -1,7 +1,7 @@ // RUN: %dxc -T as_6_5 -fspv-target-env=vulkan1.1spirv1.4 -E main -fcgl %s -spirv | FileCheck %s // CHECK: OpCapability MeshShadingEXT // CHECK: OpExtension "SPV_EXT_mesh_shader" -// CHECK: OpEntryPoint TaskEXT %main "main" [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %out_var_dummy %out_var_pos +// CHECK: OpEntryPoint TaskEXT %main "main" [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %pld // CHECK: OpExecutionMode %main LocalSize 128 1 1 // CHECK: OpDecorate [[drawid]] BuiltIn DrawIndex @@ -11,14 +11,12 @@ // CHECK: OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex -// CHECK: %pld = OpVariable %_ptr_Workgroup_MeshPayload Workgroup +// CHECK: %pld = OpVariable %_ptr_TaskPayloadWorkgroupEXT_MeshPayload TaskPayloadWorkgroupEXT // CHECK: [[drawid]] = OpVariable %_ptr_Input_int Input // CHECK: %gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input // CHECK: %gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input // CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input // CHECK: %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input -// CHECK: %out_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT -// CHECK: %out_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT struct MeshPayload { float dummy[10]; float4 pos; @@ -47,16 +45,12 @@ void main( // CHECK: %tid = OpFunctionParameter %_ptr_Function_uint // CHECK: %tig = OpFunctionParameter %_ptr_Function_uint // -// CHECK: [[a:%[0-9]+]] = OpAccessChain %_ptr_Workgroup_v4float %pld %int_1 +// CHECK: [[a:%[0-9]+]] = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_v4float %pld %int_1 // CHECK: OpStore [[a]] {{%[0-9]+}} pld.pos = float4(gtid.x, gid.y, tid, tig); // CHECK: OpControlBarrier %uint_2 %uint_2 %uint_264 // CHECK: [[e:%[0-9]+]] = OpLoad %MeshPayload %pld -// CHECK: [[f:%[0-9]+]] = OpCompositeExtract %_arr_float_uint_10 [[e]] 0 -// CHECK: OpStore %out_var_dummy [[f]] -// CHECK: [[g:%[0-9]+]] = OpCompositeExtract %v4float [[e]] 1 -// CHECK: OpStore %out_var_pos [[g]] // CHECK: [[h:%[0-9]+]] = OpLoad %int %drawId // CHECK: [[i:%[0-9]+]] = OpBitcast %uint [[h]] // CHECK: [[j:%[0-9]+]] = OpLoad %int %drawId diff --git a/tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl b/tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl index 2c0dcee65d..4534e0faac 100644 --- a/tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl +++ b/tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl @@ -1,7 +1,7 @@ // RUN: %dxc -T ms_6_5 -fspv-target-env=universal1.5 -E main -fcgl %s -spirv | FileCheck %s // CHECK: OpCapability MeshShadingEXT // CHECK: OpExtension "SPV_EXT_mesh_shader" -// CHECK: OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_dummy %in_var_pos [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%[0-9]+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%[0-9]+]] [[primshadingrate:%[0-9]+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR +// CHECK: OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_pld [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%[0-9]+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%[0-9]+]] [[primshadingrate:%[0-9]+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR // CHECK: OpExecutionMode %main LocalSize 128 1 1 // CHECK: OpExecutionMode %main OutputTrianglesNV // CHECK: OpExecutionMode %main OutputVertices 64 @@ -37,8 +37,7 @@ // CHECK: %gl_ClipDistance = OpVariable %_ptr_Output__arr__arr_float_uint_5_uint_64 Output // CHECK: %gl_CullDistance = OpVariable %_ptr_Output__arr__arr_float_uint_3_uint_64 Output -// CHECK: %in_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT -// CHECK: %in_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT +// CHECK: %in_var_pld = OpVariable %_ptr_TaskPayloadWorkgroupEXT_MeshPayload TaskPayloadWorkgroupEXT // CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input // CHECK: %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input // CHECK: %gl_Position = OpVariable %_ptr_Output__arr_v4float_uint_64 Output