TensorFlow Lite の delegate に、eager が追加されたのは、r1.11 の時、TensorFlowのEagerモードのコードを動かすための仕組みだと思っていた。
先週からのブログでアップしていたのが、その部分のコード。
ここにきて、TensorFlowのOpをTensorFlow Lite内で動かす仕組みだということが分かった。
ここにきて、TensorFlowのOpをTensorFlow Lite内で動かす仕組みだということが分かった。
最初のコードにもちゃんと書いてあった。
// Note: this is part of TF Lite's Eager delegation code which is to be // completed soon. // This is the TF Lite op that is created by the eager delegate to handle // execution of a supported subgraph. The usual flow is that the delegate // informs the interpreter of supported nodes in a graph, and each supported // subgraph is replaced with one instance of this kernel. // // The kernel is initialized with TfLiteDelegateParams from which we retrieve // the global EagerContext and BufferMap, as well as a list of inputs and // outputs to the subgraph. Those are used to build the OpData, with a list of // TensorFlow Ops that should be executed in order (which we call an OpNode). // // For each node included in the subgraph, we query the interpreter and // retrieve the associated NodeDef, which is then used to configure the // corresponding TensorFlow/Eager Op.
そして、Opを実行するコード(r1.11)
tensorflow::Status ExecuteEagerOp(tensorflow::EagerContext* eager_context,
BufferMap* buffer_map, const string& op_name,
const tensorflow::NodeDef& nodedef,
const std::vector<int>& inputs,
const std::vector<int>& outputs) {
const tensorflow::AttrTypeMap* attr_types;
TF_RETURN_IF_ERROR(
tensorflow::AttrTypeMapForOp(op_name.c_str(), &attr_types));
第一引数が tensorflow::EagerContext* ということがポイント。このメソッドは、TensorFlow の Eagerモードで動く。
tensorflow::EagerOperation op(eager_context, op_name.c_str(), attr_types);
ほい、来たよ。op が tensorflow::EagerOperation だよ。
for (const auto& attr : nodedef.attr()) {
op.MutableAttrs()->Set(attr.first, attr.second);
}
ノードの属性を op に追加。
for (int input_index : inputs) {
if (!buffer_map->HasTensor(input_index)) {
return tensorflow::errors::Internal("Invalid tensor index ", input_index);
}
auto* handle = new tensorflow::TensorHandle(
buffer_map->GetTensor(input_index), nullptr, nullptr, nullptr);
op.AddInput(handle);
handle->Unref();
}
int num_retvals = outputs.size();
tensorflow::gtl::InlinedVector<tensorflow::TensorHandle*, 2> retvals(
num_retvals, nullptr);
入力および出力を設定。
TF_RETURN_IF_ERROR(EagerExecute(&op, &retvals, &num_retvals));
op を実行。
if (outputs.size() != num_retvals) {
return tensorflow::errors::Internal(
"Unexpected number of outputs from EagerExecute");
}
for (int i = 0; i < num_retvals; ++i) {
const tensorflow::Tensor* tensor = nullptr;
TF_RETURN_IF_ERROR(retvals[i]->Tensor(&tensor));
buffer_map->SetFromTensorFlow(outputs[i], *tensor);
retvals[i]->Unref();
}
return tensorflow::Status::OK();
}
TensorFlowで実行した出力値を獲得して、出力に代入。
Status EagerExecute(EagerOperation* op,
gtl::InlinedVector<TensorHandle*, 2>* retvals,
int* num_retvals) {
TF_RETURN_IF_ERROR(MaybeUpdateOpDevice(op));
bool op_is_local = IsLocal(op->EagerContext(), op->Device());
if (op_is_local) {
return EagerLocalExecute(op, retvals, num_retvals);
}
if (op->EagerContext()->LogDevicePlacement()) {
LOG(INFO) << "Executing op " << op->Name() << " in device "
<< op->Device()->name();
}
return EagerRemoteExecute(op, retvals->data(), num_retvals);
}
これ、TensorFlow のメソッドだよ。。。