Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 4 additions & 85 deletions source/slang/slang-emit-spirv-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,98 +74,17 @@ SpvInst* emitOpEntryPoint(
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpExecutionMode
template<typename T>
template<typename... Operands, typename T>
SpvInst* emitOpExecutionMode(
SpvInstParent* parent,
IRInst* inst,
const T& entryPoint,
SpvExecutionMode mode
)
{
static_assert(isSingular<T>);
return emitInst(parent, inst, SpvOpExecutionMode, entryPoint, mode);
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpExecutionMode
template<typename T>
SpvInst* emitOpExecutionModeLocalSize(
SpvInstParent* parent,
IRInst* inst,
const T& entryPoint,
const SpvLiteralInteger& xSize,
const SpvLiteralInteger& ySize,
const SpvLiteralInteger& zSize
SpvExecutionMode mode,
const Operands& ...ops
)
{
static_assert(isSingular<T>);
return emitInst(
parent, inst, SpvOpExecutionMode, entryPoint, SpvExecutionModeLocalSize, xSize, ySize, zSize
);
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpExecutionMode
template<typename T1, typename T2, typename T3, typename T4>
SpvInst* emitOpExecutionModeLocalSizeId(
SpvInstParent* parent,
IRInst* inst,
const T1& entryPoint,
const T2& xSize,
const T3& ySize,
const T4& zSize
)
{
static_assert(isSingular<T1>);
static_assert(isSingular<T2>);
static_assert(isSingular<T3>);
static_assert(isSingular<T4>);
return emitInst(
parent, inst, SpvOpExecutionMode, entryPoint, SpvExecutionModeLocalSizeId, xSize, ySize, zSize
);
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpExecutionMode
template<typename T>
SpvInst* emitOpExecutionModeOutputVertices(
SpvInstParent* parent,
IRInst* inst,
const T& entryPoint,
const SpvLiteralInteger& vertexCount
)
{
static_assert(isSingular<T>);
return emitInst(
parent, inst, SpvOpExecutionMode, entryPoint, SpvExecutionModeOutputVertices, vertexCount
);
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpExecutionMode
template<typename T>
SpvInst* emitOpExecutionModeOutputPrimitivesEXT(
SpvInstParent* parent,
IRInst* inst,
const T& entryPoint,
const SpvLiteralInteger& primitiveCount
)
{
static_assert(isSingular<T>);
return emitInst(
parent, inst, SpvOpExecutionMode, entryPoint, SpvExecutionModeOutputPrimitivesEXT, primitiveCount
);
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpExecutionMode
template<typename T>
SpvInst* emitOpExecutionModeInvocations(
SpvInstParent* parent,
IRInst* inst,
const T& entryPoint,
const SpvLiteralInteger& invocations
)
{
static_assert(isSingular<T>);
return emitInst(
parent, inst, SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations, invocations
);
return emitInst(parent, inst, SpvOpExecutionMode, entryPoint, mode, ops...);
}

// https://registry.khronos.org/SPIR-V/specs/unified1/SPIRV.html#OpCapability
Expand Down
91 changes: 57 additions & 34 deletions source/slang/slang-emit-spirv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2765,14 +2765,14 @@ struct SPIRVEmitContext
if (isQuad)
{
verifyComputeDerivativeGroupModifiers(this->m_sink, inst->sourceLoc, true, false, numThreadsDecor);
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, entryPoint, SpvExecutionModeDerivativeGroupQuadsNV);
emitOpCapability(getSection(SpvLogicalSectionID::Capabilities), nullptr, SpvCapabilityComputeDerivativeGroupQuadsNV);
requireSPIRVExecutionMode(nullptr, getIRInstSpvID(entryPoint), SpvExecutionModeDerivativeGroupQuadsNV);
requireSPIRVCapability(SpvCapabilityComputeDerivativeGroupQuadsNV);
}
else
{
verifyComputeDerivativeGroupModifiers(this->m_sink, inst->sourceLoc, false, true, numThreadsDecor);
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, entryPoint, SpvExecutionModeDerivativeGroupLinearNV);
emitOpCapability(getSection(SpvLogicalSectionID::Capabilities), nullptr, SpvCapabilityComputeDerivativeGroupLinearNV);
requireSPIRVExecutionMode(nullptr, getIRInstSpvID(entryPoint), SpvExecutionModeDerivativeGroupLinearNV);
requireSPIRVCapability(SpvCapabilityComputeDerivativeGroupLinearNV);
}
}

Expand All @@ -2790,7 +2790,7 @@ struct SPIRVEmitContext
case kIROp_BeginFragmentShaderInterlock:
ensureExtensionDeclaration(UnownedStringSlice("SPV_EXT_fragment_shader_interlock"));
requireSPIRVCapability(SpvCapabilityFragmentShaderPixelInterlockEXT);
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, getParentFunc(inst), SpvExecutionModePixelInterlockOrderedEXT);
requireSPIRVExecutionMode(nullptr, getIRInstSpvID(getParentFunc(inst)), SpvExecutionModePixelInterlockOrderedEXT);
result = emitOpBeginInvocationInterlockEXT(parent, inst);
break;
case kIROp_EndFragmentShaderInterlock:
Expand Down Expand Up @@ -3130,10 +3130,7 @@ struct SPIRVEmitContext
if (mode == SpvExecutionModeMax)
return;

emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes),
nullptr,
entryPoint,
mode);
requireSPIRVExecutionMode(nullptr, getIRInstSpvID(entryPoint), mode);
}

