141 lines
4.4 KiB
C++
141 lines
4.4 KiB
C++
// Copyright (c) 2018 Google LLC.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "source/opt/licm_pass.h"
|
|
|
|
#include <queue>
|
|
#include <utility>
|
|
|
|
#include "source/opt/module.h"
|
|
#include "source/opt/pass.h"
|
|
|
|
namespace spvtools {
|
|
namespace opt {
|
|
|
|
Pass::Status LICMPass::Process() { return ProcessIRContext(); }
|
|
|
|
Pass::Status LICMPass::ProcessIRContext() {
|
|
Status status = Status::SuccessWithoutChange;
|
|
Module* module = get_module();
|
|
|
|
// Process each function in the module
|
|
for (auto func = module->begin();
|
|
func != module->end() && status != Status::Failure; ++func) {
|
|
status = CombineStatus(status, ProcessFunction(&*func));
|
|
}
|
|
return status;
|
|
}
|
|
|
|
Pass::Status LICMPass::ProcessFunction(Function* f) {
|
|
Status status = Status::SuccessWithoutChange;
|
|
LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);
|
|
|
|
// Process each loop in the function
|
|
for (auto it = loop_descriptor->begin();
|
|
it != loop_descriptor->end() && status != Status::Failure; ++it) {
|
|
Loop& loop = *it;
|
|
// Ignore nested loops, as we will process them in order in ProcessLoop
|
|
if (loop.IsNested()) {
|
|
continue;
|
|
}
|
|
status = CombineStatus(status, ProcessLoop(&loop, f));
|
|
}
|
|
return status;
|
|
}
|
|
|
|
Pass::Status LICMPass::ProcessLoop(Loop* loop, Function* f) {
|
|
Status status = Status::SuccessWithoutChange;
|
|
|
|
// Process all nested loops first
|
|
for (auto nl = loop->begin(); nl != loop->end() && status != Status::Failure;
|
|
++nl) {
|
|
Loop* nested_loop = *nl;
|
|
status = CombineStatus(status, ProcessLoop(nested_loop, f));
|
|
}
|
|
|
|
std::vector<BasicBlock*> loop_bbs{};
|
|
status = CombineStatus(
|
|
status,
|
|
AnalyseAndHoistFromBB(loop, f, loop->GetHeaderBlock(), &loop_bbs));
|
|
|
|
for (size_t i = 0; i < loop_bbs.size() && status != Status::Failure; ++i) {
|
|
BasicBlock* bb = loop_bbs[i];
|
|
// do not delete the element
|
|
status =
|
|
CombineStatus(status, AnalyseAndHoistFromBB(loop, f, bb, &loop_bbs));
|
|
}
|
|
|
|
return status;
|
|
}
|
|
|
|
Pass::Status LICMPass::AnalyseAndHoistFromBB(
|
|
Loop* loop, Function* f, BasicBlock* bb,
|
|
std::vector<BasicBlock*>* loop_bbs) {
|
|
bool modified = false;
|
|
std::function<bool(Instruction*)> hoist_inst =
|
|
[this, &loop, &modified](Instruction* inst) {
|
|
if (loop->ShouldHoistInstruction(this->context(), inst)) {
|
|
if (!HoistInstruction(loop, inst)) {
|
|
return false;
|
|
}
|
|
modified = true;
|
|
}
|
|
return true;
|
|
};
|
|
|
|
if (IsImmediatelyContainedInLoop(loop, f, bb)) {
|
|
if (!bb->WhileEachInst(hoist_inst, false)) {
|
|
return Status::Failure;
|
|
}
|
|
}
|
|
|
|
DominatorAnalysis* dom_analysis = context()->GetDominatorAnalysis(f);
|
|
DominatorTree& dom_tree = dom_analysis->GetDomTree();
|
|
|
|
for (DominatorTreeNode* child_dom_tree_node : *dom_tree.GetTreeNode(bb)) {
|
|
if (loop->IsInsideLoop(child_dom_tree_node->bb_)) {
|
|
loop_bbs->push_back(child_dom_tree_node->bb_);
|
|
}
|
|
}
|
|
|
|
return (modified ? Status::SuccessWithChange : Status::SuccessWithoutChange);
|
|
}
|
|
|
|
bool LICMPass::IsImmediatelyContainedInLoop(Loop* loop, Function* f,
|
|
BasicBlock* bb) {
|
|
LoopDescriptor* loop_descriptor = context()->GetLoopDescriptor(f);
|
|
return loop == (*loop_descriptor)[bb->id()];
|
|
}
|
|
|
|
bool LICMPass::HoistInstruction(Loop* loop, Instruction* inst) {
|
|
// TODO(1841): Handle failure to create pre-header.
|
|
BasicBlock* pre_header_bb = loop->GetOrCreatePreHeaderBlock();
|
|
if (!pre_header_bb) {
|
|
return false;
|
|
}
|
|
Instruction* insertion_point = &*pre_header_bb->tail();
|
|
Instruction* previous_node = insertion_point->PreviousNode();
|
|
if (previous_node && (previous_node->opcode() == SpvOpLoopMerge ||
|
|
previous_node->opcode() == SpvOpSelectionMerge)) {
|
|
insertion_point = previous_node;
|
|
}
|
|
|
|
inst->InsertBefore(insertion_point);
|
|
context()->set_instr_block(inst, pre_header_bb);
|
|
return true;
|
|
}
|
|
|
|
} // namespace opt
|
|
} // namespace spvtools
|