Updated spirv-tools.

This commit is contained in:
Бранимир Караџић 2022-07-29 21:41:54 -07:00
parent a0766b3205
commit 758d02db1e
27 changed files with 736 additions and 525 deletions

View File

@ -1 +1 @@
"v2022.3-dev", "SPIRV-Tools v2022.3-dev 05862b9695f45c8a8810d69c52ced4440446c1b7"
"v2022.3-dev", "SPIRV-Tools v2022.3-dev 484a5f0c25b53584a6b7fce0702a6bb580072d81"

View File

@ -13,23 +13,11 @@
// limitations under the License.
#include "source/opt/def_use_manager.h"
#include "source/util/make_unique.h"
namespace spvtools {
namespace opt {
namespace analysis {
// Don't compact before we have a reasonable number of ids allocated (~32kb).
static const size_t kCompactThresholdMinTotalIds = (8 * 1024);
// Compact when fewer than this fraction of the storage is used (should be 2^n
// for performance).
static const size_t kCompactThresholdFractionFreeIds = 8;
DefUseManager::DefUseManager() {
use_pool_ = MakeUnique<UseListPool>();
used_id_pool_ = MakeUnique<UsedIdListPool>();
}
void DefUseManager::AnalyzeInstDef(Instruction* inst) {
const uint32_t def_id = inst->result_id();
if (def_id != 0) {
@ -46,15 +34,15 @@ void DefUseManager::AnalyzeInstDef(Instruction* inst) {
}
void DefUseManager::AnalyzeInstUse(Instruction* inst) {
// It might have existed before.
EraseUseRecordsOfOperandIds(inst);
// Create entry for the given instruction. Note that the instruction may
// not have any in-operands. In such cases, we still need a entry for those
// instructions so this manager knows it has seen the instruction later.
UsedIdList& used_ids =
inst_to_used_id_.insert({inst, UsedIdList(used_id_pool_.get())})
.first->second;
auto* used_ids = &inst_to_used_ids_[inst];
if (used_ids->size()) {
EraseUseRecordsOfOperandIds(inst);
used_ids = &inst_to_used_ids_[inst];
}
used_ids->clear(); // It might have existed before.
for (uint32_t i = 0; i < inst->NumOperands(); ++i) {
switch (inst->GetOperand(i).type) {
@ -66,17 +54,8 @@ void DefUseManager::AnalyzeInstUse(Instruction* inst) {
uint32_t use_id = inst->GetSingleWordOperand(i);
Instruction* def = GetDef(use_id);
assert(def && "Definition is not registered.");
// Add to inst's use records
used_ids.push_back(use_id);
// Add to the users, taking care to avoid adding duplicates. We know
// the duplicate for this instruction will always be at the tail.
UseList& list = inst_to_users_.insert({def, UseList(use_pool_.get())})
.first->second;
if (list.empty() || list.back() != inst) {
list.push_back(inst);
}
id_to_users_.insert(UserEntry{def, inst});
used_ids->push_back(use_id);
} break;
default:
break;
@ -115,6 +94,23 @@ const Instruction* DefUseManager::GetDef(uint32_t id) const {
return iter->second;
}
DefUseManager::IdToUsersMap::const_iterator DefUseManager::UsersBegin(
const Instruction* def) const {
return id_to_users_.lower_bound(
UserEntry{const_cast<Instruction*>(def), nullptr});
}
bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter,
const IdToUsersMap::const_iterator& cached_end,
const Instruction* inst) const {
return (iter != cached_end && iter->def == inst);
}
bool DefUseManager::UsersNotEnd(const IdToUsersMap::const_iterator& iter,
const Instruction* inst) const {
return UsersNotEnd(iter, id_to_users_.end(), inst);
}
bool DefUseManager::WhileEachUser(
const Instruction* def, const std::function<bool(Instruction*)>& f) const {
// Ensure that |def| has been registered.
@ -122,11 +118,9 @@ bool DefUseManager::WhileEachUser(
"Definition is not registered.");
if (!def->HasResultId()) return true;
auto iter = inst_to_users_.find(def);
if (iter != inst_to_users_.end()) {
for (Instruction* user : iter->second) {
if (!f(user)) return false;
}
auto end = id_to_users_.end();
for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) {
if (!f(iter->user)) return false;
}
return true;
}
@ -157,15 +151,14 @@ bool DefUseManager::WhileEachUse(
"Definition is not registered.");
if (!def->HasResultId()) return true;
auto iter = inst_to_users_.find(def);
if (iter != inst_to_users_.end()) {
for (Instruction* user : iter->second) {
for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) {
const Operand& op = user->GetOperand(idx);
if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) {
if (def->result_id() == op.words[0]) {
if (!f(user, idx)) return false;
}
auto end = id_to_users_.end();
for (auto iter = UsersBegin(def); UsersNotEnd(iter, end, def); ++iter) {
Instruction* user = iter->user;
for (uint32_t idx = 0; idx != user->NumOperands(); ++idx) {
const Operand& op = user->GetOperand(idx);
if (op.type != SPV_OPERAND_TYPE_RESULT_ID && spvIsIdType(op.type)) {
if (def->result_id() == op.words[0]) {
if (!f(user, idx)) return false;
}
}
}
@ -237,18 +230,17 @@ void DefUseManager::AnalyzeDefUse(Module* module) {
}
void DefUseManager::ClearInst(Instruction* inst) {
if (inst_to_used_id_.find(inst) != inst_to_used_id_.end()) {
auto iter = inst_to_used_ids_.find(inst);
if (iter != inst_to_used_ids_.end()) {
EraseUseRecordsOfOperandIds(inst);
uint32_t const result_id = inst->result_id();
if (result_id != 0) {
// For each using instruction, remove result_id from their used ids.
auto iter = inst_to_users_.find(inst);
if (iter != inst_to_users_.end()) {
for (Instruction* use : iter->second) {
inst_to_used_id_.at(use).remove_first(result_id);
}
inst_to_users_.erase(iter);
if (inst->result_id() != 0) {
// Remove all uses of this inst.
auto users_begin = UsersBegin(inst);
auto end = id_to_users_.end();
auto new_end = users_begin;
for (; UsersNotEnd(new_end, end, inst); ++new_end) {
}
id_to_users_.erase(users_begin, new_end);
id_to_def_.erase(inst->result_id());
}
}
@ -257,48 +249,16 @@ void DefUseManager::ClearInst(Instruction* inst) {
void DefUseManager::EraseUseRecordsOfOperandIds(const Instruction* inst) {
// Go through all ids used by this instruction, remove this instruction's
// uses of them.
auto iter = inst_to_used_id_.find(inst);
if (iter != inst_to_used_id_.end()) {
const UsedIdList& used_ids = iter->second;
for (uint32_t def_id : used_ids) {
auto def_iter = inst_to_users_.find(GetDef(def_id));
if (def_iter != inst_to_users_.end()) {
def_iter->second.remove_first(const_cast<Instruction*>(inst));
}
}
inst_to_used_id_.erase(inst);
// If we're using only a fraction of the space in used_ids_, compact storage
// to prevent memory usage from being unbounded.
if (used_id_pool_->total_nodes() > kCompactThresholdMinTotalIds &&
used_id_pool_->used_nodes() <
used_id_pool_->total_nodes() / kCompactThresholdFractionFreeIds) {
CompactStorage();
auto iter = inst_to_used_ids_.find(inst);
if (iter != inst_to_used_ids_.end()) {
for (auto use_id : iter->second) {
id_to_users_.erase(
UserEntry{GetDef(use_id), const_cast<Instruction*>(inst)});
}
inst_to_used_ids_.erase(iter);
}
}
void DefUseManager::CompactStorage() {
CompactUseRecords();
CompactUsedIds();
}
void DefUseManager::CompactUseRecords() {
std::unique_ptr<UseListPool> new_pool = MakeUnique<UseListPool>();
for (auto& iter : inst_to_users_) {
iter.second.move_nodes(new_pool.get());
}
use_pool_ = std::move(new_pool);
}
void DefUseManager::CompactUsedIds() {
std::unique_ptr<UsedIdListPool> new_pool = MakeUnique<UsedIdListPool>();
for (auto& iter : inst_to_used_id_) {
iter.second.move_nodes(new_pool.get());
}
used_id_pool_ = std::move(new_pool);
}
bool CompareAndPrintDifferences(const DefUseManager& lhs,
const DefUseManager& rhs) {
bool same = true;
@ -317,52 +277,34 @@ bool CompareAndPrintDifferences(const DefUseManager& lhs,
same = false;
}
for (const auto& l : lhs.inst_to_used_id_) {
std::set<uint32_t> ul, ur;
lhs.ForEachUse(l.first,
[&ul](Instruction*, uint32_t id) { ul.insert(id); });
rhs.ForEachUse(l.first,
[&ur](Instruction*, uint32_t id) { ur.insert(id); });
if (ul.size() != ur.size()) {
printf(
"Diff in inst_to_used_id_: different number of used ids (%zu != %zu)",
ul.size(), ur.size());
same = false;
} else if (ul != ur) {
printf("Diff in inst_to_used_id_: different used ids\n");
same = false;
if (lhs.id_to_users_ != rhs.id_to_users_) {
for (auto p : lhs.id_to_users_) {
if (rhs.id_to_users_.count(p) == 0) {
printf("Diff in id_to_users: missing value in rhs\n");
}
}
}
for (const auto& r : rhs.inst_to_used_id_) {
auto iter_l = lhs.inst_to_used_id_.find(r.first);
if (r.second.empty() &&
!(iter_l == lhs.inst_to_used_id_.end() || iter_l->second.empty())) {
printf("Diff in inst_to_used_id_: unexpected instr in rhs\n");
same = false;
for (auto p : rhs.id_to_users_) {
if (lhs.id_to_users_.count(p) == 0) {
printf("Diff in id_to_users: missing value in lhs\n");
}
}
same = false;
}
for (const auto& l : lhs.inst_to_users_) {
std::set<Instruction*> ul, ur;
lhs.ForEachUser(l.first, [&ul](Instruction* use) { ul.insert(use); });
rhs.ForEachUser(l.first, [&ur](Instruction* use) { ur.insert(use); });
if (ul.size() != ur.size()) {
printf("Diff in inst_to_users_: different number of users (%zu != %zu)",
ul.size(), ur.size());
same = false;
} else if (ul != ur) {
printf("Diff in inst_to_users_: different users\n");
same = false;
if (lhs.inst_to_used_ids_ != rhs.inst_to_used_ids_) {
for (auto p : lhs.inst_to_used_ids_) {
if (rhs.inst_to_used_ids_.count(p.first) == 0) {
printf("Diff in inst_to_used_ids: missing value in rhs\n");
}
}
}
for (const auto& r : rhs.inst_to_users_) {
auto iter_l = lhs.inst_to_users_.find(r.first);
if (r.second.empty() &&
!(iter_l == lhs.inst_to_users_.end() || iter_l->second.empty())) {
printf("Diff in inst_to_users_: unexpected instr in rhs\n");
same = false;
for (auto p : rhs.inst_to_used_ids_) {
if (lhs.inst_to_used_ids_.count(p.first) == 0) {
printf("Diff in inst_to_used_ids: missing value in lhs\n");
}
}
same = false;
}
return same;
}

View File

@ -21,7 +21,6 @@
#include "source/opt/instruction.h"
#include "source/opt/module.h"
#include "source/util/pooled_linked_list.h"
#include "spirv-tools/libspirv.hpp"
namespace spvtools {
@ -50,6 +49,50 @@ inline bool operator<(const Use& lhs, const Use& rhs) {
return lhs.operand_index < rhs.operand_index;
}
// Definition should never be null. User can be null, however, such an entry
// should be used only for searching (e.g. all users of a particular definition)
// and never stored in a container.
struct UserEntry {
Instruction* def;
Instruction* user;
};
inline bool operator==(const UserEntry& lhs, const UserEntry& rhs) {
return lhs.def == rhs.def && lhs.user == rhs.user;
}
// Orders UserEntry for use in associative containers (i.e. less than ordering).
//
// The definition of an UserEntry is treated as the major key and the users as
// the minor key so that all the users of a particular definition are
// consecutive in a container.
//
// A null user always compares less than a real user. This is done to provide
// easy values to search for the beginning of the users of a particular
// definition (i.e. using {def, nullptr}).
struct UserEntryLess {
bool operator()(const UserEntry& lhs, const UserEntry& rhs) const {
// If lhs.def and rhs.def are both null, fall through to checking the
// second entries.
if (!lhs.def && rhs.def) return true;
if (lhs.def && !rhs.def) return false;
// If neither definition is null, then compare unique ids.
if (lhs.def && rhs.def) {
if (lhs.def->unique_id() < rhs.def->unique_id()) return true;
if (rhs.def->unique_id() < lhs.def->unique_id()) return false;
}
// Return false on equality.
if (!lhs.user && !rhs.user) return false;
if (!lhs.user) return true;
if (!rhs.user) return false;
// If neither user is null then compare unique ids.
return lhs.user->unique_id() < rhs.user->unique_id();
}
};
// A class for analyzing and managing defs and uses in an Module.
class DefUseManager {
public:
@ -59,7 +102,7 @@ class DefUseManager {
// will be communicated to the outside via the given message |consumer|. This
// instance only keeps a reference to the |consumer|, so the |consumer| should
// outlive this instance.
DefUseManager(Module* module) : DefUseManager() { AnalyzeDefUse(module); }
DefUseManager(Module* module) { AnalyzeDefUse(module); }
DefUseManager(const DefUseManager&) = delete;
DefUseManager(DefUseManager&&) = delete;
@ -171,36 +214,35 @@ class DefUseManager {
// uses.
void UpdateDefUse(Instruction* inst);
// Compacts any internal storage to save memory.
void CompactStorage();
private:
using UseList = spvtools::utils::PooledLinkedList<Instruction*>;
using UseListPool = spvtools::utils::PooledLinkedListNodes<Instruction*>;
// Stores linked lists of Instructions using a def.
using InstToUsersMap = std::unordered_map<const Instruction*, UseList>;
using IdToUsersMap = std::set<UserEntry, UserEntryLess>;
using InstToUsedIdsMap =
std::unordered_map<const Instruction*, std::vector<uint32_t>>;
using UsedIdList = spvtools::utils::PooledLinkedList<uint32_t>;
using UsedIdListPool = spvtools::utils::PooledLinkedListNodes<uint32_t>;
// Stores mapping from instruction to their UsedIdRange.
using InstToUsedIdMap = std::unordered_map<const Instruction*, UsedIdList>;
// Returns the first location that {|def|, nullptr} could be inserted into the
// users map without violating ordering.
IdToUsersMap::const_iterator UsersBegin(const Instruction* def) const;
DefUseManager();
// Returns true if |iter| has not reached the end of |def|'s users.
//
// In the first version |iter| is compared against the end of the map for
// validity before other checks. In the second version, |iter| is compared
// against |cached_end| for validity before other checks. This allows caching
// the map's end which is a performance improvement on some platforms.
bool UsersNotEnd(const IdToUsersMap::const_iterator& iter,
const Instruction* def) const;
bool UsersNotEnd(const IdToUsersMap::const_iterator& iter,
const IdToUsersMap::const_iterator& cached_end,
const Instruction* def) const;
// Analyzes the defs and uses in the given |module| and populates data
// structures in this class. Does nothing if |module| is nullptr.
void AnalyzeDefUse(Module* module);
// Removes unused entries in used_records_ and used_ids_.
void CompactUseRecords();
void CompactUsedIds();
IdToDefMap id_to_def_; // Mapping from ids to their definitions
InstToUsersMap inst_to_users_; // Map from def to uses.
std::unique_ptr<UseListPool> use_pool_;
std::unique_ptr<UsedIdListPool> used_id_pool_;
InstToUsedIdMap inst_to_used_id_; // Map from instruction to used ids.
IdToDefMap id_to_def_; // Mapping from ids to their definitions
IdToUsersMap id_to_users_; // Mapping from ids to their users
// Mapping from instructions to the ids used in the instruction.
InstToUsedIdsMap inst_to_used_ids_;
};
} // namespace analysis

View File

@ -187,6 +187,8 @@ bool IrLoader::AddInstruction(const spv_parsed_instruction_t* inst) {
module_->AddExtInstImport(std::move(spv_inst));
} else if (opcode == SpvOpMemoryModel) {
module_->SetMemoryModel(std::move(spv_inst));
} else if (opcode == SpvOpSamplerImageAddressingModeNV) {
module_->SetSampledImageAddressMode(std::move(spv_inst));
} else if (opcode == SpvOpEntryPoint) {
module_->AddEntryPoint(std::move(spv_inst));
} else if (opcode == SpvOpExecutionMode ||

View File

@ -118,8 +118,6 @@ class MergeReturnPass : public MemPass {
StructuredControlState(Instruction* break_merge, Instruction* merge)
: break_merge_(break_merge), current_merge_(merge) {}
StructuredControlState(const StructuredControlState&) = default;
bool InBreakable() const { return break_merge_; }
bool InStructuredFlow() const { return CurrentMergeId() != 0; }

View File

@ -90,6 +90,8 @@ void Module::ForEachInst(const std::function<void(Instruction*)>& f,
DELEGATE(extensions_);
DELEGATE(ext_inst_imports_);
if (memory_model_) memory_model_->ForEachInst(f, run_on_debug_line_insts);
if (sampled_image_address_mode_)
sampled_image_address_mode_->ForEachInst(f, run_on_debug_line_insts);
DELEGATE(entry_points_);
DELEGATE(execution_modes_);
DELEGATE(debugs1_);
@ -114,6 +116,9 @@ void Module::ForEachInst(const std::function<void(const Instruction*)>& f,
if (memory_model_)
static_cast<const Instruction*>(memory_model_.get())
->ForEachInst(f, run_on_debug_line_insts);
if (sampled_image_address_mode_)
static_cast<const Instruction*>(sampled_image_address_mode_.get())
->ForEachInst(f, run_on_debug_line_insts);
for (auto& i : entry_points_) DELEGATE(i);
for (auto& i : execution_modes_) DELEGATE(i);
for (auto& i : debugs1_) DELEGATE(i);

View File

@ -83,6 +83,9 @@ class Module {
// Set the memory model for this module.
inline void SetMemoryModel(std::unique_ptr<Instruction> m);
// Set the sampled image addressing mode for this module.
inline void SetSampledImageAddressMode(std::unique_ptr<Instruction> m);
// Appends an entry point instruction to this module.
inline void AddEntryPoint(std::unique_ptr<Instruction> e);
@ -158,12 +161,20 @@ class Module {
inline IteratorRange<inst_iterator> ext_inst_imports();
inline IteratorRange<const_inst_iterator> ext_inst_imports() const;
// Return the memory model instruction contained inthis module.
// Return the memory model instruction contained in this module.
inline Instruction* GetMemoryModel() { return memory_model_.get(); }
inline const Instruction* GetMemoryModel() const {
return memory_model_.get();
}
// Return the sampled image address mode instruction contained in this module.
inline Instruction* GetSampledImageAddressMode() {
return sampled_image_address_mode_.get();
}
inline const Instruction* GetSampledImageAddressMode() const {
return sampled_image_address_mode_.get();
}
// There are several kinds of debug instructions, according to where they can
// appear in the logical layout of a module:
// - Section 7a: OpString, OpSourceExtension, OpSource, OpSourceContinued
@ -288,6 +299,8 @@ class Module {
InstructionList ext_inst_imports_;
// A module only has one memory model instruction.
std::unique_ptr<Instruction> memory_model_;
// A module can only have one optional sampled image addressing mode
std::unique_ptr<Instruction> sampled_image_address_mode_;
InstructionList entry_points_;
InstructionList execution_modes_;
InstructionList debugs1_;
@ -326,6 +339,10 @@ inline void Module::SetMemoryModel(std::unique_ptr<Instruction> m) {
memory_model_ = std::move(m);
}
inline void Module::SetSampledImageAddressMode(std::unique_ptr<Instruction> m) {
sampled_image_address_mode_ = std::move(m);
}
inline void Module::AddEntryPoint(std::unique_ptr<Instruction> e) {
entry_points_.push_back(std::move(e));
}

View File

@ -62,28 +62,29 @@ spv_result_t advanceLine(spv_text text, spv_position position) {
// parameters, its the users responsibility to ensure these are non null.
spv_result_t advance(spv_text text, spv_position position) {
// NOTE: Consume white space, otherwise don't advance.
if (position->index >= text->length) return SPV_END_OF_STREAM;
switch (text->str[position->index]) {
case '\0':
return SPV_END_OF_STREAM;
case ';':
if (spv_result_t error = advanceLine(text, position)) return error;
return advance(text, position);
case ' ':
case '\t':
case '\r':
position->column++;
position->index++;
return advance(text, position);
case '\n':
position->column = 0;
position->line++;
position->index++;
return advance(text, position);
default:
break;
while (true) {
if (position->index >= text->length) return SPV_END_OF_STREAM;
switch (text->str[position->index]) {
case '\0':
return SPV_END_OF_STREAM;
case ';':
if (spv_result_t error = advanceLine(text, position)) return error;
continue;
case ' ':
case '\t':
case '\r':
position->column++;
position->index++;
continue;
case '\n':
position->column = 0;
position->line++;
position->index++;
continue;
default:
return SPV_SUCCESS;
}
}
return SPV_SUCCESS;
}
// Fetches the next word from the given text stream starting from the given

View File

@ -209,9 +209,10 @@ std::istream& operator>>(std::istream& is, FloatProxy<T>& value) {
// be the default for any non-specialized type.
template <typename T>
struct HexFloatTraits {
// Integer type that can store this hex-float.
// Integer type that can store the bit representation of this hex-float.
using uint_type = void;
// Signed integer type that can store this hex-float.
// Signed integer type that can store the bit representation of this
// hex-float.
using int_type = void;
// The numerical type that this HexFloat represents.
using underlying_type = void;
@ -958,9 +959,15 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
// This "looks" like a hex-float so treat it as one.
bool seen_p = false;
bool seen_dot = false;
// The mantissa bits, without the most significant 1 bit, and with the
// the most recently read bits in the least significant positions.
uint_type fraction = 0;
// The number of mantissa bits that have been read, including the leading 1
// bit that is not written into 'fraction'.
uint_type fraction_index = 0;
uint_type fraction = 0;
// TODO(dneto): handle overflow and underflow
int_type exponent = HF::exponent_bias;
// Strip off leading zeros so we don't have to special-case them later.
@ -968,11 +975,13 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
is.get();
}
bool is_denorm =
true; // Assume denorm "representation" until we hear otherwise.
// NB: This does not mean the value is actually denorm,
// it just means that it was written 0.
// Does the mantissa, as written, have non-zero digits to the left of
// the decimal point. Assume no until proven otherwise.
bool has_integer_part = false;
bool bits_written = false; // Stays false until we write a bit.
// Scan the mantissa hex digits until we see a '.' or the 'p' that
// starts the exponent.
while (!seen_p && !seen_dot) {
// Handle characters that are left of the fractional part.
if (next_char == '.') {
@ -980,9 +989,8 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
} else if (next_char == 'p') {
seen_p = true;
} else if (::isxdigit(next_char)) {
// We know this is not denormalized since we have stripped all leading
// zeroes and we are not a ".".
is_denorm = false;
// We have stripped all leading zeroes and we have not yet seen a ".".
has_integer_part = true;
int number = get_nibble_from_character(next_char);
for (int i = 0; i < 4; ++i, number <<= 1) {
uint_type write_bit = (number & 0x8) ? 0x1 : 0x0;
@ -993,8 +1001,12 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
fraction |
static_cast<uint_type>(
write_bit << (HF::top_bit_left_shift - fraction_index++)));
// TODO(dneto): Avoid overflow. Testing would require
// parameterization.
exponent = static_cast<int_type>(exponent + 1);
}
// Since this updated after setting fraction bits, this effectively
// drops the leading 1 bit.
bits_written |= write_bit != 0;
}
} else {
@ -1018,10 +1030,12 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
for (int i = 0; i < 4; ++i, number <<= 1) {
uint_type write_bit = (number & 0x8) ? 0x01 : 0x00;
bits_written |= write_bit != 0;
if (is_denorm && !bits_written) {
if ((!has_integer_part) && !bits_written) {
// Handle modifying the exponent here this way we can handle
// an arbitrary number of hex values without overflowing our
// integer.
// TODO(dneto): Handle underflow. Testing would require extra
// parameterization.
exponent = static_cast<int_type>(exponent - 1);
} else {
fraction = static_cast<uint_type>(
@ -1043,25 +1057,40 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
// Finished reading the part preceding 'p'.
// In hex floats syntax, the binary exponent is required.
bool seen_sign = false;
bool seen_exponent_sign = false;
int8_t exponent_sign = 1;
bool seen_written_exponent_digits = false;
// The magnitude of the exponent, as written, or the sentinel value to signal
// overflow.
int_type written_exponent = 0;
// A sentinel value signalling overflow of the magnitude of the written
// exponent. We'll assume that -written_exponent_overflow is valid for the
// type. Later we may add 1 or subtract 1 from the adjusted exponent, so leave
// room for an extra 1.
const int_type written_exponent_overflow =
std::numeric_limits<int_type>::max() - 1;
while (true) {
if (!seen_written_exponent_digits &&
(next_char == '-' || next_char == '+')) {
if (seen_sign) {
if (seen_exponent_sign) {
is.setstate(std::ios::failbit);
return is;
}
seen_sign = true;
seen_exponent_sign = true;
exponent_sign = (next_char == '-') ? -1 : 1;
} else if (::isdigit(next_char)) {
seen_written_exponent_digits = true;
// Hex-floats express their exponent as decimal.
written_exponent = static_cast<int_type>(written_exponent * 10);
written_exponent =
static_cast<int_type>(written_exponent + (next_char - '0'));
int_type digit =
static_cast<int_type>(static_cast<int_type>(next_char) - '0');
if (written_exponent >= (written_exponent_overflow - digit) / 10) {
// The exponent is very big. Saturate rather than overflow the exponent.
// signed integer, which would be undefined behaviour.
written_exponent = written_exponent_overflow;
} else {
written_exponent = static_cast<int_type>(
static_cast<int_type>(written_exponent * 10) + digit);
}
} else {
break;
}
@ -1075,10 +1104,29 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
}
written_exponent = static_cast<int_type>(written_exponent * exponent_sign);
exponent = static_cast<int_type>(exponent + written_exponent);
// Now fold in the exponent bias into the written exponent, updating exponent.
// But avoid undefined behaviour that would result from overflowing int_type.
if (written_exponent >= 0 && exponent >= 0) {
// Saturate up to written_exponent_overflow.
if (written_exponent_overflow - exponent > written_exponent) {
exponent = static_cast<int_type>(written_exponent + exponent);
} else {
exponent = written_exponent_overflow;
}
} else if (written_exponent < 0 && exponent < 0) {
// Saturate down to -written_exponent_overflow.
if (written_exponent_overflow + exponent > -written_exponent) {
exponent = static_cast<int_type>(written_exponent + exponent);
} else {
exponent = static_cast<int_type>(-written_exponent_overflow);
}
} else {
// They're of opposing sign, so it's safe to add.
exponent = static_cast<int_type>(written_exponent + exponent);
}
bool is_zero = is_denorm && (fraction == 0);
if (is_denorm && !is_zero) {
bool is_zero = (!has_integer_part) && (fraction == 0);
if ((!has_integer_part) && !is_zero) {
fraction = static_cast<uint_type>(fraction << 1);
exponent = static_cast<int_type>(exponent - 1);
} else if (is_zero) {
@ -1095,7 +1143,7 @@ std::istream& operator>>(std::istream& is, HexFloat<T, Traits>& value) {
const int_type max_exponent =
SetBits<uint_type, 0, HF::num_exponent_bits>::get;
// Handle actual denorm numbers
// Handle denorm numbers
while (exponent < 0 && !is_zero) {
fraction = static_cast<uint_type>(fraction >> 1);
exponent = static_cast<int_type>(exponent + 1);

View File

@ -1,236 +0,0 @@
// Copyright (c) 2021 The Khronos Group Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef SOURCE_UTIL_POOLED_LINKED_LIST_H_
#define SOURCE_UTIL_POOLED_LINKED_LIST_H_
#include <cstdint>
#include <vector>
namespace spvtools {
namespace utils {
// Shared storage of nodes for PooledLinkedList.
template <typename T>
class PooledLinkedListNodes {
public:
struct Node {
Node(T e, int32_t n = -1) : element(e), next(n) {}
T element = {};
int32_t next = -1;
};
PooledLinkedListNodes() = default;
PooledLinkedListNodes(const PooledLinkedListNodes&) = delete;
PooledLinkedListNodes& operator=(const PooledLinkedListNodes&) = delete;
PooledLinkedListNodes(PooledLinkedListNodes&& that) {
*this = std::move(that);
}
PooledLinkedListNodes& operator=(PooledLinkedListNodes&& that) {
vec_ = std::move(that.vec_);
free_nodes_ = that.free_nodes_;
return *this;
}
size_t total_nodes() { return vec_.size(); }
size_t free_nodes() { return free_nodes_; }
size_t used_nodes() { return total_nodes() - free_nodes(); }
private:
template <typename ListT>
friend class PooledLinkedList;
Node& at(int32_t index) { return vec_[index]; }
const Node& at(int32_t index) const { return vec_[index]; }
int32_t insert(T element) {
int32_t index = int32_t(vec_.size());
vec_.emplace_back(element);
return index;
}
std::vector<Node> vec_;
size_t free_nodes_ = 0;
};
// Implements a linked-list where list nodes come from a shared pool. This is
// meant to be used in scenarios where it is desirable to avoid many small
// allocations.
//
// Instead of pointers, the list uses indices to allow the underlying storage
// to be modified without needing to modify the list. When removing elements
// from the list, nodes are not deleted or recycled: to reclaim unused space,
// perform a sequence of |move_nodes| operations into a temporary pool, which
// then is moved into the old pool.
//
// This does *not* attempt to implement a full stl-compatible interface.
template <typename T>
class PooledLinkedList {
public:
using NodePool = PooledLinkedListNodes<T>;
using Node = typename NodePool::Node;
PooledLinkedList() = delete;
PooledLinkedList(NodePool* nodes) : nodes_(nodes) {}
// Shared iterator implementation (for iterator and const_iterator).
template <typename ElementT, typename PoolT>
class iterator_base {
public:
iterator_base(const iterator_base& i)
: nodes_(i.nodes_), index_(i.index_) {}
iterator_base& operator++() {
index_ = nodes_->at(index_).next;
return *this;
}
iterator_base& operator=(const iterator_base& i) {
nodes_ = i.nodes_;
index_ = i.index_;
return *this;
}
ElementT& operator*() const { return nodes_->at(index_).element; }
ElementT* operator->() const { return &nodes_->at(index_).element; }
friend inline bool operator==(const iterator_base& lhs,
const iterator_base& rhs) {
return lhs.nodes_ == rhs.nodes_ && lhs.index_ == rhs.index_;
}
friend inline bool operator!=(const iterator_base& lhs,
const iterator_base& rhs) {
return lhs.nodes_ != rhs.nodes_ || lhs.index_ != rhs.index_;
}
// Define standard iterator types needs so this class can be
// used with <algorithms>.
using iterator_category = std::forward_iterator_tag;
using difference_type = std::ptrdiff_t;
using value_type = ElementT;
using pointer = ElementT*;
using const_pointer = const ElementT*;
using reference = ElementT&;
using const_reference = const ElementT&;
using size_type = size_t;
private:
friend PooledLinkedList;
iterator_base(PoolT* pool, int32_t index) : nodes_(pool), index_(index) {}
PoolT* nodes_;
int32_t index_ = -1;
};
using iterator = iterator_base<T, std::vector<Node>>;
using const_iterator = iterator_base<const T, const std::vector<Node>>;
bool empty() const { return head_ == -1; }
T& front() { return nodes_->at(head_).element; }
T& back() { return nodes_->at(tail_).element; }
const T& front() const { return nodes_->at(head_).element; }
const T& back() const { return nodes_->at(tail_).element; }
iterator begin() { return iterator(&nodes_->vec_, head_); }
iterator end() { return iterator(&nodes_->vec_, -1); }
const_iterator begin() const { return const_iterator(&nodes_->vec_, head_); }
const_iterator end() const { return const_iterator(&nodes_->vec_, -1); }
// Inserts |element| at the back of the list.
void push_back(T element) {
int32_t new_tail = nodes_->insert(element);
if (head_ == -1) {
head_ = new_tail;
tail_ = new_tail;
} else {
nodes_->at(tail_).next = new_tail;
tail_ = new_tail;
}
}
// Removes the first occurrence of |element| from the list.
// Returns if |element| was removed.
bool remove_first(T element) {
int32_t* prev_next = &head_;
for (int32_t prev_index = -1, index = head_; index != -1; /**/) {
auto& node = nodes_->at(index);
if (node.element == element) {
// Snip from of the list, optionally fixing up tail pointer.
if (tail_ == index) {
assert(node.next == -1);
tail_ = prev_index;
}
*prev_next = node.next;
nodes_->free_nodes_++;
return true;
} else {
prev_next = &node.next;
}
prev_index = index;
index = node.next;
}
return false;
}
// Returns the PooledLinkedListNodes that owns this list's nodes.
NodePool* pool() { return nodes_; }
// Moves the nodes in this list into |new_pool|, providing a way to compact
// storage and reclaim unused space.
//
// Upon completing a sequence of |move_nodes| calls, you must ensure you
// retain ownership of the new storage your lists point to. Example usage:
//
// unique_ptr<NodePool> new_pool = ...;
// for (PooledLinkedList& list : lists) {
// list.move_to(new_pool);
// }
// my_pool_ = std::move(new_pool);
void move_nodes(NodePool* new_pool) {
// Be sure to construct the list in the same order, instead of simply
// doing a sequence of push_backs.
int32_t prev_entry = -1;
int32_t nodes_freed = 0;
for (int32_t index = head_; index != -1; nodes_freed++) {
const auto& node = nodes_->at(index);
int32_t this_entry = new_pool->insert(node.element);
index = node.next;
if (prev_entry == -1) {
head_ = this_entry;
} else {
new_pool->at(prev_entry).next = this_entry;
}
prev_entry = this_entry;
}
tail_ = prev_entry;
// Update our old pool's free count, now we're a member of the new pool.
nodes_->free_nodes_ += nodes_freed;
nodes_ = new_pool;
}
private:
NodePool* nodes_;
int32_t head_ = -1;
int32_t tail_ = -1;
};
} // namespace utils
} // namespace spvtools
#endif // SOURCE_UTIL_POOLED_LINKED_LIST_H_

View File

@ -293,6 +293,11 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
<< "Missing OpFunctionEnd at end of module.";
if (vstate->HasCapability(SpvCapabilityBindlessTextureNV) &&
!vstate->has_samplerimage_variable_address_mode_specified())
return vstate->diag(SPV_ERROR_INVALID_LAYOUT, nullptr)
<< "Missing required OpSamplerImageAddressingModeNV instruction.";
// Catch undefined forward references before performing further checks.
if (auto error = ValidateForwardDecls(*vstate)) return error;
@ -345,6 +350,7 @@ spv_result_t ValidateBinaryUsingContextAndValidationState(
if (auto error = NonUniformPass(*vstate, &instruction)) return error;
if (auto error = LiteralsPass(*vstate, &instruction)) return error;
if (auto error = RayQueryPass(*vstate, &instruction)) return error;
}
// Validate the preconditions involving adjacent instructions. e.g. SpvOpPhi

View File

@ -197,6 +197,9 @@ spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst);
/// Validates correctness of miscellaneous instructions.
spv_result_t MiscPass(ValidationState_t& _, const Instruction* inst);
/// Validates correctness of ray query instructions.
spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst);
/// Calculates the reachability of basic blocks.
void ReachabilityPass(ValidationState_t& _);

View File

@ -66,7 +66,8 @@ spv_result_t ValidatePhi(ValidationState_t& _, const Instruction* inst) {
assert(type_inst);
const SpvOp type_opcode = type_inst->opcode();
if (!_.options()->before_hlsl_legalization) {
if (!_.options()->before_hlsl_legalization &&
!_.HasCapability(SpvCapabilityBindlessTextureNV)) {
if (type_opcode == SpvOpTypeSampledImage ||
(_.HasCapability(SpvCapabilityShader) &&
(type_opcode == SpvOpTypeImage || type_opcode == SpvOpTypeSampler))) {

View File

@ -534,6 +534,24 @@ spv_result_t ConversionPass(ValidationState_t& _, const Instruction* inst) {
break;
}
case SpvOpConvertUToAccelerationStructureKHR: {
if (!_.IsAccelerationStructureType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Result Type to be a Acceleration Structure: "
<< spvOpcodeString(opcode);
}
const uint32_t input_type = _.GetOperandTypeId(inst, 2);
if (!input_type || !_.IsUnsigned64BitHandle(input_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected 64-bit uint scalar or 2-component 32-bit uint "
"vector as input: "
<< spvOpcodeString(opcode);
}
break;
}
default:
break;
}

View File

@ -190,6 +190,13 @@ uint32_t getBaseAlignment(uint32_t member_id, bool roundUp,
// Minimal alignment is byte-aligned.
uint32_t baseAlignment = 1;
switch (inst->opcode()) {
case SpvOpTypeSampledImage:
case SpvOpTypeSampler:
case SpvOpTypeImage:
if (vstate.HasCapability(SpvCapabilityBindlessTextureNV))
return baseAlignment = vstate.samplerimage_variable_address_mode() / 8;
assert(0);
return 0;
case SpvOpTypeInt:
case SpvOpTypeFloat:
baseAlignment = words[2] / 8;
@ -219,6 +226,7 @@ uint32_t getBaseAlignment(uint32_t member_id, bool roundUp,
baseAlignment =
componentAlignment * (num_columns == 3 ? 4 : num_columns);
}
if (roundUp) baseAlignment = align(baseAlignment, 16u);
} break;
case SpvOpTypeArray:
case SpvOpTypeRuntimeArray:
@ -256,6 +264,13 @@ uint32_t getScalarAlignment(uint32_t type_id, ValidationState_t& vstate) {
const auto inst = vstate.FindDef(type_id);
const auto& words = inst->words();
switch (inst->opcode()) {
case SpvOpTypeSampledImage:
case SpvOpTypeSampler:
case SpvOpTypeImage:
if (vstate.HasCapability(SpvCapabilityBindlessTextureNV))
return vstate.samplerimage_variable_address_mode() / 8;
assert(0);
return 0;
case SpvOpTypeInt:
case SpvOpTypeFloat:
return words[2] / 8;
@ -296,6 +311,13 @@ uint32_t getSize(uint32_t member_id, const LayoutConstraints& inherited,
const auto inst = vstate.FindDef(member_id);
const auto& words = inst->words();
switch (inst->opcode()) {
case SpvOpTypeSampledImage:
case SpvOpTypeSampler:
case SpvOpTypeImage:
if (vstate.HasCapability(SpvCapabilityBindlessTextureNV))
return vstate.samplerimage_variable_address_mode() / 8;
assert(0);
return 0;
case SpvOpTypeInt:
case SpvOpTypeFloat:
return words[2] / 8;
@ -638,7 +660,8 @@ bool hasDecoration(uint32_t id, SpvDecoration decoration,
}
// Returns true if all ids of given type have a specified decoration.
bool checkForRequiredDecoration(uint32_t struct_id, SpvDecoration decoration,
bool checkForRequiredDecoration(uint32_t struct_id,
std::function<bool(SpvDecoration)> checker,
SpvOp type, ValidationState_t& vstate) {
const auto& members = getStructMembers(struct_id, vstate);
for (size_t memberIdx = 0; memberIdx < members.size(); memberIdx++) {
@ -646,10 +669,10 @@ bool checkForRequiredDecoration(uint32_t struct_id, SpvDecoration decoration,
if (type != vstate.FindDef(id)->opcode()) continue;
bool found = false;
for (auto& dec : vstate.id_decorations(id)) {
if (decoration == dec.dec_type()) found = true;
if (checker(dec.dec_type())) found = true;
}
for (auto& dec : vstate.id_decorations(struct_id)) {
if (decoration == dec.dec_type() &&
if (checker(dec.dec_type()) &&
(int)memberIdx == dec.struct_member_index()) {
found = true;
}
@ -659,7 +682,7 @@ bool checkForRequiredDecoration(uint32_t struct_id, SpvDecoration decoration,
}
}
for (auto id : getStructMembers(struct_id, SpvOpTypeStruct, vstate)) {
if (!checkForRequiredDecoration(id, decoration, type, vstate)) {
if (!checkForRequiredDecoration(id, checker, type, vstate)) {
return false;
}
}
@ -1201,30 +1224,48 @@ spv_result_t CheckDecorationsOfBuffers(ValidationState_t& vstate) {
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
<< "Structure id " << id << " decorated as " << deco_str
<< " must not use GLSLPacked decoration.";
} else if (!checkForRequiredDecoration(id, SpvDecorationArrayStride,
SpvOpTypeArray, vstate)) {
} else if (!checkForRequiredDecoration(
id,
[](SpvDecoration d) {
return d == SpvDecorationArrayStride;
},
SpvOpTypeArray, vstate)) {
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
<< "Structure id " << id << " decorated as " << deco_str
<< " must be explicitly laid out with ArrayStride "
"decorations.";
} else if (!checkForRequiredDecoration(id,
SpvDecorationMatrixStride,
SpvOpTypeMatrix, vstate)) {
} else if (!checkForRequiredDecoration(
id,
[](SpvDecoration d) {
return d == SpvDecorationMatrixStride;
},
SpvOpTypeMatrix, vstate)) {
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
<< "Structure id " << id << " decorated as " << deco_str
<< " must be explicitly laid out with MatrixStride "
"decorations.";
} else if (!checkForRequiredDecoration(
id,
[](SpvDecoration d) {
return d == SpvDecorationRowMajor ||
d == SpvDecorationColMajor;
},
SpvOpTypeMatrix, vstate)) {
return vstate.diag(SPV_ERROR_INVALID_ID, vstate.FindDef(id))
<< "Structure id " << id << " decorated as " << deco_str
<< " must be explicitly laid out with RowMajor or "
"ColMajor decorations.";
} else if (blockRules &&
(SPV_SUCCESS != (recursive_status = checkLayout(
id, sc_str, deco_str, true,
scalar_block_layout, 0,
constraints, vstate)))) {
(SPV_SUCCESS !=
(recursive_status = checkLayout(
id, sc_str, deco_str, true, scalar_block_layout, 0,
constraints, vstate)))) {
return recursive_status;
} else if (bufferRules &&
(SPV_SUCCESS != (recursive_status = checkLayout(
id, sc_str, deco_str, false,
scalar_block_layout, 0,
constraints, vstate)))) {
(SPV_SUCCESS !=
(recursive_status = checkLayout(
id, sc_str, deco_str, false, scalar_block_layout,
0, constraints, vstate)))) {
return recursive_status;
}
}

View File

@ -927,7 +927,7 @@ spv_result_t ValidateTypeSampledImage(ValidationState_t& _,
return SPV_SUCCESS;
}
bool IsAllowedSampledImageOperand(SpvOp opcode) {
bool IsAllowedSampledImageOperand(SpvOp opcode, ValidationState_t& _) {
switch (opcode) {
case SpvOpSampledImage:
case SpvOpImageSampleImplicitLod:
@ -950,6 +950,9 @@ bool IsAllowedSampledImageOperand(SpvOp opcode) {
case SpvOpImageSparseDrefGather:
case SpvOpCopyObject:
return true;
case SpvOpStore:
if (_.HasCapability(SpvCapabilityBindlessTextureNV)) return true;
return false;
default:
return false;
}
@ -1035,7 +1038,7 @@ spv_result_t ValidateSampledImage(ValidationState_t& _,
<< _.getIdName(consumer_instr->id()) << "'.";
}
if (!IsAllowedSampledImageOperand(consumer_opcode)) {
if (!IsAllowedSampledImageOperand(consumer_opcode, _)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< "Result <id> from OpSampledImage instruction must not appear "
"as operand for Op"

View File

@ -483,6 +483,22 @@ spv_result_t InstructionPass(ValidationState_t& _, const Instruction* inst) {
if (auto error = LimitCheckNumVars(_, inst->id(), storage_class)) {
return error;
}
} else if (opcode == SpvOpSamplerImageAddressingModeNV) {
if (!_.HasCapability(SpvCapabilityBindlessTextureNV)) {
return _.diag(SPV_ERROR_MISSING_EXTENSION, inst)
<< "OpSamplerImageAddressingModeNV supported only with extension "
"SPV_NV_bindless_texture";
}
uint32_t bitwidth = inst->GetOperandAs<uint32_t>(0);
if (_.samplerimage_variable_address_mode() != 0) {
return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
<< "OpSamplerImageAddressingModeNV should only be provided once";
}
if (bitwidth != 32 && bitwidth != 64) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "OpSamplerImageAddressingModeNV bitwidth should be 64 or 32";
}
_.set_samplerimage_variable_address_mode(bitwidth);
}
if (auto error = ReservedCheck(_, inst)) return error;

View File

@ -363,6 +363,7 @@ spv_result_t ModuleLayoutPass(ValidationState_t& _, const Instruction* inst) {
case kLayoutExtensions:
case kLayoutExtInstImport:
case kLayoutMemoryModel:
case kLayoutSamplerImageAddressMode:
case kLayoutEntryPoint:
case kLayoutExecutionMode:
case kLayoutDebug1:

View File

@ -170,6 +170,16 @@ spv_result_t LogicalsPass(ValidationState_t& _, const Instruction* inst) {
break;
}
case SpvOpTypeSampledImage:
case SpvOpTypeImage:
case SpvOpTypeSampler: {
if (!_.HasCapability(SpvCapabilityBindlessTextureNV))
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Using image/sampler with OpSelect requires capability "
<< "BindlessTextureNV";
break;
}
case SpvOpTypeVector: {
dimension = type_inst->word(3);
break;

View File

@ -980,6 +980,7 @@ spv_result_t ValidateStore(ValidationState_t& _, const Instruction* inst) {
}
if (_.HasDecoration(base_type->id(), SpvDecorationBlock)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< _.VkErrorID(6925)
<< "In the Vulkan environment, cannot store to Uniform Blocks";
}
}

View File

@ -59,10 +59,7 @@ spv_result_t ValidateShaderClock(ValidationState_t& _,
// a vector of two - components of 32 -
// bit unsigned integer type
const uint32_t result_type = inst->type_id();
if (!(_.IsUnsignedIntScalarType(result_type) &&
_.GetBitWidth(result_type) == 64) &&
!(_.IsUnsignedIntVectorType(result_type) &&
_.GetDimension(result_type) == 2 && _.GetBitWidth(result_type) == 32)) {
if (!_.IsUnsigned64BitHandle(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst) << "Expected Value to be a "
"vector of two components"
" of unsigned integer"

View File

@ -0,0 +1,271 @@
// Copyright (c) 2022 The Khronos Group Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Validates ray query instructions from SPV_KHR_ray_query
#include "source/opcode.h"
#include "source/val/instruction.h"
#include "source/val/validate.h"
#include "source/val/validation_state.h"
namespace spvtools {
namespace val {
namespace {
spv_result_t ValidateRayQueryPointer(ValidationState_t& _,
const Instruction* inst,
uint32_t ray_query_index) {
const uint32_t ray_query_id = inst->GetOperandAs<uint32_t>(ray_query_index);
auto variable = _.FindDef(ray_query_id);
if (!variable || (variable->opcode() != SpvOpVariable &&
variable->opcode() != SpvOpFunctionParameter)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ray Query must be a memory object declaration";
}
auto pointer = _.FindDef(variable->GetOperandAs<uint32_t>(0));
if (!pointer || pointer->opcode() != SpvOpTypePointer) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ray Query must be a pointer";
}
auto type = _.FindDef(pointer->GetOperandAs<uint32_t>(2));
if (!type || type->opcode() != SpvOpTypeRayQueryKHR) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ray Query must be a pointer to OpTypeRayQueryKHR";
}
return SPV_SUCCESS;
}
spv_result_t ValidateIntersectionId(ValidationState_t& _,
const Instruction* inst,
uint32_t intersection_index) {
const uint32_t intersection_id =
inst->GetOperandAs<uint32_t>(intersection_index);
const uint32_t intersection_type = _.GetTypeId(intersection_id);
const SpvOp intersection_opcode = _.GetIdOpcode(intersection_id);
if (!_.IsIntScalarType(intersection_type) ||
_.GetBitWidth(intersection_type) != 32 ||
!spvOpcodeIsConstant(intersection_opcode)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected Intersection ID to be a constant 32-bit int scalar";
}
return SPV_SUCCESS;
}
} // namespace
spv_result_t RayQueryPass(ValidationState_t& _, const Instruction* inst) {
const SpvOp opcode = inst->opcode();
const uint32_t result_type = inst->type_id();
switch (opcode) {
case SpvOpRayQueryInitializeKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
if (_.GetIdOpcode(_.GetOperandTypeId(inst, 1)) !=
SpvOpTypeAccelerationStructureKHR) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Expected Acceleration Structure to be of type "
"OpTypeAccelerationStructureKHR";
}
const uint32_t ray_flags = _.GetOperandTypeId(inst, 2);
if (!_.IsIntScalarType(ray_flags) || _.GetBitWidth(ray_flags) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ray Flags must be a 32-bit int scalar";
}
const uint32_t cull_mask = _.GetOperandTypeId(inst, 3);
if (!_.IsIntScalarType(cull_mask) || _.GetBitWidth(cull_mask) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Cull Mask must be a 32-bit int scalar";
}
const uint32_t ray_origin = _.GetOperandTypeId(inst, 4);
if (!_.IsFloatVectorType(ray_origin) || _.GetDimension(ray_origin) != 3 ||
_.GetBitWidth(ray_origin) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ray Origin must be a 32-bit float 3-component vector";
}
const uint32_t ray_tmin = _.GetOperandTypeId(inst, 5);
if (!_.IsFloatScalarType(ray_tmin) || _.GetBitWidth(ray_tmin) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ray TMin must be a 32-bit float scalar";
}
const uint32_t ray_direction = _.GetOperandTypeId(inst, 6);
if (!_.IsFloatVectorType(ray_direction) ||
_.GetDimension(ray_direction) != 3 ||
_.GetBitWidth(ray_direction) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ray Direction must be a 32-bit float 3-component vector";
}
const uint32_t ray_tmax = _.GetOperandTypeId(inst, 7);
if (!_.IsFloatScalarType(ray_tmax) || _.GetBitWidth(ray_tmax) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Ray TMax must be a 32-bit float scalar";
}
break;
}
case SpvOpRayQueryTerminateKHR:
case SpvOpRayQueryConfirmIntersectionKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
break;
}
case SpvOpRayQueryGenerateIntersectionKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 0)) return error;
const uint32_t hit_t_id = _.GetOperandTypeId(inst, 1);
if (!_.IsFloatScalarType(hit_t_id) || _.GetBitWidth(hit_t_id) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "Hit T must be a 32-bit float scalar";
}
break;
}
case SpvOpRayQueryGetIntersectionFrontFaceKHR:
case SpvOpRayQueryProceedKHR:
case SpvOpRayQueryGetIntersectionCandidateAABBOpaqueKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (!_.IsBoolScalarType(result_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected Result Type to be bool scalar type";
}
if (opcode == SpvOpRayQueryGetIntersectionFrontFaceKHR) {
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
}
break;
}
case SpvOpRayQueryGetIntersectionTKHR:
case SpvOpRayQueryGetRayTMinKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (!_.IsFloatScalarType(result_type) ||
_.GetBitWidth(result_type) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected Result Type to be 32-bit float scalar type";
}
if (opcode == SpvOpRayQueryGetIntersectionTKHR) {
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
}
break;
}
case SpvOpRayQueryGetIntersectionTypeKHR:
case SpvOpRayQueryGetIntersectionInstanceCustomIndexKHR:
case SpvOpRayQueryGetIntersectionInstanceIdKHR:
case SpvOpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
case SpvOpRayQueryGetIntersectionGeometryIndexKHR:
case SpvOpRayQueryGetIntersectionPrimitiveIndexKHR:
case SpvOpRayQueryGetRayFlagsKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (!_.IsIntScalarType(result_type) || _.GetBitWidth(result_type) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected Result Type to be 32-bit int scalar type";
}
if (opcode != SpvOpRayQueryGetRayFlagsKHR) {
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
}
break;
}
case SpvOpRayQueryGetIntersectionObjectRayDirectionKHR:
case SpvOpRayQueryGetIntersectionObjectRayOriginKHR:
case SpvOpRayQueryGetWorldRayDirectionKHR:
case SpvOpRayQueryGetWorldRayOriginKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (!_.IsFloatVectorType(result_type) ||
_.GetDimension(result_type) != 3 ||
_.GetBitWidth(result_type) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected Result Type to be 32-bit float 3-component "
"vector type";
}
if (opcode == SpvOpRayQueryGetIntersectionObjectRayDirectionKHR ||
opcode == SpvOpRayQueryGetIntersectionObjectRayOriginKHR) {
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
}
break;
}
case SpvOpRayQueryGetIntersectionBarycentricsKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
if (!_.IsFloatVectorType(result_type) ||
_.GetDimension(result_type) != 2 ||
_.GetBitWidth(result_type) != 32) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected Result Type to be 32-bit float 2-component "
"vector type";
}
break;
}
case SpvOpRayQueryGetIntersectionObjectToWorldKHR:
case SpvOpRayQueryGetIntersectionWorldToObjectKHR: {
if (auto error = ValidateRayQueryPointer(_, inst, 2)) return error;
if (auto error = ValidateIntersectionId(_, inst, 3)) return error;
uint32_t num_rows = 0;
uint32_t num_cols = 0;
uint32_t col_type = 0;
uint32_t component_type = 0;
if (!_.GetMatrixTypeInfo(result_type, &num_rows, &num_cols, &col_type,
&component_type)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected matrix type as Result Type";
}
if (num_cols != 4) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected Result Type matrix to have a Column Count of 4";
}
if (!_.IsFloatScalarType(component_type) ||
_.GetBitWidth(result_type) != 32 || num_rows != 3) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< "expected Result Type matrix to have a Column Type of "
"3-component 32-bit float vectors";
}
break;
}
default:
break;
}
return SPV_SUCCESS;
}
} // namespace val
} // namespace spvtools

View File

@ -220,30 +220,23 @@ spv_result_t ValidateMemoryScope(ValidationState_t& _, const Instruction* inst,
// Vulkan Specific rules
if (spvIsVulkanEnv(_.context()->target_env)) {
if (value == SpvScopeCrossDevice) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(4638) << spvOpcodeString(opcode)
<< ": in Vulkan environment, Memory Scope cannot be CrossDevice";
}
// Vulkan 1.0 specific rules
if (_.context()->target_env == SPV_ENV_VULKAN_1_0 &&
value != SpvScopeDevice && value != SpvScopeWorkgroup &&
value != SpvScopeInvocation) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(4638) << spvOpcodeString(opcode)
<< ": in Vulkan 1.0 environment Memory Scope is limited to "
<< "Device, Workgroup and Invocation";
}
// Vulkan 1.1 specific rules
if ((_.context()->target_env == SPV_ENV_VULKAN_1_1 ||
_.context()->target_env == SPV_ENV_VULKAN_1_2) &&
value != SpvScopeDevice && value != SpvScopeWorkgroup &&
if (value != SpvScopeDevice && value != SpvScopeWorkgroup &&
value != SpvScopeSubgroup && value != SpvScopeInvocation &&
value != SpvScopeShaderCallKHR) {
value != SpvScopeShaderCallKHR && value != SpvScopeQueueFamily) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< _.VkErrorID(4638) << spvOpcodeString(opcode)
<< ": in Vulkan 1.1 and 1.2 environment Memory Scope is limited "
<< "to Device, Workgroup, Invocation, and ShaderCall";
<< ": in Vulkan environment Memory Scope is limited to Device, "
"QueueFamily, Workgroup, ShaderCallKHR, Subgroup, or "
"Invocation";
} else if (_.context()->target_env == SPV_ENV_VULKAN_1_0 &&
value == SpvScopeSubgroup &&
!_.HasCapability(SpvCapabilitySubgroupBallotKHR) &&
!_.HasCapability(SpvCapabilitySubgroupVoteKHR)) {
return _.diag(SPV_ERROR_INVALID_DATA, inst)
<< spvOpcodeString(opcode)
<< ": in Vulkan 1.0 environment Memory Scope is can not be "
"Subgroup without SubgroupBallotKHR or SubgroupVoteKHR "
"declared";
}
if (value == SpvScopeShaderCallKHR) {

View File

@ -306,35 +306,6 @@ spv_result_t ValidateTypeRuntimeArray(ValidationState_t& _,
return SPV_SUCCESS;
}
bool ContainsOpaqueType(ValidationState_t& _, const Instruction* str) {
const size_t elem_type_index = 1;
uint32_t elem_type_id;
Instruction* elem_type;
if (spvOpcodeIsBaseOpaqueType(str->opcode())) {
return true;
}
switch (str->opcode()) {
case SpvOpTypeArray:
case SpvOpTypeRuntimeArray:
elem_type_id = str->GetOperandAs<uint32_t>(elem_type_index);
elem_type = _.FindDef(elem_type_id);
return ContainsOpaqueType(_, elem_type);
case SpvOpTypeStruct:
for (size_t member_type_index = 1;
member_type_index < str->operands().size(); ++member_type_index) {
auto member_type_id = str->GetOperandAs<uint32_t>(member_type_index);
auto member_type = _.FindDef(member_type_id);
if (ContainsOpaqueType(_, member_type)) return true;
}
break;
default:
break;
}
return false;
}
spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
const uint32_t struct_id = inst->GetOperandAs<uint32_t>(0);
for (size_t member_type_index = 1;
@ -425,8 +396,21 @@ spv_result_t ValidateTypeStruct(ValidationState_t& _, const Instruction* inst) {
_.RegisterStructTypeWithBuiltInMember(struct_id);
}
const auto isOpaqueType = [&_](const Instruction* opaque_inst) {
auto opcode = opaque_inst->opcode();
if (_.HasCapability(SpvCapabilityBindlessTextureNV) &&
(opcode == SpvOpTypeImage || opcode == SpvOpTypeSampler ||
opcode == SpvOpTypeSampledImage)) {
return false;
} else if (spvOpcodeIsBaseOpaqueType(opcode)) {
return true;
}
return false;
};
if (spvIsVulkanEnv(_.context()->target_env) &&
!_.options()->before_hlsl_legalization && ContainsOpaqueType(_, inst)) {
!_.options()->before_hlsl_legalization &&
_.ContainsType(inst->id(), isOpaqueType)) {
return _.diag(SPV_ERROR_INVALID_ID, inst)
<< _.VkErrorID(4667) << "In "
<< spvLogStringForEnv(_.context()->target_env)

View File

@ -90,6 +90,8 @@ ModuleLayoutSection InstructionLayoutSection(
if (current_section == kLayoutFunctionDeclarations)
return kLayoutFunctionDeclarations;
return kLayoutFunctionDefinitions;
case SpvOpSamplerImageAddressingModeNV:
return kLayoutSamplerImageAddressMode;
default:
break;
}
@ -161,6 +163,7 @@ ValidationState_t::ValidationState_t(const spv_const_context ctx,
addressing_model_(SpvAddressingModelMax),
memory_model_(SpvMemoryModelMax),
pointer_size_and_alignment_(0),
sampler_image_addressing_mode_(0),
in_function_(false),
num_of_warnings_(0),
max_num_of_warnings_(max_warnings) {
@ -473,6 +476,15 @@ void ValidationState_t::set_memory_model(SpvMemoryModel mm) {
SpvMemoryModel ValidationState_t::memory_model() const { return memory_model_; }
void ValidationState_t::set_samplerimage_variable_address_mode(
uint32_t bit_width) {
sampler_image_addressing_mode_ = bit_width;
}
uint32_t ValidationState_t::samplerimage_variable_address_mode() const {
return sampler_image_addressing_mode_;
}
spv_result_t ValidationState_t::RegisterFunction(
uint32_t id, uint32_t ret_type_id, SpvFunctionControlMask function_control,
uint32_t function_type_id) {
@ -965,6 +977,11 @@ bool ValidationState_t::GetPointerTypeInfo(uint32_t id, uint32_t* data_type,
return true;
}
bool ValidationState_t::IsAccelerationStructureType(uint32_t id) const {
const Instruction* inst = FindDef(id);
return inst && inst->opcode() == SpvOpTypeAccelerationStructureKHR;
}
bool ValidationState_t::IsCooperativeMatrixType(uint32_t id) const {
const Instruction* inst = FindDef(id);
return inst && inst->opcode() == SpvOpTypeCooperativeMatrixNV;
@ -985,6 +1002,13 @@ bool ValidationState_t::IsUnsignedIntCooperativeMatrixType(uint32_t id) const {
return IsUnsignedIntScalarType(FindDef(id)->word(2));
}
// Either a 32 bit 2-component uint vector or a 64 bit uint scalar
bool ValidationState_t::IsUnsigned64BitHandle(uint32_t id) const {
return ((IsUnsignedIntScalarType(id) && GetBitWidth(id) == 64) ||
(IsUnsignedIntVectorType(id) && GetDimension(id) == 2 &&
GetBitWidth(id) == 32));
}
spv_result_t ValidationState_t::CooperativeMatrixShapesMatch(
const Instruction* inst, uint32_t m1, uint32_t m2) {
const auto m1_type = FindDef(m1);
@ -1951,6 +1975,8 @@ std::string ValidationState_t::VkErrorID(uint32_t id,
return VUID_WRAP(VUID-StandaloneSpirv-Uniform-06807);
case 6808:
return VUID_WRAP(VUID-StandaloneSpirv-PushConstant-06808);
case 6925:
return VUID_WRAP(VUID-StandaloneSpirv-Uniform-06925);
default:
return ""; // unknown id
}

View File

@ -44,19 +44,20 @@ namespace val {
/// of the SPIRV spec for additional details of the order. The enumerant values
/// are in the same order as the vector returned by GetModuleOrder
enum ModuleLayoutSection {
kLayoutCapabilities, /// < Section 2.4 #1
kLayoutExtensions, /// < Section 2.4 #2
kLayoutExtInstImport, /// < Section 2.4 #3
kLayoutMemoryModel, /// < Section 2.4 #4
kLayoutEntryPoint, /// < Section 2.4 #5
kLayoutExecutionMode, /// < Section 2.4 #6
kLayoutDebug1, /// < Section 2.4 #7 > 1
kLayoutDebug2, /// < Section 2.4 #7 > 2
kLayoutDebug3, /// < Section 2.4 #7 > 3
kLayoutAnnotations, /// < Section 2.4 #8
kLayoutTypes, /// < Section 2.4 #9
kLayoutFunctionDeclarations, /// < Section 2.4 #10
kLayoutFunctionDefinitions /// < Section 2.4 #11
kLayoutCapabilities, /// < Section 2.4 #1
kLayoutExtensions, /// < Section 2.4 #2
kLayoutExtInstImport, /// < Section 2.4 #3
kLayoutMemoryModel, /// < Section 2.4 #4
kLayoutSamplerImageAddressMode, /// < Section 2.4 #5
kLayoutEntryPoint, /// < Section 2.4 #6
kLayoutExecutionMode, /// < Section 2.4 #7
kLayoutDebug1, /// < Section 2.4 #8 > 1
kLayoutDebug2, /// < Section 2.4 #8 > 2
kLayoutDebug3, /// < Section 2.4 #8 > 3
kLayoutAnnotations, /// < Section 2.4 #9
kLayoutTypes, /// < Section 2.4 #10
kLayoutFunctionDeclarations, /// < Section 2.4 #11
kLayoutFunctionDefinitions /// < Section 2.4 #12
};
/// This class manages the state of the SPIR-V validation as it is being parsed.
@ -360,6 +361,20 @@ class ValidationState_t {
/// Returns the memory model of this module, or Simple if uninitialized.
SpvMemoryModel memory_model() const;
/// Sets the bit width for sampler/image type variables. If not set, they are
/// considered opaque
void set_samplerimage_variable_address_mode(uint32_t bit_width);
/// Get the addressing mode currently set. If 0, it means addressing mode is
/// invalid Sampler/Image type variables must be considered opaque This mode
/// is only valid after the instruction has been read
uint32_t samplerimage_variable_address_mode() const;
/// Returns true if the OpSamplerImageAddressingModeNV was found.
bool has_samplerimage_variable_address_mode_specified() const {
return sampler_image_addressing_mode_ != 0;
}
const AssemblyGrammar& grammar() const { return grammar_; }
/// Inserts the instruction into the list of ordered instructions in the file.
@ -592,10 +607,12 @@ class ValidationState_t {
bool IsBoolVectorType(uint32_t id) const;
bool IsBoolScalarOrVectorType(uint32_t id) const;
bool IsPointerType(uint32_t id) const;
bool IsAccelerationStructureType(uint32_t id) const;
bool IsCooperativeMatrixType(uint32_t id) const;
bool IsFloatCooperativeMatrixType(uint32_t id) const;
bool IsIntCooperativeMatrixType(uint32_t id) const;
bool IsUnsignedIntCooperativeMatrixType(uint32_t id) const;
bool IsUnsigned64BitHandle(uint32_t id) const;
// Returns true if |id| is a type id that contains |type| (or integer or
// floating point type) of |width| bits.
@ -862,7 +879,10 @@ class ValidationState_t {
// have the same pointer size (for physical pointer types).
uint32_t pointer_size_and_alignment_;
/// NOTE: See corresponding getter functions
/// bit width of sampler/image type variables. Valid values are 32 and 64
uint32_t sampler_image_addressing_mode_;
/// NOTE: See correspoding getter functions
bool in_function_;
/// The state of optional features. These are determined by capabilities

View File

@ -127,6 +127,7 @@ project "spirv-opt"
path.join(SPIRV_TOOLS, "source/val/validate_mode_setting.cpp"),
path.join(SPIRV_TOOLS, "source/val/validate_non_uniform.cpp"),
path.join(SPIRV_TOOLS, "source/val/validate_primitives.cpp"),
path.join(SPIRV_TOOLS, "source/val/validate_ray_query.cpp"),
path.join(SPIRV_TOOLS, "source/val/validate_scopes.cpp"),
path.join(SPIRV_TOOLS, "source/val/validate_small_type_uses.cpp"),
path.join(SPIRV_TOOLS, "source/val/validate_type.cpp"),