Updated spirv-cross.

This commit is contained in:
Бранимир Караџић 2024-08-30 20:21:18 -07:00
parent 7cda7c988f
commit ec4220ae44
9 changed files with 361 additions and 105 deletions

View File

@ -1850,6 +1850,11 @@ const SmallVector<SPIRBlock::Case> &Compiler::get_case_list(const SPIRBlock &blo
const auto &type = get<SPIRType>(constant->constant_type);
width = type.width;
}
else if (const auto *op = maybe_get<SPIRConstantOp>(block.condition))
{
const auto &type = get<SPIRType>(op->basetype);
width = type.width;
}
else if (const auto *var = maybe_get<SPIRVariable>(block.condition))
{
const auto &type = get<SPIRType>(var->basetype);

View File

@ -516,6 +516,10 @@ spvc_result spvc_compiler_options_set_uint(spvc_compiler_options options, spvc_c
case SPVC_COMPILER_OPTION_HLSL_FLATTEN_MATRIX_VERTEX_INPUT_SEMANTICS:
options->hlsl.flatten_matrix_vertex_input_semantics = value != 0;
break;
case SPVC_COMPILER_OPTION_HLSL_USE_ENTRY_POINT_NAME:
options->hlsl.use_entry_point_name = value != 0;
break;
#endif
#if SPIRV_CROSS_C_API_MSL
@ -1355,6 +1359,34 @@ spvc_result spvc_compiler_msl_add_resource_binding(spvc_compiler compiler,
#endif
}
spvc_result spvc_compiler_msl_add_resource_binding_2(spvc_compiler compiler,
const spvc_msl_resource_binding_2 *binding)
{
#if SPIRV_CROSS_C_API_MSL
if (compiler->backend != SPVC_BACKEND_MSL)
{
compiler->context->report_error("MSL function used on a non-MSL backend.");
return SPVC_ERROR_INVALID_ARGUMENT;
}
auto &msl = *static_cast<CompilerMSL *>(compiler->compiler.get());
MSLResourceBinding bind;
bind.binding = binding->binding;
bind.desc_set = binding->desc_set;
bind.stage = static_cast<spv::ExecutionModel>(binding->stage);
bind.msl_buffer = binding->msl_buffer;
bind.msl_texture = binding->msl_texture;
bind.msl_sampler = binding->msl_sampler;
bind.count = binding->count;
msl.add_msl_resource_binding(bind);
return SPVC_SUCCESS;
#else
(void)binding;
compiler->context->report_error("MSL function used on a non-MSL backend.");
return SPVC_ERROR_INVALID_ARGUMENT;
#endif
}
spvc_result spvc_compiler_msl_add_dynamic_buffer(spvc_compiler compiler, unsigned desc_set, unsigned binding, unsigned index)
{
#if SPIRV_CROSS_C_API_MSL
@ -2811,6 +2843,22 @@ void spvc_msl_resource_binding_init(spvc_msl_resource_binding *binding)
#endif
}
void spvc_msl_resource_binding_init_2(spvc_msl_resource_binding_2 *binding)
{
#if SPIRV_CROSS_C_API_MSL
MSLResourceBinding binding_default;
binding->desc_set = binding_default.desc_set;
binding->binding = binding_default.binding;
binding->msl_buffer = binding_default.msl_buffer;
binding->msl_texture = binding_default.msl_texture;
binding->msl_sampler = binding_default.msl_sampler;
binding->stage = static_cast<SpvExecutionModel>(binding_default.stage);
binding->count = 0;
#else
memset(binding, 0, sizeof(*binding));
#endif
}
void spvc_hlsl_resource_binding_init(spvc_hlsl_resource_binding *binding)
{
#if SPIRV_CROSS_C_API_HLSL

View File

@ -40,7 +40,7 @@ extern "C" {
/* Bumped if ABI or API breaks backwards compatibility. */
#define SPVC_C_API_VERSION_MAJOR 0
/* Bumped if APIs or enumerations are added in a backwards compatible way. */
#define SPVC_C_API_VERSION_MINOR 60
#define SPVC_C_API_VERSION_MINOR 62
/* Bumped if internal implementation details change. */
#define SPVC_C_API_VERSION_PATCH 0
@ -380,7 +380,8 @@ typedef struct spvc_msl_shader_interface_var_2
*/
SPVC_PUBLIC_API void spvc_msl_shader_interface_var_init_2(spvc_msl_shader_interface_var_2 *var);
/* Maps to C++ API. */
/* Maps to C++ API.
* Deprecated. Use spvc_msl_resource_binding_2. */
typedef struct spvc_msl_resource_binding
{
SpvExecutionModel stage;
@ -391,11 +392,24 @@ typedef struct spvc_msl_resource_binding
unsigned msl_sampler;
} spvc_msl_resource_binding;
typedef struct spvc_msl_resource_binding_2
{
SpvExecutionModel stage;
unsigned desc_set;
unsigned binding;
unsigned count;
unsigned msl_buffer;
unsigned msl_texture;
unsigned msl_sampler;
} spvc_msl_resource_binding_2;
/*
* Initializes the resource binding struct.
* The defaults are non-zero.
* Deprecated: Use spvc_msl_resource_binding_init_2.
*/
SPVC_PUBLIC_API void spvc_msl_resource_binding_init(spvc_msl_resource_binding *binding);
SPVC_PUBLIC_API void spvc_msl_resource_binding_init_2(spvc_msl_resource_binding_2 *binding);
#define SPVC_MSL_PUSH_CONSTANT_DESC_SET (~(0u))
#define SPVC_MSL_PUSH_CONSTANT_BINDING (0)
@ -730,6 +744,8 @@ typedef enum spvc_compiler_option
SPVC_COMPILER_OPTION_MSL_AGX_MANUAL_CUBE_GRAD_FIXUP = 88 | SPVC_COMPILER_OPTION_MSL_BIT,
SPVC_COMPILER_OPTION_MSL_FORCE_FRAGMENT_WITH_SIDE_EFFECTS_EXECUTION = 89 | SPVC_COMPILER_OPTION_MSL_BIT,
SPVC_COMPILER_OPTION_HLSL_USE_ENTRY_POINT_NAME = 90 | SPVC_COMPILER_OPTION_HLSL_BIT,
SPVC_COMPILER_OPTION_INT_MAX = 0x7fffffff
} spvc_compiler_option;
@ -836,8 +852,11 @@ SPVC_PUBLIC_API spvc_bool spvc_compiler_msl_needs_patch_output_buffer(spvc_compi
SPVC_PUBLIC_API spvc_bool spvc_compiler_msl_needs_input_threadgroup_mem(spvc_compiler compiler);
SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_vertex_attribute(spvc_compiler compiler,
const spvc_msl_vertex_attribute *attrs);
/* Deprecated; use spvc_compiler_msl_add_resource_binding_2(). */
SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_resource_binding(spvc_compiler compiler,
const spvc_msl_resource_binding *binding);
SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_resource_binding_2(spvc_compiler compiler,
const spvc_msl_resource_binding_2 *binding);
/* Deprecated; use spvc_compiler_msl_add_shader_input_2(). */
SPVC_PUBLIC_API spvc_result spvc_compiler_msl_add_shader_input(spvc_compiler compiler,
const spvc_msl_shader_interface_var *input);

View File

@ -783,6 +783,8 @@ uint32_t ParsedIR::get_member_decoration(TypeID id, uint32_t index, Decoration d
return dec.stream;
case DecorationSpecId:
return dec.spec_id;
case DecorationMatrixStride:
return dec.matrix_stride;
case DecorationIndex:
return dec.index;
default:

View File

@ -5213,7 +5213,8 @@ string CompilerGLSL::to_enclosed_unpacked_expression(uint32_t id, bool register_
string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && should_dereference(id))
if (is_pointer(type) && should_dereference(id))
return dereference_expression(type, to_enclosed_expression(id, register_expression_read));
else
return to_expression(id, register_expression_read);
@ -5222,7 +5223,7 @@ string CompilerGLSL::to_dereferenced_expression(uint32_t id, bool register_expre
string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
return address_of_expression(to_enclosed_expression(id, register_expression_read));
else
return to_unpacked_expression(id, register_expression_read);
@ -5231,7 +5232,7 @@ string CompilerGLSL::to_pointer_expression(uint32_t id, bool register_expression
string CompilerGLSL::to_enclosed_pointer_expression(uint32_t id, bool register_expression_read)
{
auto &type = expression_type(id);
if (type.pointer && expression_is_lvalue(id) && !should_dereference(id))
if (is_pointer(type) && expression_is_lvalue(id) && !should_dereference(id))
return address_of_expression(to_enclosed_expression(id, register_expression_read));
else
return to_enclosed_unpacked_expression(id, register_expression_read);
@ -10286,7 +10287,40 @@ string CompilerGLSL::access_chain_internal(uint32_t base, const uint32_t *indice
}
else
{
append_index(index, is_literal, true);
if (flags & ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT)
{
SPIRType tmp_type(OpTypeInt);
tmp_type.basetype = SPIRType::UInt64;
tmp_type.width = 64;
tmp_type.vecsize = 1;
tmp_type.columns = 1;
TypeID ptr_type_id = expression_type_id(base);
const SPIRType &ptr_type = get<SPIRType>(ptr_type_id);
const SPIRType &pointee_type = get_pointee_type(ptr_type);
// This only runs in native pointer backends.
// Can replace reinterpret_cast with a backend string if ever needed.
// We expect this to count as a de-reference.
// This leaks some MSL details, but feels slightly overkill to
// add yet another virtual interface just for this.
auto intptr_expr = join("reinterpret_cast<", type_to_glsl(tmp_type), ">(", expr, ")");
intptr_expr += join(" + ", to_enclosed_unpacked_expression(index), " * ",
get_decoration(ptr_type_id, DecorationArrayStride));
if (flags & ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT)
{
is_packed = true;
expr = join("*reinterpret_cast<device packed_", type_to_glsl(pointee_type),
" *>(", intptr_expr, ")");
}
else
{
expr = join("*reinterpret_cast<", type_to_glsl(ptr_type), ">(", intptr_expr, ")");
}
}
else
append_index(index, is_literal, true);
}
if (type->basetype == SPIRType::ControlPointArray)
@ -10706,6 +10740,11 @@ string CompilerGLSL::to_flattened_struct_member(const string &basename, const SP
return ret;
}
uint32_t CompilerGLSL::get_physical_type_stride(const SPIRType &) const
{
SPIRV_CROSS_THROW("Invalid to call get_physical_type_stride on a backend without native pointer support.");
}
string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32_t count, const SPIRType &target_type,
AccessChainMeta *meta, bool ptr_chain)
{
@ -10755,7 +10794,27 @@ string CompilerGLSL::access_chain(uint32_t base, const uint32_t *indices, uint32
{
AccessChainFlags flags = ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT;
if (ptr_chain)
{
flags |= ACCESS_CHAIN_PTR_CHAIN_BIT;
// PtrAccessChain could get complicated.
TypeID type_id = expression_type_id(base);
if (backend.native_pointers && has_decoration(type_id, DecorationArrayStride))
{
// If there is a mismatch we have to go via 64-bit pointer arithmetic :'(
// Using packed hacks only gets us so far, and is not designed to deal with pointer to
// random values. It works for structs though.
auto &pointee_type = get_pointee_type(get<SPIRType>(type_id));
uint32_t physical_stride = get_physical_type_stride(pointee_type);
uint32_t requested_stride = get_decoration(type_id, DecorationArrayStride);
if (physical_stride != requested_stride)
{
flags |= ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT;
if (is_vector(pointee_type))
flags |= ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT;
}
}
}
return access_chain_internal(base, indices, count, flags, meta);
}
}

View File

@ -66,7 +66,9 @@ enum AccessChainFlagBits
ACCESS_CHAIN_SKIP_REGISTER_EXPRESSION_READ_BIT = 1 << 3,
ACCESS_CHAIN_LITERAL_MSB_FORCE_ID = 1 << 4,
ACCESS_CHAIN_FLATTEN_ALL_MEMBERS_BIT = 1 << 5,
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6
ACCESS_CHAIN_FORCE_COMPOSITE_BIT = 1 << 6,
ACCESS_CHAIN_PTR_CHAIN_POINTER_ARITH_BIT = 1 << 7,
ACCESS_CHAIN_PTR_CHAIN_CAST_TO_SCALAR_BIT = 1 << 8
};
typedef uint32_t AccessChainFlags;
@ -753,6 +755,10 @@ protected:
std::string access_chain_internal(uint32_t base, const uint32_t *indices, uint32_t count, AccessChainFlags flags,
AccessChainMeta *meta);
// Only meaningful on backends with physical pointer support ala MSL.
// Relevant for PtrAccessChain / BDA.
virtual uint32_t get_physical_type_stride(const SPIRType &type) const;
spv::StorageClass get_expression_effective_storage_class(uint32_t ptr);
virtual bool access_chain_needs_stage_io_builtin_translation(uint32_t base);

View File

@ -849,11 +849,25 @@ void CompilerHLSL::emit_builtin_inputs_in_struct()
case BuiltInSubgroupLeMask:
case BuiltInSubgroupGtMask:
case BuiltInSubgroupGeMask:
case BuiltInBaseVertex:
case BuiltInBaseInstance:
// Handled specially.
break;
case BuiltInBaseVertex:
if (hlsl_options.shader_model >= 68)
{
type = "uint";
semantic = "SV_StartVertexLocation";
}
break;
case BuiltInBaseInstance:
if (hlsl_options.shader_model >= 68)
{
type = "uint";
semantic = "SV_StartInstanceLocation";
}
break;
case BuiltInHelperInvocation:
if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
SPIRV_CROSS_THROW("Helper Invocation input is only supported in PS 5.0 or higher.");
@ -1231,7 +1245,7 @@ void CompilerHLSL::emit_builtin_variables()
case BuiltInVertexIndex:
case BuiltInInstanceIndex:
type = "int";
if (hlsl_options.support_nonzero_base_vertex_base_instance)
if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68)
base_vertex_info.used = true;
break;
@ -1353,7 +1367,7 @@ void CompilerHLSL::emit_builtin_variables()
}
});
if (base_vertex_info.used)
if (base_vertex_info.used && hlsl_options.shader_model < 68)
{
string binding_info;
if (base_vertex_info.explicit_binding)
@ -3136,23 +3150,39 @@ void CompilerHLSL::emit_hlsl_entry_point()
case BuiltInVertexIndex:
case BuiltInInstanceIndex:
// D3D semantics are uint, but shader wants int.
if (hlsl_options.support_nonzero_base_vertex_base_instance)
if (hlsl_options.support_nonzero_base_vertex_base_instance || hlsl_options.shader_model >= 68)
{
if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
if (hlsl_options.shader_model >= 68)
{
if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
statement(builtin, " = int(stage_input.", builtin, " + stage_input.gl_BaseInstanceARB);");
else
statement(builtin, " = int(stage_input.", builtin, " + stage_input.gl_BaseVertexARB);");
}
else
statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
{
if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
else
statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
}
}
else
statement(builtin, " = int(stage_input.", builtin, ");");
break;
case BuiltInBaseVertex:
statement(builtin, " = SPIRV_Cross_BaseVertex;");
if (hlsl_options.shader_model >= 68)
statement(builtin, " = stage_input.gl_BaseVertexARB;");
else
statement(builtin, " = SPIRV_Cross_BaseVertex;");
break;
case BuiltInBaseInstance:
statement(builtin, " = SPIRV_Cross_BaseInstance;");
if (hlsl_options.shader_model >= 68)
statement(builtin, " = stage_input.gl_BaseInstanceARB;");
else
statement(builtin, " = SPIRV_Cross_BaseInstance;");
break;
case BuiltInInstanceId:
@ -6714,6 +6744,15 @@ string CompilerHLSL::compile()
if (need_subpass_input)
active_input_builtins.set(BuiltInFragCoord);
// Need to offset by BaseVertex/BaseInstance in SM 6.8+.
if (hlsl_options.shader_model >= 68)
{
if (active_input_builtins.get(BuiltInVertexIndex))
active_input_builtins.set(BuiltInBaseVertex);
if (active_input_builtins.get(BuiltInInstanceIndex))
active_input_builtins.set(BuiltInBaseInstance);
}
uint32_t pass_count = 0;
do
{

View File

@ -1361,14 +1361,14 @@ void CompilerMSL::emit_entry_point_declarations()
if (is_array(type))
{
if (!type.array[type.array.size() - 1])
SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
is_using_builtin_array = true;
statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, true), name,
type_to_array_glsl(type, var_id), " =");
uint32_t array_size = to_array_size_literal(type);
uint32_t array_size = get_resource_array_size(type, var_id);
if (array_size == 0)
SPIRV_CROSS_THROW("Size of runtime array with dynamic offset could not be determined from resource bindings.");
begin_scope();
for (uint32_t i = 0; i < array_size; i++)
@ -1576,8 +1576,7 @@ string CompilerMSL::compile()
preprocess_op_codes();
build_implicit_builtins();
if (needs_manual_helper_invocation_updates() &&
(active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
{
string builtin_helper_invocation = builtin_to_glsl(BuiltInHelperInvocation, StorageClassInput);
string discard_expr = join(builtin_helper_invocation, " = true, discard_fragment()");
@ -1721,7 +1720,7 @@ void CompilerMSL::preprocess_op_codes()
(is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
(need_subpass_input_ms && !msl_options.use_framebuffer_fetch_subpasses))))
needs_sample_id = true;
if (preproc.needs_helper_invocation)
if (preproc.needs_helper_invocation || active_input_builtins.get(BuiltInHelperInvocation))
needs_helper_invocation = true;
// OpKill is removed by the parser, so we need to identify those by inspecting
@ -2058,8 +2057,7 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
}
case OpDemoteToHelperInvocation:
if (needs_manual_helper_invocation_updates() &&
(active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
if (needs_manual_helper_invocation_updates() && needs_helper_invocation)
added_arg_ids.insert(builtin_helper_invocation_id);
break;
@ -2112,7 +2110,7 @@ void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::
}
if (needs_manual_helper_invocation_updates() && b.terminator == SPIRBlock::Kill &&
(active_input_builtins.get(BuiltInHelperInvocation) || needs_helper_invocation))
needs_helper_invocation)
added_arg_ids.insert(builtin_helper_invocation_id);
// TODO: Add all other operations which can affect memory.
@ -4803,7 +4801,7 @@ bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32
return false;
}
if (!mbr_type.array.empty())
if (is_array(mbr_type))
{
// If we have an array type, array stride must match exactly with SPIR-V.
@ -5615,6 +5613,10 @@ void CompilerMSL::emit_custom_templates()
// otherwise they will cause problems when linked together in a single Metallib.
void CompilerMSL::emit_custom_functions()
{
// Use when outputting overloaded functions to cover different address spaces.
static const char *texture_addr_spaces[] = { "device", "constant", "thread" };
static uint32_t texture_addr_space_count = sizeof(texture_addr_spaces) / sizeof(char*);
if (spv_function_implementations.count(SPVFuncImplArrayCopyMultidim))
spv_function_implementations.insert(SPVFuncImplArrayCopy);
@ -6264,54 +6266,62 @@ void CompilerMSL::emit_custom_functions()
break;
case SPVFuncImplGatherConstOffsets:
statement("// Wrapper function that processes a texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherConstOffsets(const thread Tex<T>& t, sampler s, "
"Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement("switch (c)");
begin_scope();
// Work around texture::gather() requiring its component parameter to be a constant expression
statement("case component::x:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
statement(" break;");
statement("case component::y:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
statement(" break;");
statement("case component::z:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
statement(" break;");
statement("case component::w:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
statement(" break;");
end_scope();
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
// Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
for (uint32_t i = 0; i < texture_addr_space_count; i++)
{
statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherConstOffsets(const ", texture_addr_spaces[i], " Tex<T>& t, sampler s, "
"Toff coffsets, component c, Tp... params) METAL_CONST_ARG(c)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement("switch (c)");
begin_scope();
// Work around texture::gather() requiring its component parameter to be a constant expression
statement("case component::x:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::x);");
statement(" break;");
statement("case component::y:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::y);");
statement(" break;");
statement("case component::z:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::z);");
statement(" break;");
statement("case component::w:");
statement(" rslts[i] = t.gather(s, spvForward<Tp>(params)..., coffsets[i], component::w);");
statement(" break;");
end_scope();
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
}
break;
case SPVFuncImplGatherCompareConstOffsets:
statement("// Wrapper function that processes a texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherCompareConstOffsets(const thread Tex<T>& t, sampler s, "
"Toff coffsets, Tp... params)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement(" rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
// Because we are passing a texture reference, we have to output an overloaded version of this function for each address space.
for (uint32_t i = 0; i < texture_addr_space_count; i++)
{
statement("// Wrapper function that processes a ", texture_addr_spaces[i], " texture gather with a constant offset array.");
statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
"typename Toff, typename... Tp>");
statement("inline vec<T, 4> spvGatherCompareConstOffsets(const ", texture_addr_spaces[i], " Tex<T>& t, sampler s, "
"Toff coffsets, Tp... params)");
begin_scope();
statement("vec<T, 4> rslts[4];");
statement("for (uint i = 0; i < 4; i++)");
begin_scope();
statement(" rslts[i] = t.gather_compare(s, spvForward<Tp>(params)..., coffsets[i]);");
end_scope();
// Pull all values from the i0j0 component of each gather footprint
statement("return vec<T, 4>(rslts[0].w, rslts[1].w, rslts[2].w, rslts[3].w);");
end_scope();
statement("");
}
break;
case SPVFuncImplSubgroupBroadcast:
@ -9246,18 +9256,40 @@ void CompilerMSL::emit_instruction(const Instruction &instruction)
uint32_t coord_id = ops[3];
emit_uninitialized_temporary_expression(result_type, id);
std::string coord_expr = to_expression(coord_id);
auto sampler_expr = to_sampler_expression(image_id);
auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
const SPIRType &image_type = expression_type(image_id);
const SPIRType &coord_type = expression_type(coord_id);
switch (image_type.image.dim)
{
case Dim1D:
if (!msl_options.texture_1D_as_2D)
SPIRV_CROSS_THROW("ImageQueryLod is not supported on 1D textures.");
[[fallthrough]];
case Dim2D:
if (coord_type.vecsize > 2)
coord_expr = enclose_expression(coord_expr) + ".xy";
break;
case DimCube:
case Dim3D:
if (coord_type.vecsize > 3)
coord_expr = enclose_expression(coord_expr) + ".xyz";
break;
default:
SPIRV_CROSS_THROW("Bad image type given to OpImageQueryLod");
}
// TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
// the reported LOD based on the sampler. NEAREST miplevel should
// round the LOD, but LINEAR miplevel should not round.
// Let's hope this does not become an issue ...
statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
to_expression(coord_id), ");");
coord_expr, ");");
statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
to_expression(coord_id), ");");
coord_expr, ");");
register_control_dependent_expression(id);
break;
}
@ -12167,21 +12199,26 @@ string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_
string CompilerMSL::to_sampler_expression(uint32_t id)
{
auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
auto expr = to_expression(combined ? combined->image : VariableID(id));
auto index = expr.find_first_of('[');
if (combined && combined->sampler)
return to_expression(combined->sampler);
uint32_t samp_id = 0;
if (combined)
samp_id = combined->sampler;
uint32_t expr_id = combined ? uint32_t(combined->image) : id;
if (index == string::npos)
return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
else
// Constexpr samplers are declared as local variables,
// so exclude any qualifier names on the image expression.
if (auto *var = maybe_get_backing_variable(expr_id))
{
auto image_expr = expr.substr(0, index);
auto array_expr = expr.substr(index);
return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
uint32_t img_id = var->basevariable ? var->basevariable : VariableID(var->self);
if (find_constexpr_sampler(img_id))
return Compiler::to_name(img_id) + sampler_name_suffix;
}
auto img_expr = to_expression(expr_id);
auto index = img_expr.find_first_of('[');
if (index == string::npos)
return img_expr + sampler_name_suffix;
else
return img_expr.substr(0, index) + sampler_name_suffix + img_expr.substr(index);
}
string CompilerMSL::to_swizzle_expression(uint32_t id)
@ -13176,7 +13213,10 @@ string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bo
addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
}
return join(decoration_flags_signal_volatile(flags) ? "volatile " : "", addr_space);
if (decoration_flags_signal_volatile(flags) && 0 != strcmp(addr_space, "thread"))
return join("volatile ", addr_space);
else
return addr_space;
}
const char *CompilerMSL::to_restrict(uint32_t id, bool space)
@ -13602,7 +13642,13 @@ string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
claimed_bindings.set(buffer_binding);
ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);
ep_args += get_argument_address_space(var) + " ";
if (recursive_inputs.count(type.self))
ep_args += string("void* ") + to_restrict(id, true) + to_name(id) + "_vp";
else
ep_args += type_to_glsl(type) + "& " + to_restrict(id, true) + to_name(id);
ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
@ -14040,7 +14086,7 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
".spvBufferSizeConstants", "[",
convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
convert_to_string(get_metal_resource_index(var, SPIRType::UInt)), "];");
}
else
{
@ -14053,7 +14099,8 @@ void CompilerMSL::fix_up_shader_inputs_outputs()
}
}
if (msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
if (!msl_options.argument_buffers &&
msl_options.replace_recursive_inputs && type_contains_recursion(type) &&
(var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer))
{
@ -17026,13 +17073,21 @@ uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type,
return msl_size;
}
uint32_t CompilerMSL::get_physical_type_stride(const SPIRType &type) const
{
// This should only be relevant for plain types such as scalars and vectors?
// If we're pointing to a struct, it will recursively pick up packed/row-major state.
return get_declared_type_size_msl(type, false, false);
}
// Returns the byte size of a struct member.
uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
{
// Pointers take 8 bytes each
// Match both pointer and array-of-pointer here.
if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
{
uint32_t type_size = 8 * (type.vecsize == 3 ? 4 : type.vecsize);
uint32_t type_size = 8;
// Work our way through potentially layered arrays,
// stopping when we hit a pointer that is not also an array.
@ -17107,9 +17162,10 @@ uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t
// Returns the byte alignment of a type.
uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
{
// Pointers aligns on multiples of 8 bytes
// Pointers align on multiples of 8 bytes.
// Deliberately ignore array-ness here. It's not relevant for alignment.
if (type.pointer && type.storage == StorageClassPhysicalStorageBuffer)
return 8 * (type.vecsize == 3 ? 4 : type.vecsize);
return 8;
switch (type.basetype)
{
@ -18134,6 +18190,13 @@ void CompilerMSL::emit_argument_buffer_aliased_descriptor(const SPIRVariable &al
}
else
{
// This alias may have already been used to emit an entry point declaration. If there is a mismatch, we need a recompile.
// Moving this code to be run earlier will also conflict,
// because we need the qualified alias for the base resource,
// so forcing recompile until things sync up is the least invasive method for now.
if (ir.meta[aliased_var.self].decoration.qualified_alias != name)
force_recompile();
// This will get wrapped in a separate temporary when a spvDescriptorArray wrapper is emitted.
set_qualified_name(aliased_var.self, name);
}
@ -18158,6 +18221,7 @@ void CompilerMSL::analyze_argument_buffers()
string name;
SPIRType::BaseType basetype;
uint32_t index;
uint32_t plane_count;
uint32_t plane;
uint32_t overlapping_var_id;
};
@ -18208,14 +18272,14 @@ void CompilerMSL::analyze_argument_buffers()
{
uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
resources_in_set[desc_set].push_back(
{ &var, to_name(var_id), SPIRType::Image, image_resource_index, i, 0 });
{ &var, to_name(var_id), SPIRType::Image, image_resource_index, plane_count, i, 0 });
}
if (type.image.dim != DimBuffer && !constexpr_sampler)
{
uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
resources_in_set[desc_set].push_back(
{ &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0, 0 });
{ &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 1, 0, 0 });
}
}
else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
@ -18231,14 +18295,14 @@ void CompilerMSL::analyze_argument_buffers()
uint32_t resource_index = get_metal_resource_index(var, type.basetype);
resources_in_set[desc_set].push_back(
{ &var, to_name(var_id), type.basetype, resource_index, 0, 0 });
{ &var, to_name(var_id), type.basetype, resource_index, 1, 0, 0 });
// Emulate texture2D atomic operations
if (atomic_image_vars_emulated.count(var.self))
{
uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
resources_in_set[desc_set].push_back(
{ &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 0, 0 });
{ &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 1, 0, 0 });
}
}
@ -18286,7 +18350,7 @@ void CompilerMSL::analyze_argument_buffers()
set_decoration(var_id, DecorationDescriptorSet, desc_set);
set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
resources_in_set[desc_set].push_back(
{ &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0, 0 });
{ &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 1, 0, 0 });
}
if (set_needs_buffer_sizes[desc_set])
@ -18297,7 +18361,7 @@ void CompilerMSL::analyze_argument_buffers()
set_decoration(var_id, DecorationDescriptorSet, desc_set);
set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
resources_in_set[desc_set].push_back(
{ &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0, 0 });
{ &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 1, 0, 0 });
}
}
}
@ -18309,7 +18373,7 @@ void CompilerMSL::analyze_argument_buffers()
uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
add_resource_name(var_id);
resources_in_set[desc_set].push_back(
{ &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0, 0 });
{ &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 1, 0, 0 });
}
for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
@ -18340,7 +18404,8 @@ void CompilerMSL::analyze_argument_buffers()
else
buffer_type.storage = StorageClassUniform;
set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
auto buffer_type_name = join("spvDescriptorSetBuffer", desc_set);
set_name(type_id, buffer_type_name);
auto &ptr_type = set<SPIRType>(ptr_type_id, OpTypePointer);
ptr_type = buffer_type;
@ -18350,8 +18415,9 @@ void CompilerMSL::analyze_argument_buffers()
ptr_type.parent_type = type_id;
uint32_t buffer_variable_id = next_id;
set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
auto &buffer_var = set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
auto buffer_name = join("spvDescriptorSet", desc_set);
set_name(buffer_variable_id, buffer_name);
// Ids must be emitted in ID order.
stable_sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
@ -18386,7 +18452,7 @@ void CompilerMSL::analyze_argument_buffers()
// If needed, synthesize and add padding members.
// member_index and next_arg_buff_index are incremented when padding members are added.
if (msl_options.pad_argument_buffer_resources && resource.overlapping_var_id == 0)
if (msl_options.pad_argument_buffer_resources && resource.plane == 0 && resource.overlapping_var_id == 0)
{
auto rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index);
while (resource.index > next_arg_buff_index)
@ -18432,7 +18498,7 @@ void CompilerMSL::analyze_argument_buffers()
// Adjust the number of slots consumed by current member itself.
// Use the count value from the app, instead of the shader, in case the
// shader is only accessing part, or even one element, of the array.
next_arg_buff_index += rez_bind.count;
next_arg_buff_index += resource.plane_count * rez_bind.count;
}
string mbr_name = ensure_valid_name(resource.name, "m");
@ -18559,6 +18625,16 @@ void CompilerMSL::analyze_argument_buffers()
set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationOverlappingBinding);
member_index++;
}
if (msl_options.replace_recursive_inputs && type_contains_recursion(buffer_type))
{
recursive_inputs.insert(type_id);
auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
auto addr_space = get_argument_address_space(buffer_var);
entry_func.fixup_hooks_in.push_back([this, addr_space, buffer_name, buffer_type_name]() {
statement(addr_space, " auto& ", buffer_name, " = *(", addr_space, " ", buffer_type_name, "*)", buffer_name, "_vp;");
});
}
}
}

View File

@ -1028,6 +1028,8 @@ protected:
uint32_t get_physical_tess_level_array_size(spv::BuiltIn builtin) const;
uint32_t get_physical_type_stride(const SPIRType &type) const override;
// MSL packing rules. These compute the effective packing rules as observed by the MSL compiler in the MSL output.
// These values can change depending on various extended decorations which control packing rules.
// We need to make these rules match up with SPIR-V declared rules.