Skip to content

Commit

Permalink
[SPIR-V] Fix mesh payload for VK_EXT_mesh_shaders
Browse files Browse the repository at this point in the history
The existing logic from VK_NV_mesh_shader was incorrectly adapted
for the VK_EXT_mesh_shader implementation when it comes to the handling
of mesh payloads as in/out variables. Because TaskPayloadWorkgroupEXT
must be applied to a single global OpVariable for each Task/Mesh shader,
the struct should not be flattened. Further, as far as I can tell,
Location assignment is not necessary for these input and output
variables, so the usual reason for flattening structs does not apply.

This change now removes the inner struct member global variables and
ensures the parent payload is decorated with TaskPayloadWorkgroupEXT.
Note that for amplification/task shaders, the payload variable is
created with the groupshared decl, and then it's storage class needs to
be updated when that variable is used as a parameter to the DispatchMesh
call, as described in: https://docs.vulkan.org/spec/latest/proposals/proposals/VK_EXT_mesh_shader.html#_hlsl_changes

Tested with updated spirv-val from: KhronosGroup/SPIRV-Tools#5640

Fixes microsoft#5981
  • Loading branch information
sudonatalie committed Apr 12, 2024
1 parent b065a0d commit 1e03d65
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 35 deletions.
75 changes: 52 additions & 23 deletions tools/clang/lib/SPIRV/DeclResultIdMapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3582,32 +3582,61 @@ bool DeclResultIdMapper::createPayloadStageVars(
}

const auto loc = decl->getLocation();
if (!type->isStructureType()) {
StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr, type,
getLocationAndComponentCount(astContext, type));
const auto name = namePrefix.str() + "." + decl->getNameAsString();
SpirvVariable *varInstr = spvBuilder.addStageIOVar(
type, sc, name, /*isPrecise=*/false, /*isNointerp=*/false, loc);

if (!varInstr)
return false;

