1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \brief Tiny graph executor that can run graph |
22 | * containing only tvm PackedFunc. |
23 | * \file graph_executor.h |
24 | */ |
25 | #ifndef TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ |
26 | #define TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ |
27 | |
28 | #include <dlpack/dlpack.h> |
29 | #include <dmlc/json.h> |
30 | #include <dmlc/memory_io.h> |
31 | #include <tvm/runtime/ndarray.h> |
32 | #include <tvm/runtime/packed_func.h> |
33 | |
34 | #include <memory> |
35 | #include <string> |
36 | #include <tuple> |
37 | #include <unordered_map> |
38 | #include <unordered_set> |
39 | #include <utility> |
40 | #include <vector> |
41 | |
42 | namespace tvm { |
43 | namespace runtime { |
44 | |
45 | /*! \brief macro to do C API call */ |
46 | #define TVM_CCALL(func) \ |
47 | { \ |
48 | int ret = (func); \ |
49 | ICHECK_EQ(ret, 0) << TVMGetLastError(); \ |
50 | } |
51 | |
52 | /*! \brief operator attributes about tvm op */ |
53 | struct TVMOpParam { |
54 | std::string func_name; |
55 | std::unordered_map<std::string, ObjectRef> attrs; |
56 | uint32_t num_inputs; |
57 | uint32_t num_outputs; |
58 | uint32_t flatten_data; |
59 | }; |
60 | |
61 | /*! |
62 | * \brief Tiny graph executor. |
63 | * |
64 | * This runtime can be accessible in various languages via |
65 | * TVM runtime PackedFunc API. |
66 | */ |
67 | class TVM_DLL GraphExecutor : public ModuleNode { |
68 | struct OpArgs { |
69 | std::vector<DLTensor> args; |
70 | std::vector<TVMValue> arg_values; |
71 | std::vector<int> arg_tcodes; |
72 | std::vector<int64_t> shape_data; |
73 | }; |
74 | |
75 | public: |
76 | using ShapeInfo = Map<String, ObjectRef>; |
77 | using DtypeInfo = Map<String, ObjectRef>; |
78 | /*! |
79 | * \brief Get member function to front-end |
80 | * \param name The name of the function. |
81 | * \param sptr_to_self The pointer to the module node. |
82 | * \return The corresponding member function. |
83 | */ |
84 | virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self); |
85 | |
86 | /*! |
87 | * \return The type key of the executor. |
88 | */ |
89 | const char* type_key() const final { return "GraphExecutor" ; } |
90 | void Run(); |
91 | |
92 | /*! |
93 | * \brief Initialize the graph executor with graph and device. |
94 | * \param graph_json The execution graph. |
95 | * \param module The module containing the compiled functions for the host |
96 | * processor. |
97 | * \param devs The device of the host and devices where graph nodes will be |
98 | * executed on. |
99 | * \param lookup_linked_param_func If given, a PackedFunc invoked to lookup linked parameters |
100 | * by storage_id. If not given, linked parameters are looked-up using an internal implementation, |
101 | * which is not compatible with RPCModules. Default is nullptr. |
102 | */ |
103 | |
104 | void Init(const std::string& graph_json, tvm::runtime::Module module, |
105 | const std::vector<Device>& devs, const PackedFunc lookup_linked_param_func = nullptr); |
106 | |
107 | /*! |
108 | * \brief Get the input index given the name of input. |
109 | * \param name The name of the input. |
110 | * \return The index of input. |
111 | */ |
112 | int GetInputIndex(const std::string& name); |
113 | |
114 | /*! |
115 | * \brief Get the input info of Graph by parsing the input nodes. |
116 | * \return The shape and dtype tuple. |
117 | */ |
118 | std::tuple<ShapeInfo, DtypeInfo> GetInputInfo() const; |
119 | |
120 | /*! |
121 | * \brief Get the output info of Graph by parsing the output nodes. |
122 | * \return The shape and dtype tuple. |
123 | */ |
124 | std::tuple<ShapeInfo, DtypeInfo> GetOutputInfo() const; |
125 | |
126 | /*! |
127 | * \brief Get the output index given the name of output. |
128 | * \param name The name of the output. |
129 | * \return The index of output. |
130 | */ |
131 | int GetOutputIndex(const std::string& name); |
132 | |
133 | /*! |
134 | * \brief set index-th input to the graph. |
135 | * \param index The input index. |
136 | * \param data_in The input data. |
137 | */ |
138 | void SetInput(int index, DLTensor* data_in); |
139 | /*! |
140 | * \brief set index-th input to the graph without copying the data |
141 | * \param index The input index. |
142 | * \param data_ref The input data that is referred. |
143 | */ |
144 | void SetInputZeroCopy(int index, DLTensor* data_ref); |
145 | /*! |
146 | * \brief set index-th output to the graph without copying the data. |
147 | * \param index The output index. |
148 | * \param data_ref The output data that is referred. |
149 | */ |
150 | void SetOutputZeroCopy(int index, DLTensor* data_ref); |
151 | /*! |
152 | * \brief Get the number of outputs |
153 | * |
154 | * \return The number of outputs from graph. |
155 | */ |
156 | int NumOutputs() const; |
157 | /*! |
158 | * \brief Get the number of inputs |
159 | * |
160 | * \return The number of inputs to the graph. |
161 | */ |
162 | int NumInputs() const; |
163 | /*! |
164 | * \brief Return NDArray for given input index. |
165 | * \param index The input index. |
166 | * |
167 | * \return NDArray corresponding to given input node index. |
168 | */ |
169 | NDArray GetInput(int index) const; |
170 | /*! |
171 | * \brief Return NDArray for given output index. |
172 | * \param index The output index. |
173 | * |
174 | * \return NDArray corresponding to given output node index. |
175 | */ |
176 | NDArray GetOutput(int index) const; |
177 | /*! |
178 | * \brief Copy index-th output to data_out. |
179 | * \param index The output index. |
180 | * \param data_out the output data. |
181 | */ |
182 | void CopyOutputTo(int index, DLTensor* data_out); |
183 | /*! |
184 | * \brief Load parameters from binary stream |
185 | * \param strm The input stream. |
186 | */ |
187 | void LoadParams(dmlc::Stream* strm); |
188 | /*! |
189 | * \brief Load parameters from parameter blob. |
190 | * \param param_blob A binary blob of parameter. |
191 | */ |
192 | void LoadParams(const std::string& param_blob); |
193 | |
194 | /*! |
195 | * \brief Share parameters from pre-existing GraphExecutor instance. |
196 | * \param other A GraphExecutor instance, previously with |LoadParams| called with the |
197 | * identical input |param_blob|. |
198 | * \param strm The input stream. |
199 | */ |
200 | void ShareParams(const GraphExecutor& other, dmlc::Stream* strm); |
201 | |
202 | /*! |
203 | * \brief Get total number of nodes. |
204 | * \return Total number of nodes. |
205 | */ |
206 | uint32_t GetNumOfNodes() const { return static_cast<uint32_t>(nodes_.size()); } |
207 | |
208 | std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; } |
209 | |
210 | protected: |
211 | // Memory pool entry. |
212 | struct PoolEntry { |
213 | int device_type; |
214 | std::vector<int64_t> shape; |
215 | DLDataType dtype; |
216 | int param_data_entry; |
217 | NDArray linked_param; |
218 | std::string scope; |
219 | // PoolEntry(int s, int dev_type, void* pre_linked_param) : |
220 | // size(s), device_type(dev_type), pre_linked_param(std::move(pre_linked_param)) {} |
221 | }; |
222 | // Node entry |
223 | struct NodeEntry { |
224 | uint32_t node_id; |
225 | uint32_t index; |
226 | uint32_t version; |
227 | inline bool operator==(const NodeEntry& other) const { |
228 | return node_id == other.node_id && index == other.index && version == other.version; |
229 | } |
230 | // JSON Loader |
231 | void Load(dmlc::JSONReader* reader) { |
232 | reader->BeginArray(); |
233 | ICHECK(reader->NextArrayItem()) << "invalid json format" ; |
234 | reader->Read(&node_id); |
235 | ICHECK(reader->NextArrayItem()) << "invalid json format" ; |
236 | reader->Read(&index); |
237 | if (reader->NextArrayItem()) { |
238 | reader->Read(&version); |
239 | ICHECK(!reader->NextArrayItem()) << "invalid json format" ; |
240 | } else { |
241 | version = 0; |
242 | } |
243 | } |
244 | }; |
245 | // Node |
246 | struct Node { |
247 | // operator type in string |
248 | std::string op_type; |
249 | // name of the op |
250 | std::string name; |
251 | // parameters |
252 | TVMOpParam param; |
253 | // inputs |
254 | std::vector<NodeEntry> inputs; |
255 | // control deps |
256 | std::vector<uint32_t> control_deps; |
257 | |
258 | // JSON Loader |
259 | void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) { |
260 | int bitmask = 0; |
261 | std::string key, value; |
262 | reader->BeginObject(); |
263 | while (reader->NextObjectItem(&key)) { |
264 | reader->Read(&value); |
265 | if (key == "func_name" ) { |
266 | param->func_name = value; |
267 | bitmask |= 1; |
268 | } else if (key == "num_inputs" ) { |
269 | param->num_inputs = strtoul(value.c_str(), nullptr, 10); |
270 | bitmask |= 2; |
271 | } else if (key == "num_outputs" ) { |
272 | param->num_outputs = strtoul(value.c_str(), nullptr, 10); |
273 | bitmask |= 4; |
274 | } else if (key == "flatten_data" ) { |
275 | param->flatten_data = strtoul(value.c_str(), nullptr, 10); |
276 | bitmask |= 8; |
277 | } else { |
278 | param->attrs[key] = String(value); |
279 | } |
280 | } |
281 | ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "invalid format" ; |
282 | } |
283 | // JSON Loader |
284 | void Load(dmlc::JSONReader* reader) { |
285 | reader->BeginObject(); |
286 | int bitmask = 0; |
287 | std::string key; |
288 | while (reader->NextObjectItem(&key)) { |
289 | if (key == "op" ) { |
290 | reader->Read(&op_type); |
291 | bitmask |= 1; |
292 | } else if (key == "name" ) { |
293 | reader->Read(&name); |
294 | bitmask |= 2; |
295 | } else if (key == "inputs" ) { |
296 | reader->Read(&inputs); |
297 | bitmask |= 4; |
298 | } else if (key == "attr" || key == "attrs" ) { |
299 | this->LoadAttrs(reader, ¶m); |
300 | } else if (key == "control_deps" ) { |
301 | reader->Read(&control_deps); |
302 | } else { |
303 | LOG(FATAL) << "do not support key " << key; |
304 | } |
305 | } |
306 | ICHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format" ; |
307 | } |
308 | }; |
309 | struct GraphAttr { |
310 | size_t storage_num_not_alloctaed{0}; |
311 | std::vector<int> storage_id; |
312 | std::vector<int> device_index; |
313 | std::vector<std::string> dltype; |
314 | std::vector<std::string> storage_scope; |
315 | std::vector<std::vector<int64_t>> shape; |
316 | // The graph attribute fields. |
317 | void Load(dmlc::JSONReader* reader) { |
318 | reader->BeginObject(); |
319 | int bitmask = 0; |
320 | std::string key, type; |
321 | while (reader->NextObjectItem(&key)) { |
322 | if (key == "dltype" ) { |
323 | reader->BeginArray(); |
324 | ICHECK(reader->NextArrayItem()); |
325 | reader->Read(&type); |
326 | ICHECK_EQ(type, "list_str" ); |
327 | ICHECK(reader->NextArrayItem()); |
328 | reader->Read(&dltype); |
329 | ICHECK(!reader->NextArrayItem()); |
330 | bitmask |= 1; |
331 | } else if (key == "storage_id" ) { |
332 | reader->BeginArray(); |
333 | ICHECK(reader->NextArrayItem()); |
334 | reader->Read(&type); |
335 | ICHECK_EQ(type, "list_int" ); |
336 | ICHECK(reader->NextArrayItem()); |
337 | reader->Read(&storage_id); |
338 | ICHECK(!reader->NextArrayItem()); |
339 | bitmask |= 2; |
340 | } else if (key == "storage_scope" ) { |
341 | reader->BeginArray(); |
342 | ICHECK(reader->NextArrayItem()); |
343 | reader->Read(&type); |
344 | ICHECK_EQ(type, "list_str" ); |
345 | ICHECK(reader->NextArrayItem()); |
346 | reader->Read(&storage_scope); |
347 | ICHECK(!reader->NextArrayItem()); |
348 | bitmask |= 1; |
349 | } else if (key == "shape" ) { |
350 | reader->BeginArray(); |
351 | ICHECK(reader->NextArrayItem()); |
352 | reader->Read(&type); |
353 | ICHECK_EQ(type, "list_shape" ); |
354 | ICHECK(reader->NextArrayItem()); |
355 | reader->Read(&shape); |
356 | ICHECK(!reader->NextArrayItem()); |
357 | bitmask |= 4; |
358 | } else if (key == "device_index" ) { |
359 | reader->BeginArray(); |
360 | ICHECK(reader->NextArrayItem()); |
361 | reader->Read(&type); |
362 | ICHECK_EQ(type, "list_int" ); |
363 | ICHECK(reader->NextArrayItem()); |
364 | reader->Read(&device_index); |
365 | ICHECK(!reader->NextArrayItem()); |
366 | } else { |
367 | reader->BeginArray(); |
368 | ICHECK(reader->NextArrayItem()); |
369 | reader->Read(&type); |
370 | if (type == "list_int" ) { |
371 | ICHECK(reader->NextArrayItem()); |
372 | std::vector<int> temp; |
373 | reader->Read(&temp); |
374 | } else if (type == "size_t" ) { |
375 | ICHECK(reader->NextArrayItem()); |
376 | size_t temp; |
377 | reader->Read(&temp); |
378 | } else { |
379 | LOG(FATAL) << "cannot skip graph attr " << key; |
380 | } |
381 | ICHECK(!reader->NextArrayItem()); |
382 | } |
383 | } |
384 | ICHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format" ; |
385 | } |
386 | }; |
387 | // The graph attribute fields. |
388 | void Load(dmlc::JSONReader* reader) { |
389 | reader->BeginObject(); |
390 | int bitmask = 0; |
391 | std::string key; |
392 | while (reader->NextObjectItem(&key)) { |
393 | if (key == "nodes" ) { |
394 | reader->Read(&nodes_); |
395 | bitmask |= 1; |
396 | } else if (key == "arg_nodes" ) { |
397 | reader->Read(&input_nodes_); |
398 | bitmask |= 2; |
399 | } else if (key == "node_row_ptr" ) { |
400 | reader->Read(&node_row_ptr_); |
401 | bitmask |= 4; |
402 | } else if (key == "heads" ) { |
403 | reader->Read(&outputs_); |
404 | bitmask |= 8; |
405 | } else if (key == "attrs" ) { |
406 | reader->Read(&attrs_); |
407 | bitmask |= 16; |
408 | } else if (key == "metadata" ) { |
409 | break; |
410 | } else { |
411 | LOG(FATAL) << "key " << key << " is not supported" ; |
412 | } |
413 | } |
414 | ICHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format" ; |
415 | } |
416 | /*! \brief PackedFunc to lookup a linked paramter from a local Module. */ |
417 | void DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv); |
418 | /*! \brief Delete NDArray::Container with linked (i.e. static) data. */ |
419 | static void LinkedNDArrayDeleter(Object* container); |
420 | /*! \brief Setup the temporal storage */ |
421 | void SetupStorage(); |
422 | /*! \brief Setup the executors. */ |
423 | void SetupOpExecs(); |
424 | /*! |
425 | * \brief Check the legality of external DLTensor*. |
426 | * \param external The external DLTensor*. |
427 | * \param eid The data_enrty_ index. |
428 | */ |
429 | void CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const; |
430 | /*! |
431 | * \brief Create an execution function given input. |
432 | * \param attrs The node attributes. |
433 | * \param args The arguments to the functor, including inputs and outputs. |
434 | * \return The created executor. |
435 | */ |
436 | std::pair<std::function<void()>, std::shared_ptr<OpArgs>> CreateTVMOp( |
437 | const TVMOpParam& attrs, const std::vector<DLTensor>& args); |
438 | // Get node entry index. |
439 | uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; } |
440 | // Get node entry index. |
441 | uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); } |
442 | // Number of node entries. |
443 | uint32_t num_node_entries() const { return node_row_ptr_.back(); } |
444 | /*! \brief The graph nodes. */ |
445 | std::vector<Node> nodes_; |
446 | /*! \brief The argument nodes. */ |
447 | std::vector<uint32_t> input_nodes_; |
448 | /*! \brief The parameter names. */ |
449 | std::unordered_set<std::string> param_names_; |
450 | /*! \brief Map of input names to input indices. */ |
451 | std::unordered_map<std::string, uint32_t> input_map_; |
452 | /*! \brief Map of output names to output indices. */ |
453 | std::unordered_map<std::string, uint32_t> output_map_; |
454 | /*! \brief Used for quick node input DLTensor* lookup given an input eid. */ |
455 | std::vector<std::vector<DLTensor*>> input_dltensors_; |
456 | /*! \brief Used for quick node output DLTensor* lookup given an output eid. */ |
457 | std::vector<std::vector<DLTensor*>> output_dltensors_; |
458 | /*! \brief Used for quick node(both model output and op input) DLTensor* lookup given an eid. */ |
459 | std::vector<std::vector<DLTensor*>> both_output_opinput_dltensors_; |
460 | /*! \brief Used for quick entry indexing. */ |
461 | std::vector<uint32_t> node_row_ptr_; |
462 | /*! \brief Output entries. */ |
463 | std::vector<NodeEntry> outputs_; |
464 | /*! \brief Additional graph attributes. */ |
465 | GraphAttr attrs_; |
466 | /*! \brief The code module that contains both host and device code. */ |
467 | tvm::runtime::Module module_; |
468 | /*! \brief Execution context of all devices including the host. */ |
469 | std::vector<Device> devices_; |
470 | /*! \brief Common storage pool for all devices. */ |
471 | std::vector<NDArray> storage_pool_; |
472 | /*! \brief Data entry of each node. */ |
473 | std::vector<NDArray> data_entry_; |
474 | /*! \brief Data alignment of each node. */ |
475 | std::vector<size_t> data_alignment_; |
476 | /*! \brief Operator on each node. */ |
477 | std::vector<std::function<void()>> op_execs_; |
478 | /*! \brief Linked parameter lookup function. */ |
479 | PackedFunc lookup_linked_param_; |
480 | /*! \brief Module's _lookup_linked_param function, used by DefaultLookupLinkedParam. */ |
481 | PackedFunc module_lookup_linked_param_; |
482 | /*! |
483 | * \brief True when module_lookup_linked_param_ is valid. |
484 | * When the module does not include linked parmeters, module_lookup_linked_param_ will be nullptr. |
485 | */ |
486 | bool module_lookup_linked_param_valid_; |
487 | }; |
488 | |
489 | std::vector<Device> GetAllDevice(const TVMArgs& args, int dev_start_arg); |
490 | } // namespace runtime |
491 | } // namespace tvm |
492 | |
493 | #endif // TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_ |
494 | |