// Make user type name conform to `SPV_GOOGLE_user_type` spec.
Expand Down Expand Up @@ -3252,14 +3249,14 @@ struct SPIRVEmitContext
{
case Stage::Fragment:
//OpExecutionMode %main OriginUpperLeft
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, dstID, SpvExecutionModeOriginUpperLeft);
requireSPIRVExecutionMode(nullptr, getIRInstSpvID(entryPoint), SpvExecutionModeOriginUpperLeft);
maybeEmitEntryPointDepthReplacingExecutionMode(entryPoint, referencedBuiltinIRVars);
for (auto decor : entryPoint->getDecorations())
{
switch (decor->getOp())
{
case kIROp_EarlyDepthStencilDecoration:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), nullptr, dstID, SpvExecutionModeEarlyFragmentTests);
requireSPIRVExecutionMode(nullptr, getIRInstSpvID(entryPoint), SpvExecutionModeEarlyFragmentTests);
break;
default:
break;
Expand Down Expand Up @@ -3293,8 +3290,6 @@ struct SPIRVEmitContext
// [3.6. Execution Mode]: LocalSize
case kIROp_NumThreadsDecoration:
{
auto section = getSection(SpvLogicalSectionID::ExecutionModes);

// TODO: The `LocalSize` execution mode option requires
// literal values for the X,Y,Z thread-group sizes.
// There is a `LocalSizeId` variant that takes `<id>`s
Expand All @@ -3305,10 +3300,10 @@ struct SPIRVEmitContext
// in those positions in the Slang IR).
//
auto numThreads = cast<IRNumThreadsDecoration>(decoration);
emitOpExecutionModeLocalSize(
section,
requireSPIRVExecutionMode(
decoration,
dstID,
SpvExecutionModeLocalSize,
SpvLiteralInteger::from32(int32_t(numThreads->getX()->getValue())),
SpvLiteralInteger::from32(int32_t(numThreads->getY()->getValue())),
SpvLiteralInteger::from32(int32_t(numThreads->getZ()->getValue()))
Expand All @@ -3324,8 +3319,7 @@ struct SPIRVEmitContext
{
auto decor = as<IRInstanceDecoration>(decoration);
auto count = int32_t(getIntVal(decor->getCount()));
auto section = getSection(SpvLogicalSectionID::ExecutionModes);
emitOpExecutionModeInvocations(section, decoration, dstID, SpvLiteralInteger::from32(count));
requireSPIRVExecutionMode(decoration, dstID, SpvExecutionModeInvocations, SpvLiteralInteger::from32(count));
}
break;
case kIROp_TriangleInputPrimitiveTypeDecoration:
Expand All @@ -3343,31 +3337,30 @@ struct SPIRVEmitContext
switch (inputDecor->getOp())
{
case kIROp_TriangleInputPrimitiveTypeDecoration:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), inputDecor, dstID, SpvExecutionModeTriangles);
requireSPIRVExecutionMode(inputDecor, dstID, SpvExecutionModeTriangles);
break;
case kIROp_LineInputPrimitiveTypeDecoration:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), inputDecor, dstID, SpvExecutionModeInputLines);
requireSPIRVExecutionMode(inputDecor, dstID, SpvExecutionModeInputLines);
break;
case kIROp_LineAdjInputPrimitiveTypeDecoration:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), inputDecor, dstID, SpvExecutionModeInputLinesAdjacency);
requireSPIRVExecutionMode(inputDecor, dstID, SpvExecutionModeInputLinesAdjacency);
break;
case kIROp_PointInputPrimitiveTypeDecoration:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), inputDecor, dstID, SpvExecutionModeInputPoints);
requireSPIRVExecutionMode(inputDecor, dstID, SpvExecutionModeInputPoints);
break;
case kIROp_TriangleAdjInputPrimitiveTypeDecoration:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), inputDecor, dstID, SpvExecutionModeInputTrianglesAdjacency);
requireSPIRVExecutionMode(inputDecor, dstID, SpvExecutionModeInputTrianglesAdjacency);
break;
}
}
// SPIRV requires MaxVertexCount decoration to appear before OutputTopologyDecoration,
// so we emit them here.
if (auto maxVertexCount = decoration->getParent()->findDecoration<IRMaxVertexCountDecoration>())
{
auto section = getSection(SpvLogicalSectionID::ExecutionModes);
emitOpExecutionModeOutputVertices(
section,
requireSPIRVExecutionMode(
maxVertexCount,
dstID,
SpvExecutionModeOutputVertices,
SpvLiteralInteger::from32(int32_t(getIntVal(maxVertexCount->getCount())))
);
}
Expand All @@ -3378,13 +3371,13 @@ struct SPIRVEmitContext
switch (type->getOp())
{
case kIROp_HLSLPointStreamType:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), decoration, dstID, SpvExecutionModeOutputPoints);
requireSPIRVExecutionMode(decoration, dstID, SpvExecutionModeOutputPoints);
break;
case kIROp_HLSLLineStreamType:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), decoration, dstID, SpvExecutionModeOutputLineStrip);
requireSPIRVExecutionMode(decoration, dstID, SpvExecutionModeOutputLineStrip);
break;
case kIROp_HLSLTriangleStreamType:
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), decoration, dstID, SpvExecutionModeOutputTriangleStrip);
requireSPIRVExecutionMode(decoration, dstID, SpvExecutionModeOutputTriangleStrip);
break;
default: SLANG_ASSERT(!"Unknown stream out type");
}
Expand Down Expand Up @@ -3433,17 +3426,17 @@ struct SPIRVEmitContext
: t == "point" ? SpvExecutionModeOutputPoints
: SpvExecutionModeMax;
SLANG_ASSERT(m != SpvExecutionModeMax);
emitOpExecutionMode(getSection(SpvLogicalSectionID::ExecutionModes), decoration, dstID, m);
requireSPIRVExecutionMode(decoration, dstID, m);
}
break;

