1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/backend/graph_codegen.cc
22 * \brief Graph executor codegen
23 */
24
25#include <dmlc/any.h>
26#include <dmlc/json.h>
27#include <tvm/ir/module.h>
28#include <tvm/relay/attrs/annotation.h>
29#include <tvm/relay/attrs/call.h>
30#include <tvm/relay/expr_functor.h>
31#include <tvm/runtime/device_api.h>
32#include <tvm/runtime/object.h>
33#include <tvm/tir/analysis.h>
34#include <tvm/tir/function.h>
35
36#include <list>
37#include <string>
38#include <vector>
39
40#include "../op/annotation/annotation.h"
41#include "../op/call/call.h"
42#include "../op/memory/device_copy.h"
43#include "../transforms/device_aware_visitors.h"
44#include "./te_compiler.h"
45#include "./utils.h"
46
47namespace tvm {
48namespace relay {
49
50// TODO(@jroesch, @csullivan): declare directly elsewhere
51backend::StaticMemoryPlan GraphPlanMemory(const Function& func);
52
53namespace backend {
54
55class GraphNode;
56class GraphInputNode;
57class GraphOpNode;
58
59using IntegerArray = Array<Integer>;
60using ShapeVector = std::vector<std::vector<int64_t>>;
61using GraphAttrs = std::unordered_map<std::string, dmlc::any>;
62using GraphObjectPtr = std::shared_ptr<GraphNode>;
63using GraphInputObjectPtr = std::shared_ptr<GraphInputNode>;
64using GraphOpObjectPtr = std::shared_ptr<GraphOpNode>;
65
66/*! \brief Node types */
67enum GraphNodeType {
68 kGraphNop,
69 kGraphInputNode,
70 kGraphOpNode,
71};
72
73class GraphNodeRef {
74 public:
75 GraphNodeRef() {}
76 GraphNodeRef(int ident, int index, int version = 0)
77 : ident_(ident), index_(index), version_(version) {}
78
79 inline void Save(dmlc::JSONWriter* writer) const {
80 writer->BeginArray();
81 writer->WriteArrayItem(ident_);
82 writer->WriteArrayItem(index_);
83 writer->WriteArrayItem(version_);
84 writer->EndArray();
85 }
86
87 inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented."; }
88
89 protected:
90 int ident_;
91 int index_{0};
92 int version_{0};
93};
94
95/*! \brief Base Node class */
96class GraphNode {
97 public:
98 GraphNode() {}
99 virtual void Save(dmlc::JSONWriter* writer) const {}
100 virtual void Load(dmlc::JSONReader* reader) {}
101 virtual GraphNodeType Type() const { return kGraphNop; }
102 virtual ~GraphNode() {}
103
104 public:
105 int num_outputs_{1};
106 std::string name_;
107 GraphAttrs attrs_;
108};
109
110/*! \brief Input Node */
111class GraphInputNode : public GraphNode {
112 public:
113 GraphInputNode() {}
114 GraphInputNode(const std::string& name, const GraphAttrs& attrs) {
115 name_ = name;
116 attrs_ = attrs;
117 }
118
119 GraphNodeType Type() const override { return kGraphInputNode; }
120
121 void Save(dmlc::JSONWriter* writer) const override {
122 const std::string op_name{"null"};
123 writer->BeginObject();
124 writer->WriteObjectKeyValue("op", op_name);
125 writer->WriteObjectKeyValue("name", this->name_);
126 writer->WriteObjectKeyValue("inputs", std::list<int>());
127 writer->EndObject();
128 }
129 static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
130 const GraphAttrs& attrs) {
131 auto ptr = std::make_shared<GraphInputNode>(name, attrs);
132 return std::dynamic_pointer_cast<GraphNode>(ptr);
133 }
134};
135
136/*! \brief Op Node */
137class GraphOpNode : public GraphNode {
138 public:
139 GraphOpNode() {}
140 GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name,
141 const std::vector<GraphNodeRef>& inputs, const GraphAttrs& attrs,
142 size_t num_outputs = 1) {
143 name_ = name;
144 attrs_ = nd_attrs;
145 op_name_ = op_name;
146 inputs_ = inputs;
147 op_attrs_ = attrs;
148 num_outputs_ = num_outputs;
149 op_attrs_["func_name"] = op_name_;
150 op_attrs_["flatten_data"] = std::string("0");
151 op_attrs_["num_inputs"] = std::to_string(inputs_.size());
152 op_attrs_["num_outputs"] = std::to_string(num_outputs_);
153 }
154
155 GraphNodeType Type() const override { return kGraphOpNode; }
156
157 void Save(dmlc::JSONWriter* writer) const override {
158 GraphAttrs attrs = op_attrs_;
159 attrs["func_name"] = this->op_name_;
160 attrs["flatten_data"] = std::string("0");
161 attrs["num_inputs"] = std::to_string(this->inputs_.size());
162 attrs["num_outputs"] = std::to_string(this->num_outputs_);
163 writer->BeginObject();
164 writer->WriteObjectKeyValue("op", op_type_name_);
165 writer->WriteObjectKeyValue("name", name_);
166 writer->WriteObjectKeyValue("attrs", attrs);
167 writer->WriteObjectKeyValue("inputs", this->inputs_);
168 writer->EndObject();
169 }
170 static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name,
171 const GraphAttrs& nd_attrs,
172 const std::string& op_name,
173 const std::vector<GraphNodeRef>& inputs,
174 const GraphAttrs& attrs, size_t num_outputs = 1) {
175 auto ptr = std::make_shared<GraphOpNode>(name, nd_attrs, op_name, inputs, attrs, num_outputs);
176 return std::dynamic_pointer_cast<GraphNode>(ptr);
177 }
178
179 public:
180 std::string op_name_;
181 std::vector<GraphNodeRef> inputs_;
182 GraphAttrs op_attrs_;
183
184 private:
185 const std::string op_type_name_{"tvm_op"};
186};
187
188/*! \brief Code generator for the graph executor, produces a module containing the graph JSON,
189 * module, and parameters.
190 */
191class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> {
192 public:
193 GraphExecutorCodegen(runtime::Module* mod, const Array<Target>& targets)
194 : mod_(mod), config_(transform::PassContext::Current(), targets) {}
195
196 StorageInfo GetStorageInfo(const Expr& e) {
197 size_t count = memory_plan_->expr_to_storage_info.count(e);
198 ICHECK_GT(count, 0) << "Expr is not existing in storage plan";
199 auto storage_info = memory_plan_->expr_to_storage_info[e];
200 return storage_info;
201 }
202
203 LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) {
204 mod_name_ = mod_name;
205 VLOG_CONTEXT << "GraphExecutorCodegen";
206 VLOG(1) << "compiling:" << std::endl << PrettyPrint(func);
207
208 // TODO(mbs): Why plan memory and update workspace sizes before lowering?
209 memory_plan_ = GraphPlanMemory(func);
210
211 backend::FunctionInfo func_info;
212
213 if (memory_plan_.defined()) {
214 // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize
215 func_info =
216 relay::tec::UpdateMainWorkspaceSize(mod, config_, memory_plan_->expr_to_storage_info);
217 mod = WithAttr(mod, "main_func_info", func_info);
218 }
219
220 IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) {
221 // We need to maintain the constant map for external
222 // functions so we pass this processing function which
223 // allows us to process each function as we lower it.
224 if (func->GetAttr<String>(attr::kCompiler).defined()) {
225 UpdateConstants(func, &params_);
226 }
227
228 // TODO(@areusch, @jroesch): We should refactor this to
229 // execute as a further pass, instead writing data to the
230 // lowering process directly.
231 tec::UpdateFunctionMetadata(func, this->function_metadata_);
232 })(mod);
233
234 Optional<backend::FunctionInfo> main_func_info =
235 lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info");
236
237 function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value());
238
239 Function lowered_main_func = Downcast<Function>(lowered_mod->Lookup("main"));
240
241 // Now that we have lowered all operators to TIR code, we can proceed with compilation.
242 //
243 // We need to unfortunately re-plan as the previous results have been invalidated by lowering
244 // we will fix this in future refactors.
245 memory_plan_ = GraphPlanMemory(lowered_main_func);
246
247 // The graph planner also can not handle planning calls to global variables to we must remap
248
249 // First we convert all the parameters into input nodes.
250 for (auto param : lowered_main_func->params) {
251 auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs());
252 var_map_[param.get()] = AddNode(node_ptr, param);
253 }
254
255 heads_ = VisitExpr(lowered_main_func->body);
256 std::ostringstream os;
257
258 dmlc::JSONWriter writer(&os);
259 GetJSON(&writer);
260 LoweredOutput ret;
261 ret.graph_json = os.str();
262
263 // Collect any runtime modules generated by external codegen.
264 ret.external_mods =
265 lowered_mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
266
267 // Collect any constants extracted by external codegen.
268 ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
269 Map<String, runtime::NDArray> const_name_to_constant =
270 lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
271 .value_or({});
272 for (const auto& kv : const_name_to_constant) {
273 VLOG(1) << "constant '" << kv.first << "' contributed by external codegen";
274 ICHECK(ret.params.emplace(kv.first, kv.second).second);
275 }
276
277 // Collect any constants extracted during lowering.
278 for (const auto& kv : params_) {
279 VLOG(1) << "constant '" << kv.first << "' contributed by TECompiler";
280 ICHECK(ret.params.emplace(kv.first, kv.second).second);
281 }
282
283 ret.function_metadata = std::move(function_metadata_);
284
285 // This is the point where we separate the functions in the module by target
286 ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
287 ret.metadata =
288 ExecutorCodegenMetadata({} /* inputs */, {} /* input_tensor_types */, {} /* outputs */,
289 {} /* output_tensor_types */, {} /* pools */, {} /* devices */,
290 runtime::kTvmExecutorGraph /* executor */, mod_name_ /* mod_name */,
291 "packed" /* interface_api */, Bool(false) /* unpacked_api */);
292 return ret;
293 }
294
295 protected:
296 /*!
297 * \brief Add node to graph
298 *
299 * \param node
300 * \param expr
301 * \return std::vector<_NodeRef>
302 */
303 std::vector<GraphNodeRef> AddNode(GraphObjectPtr node, Expr expr) {
304 auto checked_type = expr->checked_type();
305
306 auto storage_info = GetStorageInfo(expr);
307 // storage
308 std::vector<int64_t> storage_ids;
309 for (auto v : storage_info->storage_ids) {
310 storage_ids.push_back(v);
311 }
312 node->attrs_["storage_id"] = std::move(storage_ids);
313 // type
314 std::vector<int64_t> device_types;
315 for (const auto& virtual_device : storage_info->virtual_devices) {
316 // TODO(mbs): Keeping only the device type.
317 ICHECK_GT(virtual_device->device_type(), 0);
318 device_types.push_back(virtual_device->device_type());
319 }
320 size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0);
321 if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) {
322 LOG(FATAL) << "The graph contains not annotated nodes for "
323 << "heterogeneous execution. All nodes must be "
324 << "annotated.";
325 }
326 if (num_unknown_devices == 0) {
327 node->attrs_["device_index"] = device_types;
328 }
329 // storage scope
330 std::vector<std::string> storage_scope;
331 for (const auto& virtual_device : storage_info->virtual_devices) {
332 storage_scope.push_back(std::string(virtual_device->memory_scope));
333 }
334 node->attrs_["storage_scope"] = std::move(storage_scope);
335 auto node_id = nodes_.size();
336 nodes_.push_back(node);
337 // Tuple return value, flatten as tuple
338 if (const auto* tuple_type = checked_type.as<TupleTypeNode>()) {
339 std::vector<GraphNodeRef> ret;
340 ShapeVector shape;
341 std::vector<std::string> dtype;
342 for (size_t i = 0; i < tuple_type->fields.size(); ++i) {
343 if (const auto* typ = tuple_type->fields[i].as<TensorTypeNode>()) {
344 ret.push_back(GraphNodeRef(node_id, i));
345 shape.emplace_back(ShapeToJSON(typ->shape));
346 dtype.emplace_back(DType2String(typ->dtype));
347 } else {
348 LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported";
349 }
350 }
351 ICHECK_EQ(node->Type(), kGraphOpNode);
352 auto op_nd = std::dynamic_pointer_cast<GraphOpNode>(node);
353 op_nd->attrs_["shape"] = shape;
354 op_nd->attrs_["dtype"] = dtype;
355 op_nd->num_outputs_ = tuple_type->fields.size();
356 return ret;
357 }
358 // Normal tensor return type
359 if (const auto* tensor_type = checked_type.as<TensorTypeNode>()) {
360 ShapeVector shape;
361 std::vector<std::string> dtype;
362 shape.emplace_back(ShapeToJSON(tensor_type->shape));
363 dtype.emplace_back(DType2String(tensor_type->dtype));
364 node->attrs_["shape"] = shape;
365 node->attrs_["dtype"] = dtype;
366 } else {
367 LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported";
368 }
369 return {GraphNodeRef(node_id, 0)};
370 }
371
372 std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override {
373 Expr expr = GetRef<Expr>(op);
374 return var_map_[expr.get()];
375 }
376
377 std::vector<GraphNodeRef> VisitExpr_(const ConstantNode* op) override {
378 Expr expr = GetRef<Expr>(op);
379 size_t index = params_.size();
380 std::string name = "p" + std::to_string(index);
381 auto node = GraphInputNode::make_node_ptr(name, GraphAttrs());
382 auto to_return = AddNode(node, expr);
383 CHECK_EQ(to_return.size(), 1) << "Expected exactly 1 parameter node created";
384 param_storage_ids_[name] = GetStorageInfo(expr)->storage_ids[0];
385 params_[name] = op->data;
386 return to_return;
387 }
388
389 std::vector<GraphNodeRef> VisitExpr_(const TupleNode* op) override {
390 std::vector<GraphNodeRef> fields;
391 for (auto field : op->fields) {
392 auto ref_vec = VisitExpr(field);
393 for (auto ref : ref_vec) {
394 fields.push_back(ref);
395 }
396 }
397 return fields;
398 }
399
400 bool ShareSameStorage(const Expr& lhs, const Expr& rhs) {
401 StorageInfo lit = GetStorageInfo(lhs);
402 StorageInfo rit = GetStorageInfo(rhs);
403 int64_t lhs_storage_id = lit->storage_ids[0];
404 int64_t rhs_storage_id = rit->storage_ids[0];
405 return lhs_storage_id == rhs_storage_id;
406 }
407
408 std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, GraphAttrs attrs) {
409 Call call = GetRef<Call>(call_node);
410 std::vector<GraphNodeRef> inputs;
411 std::string func_name;
412
413 DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
414 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
415 if (device_copy_props.body.defined()) {
416 // The graph executor expects to see a normal call to the undefined @__copy function.
417 // The source and destination device annotations are no longer needed since they have
418 // been captured in the StorageInfos for both input and output.
419 // TODO(mbs): device_copy cleanup
420 func_name = "__copy";
421 for (const auto& n : VisitExpr(device_copy_props.body)) {
422 inputs.push_back(n);
423 }
424 } else if (call_lowered_props.lowered_func.defined()) {
425 // Extract function and arguments from the call_lowered op
426
427 func_name = call_lowered_props.lowered_func->name_hint;
428
429 for (const Expr& arg : call_lowered_props.arguments) {
430 for (auto n : VisitExpr(arg)) {
431 inputs.push_back(n);
432 }
433 }
434 if (call_lowered_props.attrs.metadata.count("relay_attrs")) {
435 if (auto relay_attrs =
436 call_lowered_props.attrs.metadata["relay_attrs"].as<DictAttrsNode>()) {
437 for (auto p : relay_attrs->dict) {
438 if (p.second.as<StringObj>()) {
439 attrs[p.first] = std::string(Downcast<String>(p.second));
440 }
441 }
442 }
443 }
444 // TODO(mbs): "reshape" cleanup.
445 if (IsReshapeOnly(call_lowered_props) &&
446 ShareSameStorage(GetRef<Expr>(call_node), call_lowered_props.arguments[0])) {
447 auto node = GraphOpNode::make_node_ptr("reshape_nop", GraphAttrs(), "__nop", inputs, attrs);
448 return AddNode(node, call);
449 }
450 } else if (!call_node->attrs.defined()) { // Call is an extern function
451 const auto* func = call_node->op.as<GlobalVarNode>();
452 ICHECK(func) << "Expected the operator to be a global var, but got "
453 << call_node->op->GetTypeKey(); // getting a relay fn here, not sure why.
454 func_name = func->name_hint;
455
456 for (const Expr& arg : call_node->args) {
457 for (auto n : VisitExpr(arg)) {
458 inputs.push_back(n);
459 }
460 }
461 } else {
462 LOG(FATAL) << "Non-primitive-call nodes should have been transformed away.\n"
463 << "The graph executor code generator expects all calls to be call_lowered, "
464 << "but found: " << std::endl
465 << PrettyPrint(call);
466 }
467
468 // Compute the operator name, because we used the get unique name when generating the kernel.
469 auto op_name = name_supply_->FreshName(func_name);
470 auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs);
471 return AddNode(node, call);
472 }
473
474 std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override {
475 relay::Call call = GetRef<Call>(call_node);
476 OnDeviceProps props = GetOnDeviceProps(call_node);
477 if (props.body.defined()) {
478 // See through "on_device" calls.
479 return VisitExpr(props.body);
480 }
481 return GraphAddCallNode(call_node, GraphAttrs());
482 }
483
484 std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override {
485 ICHECK_EQ(var_map_.count(op->var.get()), 0);
486 var_map_[op->var.get()] = VisitExpr(op->value);
487 return VisitExpr(op->body);
488 }
489 std::vector<GraphNodeRef> VisitExpr_(const TupleGetItemNode* op) override {
490 auto vtuple = VisitExpr(op->tuple);
491 return {vtuple[op->index]};
492 }
493
494 std::vector<GraphNodeRef> VisitExpr_(const OpNode* op) override {
495 LOG(FATAL) << "All OpNodes should have been expanded";
496 }
497 std::vector<GraphNodeRef> VisitExpr_(const GlobalVarNode* op) override {
498 LOG(FATAL) << "All GlobalVarNodes should be removed before graph executor's Codegen is called";
499 }
500 std::vector<GraphNodeRef> VisitExpr_(const IfNode* op) override {
501 LOG(FATAL) << "Graph executor does not support control flow (found IfNode)";
502 }
503 std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override {
504 ICHECK(op->GetAttr<String>(attr::kCompiler).defined())
505 << "Only functions supported by custom codegen";
506 return {};
507 }
508 std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override {
509 LOG(FATAL) << "Graph executor does not support references (found RefCreateNode)";
510 }
511 std::vector<GraphNodeRef> VisitExpr_(const RefReadNode* op) override {
512 LOG(FATAL) << "Graph executor does not support references (found RefReadNode)";
513 }
514 std::vector<GraphNodeRef> VisitExpr_(const RefWriteNode* op) override {
515 LOG(FATAL) << "Graph executor does not support references (found RefWriteNode)";
516 }
517 std::vector<GraphNodeRef> VisitExpr_(const ConstructorNode* op) override {
518 LOG(FATAL) << "Graph executor does not support ADTs (found ConstructorNode)";
519 }
520 std::vector<GraphNodeRef> VisitExpr_(const MatchNode* op) override {
521 LOG(FATAL) << "Graph executor does not support matching (found MatchNode)";
522 }
523 /*!
524 * \brief Generate Graph JSON
525 *
526 * \param writer json writer
527 */
528 void GetJSON(dmlc::JSONWriter* writer) {
529 std::vector<size_t> arg_nodes;
530 for (size_t i = 0; i < nodes_.size(); ++i) {
531 auto node = nodes_[i];
532 if (node->Type() == kGraphInputNode) {
533 arg_nodes.push_back(i);
534 }
535 }
536 size_t num_entry = 0;
537 ShapeVector shapes;
538 std::vector<size_t> storage_ids;
539 std::vector<std::string> storage_scopes;
540 std::vector<size_t> device_types;
541 std::vector<std::string> dltypes;
542 std::vector<size_t> node_row_ptr{0};
543 for (auto node : nodes_) {
544 const auto& shape_vec = dmlc::get<ShapeVector>(node->attrs_["shape"]);
545 const auto& storage_id = dmlc::get<std::vector<int64_t>>(node->attrs_["storage_id"]);
546 const auto& storage_scope =
547 dmlc::get<std::vector<std::string>>(node->attrs_["storage_scope"]);
548 const auto& dtype_vec = dmlc::get<std::vector<std::string>>(node->attrs_["dtype"]);
549
550 ICHECK_EQ(node->num_outputs_, shape_vec.size());
551 num_entry += node->num_outputs_;
552
553 shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end());
554 dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end());
555 storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end());
556 storage_scopes.insert(storage_scopes.end(), storage_scope.begin(), storage_scope.end());
557 if (node->attrs_.count("device_index")) {
558 const auto& dev_types = dmlc::get<std::vector<int64_t>>(node->attrs_["device_index"]);
559 device_types.insert(device_types.end(), dev_types.begin(), dev_types.end());
560 }
561 node_row_ptr.push_back(num_entry);
562 }
563
564 // verification if storage_scope contains any non global memory scope
565 // in other case it's better not to write scopes to the JSON at all
566 bool global_only_scope = true;
567 for (const auto& ss : storage_scopes) {
568 if (!(ss.empty() || ss == "global")) {
569 global_only_scope = false;
570 }
571 }
572 if (global_only_scope) {
573 storage_scopes.clear();
574 }
575 writer->BeginObject();
576 writer->WriteObjectKeyValue("nodes", nodes_);
577 writer->WriteObjectKeyValue("arg_nodes", arg_nodes);
578 writer->WriteObjectKeyValue("heads", heads_);
579 std::unordered_map<std::string, std::vector<dmlc::any>> attrs;
580 attrs["shape"].emplace_back(std::string("list_shape"));
581 attrs["shape"].emplace_back(shapes);
582 attrs["storage_id"].emplace_back(std::string("list_int"));
583 attrs["storage_id"].emplace_back(storage_ids);
584 if (device_types.size()) {
585 attrs["device_index"].emplace_back(std::string("list_int"));
586 attrs["device_index"].emplace_back(device_types);
587 }
588 if (storage_scopes.size()) {
589 attrs["storage_scope"].emplace_back(std::string("list_str"));
590 attrs["storage_scope"].emplace_back(storage_scopes);
591 }
592 attrs["dltype"].emplace_back(std::string("list_str"));
593 attrs["dltype"].emplace_back(dltypes);
594 writer->WriteObjectKeyValue("attrs", attrs);
595 writer->WriteObjectKeyValue("node_row_ptr", node_row_ptr);
596 writer->EndObject();
597 }
598
599 protected:
600 /*! \brief nodes */
601 std::vector<GraphObjectPtr> nodes_;
602 /*! \brief output of graph */
603 std::vector<GraphNodeRef> heads_;
604 /*! \brief mod */
605 runtime::Module* mod_;
606 /*! \brief variable map */
607 std::unordered_map<const Object*, std::vector<GraphNodeRef>> var_map_;
608 /*! \brief Available targets */
609 CompilationConfig config_;
610 /*!
611 * \brief parameters (i.e. ConstantNodes found in the graph).
612 * These are take as inputs to the GraphExecutor.
613 * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be
614 * used to lookup the parameter.
615 */
616 std::unordered_map<std::string, runtime::NDArray> params_;
617 std::unordered_map<std::string, int64_t> param_storage_ids_;
618 /*! \brief plan memory of device result */
619 StaticMemoryPlan memory_plan_;
620 /*! \brief the module name we use to mangle the function names */
621 String mod_name_;
622 /*! \brief function metadata */
623 Map<String, FunctionInfo> function_metadata_;
624 /*! \brief NameSupply */
625 NameSupply name_supply_ = NameSupply("");
626};
627
628class GraphExecutorCodegenModule : public runtime::ModuleNode {
629 public:
630 GraphExecutorCodegenModule() {}
631 virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
632 if (name == "init") {
633 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
634 ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
635 << "runtime::Module mod and Array<Target> targets";
636 void* mod = args[0];
637 Array<Target> targets = args[1];
638 codegen_ = std::make_shared<GraphExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod),
639 std::move(targets));
640 });
641 } else if (name == "codegen") {
642 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
643 IRModule mod = args[0];
644 Function func = args[1];
645 String mod_name = args[2];
646 this->output_ = this->codegen_->Codegen(mod, func, mod_name);
647 });
648 } else if (name == "get_graph_json") {
649 return PackedFunc(
650 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; });
651 } else if (name == "list_params_name") {
652 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
653 Array<runtime::String> ret;
654 for (const auto& kv : this->output_.params) {
655 ret.push_back(kv.first);
656 }
657 *rv = ret;
658 });
659 } else if (name == "get_param_by_name") {
660 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
661 String key = args[0];
662 auto it = this->output_.params.find(key);
663 CHECK(it != this->output_.params.end()) << "no such parameter " << key;
664 *rv = (*it).second;
665 });
666 } else if (name == "get_irmodule") {
667 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
668 *rv = this->output_.lowered_funcs;
669 });
670 } else if (name == "get_external_modules") {
671 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
672 *rv = this->output_.external_mods;
673 });
674 } else if (name == "get_devices") {
675 return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = Array<String>(); });
676 } else if (name == "get_executor_codegen_metadata") {
677 return PackedFunc(
678 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; });
679 } else if (name == "get_function_metadata") {
680 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
681 *rv = this->output_.function_metadata;
682 });
683 } else {
684 return PackedFunc([](TVMArgs args, TVMRetValue* rv) {});
685 }
686 }
687
688 const char* type_key() const final { return "RelayGraphExecutorCodegenModule"; }
689
690 private:
691 std::shared_ptr<GraphExecutorCodegen> codegen_;
692 LoweredOutput output_;
693};
694
695runtime::Module CreateGraphCodegenMod() {
696 auto ptr = make_object<GraphExecutorCodegenModule>();
697 return runtime::Module(ptr);
698}
699
700TVM_REGISTER_GLOBAL("relay.build_module._GraphExecutorCodegen")
701 .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); });
702
703} // namespace backend
704} // namespace relay
705} // namespace tvm
706
707namespace dmlc {
708namespace json {
709// JSON utils
710template <typename T>
711inline bool SameType(const dmlc::any& data) {
712 return std::type_index(data.type()) == std::type_index(typeid(T));
713}
714
715template <>
716struct Handler<std::shared_ptr<tvm::relay::backend::GraphNode>> {
717 inline static void Write(dmlc::JSONWriter* writer,
718 const std::shared_ptr<tvm::relay::backend::GraphNode>& data) {
719 data->Save(writer);
720 }
721 inline static void Read(dmlc::JSONReader* reader,
722 std::shared_ptr<tvm::relay::backend::GraphNode>* data) {
723 LOG(FATAL) << "Not implemented.";
724 }
725};
726template <>
727struct Handler<std::unordered_map<std::string, dmlc::any>> {
728 inline static void Write(dmlc::JSONWriter* writer,
729 const std::unordered_map<std::string, dmlc::any>& data) {
730 writer->BeginObject();
731 for (const auto& kv : data) {
732 auto k = kv.first;
733 const dmlc::any& v = kv.second;
734 if (SameType<std::string>(v)) {
735 writer->WriteObjectKeyValue(k, dmlc::get<std::string>(v));
736 } else if (SameType<int>(v)) {
737 writer->WriteObjectKeyValue(k, dmlc::get<int>(v));
738 } else if (SameType<std::vector<size_t>>(v)) {
739 writer->WriteObjectKeyValue(k, dmlc::get<std::vector<size_t>>(v));
740 } else if (SameType<std::vector<std::vector<int64_t>>>(v)) {
741 writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::vector<int64_t>>>(v));
742 } else if (SameType<std::vector<std::string>>(v)) {
743 writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::string>>(v));
744 } else if (SameType<std::vector<dmlc::any>>(v)) {
745 writer->WriteObjectKeyValue(k, dmlc::get<std::vector<dmlc::any>>(v));
746 } else {
747 LOG(FATAL) << "Not supported";
748 }
749 }
750 writer->EndObject();
751 }
752 inline static void Read(dmlc::JSONReader* reader,
753 std::unordered_map<std::string, dmlc::any>* data) {
754 LOG(FATAL) << "Not implemented.";
755 }
756};
757
758template <>
759struct Handler<std::vector<dmlc::any>> {
760 inline static void Write(dmlc::JSONWriter* writer, const std::vector<dmlc::any>& data) {
761 writer->BeginArray();
762 for (const auto& v : data) {
763 if (SameType<std::string>(v)) {
764 writer->WriteArrayItem(dmlc::get<std::string>(v));
765 } else if (SameType<int>(v)) {
766 writer->WriteArrayItem(dmlc::get<int>(v));
767 } else if (SameType<std::vector<size_t>>(v)) {
768 writer->WriteArrayItem(dmlc::get<std::vector<size_t>>(v));
769 } else if (SameType<std::vector<std::vector<int64_t>>>(v)) {
770 writer->WriteArrayItem(dmlc::get<std::vector<std::vector<int64_t>>>(v));
771 } else if (SameType<std::vector<std::string>>(v)) {
772 writer->WriteArrayItem(dmlc::get<std::vector<std::string>>(v));
773 } else {
774 LOG(FATAL) << "Not supported";
775 }
776 }
777 writer->EndArray();
778 }
779 inline static void Read(dmlc::JSONReader* reader, std::vector<dmlc::any>* data) {
780 LOG(FATAL) << "Not implemented.";
781 }
782};
783} // namespace json
784} // namespace dmlc
785