diff --git a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp index 9482a0a213..361b3b5b56 100644 --- a/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp +++ b/tools/clang/lib/SPIRV/DeclResultIdMapper.cpp @@ -3582,32 +3582,61 @@ bool DeclResultIdMapper::createPayloadStageVars( } const auto loc = decl->getLocation(); - if (!type->isStructureType()) { - StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type, - getLocationAndComponentCount(astContext, type)); - const auto name = namePrefix.str() + "." + decl->getNameAsString(); - SpirvVariable *varInstr = spvBuilder.addStageIOVar( - type, sc, name, /*isPrecise=*/false, /*isNointerp=*/false, loc); - - if (!varInstr) - return false; - // Even though these as user defined IO stage variables, set them as SPIR-V - // builtins in order to bypass any semantic string checks and location - // assignment. - stageVar.setIsSpirvBuiltin(); - stageVar.setSpirvInstr(varInstr); - if (stageVar.getStorageClass() == spv::StorageClass::Input || - stageVar.getStorageClass() == spv::StorageClass::Output) { - stageVar.setEntryPoint(entryFunction); + // Most struct type stage vars must be flattened, but for EXT_mesh_shaders the + // mesh payload struct should be decorated with TaskPayloadWorkgroupEXT and + // used directly as the OpEntryPoint variable. + if (!type->isStructureType() || + featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) { + + SpirvVariable *varInstr = nullptr; + + // Check whether a mesh payload module variable has already been added, as + // is the case for the groupshared payload variable parameter of + // DispatchMesh. In this case, change the storage class from Workgroup to + // TaskPayloadWorkgroupEXT. + if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) { + for (SpirvVariable *moduleVar : spvBuilder.getModule()->getVariables()) { + if (moduleVar->getAstResultType() == type) { + moduleVar->setStorageClass( + spv::StorageClass::TaskPayloadWorkgroupEXT); + varInstr = moduleVar; + } + } } - stageVars.push_back(stageVar); - if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) { - // Decorate with PerTaskNV for mesh/amplification shader payload - // variables. - spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset, - varInstr->getSourceLocation()); + // If necessary, create new stage variable for mesh payload. + if (!varInstr) { + LocationAndComponent locationAndComponentCount = + type->isStructureType() + ? LocationAndComponent({0, 0, false}) + : getLocationAndComponentCount(astContext, type); + StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, + type, locationAndComponentCount); + const auto name = namePrefix.str() + "." + decl->getNameAsString(); + varInstr = spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false, + /*isNointerp=*/false, loc); + + if (!varInstr) + return false; + + // Even though these as user defined IO stage variables, set them as + // SPIR-V builtins in order to bypass any semantic string checks and + // location assignment. + stageVar.setIsSpirvBuiltin(); + stageVar.setSpirvInstr(varInstr); + if (stageVar.getStorageClass() == spv::StorageClass::Input || + stageVar.getStorageClass() == spv::StorageClass::Output) { + stageVar.setEntryPoint(entryFunction); + } + stageVars.push_back(stageVar); + + if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) { + // Decorate with PerTaskNV for mesh/amplification shader payload + // variables. + spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset, + varInstr->getSourceLocation()); + } } if (asInput) { diff --git a/tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl b/tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl index e7061aff41..bd2ccd7723 100644 --- a/tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl +++ b/tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl @@ -1,7 +1,7 @@ // RUN: %dxc -T as_6_5 -fspv-target-env=vulkan1.1spirv1.4 -E main -fcgl %s -spirv | FileCheck %s // CHECK: OpCapability MeshShadingEXT // CHECK: OpExtension "SPV_EXT_mesh_shader" -// CHECK: OpEntryPoint TaskEXT %main "main" [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %out_var_dummy %out_var_pos +// CHECK: OpEntryPoint TaskEXT %main "main" [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %pld // CHECK: OpExecutionMode %main LocalSize 128 1 1 // CHECK: OpDecorate [[drawid]] BuiltIn DrawIndex @@ -11,14 +11,12 @@ // CHECK: OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex -// CHECK: %pld = OpVariable %_ptr_Workgroup_MeshPayload Workgroup +// CHECK: %pld = OpVariable %_ptr_TaskPayloadWorkgroupEXT_MeshPayload TaskPayloadWorkgroupEXT // CHECK: [[drawid]] = OpVariable %_ptr_Input_int Input // CHECK: %gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input // CHECK: %gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input // CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input // CHECK: %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input -// CHECK: %out_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT -// CHECK: %out_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT struct MeshPayload { float dummy[10]; float4 pos; @@ -47,16 +45,12 @@ void main( // CHECK: %tid = OpFunctionParameter %_ptr_Function_uint // CHECK: %tig = OpFunctionParameter %_ptr_Function_uint // -// CHECK: [[a:%[0-9]+]] = OpAccessChain %_ptr_Workgroup_v4float %pld %int_1 +// CHECK: [[a:%[0-9]+]] = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_v4float %pld %int_1 // CHECK: OpStore [[a]] {{%[0-9]+}} pld.pos = float4(gtid.x, gid.y, tid, tig); // CHECK: OpControlBarrier %uint_2 %uint_2 %uint_264 // CHECK: [[e:%[0-9]+]] = OpLoad %MeshPayload %pld -// CHECK: [[f:%[0-9]+]] = OpCompositeExtract %_arr_float_uint_10 [[e]] 0 -// CHECK: OpStore %out_var_dummy [[f]] -// CHECK: [[g:%[0-9]+]] = OpCompositeExtract %v4float [[e]] 1 -// CHECK: OpStore %out_var_pos [[g]] // CHECK: [[h:%[0-9]+]] = OpLoad %int %drawId // CHECK: [[i:%[0-9]+]] = OpBitcast %uint [[h]] // CHECK: [[j:%[0-9]+]] = OpLoad %int %drawId diff --git a/tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl b/tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl index 2c0dcee65d..4534e0faac 100644 --- a/tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl +++ b/tools/clang/test/CodeGenSPIRV/meshshading.ext.triangle.mesh.hlsl @@ -1,7 +1,7 @@ // RUN: %dxc -T ms_6_5 -fspv-target-env=universal1.5 -E main -fcgl %s -spirv | FileCheck %s // CHECK: OpCapability MeshShadingEXT // CHECK: OpExtension "SPV_EXT_mesh_shader" -// CHECK: OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_dummy %in_var_pos [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%[0-9]+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%[0-9]+]] [[primshadingrate:%[0-9]+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR +// CHECK: OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_pld [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%[0-9]+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%[0-9]+]] [[primshadingrate:%[0-9]+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR // CHECK: OpExecutionMode %main LocalSize 128 1 1 // CHECK: OpExecutionMode %main OutputTrianglesNV // CHECK: OpExecutionMode %main OutputVertices 64 @@ -37,8 +37,7 @@ // CHECK: %gl_ClipDistance = OpVariable %_ptr_Output__arr__arr_float_uint_5_uint_64 Output // CHECK: %gl_CullDistance = OpVariable %_ptr_Output__arr__arr_float_uint_3_uint_64 Output -// CHECK: %in_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT -// CHECK: %in_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT +// CHECK: %in_var_pld = OpVariable %_ptr_TaskPayloadWorkgroupEXT_MeshPayload TaskPayloadWorkgroupEXT // CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input // CHECK: %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input // CHECK: %gl_Position = OpVariable %_ptr_Output__arr_v4float_uint_64 Output