case kIROp_VerticesDecoration:
{
const auto c = cast<IRVerticesDecoration>(decoration);
emitOpExecutionModeOutputVertices(
getSection(SpvLogicalSectionID::ExecutionModes),
requireSPIRVExecutionMode(
decoration,
dstID,
SpvExecutionModeOutputVertices,
SpvLiteralInteger::from32(int32_t(c->getMaxSize()->getValue()))
);
}
Expand All @@ -3452,10 +3445,10 @@ struct SPIRVEmitContext
case kIROp_PrimitivesDecoration:
{
const auto c = cast<IRPrimitivesDecoration>(decoration);
emitOpExecutionModeOutputPrimitivesEXT(
getSection(SpvLogicalSectionID::ExecutionModes),
requireSPIRVExecutionMode(
decoration,
dstID,
SpvExecutionModeOutputPrimitivesEXT,
SpvLiteralInteger::from32(int32_t(c->getMaxSize()->getValue()))
);
}
Expand Down Expand Up @@ -6136,7 +6129,6 @@ struct SPIRVEmitContext
}

OrderedHashSet<SpvCapability> m_capabilities;

void requireSPIRVCapability(SpvCapability capability)
{
if (m_capabilities.add(capability))
Expand All @@ -6149,6 +6141,37 @@ struct SPIRVEmitContext
}
}