// Even though these as user defined IO stage variables, set them as SPIR-V
// builtins in order to bypass any semantic string checks and location
// assignment.
stageVar.setIsSpirvBuiltin();
stageVar.setSpirvInstr(varInstr);
if (stageVar.getStorageClass() == spv::StorageClass::Input ||
stageVar.getStorageClass() == spv::StorageClass::Output) {
stageVar.setEntryPoint(entryFunction);
// Most struct type stage vars must be flattened, but for EXT_mesh_shaders the
// mesh payload struct should be decorated with TaskPayloadWorkgroupEXT and
// used directly as the OpEntryPoint variable.
if (!type->isStructureType() ||
featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {

SpirvVariable *varInstr = nullptr;

// Check whether a mesh payload module variable has already been added, as
// is the case for the groupshared payload variable parameter of
// DispatchMesh. In this case, change the storage class from Workgroup to
// TaskPayloadWorkgroupEXT.
if (featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
for (SpirvVariable *moduleVar : spvBuilder.getModule()->getVariables()) {
if (moduleVar->getAstResultType() == type) {
moduleVar->setStorageClass(
spv::StorageClass::TaskPayloadWorkgroupEXT);
varInstr = moduleVar;
}
}
}
stageVars.push_back(stageVar);

if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
// Decorate with PerTaskNV for mesh/amplification shader payload
// variables.
spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
varInstr->getSourceLocation());
// If necessary, create new stage variable for mesh payload.
if (!varInstr) {
LocationAndComponent locationAndComponentCount =
type->isStructureType()
? LocationAndComponent({0, 0, false})
: getLocationAndComponentCount(astContext, type);
StageVar stageVar(sigPoint, /*semaInfo=*/{}, /*builtinAttr=*/nullptr,
type, locationAndComponentCount);
const auto name = namePrefix.str() + "." + decl->getNameAsString();
varInstr = spvBuilder.addStageIOVar(type, sc, name, /*isPrecise=*/false,
/*isNointerp=*/false, loc);

if (!varInstr)
return false;

// Even though these as user defined IO stage variables, set them as
// SPIR-V builtins in order to bypass any semantic string checks and
// location assignment.
stageVar.setIsSpirvBuiltin();
stageVar.setSpirvInstr(varInstr);
if (stageVar.getStorageClass() == spv::StorageClass::Input ||
stageVar.getStorageClass() == spv::StorageClass::Output) {
stageVar.setEntryPoint(entryFunction);
}
stageVars.push_back(stageVar);

if (!featureManager.isExtensionEnabled(Extension::EXT_mesh_shader)) {
// Decorate with PerTaskNV for mesh/amplification shader payload
// variables.
spvBuilder.decoratePerTaskNV(varInstr, payloadMemOffset,
varInstr->getSourceLocation());
}
}

if (asInput) {
Expand Down
12 changes: 3 additions & 9 deletions tools/clang/test/CodeGenSPIRV/meshshading.ext.amplification.hlsl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: %dxc -T as_6_5 -fspv-target-env=vulkan1.1spirv1.4 -E main -fcgl %s -spirv | FileCheck %s
// CHECK: OpCapability MeshShadingEXT
// CHECK: OpExtension "SPV_EXT_mesh_shader"
// CHECK: OpEntryPoint TaskEXT %main "main" [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %out_var_dummy %out_var_pos
// CHECK: OpEntryPoint TaskEXT %main "main" [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %pld
// CHECK: OpExecutionMode %main LocalSize 128 1 1

// CHECK: OpDecorate [[drawid]] BuiltIn DrawIndex
Expand All @@ -11,14 +11,12 @@
// CHECK: OpDecorate %gl_LocalInvocationIndex BuiltIn LocalInvocationIndex


// CHECK: %pld = OpVariable %_ptr_Workgroup_MeshPayload Workgroup
// CHECK: %pld = OpVariable %_ptr_TaskPayloadWorkgroupEXT_MeshPayload TaskPayloadWorkgroupEXT
// CHECK: [[drawid]] = OpVariable %_ptr_Input_int Input
// CHECK: %gl_LocalInvocationID = OpVariable %_ptr_Input_v3uint Input
// CHECK: %gl_WorkGroupID = OpVariable %_ptr_Input_v3uint Input
// CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
// CHECK: %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
// CHECK: %out_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT
// CHECK: %out_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT
struct MeshPayload {
float dummy[10];
float4 pos;
Expand Down Expand Up @@ -47,16 +45,12 @@ void main(
// CHECK: %tid = OpFunctionParameter %_ptr_Function_uint
// CHECK: %tig = OpFunctionParameter %_ptr_Function_uint
//
// CHECK: [[a:%[0-9]+]] = OpAccessChain %_ptr_Workgroup_v4float %pld %int_1
// CHECK: [[a:%[0-9]+]] = OpAccessChain %_ptr_TaskPayloadWorkgroupEXT_v4float %pld %int_1
// CHECK: OpStore [[a]] {{%[0-9]+}}
pld.pos = float4(gtid.x, gid.y, tid, tig);

// CHECK: OpControlBarrier %uint_2 %uint_2 %uint_264
// CHECK: [[e:%[0-9]+]] = OpLoad %MeshPayload %pld
// CHECK: [[f:%[0-9]+]] = OpCompositeExtract %_arr_float_uint_10 [[e]] 0
// CHECK: OpStore %out_var_dummy [[f]]
// CHECK: [[g:%[0-9]+]] = OpCompositeExtract %v4float [[e]] 1
// CHECK: OpStore %out_var_pos [[g]]
// CHECK: [[h:%[0-9]+]] = OpLoad %int %drawId
// CHECK: [[i:%[0-9]+]] = OpBitcast %uint [[h]]
// CHECK: [[j:%[0-9]+]] = OpLoad %int %drawId
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// RUN: %dxc -T ms_6_5 -fspv-target-env=universal1.5 -E main -fcgl %s -spirv | FileCheck %s
// CHECK: OpCapability MeshShadingEXT
// CHECK: OpExtension "SPV_EXT_mesh_shader"
// CHECK: OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_dummy %in_var_pos [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%[0-9]+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%[0-9]+]] [[primshadingrate:%[0-9]+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR
// CHECK: OpEntryPoint MeshEXT %main "main" %gl_ClipDistance %gl_CullDistance %in_var_pld [[drawid:%[0-9]+]] %gl_LocalInvocationID %gl_WorkGroupID %gl_GlobalInvocationID %gl_LocalInvocationIndex %gl_Position %gl_PointSize %out_var_USER %out_var_USER_ARR %out_var_USER_MAT [[primindices:%[0-9]+]] %gl_PrimitiveID %gl_Layer %gl_ViewportIndex [[cullprim:%[0-9]+]] [[primshadingrate:%[0-9]+]] %out_var_PRIM_USER %out_var_PRIM_USER_ARR
// CHECK: OpExecutionMode %main LocalSize 128 1 1
// CHECK: OpExecutionMode %main OutputTrianglesNV
// CHECK: OpExecutionMode %main OutputVertices 64
Expand Down Expand Up @@ -37,8 +37,7 @@

// CHECK: %gl_ClipDistance = OpVariable %_ptr_Output__arr__arr_float_uint_5_uint_64 Output
// CHECK: %gl_CullDistance = OpVariable %_ptr_Output__arr__arr_float_uint_3_uint_64 Output
// CHECK: %in_var_dummy = OpVariable %_ptr_TaskPayloadWorkgroupEXT__arr_float_uint_10 TaskPayloadWorkgroupEXT
// CHECK: %in_var_pos = OpVariable %_ptr_TaskPayloadWorkgroupEXT_v4float TaskPayloadWorkgroupEXT
// CHECK: %in_var_pld = OpVariable %_ptr_TaskPayloadWorkgroupEXT_MeshPayload TaskPayloadWorkgroupEXT
// CHECK: %gl_GlobalInvocationID = OpVariable %_ptr_Input_v3uint Input
// CHECK: %gl_LocalInvocationIndex = OpVariable %_ptr_Input_uint Input
// CHECK: %gl_Position = OpVariable %_ptr_Output__arr_v4float_uint_64 Output
Expand Down

0 comments on commit 1e03d65

Please sign in to comment.