1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file graph_executor.cc |
22 | */ |
23 | #include "graph_executor.h" |
24 | |
25 | #include <tvm/runtime/container/map.h> |
26 | #include <tvm/runtime/container/string.h> |
27 | #include <tvm/runtime/data_type.h> |
28 | #include <tvm/runtime/device_api.h> |
29 | #include <tvm/runtime/ndarray.h> |
30 | #include <tvm/runtime/packed_func.h> |
31 | #include <tvm/runtime/profiling.h> |
32 | #include <tvm/runtime/registry.h> |
33 | #include <tvm/runtime/serializer.h> |
34 | |
35 | #include <algorithm> |
36 | #include <functional> |
37 | #include <memory> |
38 | #include <numeric> |
39 | #include <string> |
40 | #include <unordered_set> |
41 | #include <utility> |
42 | #include <vector> |
43 | |
44 | #include "../file_utils.h" |
45 | #include "../texture.h" |
46 | |
47 | namespace tvm { |
48 | namespace runtime { |
49 | namespace details { |
50 | inline size_t GetDataAlignment(const DLTensor& arr) { |
51 | size_t align = (arr.dtype.bits / 8) * arr.dtype.lanes; |
52 | if (align < kAllocAlignment) return kAllocAlignment; |
53 | return align; |
54 | } |
55 | constexpr auto Is2DStorage = IsTextureStorage; |
56 | } // namespace details |
57 | |
58 | /*! |
59 | * \brief Run all the operations one by one. |
60 | */ |
61 | void GraphExecutor::Run() { |
62 | // setup the array and requirements. |
63 | for (size_t i = 0; i < op_execs_.size(); ++i) { |
64 | if (op_execs_[i]) op_execs_[i](); |
65 | } |
66 | } |
67 | |
68 | /*! |
69 | * \brief Initialize the graph executor with graph and device. |
70 | * \param graph_json The execution graph. |
71 | * \param module The module containing the compiled functions for the host |
72 | * processor. |
73 | * \param devs The devices of the host and devices where graph nodes will be |
74 | * executed on. |
75 | * \param lookup_linked_param_func Linked parameter lookup function. Default is nullptr. |
76 | */ |
77 | void GraphExecutor::Init(const std::string& graph_json, tvm::runtime::Module module, |
78 | const std::vector<Device>& devs, |
79 | const PackedFunc lookup_linked_param_func) { |
80 | std::istringstream is(graph_json); |
81 | dmlc::JSONReader reader(&is); |
82 | this->Load(&reader); |
83 | module_ = module; |
84 | devices_ = devs; |
85 | lookup_linked_param_ = lookup_linked_param_func; |
86 | if (lookup_linked_param_ == nullptr) { |
87 | lookup_linked_param_ = PackedFunc( |
88 | [this](TVMArgs args, TVMRetValue* rv) { this->DefaultLookupLinkedParam(args, rv); }); |
89 | } |
90 | this->SetupStorage(); |
91 | this->SetupOpExecs(); |
92 | for (size_t i = 0; i < input_nodes_.size(); i++) { |
93 | const uint32_t nid = input_nodes_[i]; |
94 | std::string& name = nodes_[nid].name; |
95 | input_map_[name] = i; |
96 | } |
97 | for (size_t i = 0; i < outputs_.size(); i++) { |
98 | const uint32_t nid = outputs_[i].node_id; |
99 | std::string& name = nodes_[nid].name; |
100 | std::stringstream ss; |
101 | ss << name << ":" << i; |
102 | output_map_[ss.str()] = i; |
103 | } |
104 | } |
105 | |
106 | /*! |
107 | * \brief Get the input index given the name of input. |
108 | * \param name The name of the input. |
109 | * \return The index of input. |
110 | */ |
111 | int GraphExecutor::GetInputIndex(const std::string& name) { |
112 | auto it = input_map_.find(name); |
113 | if (it != input_map_.end()) { |
114 | return it->second; |
115 | } |
116 | return -1; |
117 | } |
118 | |
119 | /*! |
120 | * \brief Get the input info of Graph by parsing the input nodes. |
121 | * \return The shape and dtype tuple. |
122 | */ |
123 | std::tuple<GraphExecutor::ShapeInfo, GraphExecutor::DtypeInfo> GraphExecutor::GetInputInfo() const { |
124 | GraphExecutor::ShapeInfo shape_dict; |
125 | GraphExecutor::DtypeInfo dtype_dict; |
126 | for (uint32_t nid : input_nodes_) { |
127 | CHECK_LE(nid, nodes_.size()); |
128 | std::string name = nodes_[nid].name; |
129 | if (param_names_.find(name) == param_names_.end()) { |
130 | CHECK_LE(nid, attrs_.shape.size()); |
131 | auto shape = attrs_.shape[nid]; |
132 | shape_dict.Set(name, ShapeTuple(shape)); |
133 | CHECK_LE(nid, attrs_.dltype.size()); |
134 | auto dtype = attrs_.dltype[nid]; |
135 | dtype_dict.Set(name, String(dtype)); |
136 | } |
137 | } |
138 | return std::make_tuple(shape_dict, dtype_dict); |
139 | } |
140 | |
141 | /*! |
142 | * \brief Get the output info of Graph by parsing the output nodes. |
143 | * \return The shape and dtype tuple. |
144 | */ |
145 | std::tuple<GraphExecutor::ShapeInfo, GraphExecutor::DtypeInfo> GraphExecutor::GetOutputInfo() |
146 | const { |
147 | GraphExecutor::ShapeInfo shape_dict; |
148 | GraphExecutor::DtypeInfo dtype_dict; |
149 | for (auto out : outputs_) { |
150 | uint32_t nid = out.node_id; |
151 | CHECK_LE(nid, nodes_.size()); |
152 | std::string name = nodes_[nid].name; |
153 | CHECK_LE(nid, attrs_.shape.size()); |
154 | auto shape = attrs_.shape[nid]; |
155 | shape_dict.Set(name, ShapeTuple(shape)); |
156 | CHECK_LE(nid, attrs_.dltype.size()); |
157 | auto dtype = attrs_.dltype[nid]; |
158 | dtype_dict.Set(name, String(dtype)); |
159 | } |
160 | return std::make_tuple(shape_dict, dtype_dict); |
161 | } |
162 | |
163 | /*! |
164 | * \brief Get the output index given the name of output. |
165 | * \param name The name of the output. |
166 | * \return The index of output. |
167 | */ |
168 | int GraphExecutor::GetOutputIndex(const std::string& name) { |
169 | auto it = output_map_.find(name); |
170 | if (it != output_map_.end()) { |
171 | return it->second; |
172 | } |
173 | return -1; |
174 | } |
175 | /*! |
176 | * \brief set index-th input to the graph. |
177 | * \param index The input index. |
178 | * \param data_in The input data. |
179 | */ |
180 | void GraphExecutor::SetInput(int index, DLTensor* data_in) { |
181 | ICHECK_LT(static_cast<size_t>(index), input_nodes_.size()); |
182 | uint32_t eid = this->entry_id(input_nodes_[index], 0); |
183 | data_entry_[eid].CopyFrom(data_in); |
184 | } |
185 | /*! |
186 | * \brief Check the legality of external DLTensor*. |
187 | * \param external The external DLTensor*. |
188 | * \param eid The data_enrty_ index. |
189 | */ |
190 | void GraphExecutor::CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const { |
191 | const DLTensor* internal = data_entry_[eid].operator->(); |
192 | |
193 | ICHECK_EQ(data_alignment_[eid], details::GetDataAlignment(*external)); |
194 | ICHECK_EQ(reinterpret_cast<size_t>(static_cast<char*>(external->data) + external->byte_offset) % |
195 | kAllocAlignment, |
196 | 0); |
197 | ICHECK_EQ(internal->ndim, static_cast<size_t>(external->ndim)); |
198 | ICHECK_EQ(internal->device.device_type, external->device.device_type); |
199 | ICHECK_EQ(internal->device.device_id, external->device.device_id); |
200 | for (auto i = 0; i < external->ndim; ++i) { |
201 | ICHECK_EQ(internal->shape[i], external->shape[i]); |
202 | } |
203 | } |
204 | /*! |
205 | * \brief set index-th input to the graph without copying the data. |
206 | * \param index The input index. |
207 | * \param data_ref The input data that is referred. |
208 | */ |
209 | void GraphExecutor::SetInputZeroCopy(int index, DLTensor* data_ref) { |
210 | ICHECK_LT(static_cast<size_t>(index), input_nodes_.size()); |
211 | uint32_t eid = this->entry_id(input_nodes_[index], 0); |
212 | // check the consistency of input |
213 | CheckExternalDLTensor(data_ref, eid); |
214 | // Update the data pointer for each argument of each op |
215 | for (DLTensor* t : input_dltensors_[eid]) { |
216 | t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset; |
217 | } |
218 | } |
219 | /*! |
220 | * \brief set index-th output to the graph without copying the data. |
221 | * \param index The output index. |
222 | * \param data_ref The output data that is referred. |
223 | */ |
224 | void GraphExecutor::SetOutputZeroCopy(int index, DLTensor* data_ref) { |
225 | ICHECK_LT(static_cast<size_t>(index), outputs_.size()); |
226 | ICHECK_LT(static_cast<size_t>(index), output_dltensors_.size()); |
227 | const NodeEntry& output_node = outputs_[index]; |
228 | uint32_t output_node_eid = this->entry_id(output_node); |
229 | |
230 | // check the consistency of output |
231 | CheckExternalDLTensor(data_ref, output_node_eid); |
232 | |
233 | // Update the data pointer for output op |
234 | for (DLTensor* t : output_dltensors_[output_node_eid]) { |
235 | t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset; |
236 | } |
237 | |
238 | // Update the input of the op connected to the output |
239 | for (DLTensor* t : both_output_opinput_dltensors_[output_node_eid]) { |
240 | t->data = static_cast<char*>(data_ref->data) + data_ref->byte_offset; |
241 | } |
242 | } |
243 | /*! |
244 | * \brief Get the number of outputs |
245 | * |
246 | * \return The number of outputs from graph. |
247 | */ |
248 | int GraphExecutor::NumOutputs() const { return outputs_.size(); } |
249 | /*! |
250 | * \brief Get the number of inputs |
251 | * |
252 | * \return The number of inputs to the graph. |
253 | */ |
254 | int GraphExecutor::NumInputs() const { return input_nodes_.size(); } |
255 | /*! |
256 | * \brief Return NDArray for given input index. |
257 | * \param index The input index. |
258 | * |
259 | * \return NDArray corresponding to given input node index. |
260 | */ |
261 | NDArray GraphExecutor::GetInput(int index) const { |
262 | ICHECK_LT(static_cast<size_t>(index), input_nodes_.size()); |
263 | uint32_t eid = this->entry_id(input_nodes_[index], 0); |
264 | return data_entry_[eid]; |
265 | } |
266 | /*! |
267 | * \brief Return NDArray for given output index. |
268 | * \param index The output index. |
269 | * |
270 | * \return NDArray corresponding to given output node index. |
271 | */ |
272 | NDArray GraphExecutor::GetOutput(int index) const { |
273 | ICHECK_LT(static_cast<size_t>(index), outputs_.size()); |
274 | uint32_t eid = this->entry_id(outputs_[index]); |
275 | return data_entry_[eid]; |
276 | } |
277 | /*! |
278 | * \brief Copy index-th output to data_out. |
279 | * \param index The output index. |
280 | * \param data_out the output data. |
281 | */ |
282 | void GraphExecutor::CopyOutputTo(int index, DLTensor* data_out) { |
283 | ICHECK_LT(static_cast<size_t>(index), outputs_.size()); |
284 | uint32_t eid = this->entry_id(outputs_[index]); |
285 | |
286 | // Check the shapes to avoid receiving in different dimension but same size. |
287 | const NDArray& data = data_entry_[eid]; |
288 | ICHECK_EQ(data->ndim, data_out->ndim); |
289 | for (int32_t j = 0; j < data->ndim; ++j) { |
290 | ICHECK_EQ(data->shape[j], data_out->shape[j]); |
291 | } |
292 | |
293 | data_entry_[eid].CopyTo(data_out); |
294 | } |
295 | |
296 | /*! |
297 | * \brief Load parameters from parameter blob. |
298 | * \param param_blob A binary blob of parameter. |
299 | */ |
300 | void GraphExecutor::LoadParams(const std::string& param_blob) { |
301 | dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob)); |
302 | this->LoadParams(&strm); |
303 | } |
304 | |
305 | void GraphExecutor::LoadParams(dmlc::Stream* strm) { |
306 | Map<String, NDArray> params = ::tvm::runtime::LoadParams(strm); |
307 | for (auto& p : params) { |
308 | param_names_.insert(p.first); |
309 | int in_idx = GetInputIndex(p.first); |
310 | if (in_idx < 0) continue; |
311 | uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); |
312 | data_entry_[eid].CopyFrom(p.second); |
313 | } |
314 | } |
315 | |
316 | void GraphExecutor::ShareParams(const GraphExecutor& other, dmlc::Stream* strm) { |
317 | uint64_t , reserved; |
318 | ICHECK(strm->Read(&header)) << "Invalid parameters file format" ; |
319 | ICHECK(header == kTVMNDArrayListMagic) << "Invalid parameters file format" ; |
320 | ICHECK(strm->Read(&reserved)) << "Invalid parameters file format" ; |
321 | std::vector<std::string> names; |
322 | ICHECK(strm->Read(&names)) << "Invalid parameters file format" ; |
323 | uint64_t sz; |
324 | strm->Read(&sz); |
325 | size_t size = static_cast<size_t>(sz); |
326 | ICHECK(size == names.size()) << "Invalid parameters file format" ; |
327 | for (size_t i = 0; i < size; ++i) { |
328 | int in_idx = GetInputIndex(names[i]); |
329 | if (in_idx < 0) continue; |
330 | uint32_t eid = this->entry_id(input_nodes_[in_idx], 0); |
331 | ICHECK_LT(eid, data_entry_.size()); |
332 | ICHECK_EQ(data_entry_[eid].use_count(), 1); |
333 | data_entry_[eid] = other.GetInput(GetInputIndex(names[i])); |
334 | ICHECK_GT(data_entry_[eid].use_count(), 1); |
335 | const DLTensor* tmp = data_entry_[eid].operator->(); |
336 | data_alignment_[eid] = details::GetDataAlignment(*tmp); |
337 | } |
338 | this->SetupOpExecs(); |
339 | } |
340 | |
341 | void GraphExecutor::LinkedNDArrayDeleter(Object* container) { |
342 | // container is the NDArray::Container which needs to get deleted. |
343 | // The data member points to global const memory, so it does not need deleting. |
344 | delete static_cast<NDArray::Container*>(container); |
345 | } |
346 | |
347 | void GraphExecutor::DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv) { |
348 | Module mod = args[0]; |
349 | int64_t storage_id = args[1]; |
350 | DLTensor* template_tensor = args[2]; |
351 | Device dev = args[3]; |
352 | // Get pre-linked parameter lookup function, if it was generated. When pf == nullptr, no linked |
353 | // params are present. |
354 | if (!module_lookup_linked_param_valid_) { |
355 | module_lookup_linked_param_ = |
356 | mod.GetFunction(::tvm::runtime::symbol::tvm_lookup_linked_param, true); |
357 | } |
358 | if (module_lookup_linked_param_ == nullptr) { |
359 | *rv = nullptr; |
360 | return; |
361 | } |
362 | |
363 | TVMRetValue opaque_handle = module_lookup_linked_param_(storage_id); |
364 | if (opaque_handle.type_code() == kTVMNullptr) { |
365 | *rv = nullptr; |
366 | return; |
367 | } |
368 | |
369 | std::vector<int64_t> shape_vec{template_tensor->shape, |
370 | template_tensor->shape + template_tensor->ndim}; |
371 | |
372 | auto* container = new NDArray::Container(static_cast<void*>(opaque_handle), shape_vec, |
373 | template_tensor->dtype, dev); |
374 | container->SetDeleter(GraphExecutor::LinkedNDArrayDeleter); |
375 | *rv = NDArray(GetObjectPtr<Object>(container)); |
376 | } |
377 | |
378 | void GraphExecutor::SetupStorage() { |
379 | // Grab saved optimization plan from graph. |
380 | std::vector<DLDataType> vtype; |
381 | for (const std::string& s_type : attrs_.dltype) { |
382 | vtype.push_back(tvm::runtime::String2DLDataType(s_type)); |
383 | } |
384 | |
385 | // Size and device type of each storage pool entry. |
386 | std::vector<PoolEntry> pool_entry; |
387 | // Find the maximum space size. |
388 | for (size_t i = 0; i < attrs_.shape.size(); ++i) { |
389 | int storage_id = attrs_.storage_id[i]; |
390 | std::string storage_scope = attrs_.storage_scope.empty() ? "" : attrs_.storage_scope[i]; |
391 | // Use the fallback device if no device index is available. |
392 | int device_type = static_cast<int>(devices_[0].device_type); |
393 | if (!attrs_.device_index.empty()) { |
394 | device_type = attrs_.device_index[i]; |
395 | } |
396 | |
397 | uint32_t sid = static_cast<uint32_t>(storage_id); |
398 | if (sid >= pool_entry.size()) { |
399 | pool_entry.resize(sid + 1, {-1, {0}, {}}); |
400 | } else { |
401 | ICHECK(pool_entry[sid].device_type == -1 || pool_entry[sid].device_type == device_type) |
402 | << "The same pool entry cannot be assigned to multiple devices" ; |
403 | } |
404 | TVMRetValue lookup_rv; |
405 | { |
406 | std::vector<int64_t> shape_vec{attrs_.shape[i].begin(), attrs_.shape[i].end()}; |
407 | DLTensor template_tensor{nullptr, Device{kDLCPU, 0}, static_cast<int>(shape_vec.size()), |
408 | vtype[i], shape_vec.data(), nullptr, |
409 | 0}; |
410 | lookup_rv = lookup_linked_param_(module_, sid, &template_tensor, devices_[0]); |
411 | } |
412 | if (lookup_rv.type_code() != kTVMNullptr) { |
413 | pool_entry[sid].linked_param = lookup_rv; |
414 | } |
415 | pool_entry[sid].param_data_entry = i; |
416 | pool_entry[sid].device_type = device_type; |
417 | pool_entry[sid].scope = storage_scope; |
418 | |
419 | DLDataType t = vtype[i]; |
420 | if (!details::Is2DStorage(storage_scope)) { |
421 | size_t size = 1; |
422 | for (int64_t sz : attrs_.shape[i]) { |
423 | size *= static_cast<size_t>(sz); |
424 | } |
425 | size_t bits = t.bits * t.lanes; |
426 | ICHECK(bits % 8U == 0U || bits == 1U || bits == 4U); |
427 | int64_t bytes = ((bits + 7U) / 8U) * size; |
428 | pool_entry[sid].shape[0] = std::max(pool_entry[sid].shape[0], bytes); |
429 | pool_entry[sid].dtype = DLDataType{kDLFloat, 32, 1}; |
430 | } else { |
431 | if (pool_entry[sid].shape.size() == 1) { |
432 | pool_entry[sid].shape.resize(3, 0); |
433 | } |
434 | size_t axis = runtime::DefaultTextureLayoutSeparator(attrs_.shape[i].size(), storage_scope); |
435 | auto shape = ApplyTexture2DFlattening<int64_t>(attrs_.shape[i], attrs_.shape[i].size(), axis); |
436 | pool_entry[sid].shape[0] = std::max(pool_entry[sid].shape[0], shape.height); |
437 | pool_entry[sid].shape[1] = std::max(pool_entry[sid].shape[1], shape.width); |
438 | CHECK(pool_entry[sid].shape[2] == 0 || pool_entry[sid].shape[2] == shape.channel) |
439 | << pool_entry[sid].shape[2] << " != " << shape.channel |
440 | << ", texture channel length must be consistent within a storage pool" ; |
441 | pool_entry[sid].shape[2] = shape.channel; |
442 | CHECK(pool_entry[sid].dtype.bits == 0 || TypeEqual(pool_entry[sid].dtype, t)) |
443 | << DLDataType2String(pool_entry[sid].dtype) << " != " << DLDataType2String(t) |
444 | << ", pool entry for 2d texure allocations must be of the same type;" |
445 | << " downstream error from memory planner likely" ; |
446 | pool_entry[sid].dtype = t; |
447 | } |
448 | } |
449 | |
450 | // Allocate the space. |
451 | for (const auto& pit : pool_entry) { |
452 | // This for loop is very fast since there are usually only a couple of |
453 | // devices available on the same hardware. |
454 | const auto& cit = std::find_if(devices_.begin(), devices_.end(), [&pit](const Device& d) { |
455 | return pit.device_type == static_cast<int>(d.device_type); |
456 | }); |
457 | Device dev = cit == devices_.end() ? devices_[0] : *cit; |
458 | if (pit.linked_param.defined()) { |
459 | storage_pool_.push_back(pit.linked_param); |
460 | } else { |
461 | std::vector<int64_t> shape = pit.shape; |
462 | if (shape.size() == 1) { |
463 | shape[0] = (shape[0] + 3) / 4; |
464 | } |
465 | Optional<String> mem_scope; |
466 | if (!pit.scope.empty()) { |
467 | mem_scope = String(pit.scope); |
468 | } |
469 | storage_pool_.push_back(NDArray::Empty(shape, pit.dtype, dev, mem_scope)); |
470 | } |
471 | } |
472 | |
473 | // Assign the pooled entries. A unified memory pool is used to simplifiy |
474 | // memory assignment for each node entry. The allocated memory on each device |
475 | // is mapped to this pool. |
476 | data_entry_.resize(num_node_entries()); |
477 | data_alignment_.resize(num_node_entries()); |
478 | for (size_t i = 0; i < data_entry_.size(); ++i) { |
479 | int storage_id = attrs_.storage_id[i]; |
480 | ICHECK_LT(static_cast<size_t>(storage_id), storage_pool_.size()); |
481 | data_entry_[i] = storage_pool_[storage_id].CreateView(attrs_.shape[i], vtype[i]); |
482 | |
483 | const DLTensor* tmp = data_entry_[i].operator->(); |
484 | data_alignment_[i] = details::GetDataAlignment(*tmp); |
485 | } |
486 | } |
487 | |
488 | void GraphExecutor::SetupOpExecs() { |
489 | op_execs_.resize(this->GetNumOfNodes()); |
490 | input_dltensors_.resize(num_node_entries()); |
491 | output_dltensors_.resize(num_node_entries()); |
492 | both_output_opinput_dltensors_.resize(num_node_entries()); |
493 | std::unordered_set<uint32_t> input_node_eids; |
494 | for (size_t i = 0; i < input_nodes_.size(); i++) { |
495 | uint32_t nid = input_nodes_[i]; |
496 | input_node_eids.insert(entry_id(nid, 0)); |
497 | } |
498 | std::unordered_set<uint32_t> output_node_eids; |
499 | for (size_t i = 0; i < outputs_.size(); i++) { |
500 | output_node_eids.insert(entry_id(outputs_[i])); |
501 | } |
502 | |
503 | // setup the array and requirements. |
504 | for (uint32_t nid = 0; nid < this->GetNumOfNodes(); ++nid) { |
505 | const auto& inode = nodes_[nid]; |
506 | if (inode.op_type == "null" ) continue; |
507 | std::vector<DLTensor> args; |
508 | for (const auto& e : inode.inputs) { |
509 | uint32_t eid = this->entry_id(e); |
510 | args.push_back(*(data_entry_[eid].operator->())); |
511 | } |
512 | for (uint32_t index = 0; index < inode.param.num_outputs; ++index) { |
513 | uint32_t eid = this->entry_id(nid, index); |
514 | args.push_back(*(data_entry_[eid].operator->())); |
515 | } |
516 | ICHECK(inode.op_type == "tvm_op" ) << "Can only take tvm_op as op" ; |
517 | |
518 | std::shared_ptr<OpArgs> op_args = nullptr; |
519 | std::tie(op_execs_[nid], op_args) = CreateTVMOp(inode.param, args); |
520 | |
521 | for (size_t i = 0; i < inode.inputs.size(); i++) { |
522 | uint32_t input_eid = this->entry_id(inode.inputs[i]); |
523 | // check if op input is model input |
524 | if (input_node_eids.count(input_eid) > 0) { |
525 | input_dltensors_[input_eid].push_back( |
526 | static_cast<DLTensor*>(op_args->arg_values[i].v_handle)); |
527 | } |
528 | // check if any model output is the input of the op |
529 | if (output_node_eids.count(input_eid) > 0) { |
530 | both_output_opinput_dltensors_[input_eid].push_back( |
531 | static_cast<DLTensor*>(op_args->arg_values[i].v_handle)); |
532 | } |
533 | } |
534 | |
535 | for (uint32_t i = inode.inputs.size(); i < inode.inputs.size() + inode.param.num_outputs; ++i) { |
536 | uint32_t output_eid = this->entry_id(nid, i - inode.inputs.size()); |
537 | // check if op output is model output |
538 | if (output_node_eids.count(output_eid) > 0) { |
539 | output_dltensors_[output_eid].push_back( |
540 | static_cast<DLTensor*>(op_args->arg_values[i].v_handle)); |
541 | } |
542 | } |
543 | } |
544 | } |
545 | |
546 | std::pair<std::function<void()>, std::shared_ptr<GraphExecutor::OpArgs>> GraphExecutor::CreateTVMOp( |
547 | const TVMOpParam& param, const std::vector<DLTensor>& args) { |
548 | std::shared_ptr<GraphExecutor::OpArgs> arg_ptr = std::make_shared<GraphExecutor::OpArgs>(); |
549 | // setup address. |
550 | arg_ptr->args = args; |
551 | if (param.flatten_data) { |
552 | arg_ptr->shape_data.resize(arg_ptr->args.size()); |
553 | } |
554 | for (size_t i = 0; i < arg_ptr->args.size(); ++i) { |
555 | TVMValue v; |
556 | DLTensor* t = &arg_ptr->args[i]; |
557 | v.v_handle = t; |
558 | arg_ptr->arg_values.push_back(v); |
559 | arg_ptr->arg_tcodes.push_back(kTVMDLTensorHandle); |
560 | if (param.flatten_data) { |
561 | arg_ptr->shape_data[i] = |
562 | std::accumulate(t->shape, t->shape + t->ndim, 1, std::multiplies<int64_t>()); |
563 | t->ndim = 1; |
564 | t->shape = &(arg_ptr->shape_data[i]); |
565 | } |
566 | } |
567 | |
568 | if (param.func_name == "__nop" ) { |
569 | return {[]() {}, arg_ptr}; |
570 | } else if (param.func_name == "__copy" ) { |
571 | // Perform cross device data copy. |
572 | // Directly copy data from the input to the output. |
573 | // TODO(mbs): device_copy cleanup. |
574 | auto fexec = [arg_ptr]() { |
575 | DLTensor* from = static_cast<DLTensor*>(arg_ptr->arg_values[0].v_handle); |
576 | DLTensor* to = static_cast<DLTensor*>(arg_ptr->arg_values[1].v_handle); |
577 | TVM_CCALL(TVMArrayCopyFromTo(from, to, nullptr)); |
578 | }; |
579 | return {fexec, arg_ptr}; |
580 | } |
581 | |
582 | // Get compiled function from the module that contains both host and device |
583 | // code. |
584 | tvm::runtime::PackedFunc pf = module_.GetFunction(param.func_name, true); |
585 | ICHECK(pf != nullptr) << "no such function in module: " << param.func_name; |
586 | |
587 | auto fexec = [arg_ptr, pf]() { |
588 | TVMRetValue rv; |
589 | TVMArgs targs(arg_ptr->arg_values.data(), arg_ptr->arg_tcodes.data(), |
590 | static_cast<int>(arg_ptr->arg_values.size())); |
591 | pf.CallPacked(targs, &rv); |
592 | }; |
593 | return {fexec, arg_ptr}; |
594 | } |
595 | |
596 | PackedFunc GraphExecutor::GetFunction(const std::string& name, |
597 | const ObjectPtr<Object>& sptr_to_self) { |
598 | // Return member functions during query. |
599 | if (name == "set_input" ) { |
600 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
601 | if (String::CanConvertFrom(args[0])) { |
602 | int in_idx = this->GetInputIndex(args[0].operator String()); |
603 | if (in_idx >= 0) this->SetInput(in_idx, args[1]); |
604 | } else { |
605 | this->SetInput(args[0], args[1]); |
606 | } |
607 | }); |
608 | } else if (name == "set_input_zero_copy" ) { |
609 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
610 | if (String::CanConvertFrom(args[0])) { |
611 | int in_idx = this->GetInputIndex(args[0].operator String()); |
612 | if (in_idx >= 0) this->SetInputZeroCopy(in_idx, args[1]); |
613 | } else { |
614 | this->SetInputZeroCopy(args[0], args[1]); |
615 | } |
616 | }); |
617 | } else if (name == "set_output_zero_copy" ) { |
618 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
619 | if (String::CanConvertFrom(args[0])) { |
620 | int out_idx = this->GetOutputIndex(args[0].operator String()); |
621 | if (out_idx >= 0) this->SetOutputZeroCopy(out_idx, args[1]); |
622 | } else { |
623 | this->SetOutputZeroCopy(args[0], args[1]); |
624 | } |
625 | }); |
626 | } else if (name == "get_output" ) { |
627 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
628 | if (args.num_args == 2) { |
629 | this->CopyOutputTo(args[0], args[1]); |
630 | } else { |
631 | int out_idx = -1; |
632 | if (String::CanConvertFrom(args[0])) { |
633 | for (size_t i = 0; i < outputs_.size(); i++) { |
634 | std::string& name = nodes_[outputs_[i].node_id].name; |
635 | if (args[0].operator String() == name) { |
636 | out_idx = i; |
637 | } |
638 | } |
639 | CHECK(out_idx != -1) << "Invalid output node:" << args[0].operator String(); |
640 | } else { |
641 | out_idx = args[0]; |
642 | } |
643 | *rv = this->GetOutput(out_idx); |
644 | } |
645 | }); |
646 | } else if (name == "get_input" ) { |
647 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
648 | int in_idx = 0; |
649 | if (String::CanConvertFrom(args[0])) { |
650 | in_idx = this->GetInputIndex(args[0].operator String()); |
651 | } else { |
652 | in_idx = args[0]; |
653 | } |
654 | if (in_idx >= 0) { |
655 | *rv = this->GetInput(in_idx); |
656 | } |
657 | }); |
658 | } else if (name == "get_num_outputs" ) { |
659 | return PackedFunc( |
660 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumOutputs(); }); |
661 | } else if (name == "get_num_inputs" ) { |
662 | return PackedFunc( |
663 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->NumInputs(); }); |
664 | } else if (name == "run" ) { |
665 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { this->Run(); }); |
666 | } else if (name == "run_from_inputs" ) { |
667 | return PackedFunc( |
668 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
669 | CHECK(args.size() % 2 == 0) |
670 | << "Number of arguments to run_from_inputs must be an even number of key-value pairs" ; |
671 | Device host{static_cast<DLDeviceType>(args[0].operator int()), args[1].operator int()}; |
672 | for (int i = 2; i < args.size(); i += 2) { |
673 | if (String::CanConvertFrom(args[i])) { |
674 | int in_idx = this->GetInputIndex(args[i].operator String()); |
675 | if (in_idx >= 0) { |
676 | this->SetInput(in_idx, args[i + 1]); |
677 | } else { |
678 | LOG(FATAL) << args[i].operator String() << " is not a valid input name" ; |
679 | } |
680 | } else { |
681 | this->SetInput(args[i], args[i + 1]); |
682 | } |
683 | } |
684 | this->Run(); |
685 | Array<NDArray> outputs; |
686 | for (int i = 0; i < this->NumOutputs(); i++) { |
687 | NDArray out = this->GetOutput(i); |
688 | NDArray a = NDArray::Empty(out.Shape(), out.DataType(), host); |
689 | a.CopyFrom(out); |
690 | outputs.push_back(a); |
691 | } |
692 | *rv = outputs; |
693 | }); |
694 | } else if (name == "load_params" ) { |
695 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
696 | this->LoadParams(args[0].operator std::string()); |
697 | }); |
698 | } else if (name == "share_params" ) { |
699 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
700 | const auto& module = args[0].operator Module(); |
701 | ICHECK_EQ(module.operator->()->type_key(), std::string("GraphExecutor" )); |
702 | const auto& param_blob = args[1].operator std::string(); |
703 | dmlc::MemoryStringStream strm(const_cast<std::string*>(¶m_blob)); |
704 | this->ShareParams(dynamic_cast<const GraphExecutor&>(*module.operator->()), &strm); |
705 | }); |
706 | } else if (name == "get_input_index" ) { |
707 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
708 | CHECK(String::CanConvertFrom(args[0])) << "Input key is not a string" ; |
709 | *rv = this->GetInputIndex(args[0].operator String()); |
710 | }); |
711 | } else if (name == "get_input_info" ) { |
712 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
713 | auto [shape_info, dtype_info] = this->GetInputInfo(); |
714 | Map<String, ObjectRef> input_info; |
715 | input_info.Set("shape" , shape_info); |
716 | input_info.Set("dtype" , dtype_info); |
717 | *rv = input_info; |
718 | }); |
719 | } else if (name == "get_output_info" ) { |
720 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
721 | auto [shape_info, dtype_info] = this->GetOutputInfo(); |
722 | Map<String, ObjectRef> input_info; |
723 | input_info.Set("shape" , shape_info); |
724 | input_info.Set("dtype" , dtype_info); |
725 | *rv = input_info; |
726 | }); |
727 | } else { |
728 | return PackedFunc(); |
729 | } |
730 | } |
731 | |
732 | Module GraphExecutorCreate(const std::string& sym_json, const tvm::runtime::Module& m, |
733 | const std::vector<Device>& devs, |
734 | const PackedFunc lookup_linked_param_func) { |
735 | auto exec = make_object<GraphExecutor>(); |
736 | exec->Init(sym_json, m, devs, lookup_linked_param_func); |
737 | return Module(exec); |
738 | } |
739 | |
740 | // Get all devices for the host and other runtime devices. |
741 | std::vector<Device> GetAllDevice(const TVMArgs& args, int dev_start_arg) { |
742 | // Reserve the first item as the fallback device. |
743 | std::vector<Device> ret; |
744 | Device dev; |
745 | for (int i = dev_start_arg; i < args.num_args; i += 2) { |
746 | int dev_type = args[i]; |
747 | dev.device_type = static_cast<DLDeviceType>(dev_type); |
748 | dev.device_id = args[i + 1]; |
749 | ret.push_back(dev); |
750 | } |
751 | return ret; |
752 | } |
753 | |
754 | // 4-argument version is currently reserved to keep support of calling |
755 | // from tvm4j and javascript, since they don't have heterogeneous |
756 | // execution support yet. For heterogenenous execution, at least 5 arguments will |
757 | // be passed in. The third one is the number of devices. |
758 | // Eventually, we will only probably pass Device for all the languages. |
759 | TVM_REGISTER_GLOBAL("tvm.graph_executor.create" ).set_body([](TVMArgs args, TVMRetValue* rv) { |
760 | ICHECK_GE(args.num_args, 4) << "The expected number of arguments for graph_executor.create is " |
761 | "at least 4, but it has " |
762 | << args.num_args; |
763 | PackedFunc lookup_linked_param_func; |
764 | int dev_start_arg = 2; |
765 | if (args[2].type_code() == kTVMPackedFuncHandle) { |
766 | lookup_linked_param_func = args[2]; |
767 | dev_start_arg++; |
768 | } |
769 | const auto& devices = GetAllDevice(args, dev_start_arg); |
770 | *rv = GraphExecutorCreate(args[0], args[1], devices, lookup_linked_param_func); |
771 | }); |
772 | } // namespace runtime |
773 | } // namespace tvm |
774 | |