今回が最後。
モデルの実行部分。
def test_resnet():
# Load the model
model_path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
'resnet18')
classes = 1000
device = 'cpu'
model = DLRModel(model_path, device)
# Run the model
image = np.load(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'dog.npy')).astype(np.float32)
#flatten within a input array
input_data = {'data': image}
probabilities = model.run(input_data) #need to be a list of input arrays matching input names
assert probabilities[0].argmax() == 111
の
probabilities = model.run(input_data) #need to be a list of input arrays matching input namesの run メソッド。
def run(self, input_values):
out = []
# set input(s)
if isinstance(input_values, (np.ndarray, np.generic)):
# Treelite model or single input tvm/treelite model.
# Treelite has a dummy input name 'data'.
if self.input_names:
self._set_input(self.input_names[0], input_values)
elif isinstance(input_values, dict):
# TVM model
for key, value in input_values.items():
if self.input_names and key not in self.input_names:
raise ValueError("%s is not a valid input name." % key)
self._set_input(key, value)
else:
raise ValueError("input_values must be of type dict (tvm model) " +
"or a np.ndarray/generic (representing treelite models)")
# run model
self._run()
# get output
for i in range(self.num_outputs):
ith_out = self._get_output(i)
out.append(ith_out)
return out
で、_run メソッドを呼んでいます。
def _run(self):
"""A light wrapper to call run in the DLR backend."""
self._check_call(self.lib.RunDLRModel(byref(self.handle)))
RunDLRModelメソッドを呼んでいますね。
extern "C" int RunDLRModel(DLRModelHandle *handle) {
API_BEGIN();
static_cast<DLRModel *>(*handle)->Run();
API_END();
}
DLRModelクラス の Runメソッドを呼んでいます。
void DLRModel::Run() {
if (backend_ == DLRBackend::kTVM) {
// get the function from the module(run it)
tvm::runtime::PackedFunc run = tvm_module_->GetFunction("run");
run();
} else if (backend_ == DLRBackend::kTREELITE) {
// NOTE: Assume batch size is 1. However, Treelite internally can support
// arbitrary batch size
size_t out_result_size;
CHECK_EQ(TreelitePredictorPredictInst(treelite_model_, treelite_input_.get(),
0, treelite_output_.get(),
&out_result_size), 0)
<< TreeliteGetLastError();
}
}
にて、TVMモデルの時は、GetFunctionメソッドにてrunメソッド呼んでいますね。
} else if (name == "run") {
return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
this->Run();
});
}
最後は、GraphRuntime::Runメソッドを実行。
void GraphRuntime::Run() {
// setup the array and requirements.
for (size_t i = 0; i < op_execs_.size(); ++i) {
if (op_execs_[i]) op_execs_[i]();
}
}
ひとつづつ、オペを実行しているだけですね。