371 lines
16 KiB
C++
371 lines
16 KiB
C++
// Copyright (c) 2019 Google LLC
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "source/fuzz/force_render_red.h"
|
|
|
|
#include "source/fuzz/fact_manager.h"
|
|
#include "source/fuzz/instruction_descriptor.h"
|
|
#include "source/fuzz/protobufs/spirvfuzz_protobufs.h"
|
|
#include "source/fuzz/transformation_replace_constant_with_uniform.h"
|
|
#include "source/fuzz/uniform_buffer_element_descriptor.h"
|
|
#include "source/opt/build_module.h"
|
|
#include "source/opt/ir_context.h"
|
|
#include "source/opt/types.h"
|
|
#include "source/util/make_unique.h"
|
|
#include "tools/util/cli_consumer.h"
|
|
|
|
#include <algorithm>
|
|
#include <utility>
|
|
|
|
namespace spvtools {
|
|
namespace fuzz {
|
|
|
|
namespace {
|
|
|
|
// Helper method to find the fragment shader entry point, complaining if there
|
|
// is no shader or if there is no fragment entry point.
|
|
opt::Function* FindFragmentShaderEntryPoint(opt::IRContext* ir_context,
|
|
MessageConsumer message_consumer) {
|
|
// Check that this is a fragment shader
|
|
bool found_capability_shader = false;
|
|
for (auto& capability : ir_context->capabilities()) {
|
|
assert(capability.opcode() == SpvOpCapability);
|
|
if (capability.GetSingleWordInOperand(0) == SpvCapabilityShader) {
|
|
found_capability_shader = true;
|
|
break;
|
|
}
|
|
}
|
|
if (!found_capability_shader) {
|
|
message_consumer(
|
|
SPV_MSG_ERROR, nullptr, {},
|
|
"Forcing of red rendering requires the Shader capability.");
|
|
return nullptr;
|
|
}
|
|
|
|
opt::Instruction* fragment_entry_point = nullptr;
|
|
for (auto& entry_point : ir_context->module()->entry_points()) {
|
|
if (entry_point.GetSingleWordInOperand(0) == SpvExecutionModelFragment) {
|
|
fragment_entry_point = &entry_point;
|
|
break;
|
|
}
|
|
}
|
|
if (fragment_entry_point == nullptr) {
|
|
message_consumer(SPV_MSG_ERROR, nullptr, {},
|
|
"Forcing of red rendering requires an entry point with "
|
|
"the Fragment execution model.");
|
|
return nullptr;
|
|
}
|
|
|
|
for (auto& function : *ir_context->module()) {
|
|
if (function.result_id() ==
|
|
fragment_entry_point->GetSingleWordInOperand(1)) {
|
|
return &function;
|
|
}
|
|
}
|
|
assert(
|
|
false &&
|
|
"A valid module must have a function associate with each entry point.");
|
|
return nullptr;
|
|
}
|
|
|
|
// Helper method to check that there is a single vec4 output variable and get a
|
|
// pointer to it.
|
|
opt::Instruction* FindVec4OutputVariable(opt::IRContext* ir_context,
|
|
MessageConsumer message_consumer) {
|
|
opt::Instruction* output_variable = nullptr;
|
|
for (auto& inst : ir_context->types_values()) {
|
|
if (inst.opcode() == SpvOpVariable &&
|
|
inst.GetSingleWordInOperand(0) == SpvStorageClassOutput) {
|
|
if (output_variable != nullptr) {
|
|
message_consumer(SPV_MSG_ERROR, nullptr, {},
|
|
"Only one output variable can be handled at present; "
|
|
"found multiple.");
|
|
return nullptr;
|
|
}
|
|
output_variable = &inst;
|
|
// Do not break, as we want to check for multiple output variables.
|
|
}
|
|
}
|
|
if (output_variable == nullptr) {
|
|
message_consumer(SPV_MSG_ERROR, nullptr, {},
|
|
"No output variable to which to write red was found.");
|
|
return nullptr;
|
|
}
|
|
|
|
auto output_variable_base_type = ir_context->get_type_mgr()
|
|
->GetType(output_variable->type_id())
|
|
->AsPointer()
|
|
->pointee_type()
|
|
->AsVector();
|
|
if (!output_variable_base_type ||
|
|
output_variable_base_type->element_count() != 4 ||
|
|
!output_variable_base_type->element_type()->AsFloat()) {
|
|
message_consumer(SPV_MSG_ERROR, nullptr, {},
|
|
"The output variable must have type vec4.");
|
|
return nullptr;
|
|
}
|
|
|
|
return output_variable;
|
|
}
|
|
|
|
// Helper to get the ids of float constants 0.0 and 1.0, creating them if
|
|
// necessary.
|
|
std::pair<uint32_t, uint32_t> FindOrCreateFloatZeroAndOne(
|
|
opt::IRContext* ir_context, opt::analysis::Float* float_type) {
|
|
float one = 1.0;
|
|
uint32_t one_as_uint;
|
|
memcpy(&one_as_uint, &one, sizeof(float));
|
|
std::vector<uint32_t> zero_bytes = {0};
|
|
std::vector<uint32_t> one_bytes = {one_as_uint};
|
|
auto constant_zero = ir_context->get_constant_mgr()->RegisterConstant(
|
|
MakeUnique<opt::analysis::FloatConstant>(float_type, zero_bytes));
|
|
auto constant_one = ir_context->get_constant_mgr()->RegisterConstant(
|
|
MakeUnique<opt::analysis::FloatConstant>(float_type, one_bytes));
|
|
auto constant_zero_id = ir_context->get_constant_mgr()
|
|
->GetDefiningInstruction(constant_zero)
|
|
->result_id();
|
|
auto constant_one_id = ir_context->get_constant_mgr()
|
|
->GetDefiningInstruction(constant_one)
|
|
->result_id();
|
|
return std::pair<uint32_t, uint32_t>(constant_zero_id, constant_one_id);
|
|
}
|
|
|
|
std::unique_ptr<TransformationReplaceConstantWithUniform>
|
|
MakeConstantUniformReplacement(opt::IRContext* ir_context,
|
|
const FactManager& fact_manager,
|
|
uint32_t constant_id,
|
|
uint32_t greater_than_instruction,
|
|
uint32_t in_operand_index) {
|
|
return MakeUnique<TransformationReplaceConstantWithUniform>(
|
|
MakeIdUseDescriptor(constant_id,
|
|
MakeInstructionDescriptor(greater_than_instruction,
|
|
SpvOpFOrdGreaterThan, 0),
|
|
in_operand_index),
|
|
fact_manager.GetUniformDescriptorsForConstant(ir_context, constant_id)[0],
|
|
ir_context->TakeNextId(), ir_context->TakeNextId());
|
|
}
|
|
|
|
} // namespace
|
|
|
|
bool ForceRenderRed(
|
|
const spv_target_env& target_env, const std::vector<uint32_t>& binary_in,
|
|
const spvtools::fuzz::protobufs::FactSequence& initial_facts,
|
|
std::vector<uint32_t>* binary_out) {
|
|
auto message_consumer = spvtools::utils::CLIMessageConsumer;
|
|
spvtools::SpirvTools tools(target_env);
|
|
if (!tools.IsValid()) {
|
|
message_consumer(SPV_MSG_ERROR, nullptr, {},
|
|
"Failed to create SPIRV-Tools interface; stopping.");
|
|
return false;
|
|
}
|
|
|
|
// Initial binary should be valid.
|
|
if (!tools.Validate(&binary_in[0], binary_in.size())) {
|
|
message_consumer(SPV_MSG_ERROR, nullptr, {},
|
|
"Initial binary is invalid; stopping.");
|
|
return false;
|
|
}
|
|
|
|
// Build the module from the input binary.
|
|
std::unique_ptr<opt::IRContext> ir_context = BuildModule(
|
|
target_env, message_consumer, binary_in.data(), binary_in.size());
|
|
assert(ir_context);
|
|
|
|
// Set up a fact manager with any given initial facts.
|
|
FactManager fact_manager;
|
|
for (auto& fact : initial_facts.fact()) {
|
|
fact_manager.AddFact(fact, ir_context.get());
|
|
}
|
|
|
|
auto entry_point_function =
|
|
FindFragmentShaderEntryPoint(ir_context.get(), message_consumer);
|
|
auto output_variable =
|
|
FindVec4OutputVariable(ir_context.get(), message_consumer);
|
|
if (entry_point_function == nullptr || output_variable == nullptr) {
|
|
return false;
|
|
}
|
|
|
|
opt::analysis::Float temp_float_type(32);
|
|
opt::analysis::Float* float_type = ir_context->get_type_mgr()
|
|
->GetRegisteredType(&temp_float_type)
|
|
->AsFloat();
|
|
std::pair<uint32_t, uint32_t> zero_one_float_ids =
|
|
FindOrCreateFloatZeroAndOne(ir_context.get(), float_type);
|
|
|
|
// Make the new exit block
|
|
auto new_exit_block_id = ir_context->TakeNextId();
|
|
{
|
|
auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
|
|
new_exit_block_id,
|
|
opt::Instruction::OperandList());
|
|
auto new_exit_block = MakeUnique<opt::BasicBlock>(std::move(label));
|
|
new_exit_block->AddInstruction(MakeUnique<opt::Instruction>(
|
|
ir_context.get(), SpvOpReturn, 0, 0, opt::Instruction::OperandList()));
|
|
entry_point_function->AddBasicBlock(std::move(new_exit_block));
|
|
}
|
|
|
|
// Make the new entry block
|
|
{
|
|
auto label = MakeUnique<opt::Instruction>(ir_context.get(), SpvOpLabel, 0,
|
|
ir_context->TakeNextId(),
|
|
opt::Instruction::OperandList());
|
|
auto new_entry_block = MakeUnique<opt::BasicBlock>(std::move(label));
|
|
|
|
// Make an instruction to construct vec4(1.0, 0.0, 0.0, 1.0), representing
|
|
// the colour red.
|
|
opt::Operand zero_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.first}};
|
|
opt::Operand one_float = {SPV_OPERAND_TYPE_ID, {zero_one_float_ids.second}};
|
|
opt::Instruction::OperandList op_composite_construct_operands = {
|
|
one_float, zero_float, zero_float, one_float};
|
|
auto temp_vec4 = opt::analysis::Vector(float_type, 4);
|
|
auto vec4_id = ir_context->get_type_mgr()->GetId(&temp_vec4);
|
|
auto red = MakeUnique<opt::Instruction>(
|
|
ir_context.get(), SpvOpCompositeConstruct, vec4_id,
|
|
ir_context->TakeNextId(), op_composite_construct_operands);
|
|
auto red_id = red->result_id();
|
|
new_entry_block->AddInstruction(std::move(red));
|
|
|
|
// Make an instruction to store red into the output color.
|
|
opt::Operand variable_to_store_into = {SPV_OPERAND_TYPE_ID,
|
|
{output_variable->result_id()}};
|
|
opt::Operand value_to_be_stored = {SPV_OPERAND_TYPE_ID, {red_id}};
|
|
opt::Instruction::OperandList op_store_operands = {variable_to_store_into,
|
|
value_to_be_stored};
|
|
new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
|
|
ir_context.get(), SpvOpStore, 0, 0, op_store_operands));
|
|
|
|
// We are going to attempt to construct 'false' as an expression of the form
|
|
// 'literal1 > literal2'. If we succeed, we will later replace each literal
|
|
// with a uniform of the same value - we can only do that replacement once
|
|
// we have added the entry block to the module.
|
|
std::unique_ptr<TransformationReplaceConstantWithUniform>
|
|
first_greater_then_operand_replacement = nullptr;
|
|
std::unique_ptr<TransformationReplaceConstantWithUniform>
|
|
second_greater_then_operand_replacement = nullptr;
|
|
uint32_t id_guaranteed_to_be_false = 0;
|
|
|
|
opt::analysis::Bool temp_bool_type;
|
|
opt::analysis::Bool* registered_bool_type =
|
|
ir_context->get_type_mgr()
|
|
->GetRegisteredType(&temp_bool_type)
|
|
->AsBool();
|
|
|
|
auto float_type_id = ir_context->get_type_mgr()->GetId(float_type);
|
|
auto types_for_which_uniforms_are_known =
|
|
fact_manager.GetTypesForWhichUniformValuesAreKnown();
|
|
|
|
// Check whether we have any float uniforms.
|
|
if (std::find(types_for_which_uniforms_are_known.begin(),
|
|
types_for_which_uniforms_are_known.end(),
|
|
float_type_id) != types_for_which_uniforms_are_known.end()) {
|
|
// We have at least one float uniform; let's see whether we have at least
|
|
// two.
|
|
auto available_constants =
|
|
fact_manager.GetConstantsAvailableFromUniformsForType(
|
|
ir_context.get(), float_type_id);
|
|
if (available_constants.size() > 1) {
|
|
// Grab the float constants associated with the first two known float
|
|
// uniforms.
|
|
auto first_constant =
|
|
ir_context->get_constant_mgr()
|
|
->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
|
|
available_constants[0]))
|
|
->AsFloatConstant();
|
|
auto second_constant =
|
|
ir_context->get_constant_mgr()
|
|
->GetConstantFromInst(ir_context->get_def_use_mgr()->GetDef(
|
|
available_constants[1]))
|
|
->AsFloatConstant();
|
|
|
|
// Now work out which of the two constants is larger than the other.
|
|
uint32_t larger_constant_index = 0;
|
|
uint32_t smaller_constant_index = 0;
|
|
if (first_constant->GetFloat() > second_constant->GetFloat()) {
|
|
larger_constant_index = 0;
|
|
smaller_constant_index = 1;
|
|
} else if (first_constant->GetFloat() < second_constant->GetFloat()) {
|
|
larger_constant_index = 1;
|
|
smaller_constant_index = 0;
|
|
}
|
|
|
|
// Only proceed with these constants if they have turned out to be
|
|
// distinct.
|
|
if (larger_constant_index != smaller_constant_index) {
|
|
// We are in a position to create 'false' as 'literal1 > literal2', so
|
|
// reserve an id for this computation; this id will end up being
|
|
// guaranteed to be 'false'.
|
|
id_guaranteed_to_be_false = ir_context->TakeNextId();
|
|
|
|
auto smaller_constant = available_constants[smaller_constant_index];
|
|
auto larger_constant = available_constants[larger_constant_index];
|
|
|
|
opt::Instruction::OperandList greater_than_operands = {
|
|
{SPV_OPERAND_TYPE_ID, {smaller_constant}},
|
|
{SPV_OPERAND_TYPE_ID, {larger_constant}}};
|
|
new_entry_block->AddInstruction(MakeUnique<opt::Instruction>(
|
|
ir_context.get(), SpvOpFOrdGreaterThan,
|
|
ir_context->get_type_mgr()->GetId(registered_bool_type),
|
|
id_guaranteed_to_be_false, greater_than_operands));
|
|
|
|
first_greater_then_operand_replacement =
|
|
MakeConstantUniformReplacement(ir_context.get(), fact_manager,
|
|
smaller_constant,
|
|
id_guaranteed_to_be_false, 0);
|
|
second_greater_then_operand_replacement =
|
|
MakeConstantUniformReplacement(ir_context.get(), fact_manager,
|
|
larger_constant,
|
|
id_guaranteed_to_be_false, 1);
|
|
}
|
|
}
|
|
}
|
|
|
|
if (id_guaranteed_to_be_false == 0) {
|
|
auto constant_false = ir_context->get_constant_mgr()->RegisterConstant(
|
|
MakeUnique<opt::analysis::BoolConstant>(registered_bool_type, false));
|
|
id_guaranteed_to_be_false = ir_context->get_constant_mgr()
|
|
->GetDefiningInstruction(constant_false)
|
|
->result_id();
|
|
}
|
|
|
|
opt::Operand false_condition = {SPV_OPERAND_TYPE_ID,
|
|
{id_guaranteed_to_be_false}};
|
|
opt::Operand then_block = {SPV_OPERAND_TYPE_ID,
|
|
{entry_point_function->entry()->id()}};
|
|
opt::Operand else_block = {SPV_OPERAND_TYPE_ID, {new_exit_block_id}};
|
|
opt::Instruction::OperandList op_branch_conditional_operands = {
|
|
false_condition, then_block, else_block};
|
|
new_entry_block->AddInstruction(
|
|
MakeUnique<opt::Instruction>(ir_context.get(), SpvOpBranchConditional,
|
|
0, 0, op_branch_conditional_operands));
|
|
|
|
entry_point_function->InsertBasicBlockBefore(
|
|
std::move(new_entry_block), entry_point_function->entry().get());
|
|
|
|
for (auto& replacement : {first_greater_then_operand_replacement.get(),
|
|
second_greater_then_operand_replacement.get()}) {
|
|
if (replacement) {
|
|
assert(replacement->IsApplicable(ir_context.get(), fact_manager));
|
|
replacement->Apply(ir_context.get(), &fact_manager);
|
|
}
|
|
}
|
|
}
|
|
|
|
// Write out the module as a binary.
|
|
ir_context->module()->ToBinary(binary_out, false);
|
|
return true;
|
|
}
|
|
|
|
} // namespace fuzz
|
|
} // namespace spvtools
|