1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \brief Tiny graph executor that can run graph
22 * containing only tvm PackedFunc.
23 * \file graph_executor.h
24 */
25#ifndef TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
26#define TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
27
28#include <dlpack/dlpack.h>
29#include <dmlc/json.h>
30#include <dmlc/memory_io.h>
31#include <tvm/runtime/ndarray.h>
32#include <tvm/runtime/packed_func.h>
33
34#include <memory>
35#include <string>
36#include <tuple>
37#include <unordered_map>
38#include <unordered_set>
39#include <utility>
40#include <vector>
41
42namespace tvm {
43namespace runtime {
44
45/*! \brief macro to do C API call */
46#define TVM_CCALL(func) \
47 { \
48 int ret = (func); \
49 ICHECK_EQ(ret, 0) << TVMGetLastError(); \
50 }
51
52/*! \brief operator attributes about tvm op */
53struct TVMOpParam {
54 std::string func_name;
55 std::unordered_map<std::string, ObjectRef> attrs;
56 uint32_t num_inputs;
57 uint32_t num_outputs;
58 uint32_t flatten_data;
59};
60
61/*!
62 * \brief Tiny graph executor.
63 *
64 * This runtime can be accessible in various languages via
65 * TVM runtime PackedFunc API.
66 */
67class TVM_DLL GraphExecutor : public ModuleNode {
68 struct OpArgs {
69 std::vector<DLTensor> args;
70 std::vector<TVMValue> arg_values;
71 std::vector<int> arg_tcodes;
72 std::vector<int64_t> shape_data;
73 };
74
75 public:
76 using ShapeInfo = Map<String, ObjectRef>;
77 using DtypeInfo = Map<String, ObjectRef>;
78 /*!
79 * \brief Get member function to front-end
80 * \param name The name of the function.
81 * \param sptr_to_self The pointer to the module node.
82 * \return The corresponding member function.
83 */
84 virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
85
86 /*!
87 * \return The type key of the executor.
88 */
89 const char* type_key() const final { return "GraphExecutor"; }
90 void Run();
91
92 /*!
93 * \brief Initialize the graph executor with graph and device.
94 * \param graph_json The execution graph.
95 * \param module The module containing the compiled functions for the host
96 * processor.
97 * \param devs The device of the host and devices where graph nodes will be
98 * executed on.
99 * \param lookup_linked_param_func If given, a PackedFunc invoked to lookup linked parameters
100 * by storage_id. If not given, linked parameters are looked-up using an internal implementation,
101 * which is not compatible with RPCModules. Default is nullptr.
102 */
103
104 void Init(const std::string& graph_json, tvm::runtime::Module module,
105 const std::vector<Device>& devs, const PackedFunc lookup_linked_param_func = nullptr);
106
107 /*!
108 * \brief Get the input index given the name of input.
109 * \param name The name of the input.
110 * \return The index of input.
111 */
112 int GetInputIndex(const std::string& name);
113
114 /*!
115 * \brief Get the input info of Graph by parsing the input nodes.
116 * \return The shape and dtype tuple.
117 */
118 std::tuple<ShapeInfo, DtypeInfo> GetInputInfo() const;
119
120 /*!
121 * \brief Get the output info of Graph by parsing the output nodes.
122 * \return The shape and dtype tuple.
123 */
124 std::tuple<ShapeInfo, DtypeInfo> GetOutputInfo() const;
125
126 /*!
127 * \brief Get the output index given the name of output.
128 * \param name The name of the output.
129 * \return The index of output.
130 */
131 int GetOutputIndex(const std::string& name);
132
133 /*!
134 * \brief set index-th input to the graph.
135 * \param index The input index.
136 * \param data_in The input data.
137 */
138 void SetInput(int index, DLTensor* data_in);
139 /*!
140 * \brief set index-th input to the graph without copying the data
141 * \param index The input index.
142 * \param data_ref The input data that is referred.
143 */
144 void SetInputZeroCopy(int index, DLTensor* data_ref);
145 /*!
146 * \brief set index-th output to the graph without copying the data.
147 * \param index The output index.
148 * \param data_ref The output data that is referred.
149 */
150 void SetOutputZeroCopy(int index, DLTensor* data_ref);
151 /*!
152 * \brief Get the number of outputs
153 *
154 * \return The number of outputs from graph.
155 */
156 int NumOutputs() const;
157 /*!
158 * \brief Get the number of inputs
159 *
160 * \return The number of inputs to the graph.
161 */
162 int NumInputs() const;
163 /*!
164 * \brief Return NDArray for given input index.
165 * \param index The input index.
166 *
167 * \return NDArray corresponding to given input node index.
168 */
169 NDArray GetInput(int index) const;
170 /*!
171 * \brief Return NDArray for given output index.
172 * \param index The output index.
173 *
174 * \return NDArray corresponding to given output node index.
175 */
176 NDArray GetOutput(int index) const;
177 /*!
178 * \brief Copy index-th output to data_out.
179 * \param index The output index.
180 * \param data_out the output data.
181 */
182 void CopyOutputTo(int index, DLTensor* data_out);
183 /*!
184 * \brief Load parameters from binary stream
185 * \param strm The input stream.
186 */
187 void LoadParams(dmlc::Stream* strm);
188 /*!
189 * \brief Load parameters from parameter blob.
190 * \param param_blob A binary blob of parameter.
191 */
192 void LoadParams(const std::string& param_blob);
193
194 /*!
195 * \brief Share parameters from pre-existing GraphExecutor instance.
196 * \param other A GraphExecutor instance, previously with |LoadParams| called with the
197 * identical input |param_blob|.
198 * \param strm The input stream.
199 */
200 void ShareParams(const GraphExecutor& other, dmlc::Stream* strm);
201
202 /*!
203 * \brief Get total number of nodes.
204 * \return Total number of nodes.
205 */
206 uint32_t GetNumOfNodes() const { return static_cast<uint32_t>(nodes_.size()); }
207
208 std::string GetNodeName(uint32_t nid) const { return nodes_[nid].name; }
209
210 protected:
211 // Memory pool entry.
212 struct PoolEntry {
213 int device_type;
214 std::vector<int64_t> shape;
215 DLDataType dtype;
216 int param_data_entry;
217 NDArray linked_param;
218 std::string scope;
219 // PoolEntry(int s, int dev_type, void* pre_linked_param) :
220 // size(s), device_type(dev_type), pre_linked_param(std::move(pre_linked_param)) {}
221 };
222 // Node entry
223 struct NodeEntry {
224 uint32_t node_id;
225 uint32_t index;
226 uint32_t version;
227 inline bool operator==(const NodeEntry& other) const {
228 return node_id == other.node_id && index == other.index && version == other.version;
229 }
230 // JSON Loader
231 void Load(dmlc::JSONReader* reader) {
232 reader->BeginArray();
233 ICHECK(reader->NextArrayItem()) << "invalid json format";
234 reader->Read(&node_id);
235 ICHECK(reader->NextArrayItem()) << "invalid json format";
236 reader->Read(&index);
237 if (reader->NextArrayItem()) {
238 reader->Read(&version);
239 ICHECK(!reader->NextArrayItem()) << "invalid json format";
240 } else {
241 version = 0;
242 }
243 }
244 };
245 // Node
246 struct Node {
247 // operator type in string
248 std::string op_type;
249 // name of the op
250 std::string name;
251 // parameters
252 TVMOpParam param;
253 // inputs
254 std::vector<NodeEntry> inputs;
255 // control deps
256 std::vector<uint32_t> control_deps;
257
258 // JSON Loader
259 void LoadAttrs(dmlc::JSONReader* reader, TVMOpParam* param) {
260 int bitmask = 0;
261 std::string key, value;
262 reader->BeginObject();
263 while (reader->NextObjectItem(&key)) {
264 reader->Read(&value);
265 if (key == "func_name") {
266 param->func_name = value;
267 bitmask |= 1;
268 } else if (key == "num_inputs") {
269 param->num_inputs = strtoul(value.c_str(), nullptr, 10);
270 bitmask |= 2;
271 } else if (key == "num_outputs") {
272 param->num_outputs = strtoul(value.c_str(), nullptr, 10);
273 bitmask |= 4;
274 } else if (key == "flatten_data") {
275 param->flatten_data = strtoul(value.c_str(), nullptr, 10);
276 bitmask |= 8;
277 } else {
278 param->attrs[key] = String(value);
279 }
280 }
281 ICHECK_EQ(bitmask, 1 | 2 | 4 | 8) << "invalid format";
282 }
283 // JSON Loader
284 void Load(dmlc::JSONReader* reader) {
285 reader->BeginObject();
286 int bitmask = 0;
287 std::string key;
288 while (reader->NextObjectItem(&key)) {
289 if (key == "op") {
290 reader->Read(&op_type);
291 bitmask |= 1;
292 } else if (key == "name") {
293 reader->Read(&name);
294 bitmask |= 2;
295 } else if (key == "inputs") {
296 reader->Read(&inputs);
297 bitmask |= 4;
298 } else if (key == "attr" || key == "attrs") {
299 this->LoadAttrs(reader, &param);
300 } else if (key == "control_deps") {
301 reader->Read(&control_deps);
302 } else {
303 LOG(FATAL) << "do not support key " << key;
304 }
305 }
306 ICHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format";
307 }
308 };
309 struct GraphAttr {
310 size_t storage_num_not_alloctaed{0};
311 std::vector<int> storage_id;
312 std::vector<int> device_index;
313 std::vector<std::string> dltype;
314 std::vector<std::string> storage_scope;
315 std::vector<std::vector<int64_t>> shape;
316 // The graph attribute fields.
317 void Load(dmlc::JSONReader* reader) {
318 reader->BeginObject();
319 int bitmask = 0;
320 std::string key, type;
321 while (reader->NextObjectItem(&key)) {
322 if (key == "dltype") {
323 reader->BeginArray();
324 ICHECK(reader->NextArrayItem());
325 reader->Read(&type);
326 ICHECK_EQ(type, "list_str");
327 ICHECK(reader->NextArrayItem());
328 reader->Read(&dltype);
329 ICHECK(!reader->NextArrayItem());
330 bitmask |= 1;
331 } else if (key == "storage_id") {
332 reader->BeginArray();
333 ICHECK(reader->NextArrayItem());
334 reader->Read(&type);
335 ICHECK_EQ(type, "list_int");
336 ICHECK(reader->NextArrayItem());
337 reader->Read(&storage_id);
338 ICHECK(!reader->NextArrayItem());
339 bitmask |= 2;
340 } else if (key == "storage_scope") {
341 reader->BeginArray();
342 ICHECK(reader->NextArrayItem());
343 reader->Read(&type);
344 ICHECK_EQ(type, "list_str");
345 ICHECK(reader->NextArrayItem());
346 reader->Read(&storage_scope);
347 ICHECK(!reader->NextArrayItem());
348 bitmask |= 1;
349 } else if (key == "shape") {
350 reader->BeginArray();
351 ICHECK(reader->NextArrayItem());
352 reader->Read(&type);
353 ICHECK_EQ(type, "list_shape");
354 ICHECK(reader->NextArrayItem());
355 reader->Read(&shape);
356 ICHECK(!reader->NextArrayItem());
357 bitmask |= 4;
358 } else if (key == "device_index") {
359 reader->BeginArray();
360 ICHECK(reader->NextArrayItem());
361 reader->Read(&type);
362 ICHECK_EQ(type, "list_int");
363 ICHECK(reader->NextArrayItem());
364 reader->Read(&device_index);
365 ICHECK(!reader->NextArrayItem());
366 } else {
367 reader->BeginArray();
368 ICHECK(reader->NextArrayItem());
369 reader->Read(&type);
370 if (type == "list_int") {
371 ICHECK(reader->NextArrayItem());
372 std::vector<int> temp;
373 reader->Read(&temp);
374 } else if (type == "size_t") {
375 ICHECK(reader->NextArrayItem());
376 size_t temp;
377 reader->Read(&temp);
378 } else {
379 LOG(FATAL) << "cannot skip graph attr " << key;
380 }
381 ICHECK(!reader->NextArrayItem());
382 }
383 }
384 ICHECK_EQ(bitmask, 1 | 2 | 4) << "invalid format";
385 }
386 };
387 // The graph attribute fields.
388 void Load(dmlc::JSONReader* reader) {
389 reader->BeginObject();
390 int bitmask = 0;
391 std::string key;
392 while (reader->NextObjectItem(&key)) {
393 if (key == "nodes") {
394 reader->Read(&nodes_);
395 bitmask |= 1;
396 } else if (key == "arg_nodes") {
397 reader->Read(&input_nodes_);
398 bitmask |= 2;
399 } else if (key == "node_row_ptr") {
400 reader->Read(&node_row_ptr_);
401 bitmask |= 4;
402 } else if (key == "heads") {
403 reader->Read(&outputs_);
404 bitmask |= 8;
405 } else if (key == "attrs") {
406 reader->Read(&attrs_);
407 bitmask |= 16;
408 } else if (key == "metadata") {
409 break;
410 } else {
411 LOG(FATAL) << "key " << key << " is not supported";
412 }
413 }
414 ICHECK_EQ(bitmask, 1 | 2 | 4 | 8 | 16) << "invalid format";
415 }
416 /*! \brief PackedFunc to lookup a linked paramter from a local Module. */
417 void DefaultLookupLinkedParam(TVMArgs args, TVMRetValue* rv);
418 /*! \brief Delete NDArray::Container with linked (i.e. static) data. */
419 static void LinkedNDArrayDeleter(Object* container);
420 /*! \brief Setup the temporal storage */
421 void SetupStorage();
422 /*! \brief Setup the executors. */
423 void SetupOpExecs();
424 /*!
425 * \brief Check the legality of external DLTensor*.
426 * \param external The external DLTensor*.
427 * \param eid The data_enrty_ index.
428 */
429 void CheckExternalDLTensor(const DLTensor* external, uint32_t eid) const;
430 /*!
431 * \brief Create an execution function given input.
432 * \param attrs The node attributes.
433 * \param args The arguments to the functor, including inputs and outputs.
434 * \return The created executor.
435 */
436 std::pair<std::function<void()>, std::shared_ptr<OpArgs>> CreateTVMOp(
437 const TVMOpParam& attrs, const std::vector<DLTensor>& args);
438 // Get node entry index.
439 uint32_t entry_id(uint32_t nid, uint32_t index) const { return node_row_ptr_[nid] + index; }
440 // Get node entry index.
441 uint32_t entry_id(const NodeEntry& e) const { return entry_id(e.node_id, e.index); }
442 // Number of node entries.
443 uint32_t num_node_entries() const { return node_row_ptr_.back(); }
444 /*! \brief The graph nodes. */
445 std::vector<Node> nodes_;
446 /*! \brief The argument nodes. */
447 std::vector<uint32_t> input_nodes_;
448 /*! \brief The parameter names. */
449 std::unordered_set<std::string> param_names_;
450 /*! \brief Map of input names to input indices. */
451 std::unordered_map<std::string, uint32_t> input_map_;
452 /*! \brief Map of output names to output indices. */
453 std::unordered_map<std::string, uint32_t> output_map_;
454 /*! \brief Used for quick node input DLTensor* lookup given an input eid. */
455 std::vector<std::vector<DLTensor*>> input_dltensors_;
456 /*! \brief Used for quick node output DLTensor* lookup given an output eid. */
457 std::vector<std::vector<DLTensor*>> output_dltensors_;
458 /*! \brief Used for quick node(both model output and op input) DLTensor* lookup given an eid. */
459 std::vector<std::vector<DLTensor*>> both_output_opinput_dltensors_;
460 /*! \brief Used for quick entry indexing. */
461 std::vector<uint32_t> node_row_ptr_;
462 /*! \brief Output entries. */
463 std::vector<NodeEntry> outputs_;
464 /*! \brief Additional graph attributes. */
465 GraphAttr attrs_;
466 /*! \brief The code module that contains both host and device code. */
467 tvm::runtime::Module module_;
468 /*! \brief Execution context of all devices including the host. */
469 std::vector<Device> devices_;
470 /*! \brief Common storage pool for all devices. */
471 std::vector<NDArray> storage_pool_;
472 /*! \brief Data entry of each node. */
473 std::vector<NDArray> data_entry_;
474 /*! \brief Data alignment of each node. */
475 std::vector<size_t> data_alignment_;
476 /*! \brief Operator on each node. */
477 std::vector<std::function<void()>> op_execs_;
478 /*! \brief Linked parameter lookup function. */
479 PackedFunc lookup_linked_param_;
480 /*! \brief Module's _lookup_linked_param function, used by DefaultLookupLinkedParam. */
481 PackedFunc module_lookup_linked_param_;
482 /*!
483 * \brief True when module_lookup_linked_param_ is valid.
484 * When the module does not include linked parmeters, module_lookup_linked_param_ will be nullptr.
485 */
486 bool module_lookup_linked_param_valid_;
487};
488
489std::vector<Device> GetAllDevice(const TVMArgs& args, int dev_start_arg);
490} // namespace runtime
491} // namespace tvm
492
493#endif // TVM_RUNTIME_GRAPH_EXECUTOR_GRAPH_EXECUTOR_H_
494