Updated spirv-cross.

This commit is contained in:
Бранимир Караџић 2023-07-01 08:37:04 -07:00
parent 31f48d7de8
commit ec6fe83232
2 changed files with 86 additions and 39 deletions

View File

@ -6056,6 +6056,11 @@ void CompilerHLSL::emit_instruction(const Instruction &instruction)
break;
}
case OpAtomicFAddEXT:
case OpAtomicFMinEXT:
case OpAtomicFMaxEXT:
SPIRV_CROSS_THROW("Floating-point atomics are not supported in HLSL.");
case OpAtomicCompareExchange:
case OpAtomicExchange:
case OpAtomicISub:

View File

@ -9137,6 +9137,16 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
break;
}
// Legacy sub-group stuff ...
case OpSubgroupBallotKHR:
case OpSubgroupFirstInvocationKHR:
case OpSubgroupReadInvocationKHR:
case OpSubgroupAllKHR:
case OpSubgroupAnyKHR:
case OpSubgroupAllEqualKHR:
emit_subgroup_op(instruction);
break;
// SPV_INTEL_shader_integer_functions2
case OpUCountLeadingZerosINTEL:
MSL_UFOP(clz);
@ -15149,6 +15159,10 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
case OpGroupNonUniformBallotFindLSB:
case OpGroupNonUniformBallotFindMSB:
case OpGroupNonUniformBallotBitCount:
case OpSubgroupBallotKHR:
case OpSubgroupAllKHR:
case OpSubgroupAnyKHR:
case OpSubgroupAllEqualKHR:
if (!msl_options.supports_msl_version(2, 2))
SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up.");
break;
@ -15159,6 +15173,7 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
case OpGroupNonUniformShuffleDown:
case OpGroupNonUniformQuadSwap:
case OpGroupNonUniformQuadBroadcast:
case OpSubgroupReadInvocationKHR:
break;
}
}
@ -15174,14 +15189,31 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
case OpGroupNonUniformShuffleXor:
case OpGroupNonUniformShuffleUp:
case OpGroupNonUniformShuffleDown:
case OpSubgroupReadInvocationKHR:
break;
}
}
uint32_t result_type = ops[0];
uint32_t id = ops[1];
uint32_t op_idx = 0;
uint32_t result_type = ops[op_idx++];
uint32_t id = ops[op_idx++];
auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
Scope scope;
switch (op)
{
case OpSubgroupBallotKHR:
case OpSubgroupFirstInvocationKHR:
case OpSubgroupReadInvocationKHR:
case OpSubgroupAllKHR:
case OpSubgroupAnyKHR:
case OpSubgroupAllEqualKHR:
// These earlier instructions don't have the scope operand.
scope = ScopeSubgroup;
break;
default:
scope = static_cast<Scope>(evaluate_constant_u32(ops[op_idx++]));
break;
}
if (scope != ScopeSubgroup)
SPIRV_CROSS_THROW("Only subgroup scope is supported.");
@ -15195,47 +15227,50 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
break;
case OpGroupNonUniformBroadcast:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBroadcast");
case OpSubgroupReadInvocationKHR:
emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupBroadcast");
break;
case OpGroupNonUniformBroadcastFirst:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBroadcastFirst");
case OpSubgroupFirstInvocationKHR:
emit_unary_func_op(result_type, id, ops[op_idx], "spvSubgroupBroadcastFirst");
break;
case OpGroupNonUniformBallot:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
case OpSubgroupBallotKHR:
emit_unary_func_op(result_type, id, ops[op_idx], "spvSubgroupBallot");
break;
case OpGroupNonUniformInverseBallot:
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
break;
case OpGroupNonUniformBallotBitExtract:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupBallotBitExtract");
break;
case OpGroupNonUniformBallotFindLSB:
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
break;
case OpGroupNonUniformBallotFindMSB:
emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
break;
case OpGroupNonUniformBallotBitCount:
{
auto operation = static_cast<GroupOperation>(ops[3]);
auto operation = static_cast<GroupOperation>(ops[op_idx++]);
switch (operation)
{
case GroupOperationReduce:
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
break;
case GroupOperationInclusiveScan:
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_invocation_id_id,
"spvSubgroupBallotInclusiveBitCount");
break;
case GroupOperationExclusiveScan:
emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
emit_binary_func_op(result_type, id, ops[op_idx], builtin_subgroup_invocation_id_id,
"spvSubgroupBallotExclusiveBitCount");
break;
default:
@ -15245,57 +15280,60 @@ void CompilerMSL::emit_subgroup_op(const Instruction &i)
}
case OpGroupNonUniformShuffle:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffle");
emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupShuffle");
break;
case OpGroupNonUniformShuffleXor:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleXor");
emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupShuffleXor");
break;
case OpGroupNonUniformShuffleUp:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleUp");
emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupShuffleUp");
break;
case OpGroupNonUniformShuffleDown:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleDown");
emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvSubgroupShuffleDown");
break;
case OpGroupNonUniformAll:
case OpSubgroupAllKHR:
if (msl_options.use_quadgroup_operation())
emit_unary_func_op(result_type, id, ops[3], "quad_all");
emit_unary_func_op(result_type, id, ops[op_idx], "quad_all");
else
emit_unary_func_op(result_type, id, ops[3], "simd_all");
emit_unary_func_op(result_type, id, ops[op_idx], "simd_all");
break;
case OpGroupNonUniformAny:
case OpSubgroupAnyKHR:
if (msl_options.use_quadgroup_operation())
emit_unary_func_op(result_type, id, ops[3], "quad_any");
emit_unary_func_op(result_type, id, ops[op_idx], "quad_any");
else
emit_unary_func_op(result_type, id, ops[3], "simd_any");
emit_unary_func_op(result_type, id, ops[op_idx], "simd_any");
break;
case OpGroupNonUniformAllEqual:
emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
case OpSubgroupAllEqualKHR:
emit_unary_func_op(result_type, id, ops[op_idx], "spvSubgroupAllEqual");
break;
// clang-format off
#define MSL_GROUP_OP(op, msl_op) \
case OpGroupNonUniform##op: \
{ \
auto operation = static_cast<GroupOperation>(ops[3]); \
auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
if (operation == GroupOperationReduce) \
emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
emit_unary_func_op(result_type, id, ops[op_idx], "simd_" #msl_op); \
else if (operation == GroupOperationInclusiveScan) \
emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_inclusive_" #msl_op); \
else if (operation == GroupOperationExclusiveScan) \
emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
emit_unary_func_op(result_type, id, ops[op_idx], "simd_prefix_exclusive_" #msl_op); \
else if (operation == GroupOperationClusteredReduce) \
{ \
/* Only cluster sizes of 4 are supported. */ \
uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
if (cluster_size != 4) \
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
} \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
@ -15311,9 +15349,9 @@ case OpGroupNonUniform##op: \
#define MSL_GROUP_OP(op, msl_op) \
case OpGroupNonUniform##op: \
{ \
auto operation = static_cast<GroupOperation>(ops[3]); \
auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
if (operation == GroupOperationReduce) \
emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
emit_unary_func_op(result_type, id, ops[op_idx], "simd_" #msl_op); \
else if (operation == GroupOperationInclusiveScan) \
SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
else if (operation == GroupOperationExclusiveScan) \
@ -15321,10 +15359,10 @@ case OpGroupNonUniform##op: \
else if (operation == GroupOperationClusteredReduce) \
{ \
/* Only cluster sizes of 4 are supported. */ \
uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
if (cluster_size != 4) \
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
emit_unary_func_op(result_type, id, ops[op_idx], "quad_" #msl_op); \
} \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
@ -15334,9 +15372,9 @@ case OpGroupNonUniform##op: \
#define MSL_GROUP_OP_CAST(op, msl_op, type) \
case OpGroupNonUniform##op: \
{ \
auto operation = static_cast<GroupOperation>(ops[3]); \
auto operation = static_cast<GroupOperation>(ops[op_idx++]); \
if (operation == GroupOperationReduce) \
emit_unary_func_op_cast(result_type, id, ops[4], "simd_" #msl_op, type, type); \
emit_unary_func_op_cast(result_type, id, ops[op_idx], "simd_" #msl_op, type, type); \
else if (operation == GroupOperationInclusiveScan) \
SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
else if (operation == GroupOperationExclusiveScan) \
@ -15344,10 +15382,10 @@ case OpGroupNonUniform##op: \
else if (operation == GroupOperationClusteredReduce) \
{ \
/* Only cluster sizes of 4 are supported. */ \
uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
uint32_t cluster_size = evaluate_constant_u32(ops[op_idx + 1]); \
if (cluster_size != 4) \
SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
emit_unary_func_op_cast(result_type, id, ops[op_idx], "quad_" #msl_op, type, type); \
} \
else \
SPIRV_CROSS_THROW("Invalid group operation."); \
@ -15371,11 +15409,11 @@ case OpGroupNonUniform##op: \
#undef MSL_GROUP_OP_CAST
case OpGroupNonUniformQuadSwap:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadSwap");
emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvQuadSwap");
break;
case OpGroupNonUniformQuadBroadcast:
emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadBroadcast");
emit_binary_func_op(result_type, id, ops[op_idx], ops[op_idx + 1], "spvQuadBroadcast");
break;
default:
@ -16644,12 +16682,15 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
}
case OpGroupNonUniformBroadcast:
case OpSubgroupReadInvocationKHR:
return SPVFuncImplSubgroupBroadcast;
case OpGroupNonUniformBroadcastFirst:
case OpSubgroupFirstInvocationKHR:
return SPVFuncImplSubgroupBroadcastFirst;
case OpGroupNonUniformBallot:
case OpSubgroupBallotKHR:
return SPVFuncImplSubgroupBallot;
case OpGroupNonUniformInverseBallot:
@ -16666,6 +16707,7 @@ CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op o
return SPVFuncImplSubgroupBallotBitCount;
case OpGroupNonUniformAllEqual:
case OpSubgroupAllEqualKHR:
return SPVFuncImplSubgroupAllEqual;
case OpGroupNonUniformShuffle: