実装は、src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.cppの HalideSubgraphExtraction::run_on_function メソッドです。
引用 bool runtime::cpu::pass::HalideSubgraphExtraction::run_on_function( std::shared_ptr<ngraph::Function> function) { list<shared_ptr<Node>> worklist; auto results = function->get_results(); // Artificial limitation if (results.size() > 1) { return false; } if (function->get_result()->get_element_type() != element::f32) { return false; } for (const auto& result : results) { worklist.emplace_back(result); }
ここまでで、出力、つまり、戻り値の数が1であることを確認後、戻り値の方が f32 であることも確認。最後に、
戻り値を worklist に保存。
戻り値を worklist に保存。
unordered_set<shared_ptr<Node>> ops; list<shared_ptr<Node>> ordered_ops; while (!worklist.empty()) { const auto& node = worklist.front(); if (!halide::skiplist.count(TI(*node))) { if (halide::whitelist.count(TI(*node))) { ops.emplace(node); ordered_ops.emplace_back(node); } else { break; } } const auto& args = node->get_arguments(); for (const auto& arg : args) { worklist.emplace_back(arg); } worklist.pop_front(); }
NodeVector liveins;
for (const auto& op : ops)
{
const auto& args = op->get_arguments();
for (const auto& arg : args)
{
if (!ops.count(arg))
{
liveins.emplace_back(arg);
}
}
}
ordered_ops.reverse();
if (ordered_ops.size() > 1)
{
auto subgraph = make_shared<cpu::op::HalideOp>(liveins,
ordered_ops,
function->get_result()->get_element_type(),
function->get_result()->get_shape());
replace_node(function->get_result()->get_argument(0), subgraph);
return true;
}
else
{
return false;
}
}
ops と ordered_ops から cpu::op::HalideOp を使って、subgraph を生成します。
作った、replace_node を使って、生成した subgraph に置き換えます。
作った、replace_node を使って、生成した subgraph に置き換えます。
replace_node は、/src/ngraph/runtime/cpu/pass/halide_subgraph_extraction.cppで次のように定義されています。
void Function::replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl)
{
ngraph::replace_node(old, repl);
}
nraph::replace_node を呼んでいます。
void ngraph::replace_node(std::shared_ptr<Node> target, std::shared_ptr<Node> replacement) { if (target->is_output()) { throw ngraph_error("Result nodes cannot be replaced."); } if (target->get_users().empty()) { throw ngraph_error("replacing an unreachable node"); } // Fix input/output descriptors assert(target->get_outputs().size() == replacement->get_outputs().size());
置き換える側のノードをチェックをしてから下記のように置き換えていきます。
// For each of target's output O with replacement output O_rep:
// For each O's connected downstream input I:
// Change I's connected upstream output to O_rep
for (size_t i = 0; i < target->get_outputs().size(); i++)
{
auto& target_output = target->get_outputs().at(i);
std::set<ngraph::descriptor::Input*> copy_inputs{begin(target_output.get_inputs()),
end(target_output.get_inputs())};
for (auto input : copy_inputs)
{
input->replace_output(replacement->get_outputs().at(i));
}
}
}
test/halide.cppの下記のコードを見てみましょう。
TEST(halide, halide_subgraph)
{
Shape shape{8};
auto A = make_shared<op::Parameter>(element::f32, shape);
auto B = make_shared<op::Parameter>(element::f32, shape);
auto C = make_shared<op::Parameter>(element::f32, shape);
auto D = make_shared<op::Parameter>(element::f32, shape);
auto relu = make_shared<op::Relu>((A + B) * C);
auto f = make_shared<Function>(relu + D, ParameterVector{A, B, C, D});
auto backend = runtime::Backend::create("CPU");
shared_ptr<runtime::Tensor> a = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> b = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> c = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> d = backend->create_tensor(element::f32, shape);
shared_ptr<runtime::Tensor> result = backend->create_tensor(element::f32, shape);
vector<float> data{-1, 4, -2, 5, 1, 5, 7, 9};
copy_data(a, data);
copy_data(b, data);
copy_data(c, data);
copy_data(d, data);
vector<float> expected{1, 36, 6, 55, 3, 55, 105, 171};
backend->call_with_validate(backend->compile(f), {result}, {a, b, c, d});
EXPECT_TRUE(test::all_close(read_vector<float>(result), expected, 1.0e-4f, 1.0e-4f));
}
auto relu = make_shared<op::Relu>((A + B) * C);
と
auto f = make_shared<Function>(relu + D, ParameterVector{A, B, C, D});
の
relu + D
のどこかが HalideのSubgraph として構築されるのでしょうね。明日は、cpu::op::HalideOp を見ていきます。