xla_primitive_callableメソッドの
def xla_primitive_callable(prim, *abstract_args, **kwargs): shapes = map(xla_shape, abstract_args) built_c = primitive_computation(prim, *shapes, **kwargs) result_shape = xla_shape_to_result_shape(built_c.GetReturnValueShape()) handle_result = result_handler(result_shape) compiled = built_c.Compile(shapes, xb.get_compile_options()) return partial(execute_compiled_primitive, compiled, handle_result)
最後に、execute_compiled_primitiveメソッドなるものがあります。
def execute_compiled_primitive(compiled, result_handler, *args): input_bufs = [device_put(canonicalize_pyval_dtype(x)) for x in args] return result_handler(compiled.Execute(input_bufs, not core.skip_checks))
compiled.Executeメソッドは、local_computation_builder.ccにて、次のように定義されています。
StatusOr<LocalShapedBuffer*> CompiledLocalComputation::Execute(
absl::Span<LocalShapedBuffer* const> argument_handles) {
TF_ASSIGN_OR_RETURN(LocalClient * client, GetOrCreateLocalClient());
StatusOr<int> device_ordinal_status = client->ReplicaNumberToDeviceOrdinal(0);
StatusOr<ScopedShapedBuffer> result_buffer_status;
if (!device_ordinal_status.ok()) {
result_buffer_status = device_ordinal_status.status();
} else {
const int device_ordinal = device_ordinal_status.ValueOrDie();
VLOG(3) << "Replica 0 mapped to device ordinal for execution: "
<< device_ordinal;
std::vector<const ShapedBuffer*> argument_buffers;
argument_buffers.reserve(argument_handles.size());
for (auto& handle : argument_handles) {
argument_buffers.push_back(handle->shaped_buffer());
}
DeviceAssignment device_assignment =
client->backend()
.computation_placer()
->AssignDevices(1, /*computation_count=*/1)
.ConsumeValueOrDie();
ExecutableRunOptions options;
options.set_device_ordinal(device_ordinal);
options.set_allocator(client->backend().memory_allocator());
options.set_intra_op_thread_pool(
client->backend().eigen_intra_op_thread_pool_device());
options.set_device_assignment(&device_assignment);
result_buffer_status = executable_->Run(argument_buffers, options);
}
if (!result_buffer_status.ok()) {
return InternalError(
"Failed running replica 0 (other replicas may have failed as well): "
"%s.",
result_buffer_status.status().ToString());
}
return new LocalShapedBuffer(std::move(result_buffer_status).ValueOrDie());
}
executable_->Run にて、実行しています。
xla_primitive_callableメソッドは、apply_primitiveメソッドから呼ばれいています。
def apply_primitive(prim, *args, **kwargs): abstract_args = map(abstractify, args) compiled_fun = xla_primitive_callable(prim, *abstract_args, **kwargs) return compiled_fun(*args)
standard_primitiveメソッド、どうやらこれが基本的なものらしいですね。
def standard_primitive(shape_rule, dtype_rule, name, translation_rule=None): prim = Primitive(name) prim.def_impl(partial(xla.apply_primitive, prim)) prim.def_abstract_eval(partial(standard_abstract_eval, shape_rule, dtype_rule)) xla.translations[prim] = translation_rule or partial(standard_translate, name) return prim