1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file graph_executor.cc
22 */
23#include "graph_executor.h"
24
25#include <tvm/runtime/container/map.h>
26#include <tvm/runtime/container/string.h>
27#include <tvm/runtime/data_type.h>
28#include <tvm/runtime/device_api.h>
29#include <tvm/runtime/ndarray.h>
30#include <tvm/runtime/packed_func.h>
31#include <tvm/runtime/profiling.h>
32#include <tvm/runtime/registry.h>
33#include <tvm/runtime/serializer.h>
34
35#include <algorithm>
36#include <functional>
37#include <memory>
38#include <numeric>
39#include <string>
40#include <unordered_set>
41#include <utility>
42#include <vector>
43
44#include "../file_utils.h"
45#include "../texture.h"
46
47namespace tvm {
48namespace runtime {
49namespace details {
50inline size_t GetDataAlignment(const DLTensor& arr) {
51 size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes;
52 if (align < kAllocAlignment) return kAllocAlignment;
53 return align;
54}
55constexpr auto Is2DStorage = IsTextureStorage;
56} // namespace details
57
58/*!
59 * \brief Run all the operations one by one.
60 */
61void GraphExecutor::Run() {
62 // setup the array and requirements.
63 for (size_t i = 0; i < op_execs_.size(); ++i) {
64 if (op_execs_[i]) op_execs_[i]();
65 }
66}
67
68/*!
69 * \brief Initialize the graph executor with graph and device.
70 * \param graph_json The execution graph.
71 * \param module The module containing the compiled functions for the host
72 * processor.
73 * \param devs The devices of the host and devices where graph nodes will be
74 * executed on.
75 * \param lookup_linked_param_func Linked parameter lookup function. Default is nullptr.
76 */
77void GraphExecutor::Init(const std::string& graph_json, tvm::runtime::Module module,
78 const std::vector<Device>& devs,
79 const PackedFunc lookup_linked_param_func) {
80 std::istringstream is(graph_json);
81 dmlc::JSONReader reader(&is);
82 this->Load(&reader);
83 module_ = module;
84 devices_ = devs;
85 lookup_linked_param_ = lookup_linked_param_func;
86 if (lookup_linked_param_ == nullptr) {
87 lookup_linked_param_ = PackedFunc(
88 [this](TVMArgs args, TVMRetValue* rv) { this->DefaultLookupLinkedParam(args, rv); });
89 }
90 this->SetupStorage();
91 this->SetupOpExecs();
92 for (size_t i = 0; i < input_nodes_.size(); i++) {
93 const uint32_t nid = input_nodes_[i];
94 std::string& name = nodes_[nid].name;
95 input_map_[name] = i;
96 }
97 for (size_t i = 0; i < outputs_.size(); i++) {
98 const uint32_t nid = outputs_[i].node_id;
99 std::string& name = nodes_[nid].name;
100 std::stringstream ss;
101 ss << name << ":" << i;
102 output_map_[ss.str()] = i;
103 }
104}
105
106/*!
107 * \brief Get the input index given the name of input.
108 * \param name The name of the input.
109 * \return The index of input.
110 */
111int GraphExecutor::GetInputIndex(const std::string& name) {
112 auto it = input_map_.find(name);
113 if (it != input_map_.end()) {
114 return it->second;
115 }
116 return -1;
117}
118
119/*!
120 * \brief Get the input info of Graph by parsing the input nodes.
121 * \return The shape and dtype tuple.
122 */
123std::tuple<GraphExecutor::ShapeInfo, GraphExecutor::DtypeInfo> GraphExecutor::GetInputInfo() const {
124 GraphExecutor::ShapeInfo shape_dict;
125 GraphExecutor::DtypeInfo dtype_dict;
126 for (uint32_t nid : input_nodes_) {
127 CHECK_LE(nid, nodes_.size());
128 std::string name = nodes_[nid].name;
129 if (param_names_.find(name) == param_names_.end()) {
130 CHECK_LE(nid, attrs_.shape.size());
131 auto shape = attrs_.shape[nid];
132 shape_dict.Set(name, ShapeTuple(shape));
133 CHECK_LE(nid, attrs_.dltype.size());
134 auto dtype = attrs_.dltype[nid];
135 dtype_dict.Set(name, String(dtype));
136 }
137 }
138 return std::make_tuple(shape_dict, dtype_dict);
139}
140
141/*!
142 * \brief Get the output info of Graph by parsing the output nodes.
143 * \return The shape and dtype tuple.
144 */
145std::tuple<GraphExecutor::ShapeInfo, GraphExecutor::DtypeInfo> GraphExecutor::GetOutputInfo()
146 const {
147 GraphExecutor::ShapeInfo shape_dict;
148 GraphExecutor::DtypeInfo dtype_dict;
149 for (auto out : outputs_) {
150 uint32_t nid = out.node_id;
151 CHECK_LE(nid, nodes_.size());
152 std::string name = nodes_[nid].name;
153 CHECK_LE(nid, attrs_.shape.size());
154 auto shape = attrs_.shape[nid];
155 shape_dict.Set(name, ShapeTuple(shape));
156 CHECK_LE(nid, attrs_.dltype.size());
157 auto dtype = attrs_.dltype[nid];
158 dtype_dict.Set(name, String(dtype));
159 }
160 return std::make_tuple(shape_dict, dtype_dict);
161}
162
163/*!
164 * \brief Get the output index given the name of output.
165 * \param name The name of the output.
166 * \return The index of output.
167 */
168int GraphExecutor::GetOutputIndex(const std::string& name) {
169 auto it = output_map_.find(name);
170 if (it != output_map_.end()) {
171 return it->second;
172 }
173 return -1;
174}
175/*!
176 * \brief set index-th input to the graph.
177 * \param index The input index.
178 * \param data_in The input data.
179 */
180void GraphExecutor::SetInput(int index, DLTensor* data_in) {
181 ICHECK_LT(static_cast<size_t>(index), input_nodes_.size());
182 uint32_t eid = this->entry_id(input_nodes_[index], 0);
183 data_entry_[eid].CopyFrom(data_in);
184}
185/*!
186 * \brief Check the legality of external DLTensor*.
187 * \param external The external DLTensor*.
188 * \param eid The data_enrty_ index.
189 */
190void GraphExecutor::CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const {
191 const DLTensor* internal = data_entry_[eid].operator->();
192
193 ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*external));
194 ICHECK_EQ(reinterpret_cast<size_t>(static_cast<char*>(external->data) + external->byte_offset) %
195 kAllocAlignment,
196 0);
197 ICHECK_EQ(internal->ndim, static_cast<size_t>(external->ndim));
198 ICHECK_EQ(internal->device.device_type, external->device.device_type);
199 ICHECK_EQ(internal->device.device_id, external->device.device_id);
200 for (auto i = 0; i < external->ndim; ++i) {
201 ICHECK_EQ(internal->shape[i], external->shape[i]);
202 }
203}
204/*!
205 * \brief set index-th input to the graph without copying the data.
206 * \param index The input index.
207 * \param data_ref The input data that is referred.
208 */
209void GraphExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) {
210 ICHECK_LT(static_cast<size_t>(index), input_nodes_.size());
211 uint32_t eid = this->entry_id(input_nodes_[index], 0);
212 // check the consistency of input
213 CheckExternalDLTensor(data_ref, eid);
214 // Update the data pointer for each argument of each op
215 for (DLTensor* t : input_dltensors_[eid]) {
216 t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset;
217 }
218}
219/*!
220 * \brief set index-th output to the graph without copying the data.
221 * \param index The output index.
222 * \param data_ref The output data that is referred.
223 */
224void GraphExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) {
225 ICHECK_LT(static_cast<size_t>(index), outputs_.size());
226 ICHECK_LT(static_cast<size_t>(index), output_dltensors_.size());
227 const NodeEntry& output_node = outputs_[index];
228 uint32_t output_node_eid = this->entry_id(output_node);
229
230 // check the consistency of output
231 CheckExternalDLTensor(data_ref, output_node_eid);
232
233 // Update the data pointer for output op
234 for (DLTensor* t : output_dltensors_[output_node_eid]) {
235 t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset;
236 }
237
238 // Update the input of the op connected to the output
239 for (DLTensor* t : both_output_opinput_dltensors_[output_node_eid]) {
240 t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset;
241 }
242}
243/*!
244 * \brief Get the number of outputs
245 *
246 * \return The number of outputs from graph.
247 */
248int GraphExecutor::NumOutputs() const { return outputs_.size(); }
249/*!
250 * \brief Get the number of inputs
251 *
252 * \return The number of inputs to the graph.
253 */
254int GraphExecutor::NumInputs() const { return input_nodes_.size(); }
255/*!
256 * \brief Return NDArray for given input index.
257 * \param index The input index.
258 *
259 * \return NDArray corresponding to given input node index.
260 */
261NDArray GraphExecutor::GetInput(int index) const {
262 ICHECK_LT(static_cast<size_t>(index), input_nodes_.size());
263 uint32_t eid = this->entry_id(input_nodes_[index], 0);
264 return data_entry_[eid];
265}
266/*!
267 * \brief Return NDArray for given output index.
268 * \param index The output index.
269 *
270 * \return NDArray corresponding to given output node index.
271 */
272NDArray GraphExecutor::GetOutput(int index) const {
273 ICHECK_LT(static_cast<size_t>(index), outputs_.size());
274 uint32_t eid = this->entry_id(outputs_[index]);
275 return data_entry_[eid];
276}
277/*!
278 * \brief Copy index-th output to data_out.
279 * \param index The output index.
280 * \param data_out the output data.
281 */
282void GraphExecutor::CopyOutputTo(int index, DLTensor* data_out) {
283 ICHECK_LT(static_cast<size_t>(index), outputs_.size());
284 uint32_t eid = this->entry_id(outputs_[index]);
285
286 // Check the shapes to avoid receiving in different dimension but same size.
287 const NDArray& data = data_entry_[eid];
288 ICHECK_EQ(data->ndim, data_out->ndim);
289 for (int32_t j = 0; j < data->ndim; ++j) {
290 ICHECK_EQ(data->shape[j], data_out->shape[j]);
291 }
292
293 data_entry_[eid].CopyTo(data_out);
294}
295
296/*!
297 * \brief Load parameters from parameter blob.
298 * \param param_blob A binary blob of parameter.
299 */
300void GraphExecutor::LoadParams(const std::string& param_blob) {
301 dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
302 this->LoadParams(&strm);
303}
304
305void GraphExecutor::LoadParams(dmlc::Stream* strm) {
306 Map<String, NDArray> params = ::tvm::runtime::LoadParams(strm);
307 for (auto& p : params) {
308 param_names_.insert(p.first);
309 int in_idx = GetInputIndex(p.first);
310 if (in_idx < 0) continue;
311 uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
312 data_entry_[eid].CopyFrom(p.second);
313 }
314}
315
316void GraphExecutor::ShareParams(const GraphExecutor& other, dmlc::Stream* strm) {
317 uint64_t header, reserved;
318 ICHECK(strm->Read(&header)) << "Invalid parameters file format";
319 ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format";
320 ICHECK(strm->Read(&reserved)) << "Invalid parameters file format";
321 std::vector<std::string> names;
322 ICHECK(strm->Read(&names)) << "Invalid parameters file format";
323 uint64_t sz;
324 strm->Read(&sz);
325 size_t size = static_cast<size_t>(sz);
326 ICHECK(size == names.size()) << "Invalid parameters file format";
327 for (size_t i = 0; i < size; ++i) {
328 int in_idx = GetInputIndex(names[i]);
329 if (in_idx < 0) continue;
330 uint32_t eid = this->entry_id(input_nodes_[in_idx], 0);
331 ICHECK_LT(eid, data_entry_.size());
332 ICHECK_EQ(data_entry_[eid].use_count(), 1);
333 data_entry_[eid] = other.GetInput(GetInputIndex(names[i]));
334 ICHECK_GT(data_entry_[eid].use_count(), 1);
335 const DLTensor* tmp = data_entry_[eid].operator->();
336 data_alignment_[eid] = details::GetDataAlignment(*tmp);
337 }
338 this->SetupOpExecs();
339}
340
341void GraphExecutor::LinkedNDArrayDeleter(Object* container) {
342 // container is the NDArray::Container which needs to get deleted.
343 // The data member points to global const memory, so it does not need deleting.
344 delete static_cast<NDArray::Container*>(container);
345}
346
347void GraphExecutor::DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv) {
348 Module mod = args[0];
349 int64_t storage_id = args[1];
350 DLTensor* template_tensor = args[2];
351 Device dev = args[3];
352 // Get pre-linked parameter lookup function, if it was generated. When pf == nullptr, no linked
353 // params are present.
354 if (!module_lookup_linked_param_valid_) {
355 module_lookup_linked_param_ =
356 mod.GetFunction(::tvm::runtime::symbol::tvm_lookup_linked_param, true);
357 }
358 if (module_lookup_linked_param_ == nullptr) {
359 *rv = nullptr;
360 return;
361 }
362
363 TVMRetValue opaque_handle = module_lookup_linked_param_(storage_id);
364 if (opaque_handle.type_code() == kTVMNullptr) {
365 *rv = nullptr;
366 return;
367 }
368
369 std::vector<int64_t> shape_vec{template_tensor->shape,
370 template_tensor->shape + template_tensor->ndim};
371
372 auto* container = new NDArray::Container(static_cast<void*>(opaque_handle), shape_vec,
373 template_tensor->dtype, dev);
374 container->SetDeleter(GraphExecutor::LinkedNDArrayDeleter);
375 *rv = NDArray(GetObjectPtr<Object>(container));
376}
377
378void GraphExecutor::SetupStorage() {
379 // Grab saved optimization plan from graph.
380 std::vector<DLDataType> vtype;
381 for (const std::string& s_type : attrs_.dltype) {
382 vtype.push_back(tvm::runtime::String2DLDataType(s_type));
383 }
384
385 // Size and device type of each storage pool entry.
386 std::vector<PoolEntry> pool_entry;
387 // Find the maximum space size.
388 for (size_t i = 0; i < attrs_.shape.size(); ++i) {
389 int storage_id = attrs_.storage_id[i];
390 std::string storage_scope = attrs_.storage_scope.empty() ? "" : attrs_.storage_scope[i];
391 // Use the fallback device if no device index is available.
392 int device_type = static_cast<int>(devices_[0].device_type);
393 if (!attrs_.device_index.empty()) {
394 device_type = attrs_.device_index[i];
395 }
396
397 uint32_t sid = static_cast<uint32_t>(storage_id);
398 if (sid >= pool_entry.size()) {
399 pool_entry.resize(sid + 1, {-1, {0}, {}});
400 } else {
401 ICHECK(pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type)
402 << "The same pool entry cannot be assigned to multiple devices";
403 }
404 TVMRetValue lookup_rv;
405 {
406 std::vector<int64_t> shape_vec{attrs_.shape[i].begin(), attrs_.shape[i].end()};
407 DLTensor template_tensor{nullptr, Device{kDLCPU, 0}, static_cast<int>(shape_vec.size()),
408 vtype[i], shape_vec.data(), nullptr,
409 0};
410 lookup_rv = lookup_linked_param_(module_, sid, &template_tensor, devices_[0]);
411 }
412 if (lookup_rv.type_code() != kTVMNullptr) {
413 pool_entry[sid].linked_param = lookup_rv;
414 }
415 pool_entry[sid].param_data_entry = i;
416 pool_entry[sid].device_type = device_type;
417 pool_entry[sid].scope = storage_scope;
418
419 DLDataType t = vtype[i];
420 if (!details::Is2DStorage(storage_scope)) {
421 size_t size = 1;
422 for (int64_t sz : attrs_.shape[i]) {
423 size *= static_cast<size_t>(sz);
424 }
425 size_t bits = t.bits * t.lanes;
426 ICHECK(bits % 8U == 0U || bits == 1U || bits == 4U);
427 int64_t bytes = ((bits + 7U) / 8U) * size;
428 pool_entry[sid].shape[0] = std::max(pool_entry[sid].shape[0], bytes);
429 pool_entry[sid].dtype = DLDataType{kDLFloat, 32, 1};
430 } else {
431 if (pool_entry[sid].shape.size() == 1) {
432 pool_entry[sid].shape.resize(3, 0);
433 }
434 size_t axis = runtime::DefaultTextureLayoutSeparator(attrs_.shape[i].size(), storage_scope);
435 auto shape = ApplyTexture2DFlattening<int64_t>(attrs_.shape[i], attrs_.shape[i].size(), axis);
436 pool_entry[sid].shape[0] = std::max(pool_entry[sid].shape[0], shape.height);
437 pool_entry[sid].shape[1] = std::max(pool_entry[sid].shape[1], shape.width);
438 CHECK(pool_entry[sid].shape[2] == 0 || pool_entry[sid].shape[2] == shape.channel)
439 << pool_entry[sid].shape[2] << " != " << shape.channel
440 << ", texture channel length must be consistent within a storage pool";
441 pool_entry[sid].shape[2] = shape.channel;
442 CHECK(pool_entry[sid].dtype.bits == 0 || TypeEqual(pool_entry[sid].dtype, t))
443 << DLDataType2String(pool_entry[sid].dtype) << " != " << DLDataType2String(t)
444 << ", pool entry for 2d texure allocations must be of the same type;"
445 << " downstream error from memory planner likely";
446 pool_entry[sid].dtype = t;
447 }
448 }
449
450 // Allocate the space.
451 for (const auto& pit : pool_entry) {
452 // This for loop is very fast since there are usually only a couple of
453 // devices available on the same hardware.
454 const auto& cit = std::find_if(devices_.begin(), devices_.end(), [&pit](const Device& d) {
455 return pit.device_type == static_cast<int>(d.device_type);
456 });
457 Device dev = cit == devices_.end() ? devices_[0] : *cit;
458 if (pit.linked_param.defined()) {
459 storage_pool_.push_back(pit.linked_param);
460 } else {
461 std::vector<int64_t> shape = pit.shape;
462 if (shape.size() == 1) {
463 shape[0] = (shape[0] + 3) / 4;
464 }
465 Optional<String> mem_scope;
466 if (!pit.scope.empty()) {
467 mem_scope = String(pit.scope);
468 }
469 storage_pool_.push_back(NDArray::Empty(shape, pit.dtype, dev, mem_scope));
470 }
471 }
472
473 // Assign the pooled entries. A unified memory pool is used to simplifiy
474 // memory assignment for each node entry. The allocated memory on each device
475 // is mapped to this pool.
476 data_entry_.resize(num_node_entries());
477 data_alignment_.resize(num_node_entries());
478 for (size_t i = 0; i < data_entry_.size(); ++i) {
479 int storage_id = attrs_.storage_id[i];
480 ICHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size());
481 data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]);
482
483 const DLTensor* tmp = data_entry_[i].operator->();
484 data_alignment_[i] = details::GetDataAlignment(*tmp);
485 }
486}
487
488void GraphExecutor::SetupOpExecs() {
489 op_execs_.resize(this->GetNumOfNodes());
490 input_dltensors_.resize(num_node_entries());
491 output_dltensors_.resize(num_node_entries());
492 both_output_opinput_dltensors_.resize(num_node_entries());
493 std::unordered_set<uint32_t> input_node_eids;
494 for (size_t i = 0; i < input_nodes_.size(); i++) {
495 uint32_t nid = input_nodes_[i];
496 input_node_eids.insert(entry_id(nid, 0));
497 }
498 std::unordered_set<uint32_t> output_node_eids;
499 for (size_t i = 0; i < outputs_.size(); i++) {
500 output_node_eids.insert(entry_id(outputs_[i]));
501 }
502
503 // setup the array and requirements.
504 for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) {
505 const auto& inode = nodes_[nid];
506 if (inode.op_type == "null") continue;
507 std::vector<DLTensor> args;
508 for (const auto& e : inode.inputs) {
509 uint32_t eid = this->entry_id(e);
510 args.push_back(*(data_entry_[eid].operator->()));
511 }
512 for (uint32_t index = 0; index < inode.param.num_outputs; ++index) {
513 uint32_t eid = this->entry_id(nid, index);
514 args.push_back(*(data_entry_[eid].operator->()));
515 }
516 ICHECK(inode.op_type == "tvm_op") << "Can only take tvm_op as op";
517
518 std::shared_ptr<OpArgs> op_args = nullptr;
519 std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args);
520
521 for (size_t i = 0; i < inode.inputs.size(); i++) {
522 uint32_t input_eid = this->entry_id(inode.inputs[i]);
523 // check if op input is model input
524 if (input_node_eids.count(input_eid) > 0) {
525 input_dltensors_[input_eid].push_back(
526 static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
527 }
528 // check if any model output is the input of the op
529 if (output_node_eids.count(input_eid) > 0) {
530 both_output_opinput_dltensors_[input_eid].push_back(
531 static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
532 }
533 }
534
535 for (uint32_t i = inode.inputs.size(); i < inode.inputs.size() + inode.param.num_outputs; ++i) {
536 uint32_t output_eid = this->entry_id(nid, i - inode.inputs.size());
537 // check if op output is model output
538 if (output_node_eids.count(output_eid) > 0) {
539 output_dltensors_[output_eid].push_back(
540 static_cast<DLTensor*>(op_args->arg_values[i].v_handle));
541 }
542 }
543 }
544}
545
546std::pair<std::function<void()>, std::shared_ptr<GraphExecutor::OpArgs>> GraphExecutor::CreateTVMOp(
547 const TVMOpParam& param, const std::vector<DLTensor>& args) {
548 std::shared_ptr<GraphExecutor::OpArgs> arg_ptr = std::make_shared<GraphExecutor::OpArgs>();
549 // setup address.
550 arg_ptr->args = args;
551 if (param.flatten_data) {
552 arg_ptr->shape_data.resize(arg_ptr->args.size());
553 }
554 for (size_t i = 0; i < arg_ptr->args.size(); ++i) {
555 TVMValue v;
556 DLTensor* t = &arg_ptr->args[i];
557 v.v_handle = t;
558 arg_ptr->arg_values.push_back(v);
559 arg_ptr->arg_tcodes.push_back(kTVMDLTensorHandle);
560 if (param.flatten_data) {
561 arg_ptr->shape_data[i] =
562 std::accumulate(t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>());
563 t->ndim = 1;
564 t->shape = &(arg_ptr->shape_data[i]);
565 }
566 }
567
568 if (param.func_name == "__nop") {
569 return {[]() {}, arg_ptr};
570 } else if (param.func_name == "__copy") {
571 // Perform cross device data copy.
572 // Directly copy data from the input to the output.
573 // TODO(mbs): device_copy cleanup.
574 auto fexec = [arg_ptr]() {
575 DLTensor* from = static_cast<DLTensor*>(arg_ptr->arg_values[0].v_handle);
576 DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle);
577 TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr));
578 };
579 return {fexec, arg_ptr};
580 }
581
582 // Get compiled function from the module that contains both host and device
583 // code.
584 tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, true);
585 ICHECK(pf != nullptr) << "no such function in module: " << param.func_name;
586
587 auto fexec = [arg_ptr, pf]() {
588 TVMRetValue rv;
589 TVMArgs targs(arg_ptr->arg_values.data(), arg_ptr->arg_tcodes.data(),
590 static_cast<int>(arg_ptr->arg_values.size()));
591 pf.CallPacked(targs, &rv);
592 };
593 return {fexec, arg_ptr};
594}
595
596PackedFunc GraphExecutor::GetFunction(const std::string& name,
597 const ObjectPtr<Object>& sptr_to_self) {
598 // Return member functions during query.
599 if (name == "set_input") {
600 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
601 if (String::CanConvertFrom(args[0])) {
602 int in_idx = this->GetInputIndex(args[0].operator String());
603 if (in_idx >= 0) this->SetInput(in_idx, args[1]);
604 } else {
605 this->SetInput(args[0], args[1]);
606 }
607 });
608 } else if (name == "set_input_zero_copy") {
609 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
610 if (String::CanConvertFrom(args[0])) {
611 int in_idx = this->GetInputIndex(args[0].operator String());
612 if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]);
613 } else {
614 this->SetInputZeroCopy(args[0], args[1]);
615 }
616 });
617 } else if (name == "set_output_zero_copy") {
618 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
619 if (String::CanConvertFrom(args[0])) {
620 int out_idx = this->GetOutputIndex(args[0].operator String());
621 if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]);
622 } else {
623 this->SetOutputZeroCopy(args[0], args[1]);
624 }
625 });
626 } else if (name == "get_output") {
627 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
628 if (args.num_args == 2) {
629 this->CopyOutputTo(args[0], args[1]);
630 } else {
631 int out_idx = -1;
632 if (String::CanConvertFrom(args[0])) {
633 for (size_t i = 0; i < outputs_.size(); i++) {
634 std::string& name = nodes_[outputs_[i].node_id].name;
635 if (args[0].operator String() == name) {
636 out_idx = i;
637 }
638 }
639 CHECK(out_idx != -1) << "Invalid output node:" << args[0].operator String();
640 } else {
641 out_idx = args[0];
642 }
643 *rv = this->GetOutput(out_idx);
644 }
645 });
646 } else if (name == "get_input") {
647 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
648 int in_idx = 0;
649 if (String::CanConvertFrom(args[0])) {
650 in_idx = this->GetInputIndex(args[0].operator String());
651 } else {
652 in_idx = args[0];
653 }
654 if (in_idx >= 0) {
655 *rv = this->GetInput(in_idx);
656 }
657 });
658 } else if (name == "get_num_outputs") {
659 return PackedFunc(
660 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); });
661 } else if (name == "get_num_inputs") {
662 return PackedFunc(
663 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); });
664 } else if (name == "run") {
665 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); });
666 } else if (name == "run_from_inputs") {
667 return PackedFunc(
668 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
669 CHECK(args.size() % 2 == 0)
670 << "Number of arguments to run_from_inputs must be an even number of key-value pairs";
671 Device host{static_cast<DLDeviceType>(args[0].operator int()), args[1].operator int()};
672 for (int i = 2; i < args.size(); i += 2) {
673 if (String::CanConvertFrom(args[i])) {
674 int in_idx = this->GetInputIndex(args[i].operator String());
675 if (in_idx >= 0) {
676 this->SetInput(in_idx, args[i + 1]);
677 } else {
678 LOG(FATAL) << args[i].operator String() << " is not a valid input name";
679 }
680 } else {
681 this->SetInput(args[i], args[i + 1]);
682 }
683 }
684 this->Run();
685 Array<NDArray> outputs;
686 for (int i = 0; i < this->NumOutputs(); i++) {
687 NDArray out = this->GetOutput(i);
688 NDArray a = NDArray::Empty(out.Shape(), out.DataType(), host);
689 a.CopyFrom(out);
690 outputs.push_back(a);
691 }
692 *rv = outputs;
693 });
694 } else if (name == "load_params") {
695 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
696 this->LoadParams(args[0].operator std::string());
697 });
698 } else if (name == "share_params") {
699 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
700 const auto& module = args[0].operator Module();
701 ICHECK_EQ(module.operator->()->type_key(), std::string("GraphExecutor"));
702 const auto& param_blob = args[1].operator std::string();
703 dmlc::MemoryStringStream strm(const_cast<std::string*>(&param_blob));
704 this->ShareParams(dynamic_cast<const GraphExecutor&>(*module.operator->()), &strm);
705 });
706 } else if (name == "get_input_index") {
707 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
708 CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string";
709 *rv = this->GetInputIndex(args[0].operator String());
710 });
711 } else if (name == "get_input_info") {
712 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
713 auto [shape_info, dtype_info] = this->GetInputInfo();
714 Map<String, ObjectRef> input_info;
715 input_info.Set("shape", shape_info);
716 input_info.Set("dtype", dtype_info);
717 *rv = input_info;
718 });
719 } else if (name == "get_output_info") {
720 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
721 auto [shape_info, dtype_info] = this->GetOutputInfo();
722 Map<String, ObjectRef> input_info;
723 input_info.Set("shape", shape_info);
724 input_info.Set("dtype", dtype_info);
725 *rv = input_info;
726 });
727 } else {
728 return PackedFunc();
729 }
730}
731
732Module GraphExecutorCreate(const std::string& sym_json, const tvm::runtime::Module& m,
733 const std::vector<Device>& devs,
734 const PackedFunc lookup_linked_param_func) {
735 auto exec = make_object<GraphExecutor>();
736 exec->Init(sym_json, m, devs, lookup_linked_param_func);
737 return Module(exec);
738}
739
740// Get all devices for the host and other runtime devices.
741std::vector<Device> GetAllDevice(const TVMArgs& args, int dev_start_arg) {
742 // Reserve the first item as the fallback device.
743 std::vector<Device> ret;
744 Device dev;
745 for (int i = dev_start_arg; i < args.num_args; i += 2) {
746 int dev_type = args[i];
747 dev.device_type = static_cast<DLDeviceType>(dev_type);
748 dev.device_id = args[i + 1];
749 ret.push_back(dev);
750 }
751 return ret;
752}
753
754// 4-argument version is currently reserved to keep support of calling
755// from tvm4j and javascript, since they don't have heterogeneous
756// execution support yet. For heterogenenous execution, at least 5 arguments will
757// be passed in. The third one is the number of devices.
758// Eventually, we will only probably pass Device for all the languages.
759TVM_REGISTER_GLOBAL("tvm.graph_executor.create").set_body([](TVMArgs args, TVMRetValue* rv) {
760 ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_executor.create is "
761 "at least 4, but it has "
762 << args.num_args;
763 PackedFunc lookup_linked_param_func;
764 int dev_start_arg = 2;
765 if (args[2].type_code() == kTVMPackedFuncHandle) {
766 lookup_linked_param_func = args[2];
767 dev_start_arg++;
768 }
769 const auto& devices = GetAllDevice(args, dev_start_arg);
770 *rv = GraphExecutorCreate(args[0], args[1], devices, lookup_linked_param_func);
771});
772} // namespace runtime
773} // namespace tvm
774