Dictionary<SpvWord, OrderedHashSet<SpvExecutionMode>> m_executionModes;
template<typename... Operands>
void requireSPIRVExecutionMode(IRInst* parentInst, SpvWord entryPoint, SpvExecutionMode executionMode, const Operands& ...ops)
{
if (m_executionModes[entryPoint].add(executionMode))
{
emitOpExecutionMode(
getSection(SpvLogicalSectionID::ExecutionModes),
parentInst,
entryPoint,
executionMode,
ops...
);
Copy link
Collaborator

@jkwak-work jkwak-work May 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can replace this lines to

emitInst(getSection(SpvLogicalSectionID::ExecutionModes),
    parentInst,
    SpvOpExecutionMode,
    entryPoint,
    executionMode,
    ops...);

and remove the function emitOpExecutionMode().
Because the function is used only once here.
And the body is not doing anything much.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can remove it for now.
If someone else has a reason for the design to keep the wrapper, it should be trivial to add back anyways.

}
}

template<typename T1, typename T2, typename T3>
SpvInst* emitOpExecutionModeLocalSizeId(
IRInst* inst,
SpvWord entryPoint,
const T1& xSize,
const T2& ySize,
const T3& zSize
)
{
static_assert(isSingular<T1>);
static_assert(isSingular<T2>);
static_assert(isSingular<T3>);
requireSPIRVExecutionMode(inst, entryPoint, SpvExecutionModeLocalSizeId, xSize, ySize, zSize);
}

SPIRVEmitContext(IRModule* module, TargetProgram* program, DiagnosticSink* sink)
: SPIRVEmitSharedContext(module, program, sink)
, m_irModule(module)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@
// CHECK_GLSL_LINEAR_C: layout(derivative_group_linearNV)

//TEST:SIMPLE(filecheck=CHECK_SPV_QUAD_C): -allow-glsl -stage compute -entry computeMain -target spirv -DQUAD -DCOMPUTE
// CHECK_SPV_QUAD_C: DerivativeGroupQuadsNV
// CHECK_SPV_QUAD_C: "SPV_NV_compute_shader_derivatives"
// CHECK_SPV_QUAD_C-COUNT-1: DerivativeGroupQuadsNV
// CHECK_SPV_QUAD_C-COUNT-1: "SPV_NV_compute_shader_derivatives"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think -COUNT-1 will do what you want.
You will need to do,

// CHECK_SPV_QUAD_C: DerivativeGroupQuadsNV
// CHECK_SPV_QUAD_C-NOT: DerivativeGroupQuadsNV
// CHECK_SPV_QUAD_C: "SPV_NV_compute_shader_derivatives"
// CHECK_SPV_QUAD_C-NOT: "SPV_NV_compute_shader_derivatives"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I checked-- You are correct.
Thank-you for catching this.


//TEST:SIMPLE(filecheck=CHECK_SPV_LINEAR_C): -allow-glsl -stage compute -entry computeMain -target spirv -DLINEAR -DCOMPUTE
// CHECK_SPV_LINEAR_C: DerivativeGroupLinearNV
// CHECK_SPV_LINEAR_C: "SPV_NV_compute_shader_derivatives"
// CHECK_SPV_LINEAR_C-COUNT-1: DerivativeGroupLinearNV
// CHECK_SPV_LINEAR_C-COUNT-1: "SPV_NV_compute_shader_derivatives"

//TEST:SIMPLE(filecheck=CHECK_HLSL_C): -allow-glsl -stage compute -entry computeMain -target hlsl -DCOMPUTE
// CHECK_HLSL_C: computeMain(
Expand Down