1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/backend/graph_codegen.cc |
22 | * \brief Graph executor codegen |
23 | */ |
24 | |
25 | #include <dmlc/any.h> |
26 | #include <dmlc/json.h> |
27 | #include <tvm/ir/module.h> |
28 | #include <tvm/relay/attrs/annotation.h> |
29 | #include <tvm/relay/attrs/call.h> |
30 | #include <tvm/relay/expr_functor.h> |
31 | #include <tvm/runtime/device_api.h> |
32 | #include <tvm/runtime/object.h> |
33 | #include <tvm/tir/analysis.h> |
34 | #include <tvm/tir/function.h> |
35 | |
36 | #include <list> |
37 | #include <string> |
38 | #include <vector> |
39 | |
40 | #include "../op/annotation/annotation.h" |
41 | #include "../op/call/call.h" |
42 | #include "../op/memory/device_copy.h" |
43 | #include "../transforms/device_aware_visitors.h" |
44 | #include "./te_compiler.h" |
45 | #include "./utils.h" |
46 | |
47 | namespace tvm { |
48 | namespace relay { |
49 | |
50 | // TODO(@jroesch, @csullivan): declare directly elsewhere |
51 | backend::StaticMemoryPlan GraphPlanMemory(const Function& func); |
52 | |
53 | namespace backend { |
54 | |
55 | class GraphNode; |
56 | class GraphInputNode; |
57 | class GraphOpNode; |
58 | |
59 | using IntegerArray = Array<Integer>; |
60 | using ShapeVector = std::vector<std::vector<int64_t>>; |
61 | using GraphAttrs = std::unordered_map<std::string, dmlc::any>; |
62 | using GraphObjectPtr = std::shared_ptr<GraphNode>; |
63 | using GraphInputObjectPtr = std::shared_ptr<GraphInputNode>; |
64 | using GraphOpObjectPtr = std::shared_ptr<GraphOpNode>; |
65 | |
66 | /*! \brief Node types */ |
67 | enum GraphNodeType { |
68 | kGraphNop, |
69 | kGraphInputNode, |
70 | kGraphOpNode, |
71 | }; |
72 | |
73 | class GraphNodeRef { |
74 | public: |
75 | GraphNodeRef() {} |
76 | GraphNodeRef(int ident, int index, int version = 0) |
77 | : ident_(ident), index_(index), version_(version) {} |
78 | |
79 | inline void Save(dmlc::JSONWriter* writer) const { |
80 | writer->BeginArray(); |
81 | writer->WriteArrayItem(ident_); |
82 | writer->WriteArrayItem(index_); |
83 | writer->WriteArrayItem(version_); |
84 | writer->EndArray(); |
85 | } |
86 | |
87 | inline void Load(dmlc::JSONReader* reader) { LOG(FATAL) << "Not implemented." ; } |
88 | |
89 | protected: |
90 | int ident_; |
91 | int index_{0}; |
92 | int version_{0}; |
93 | }; |
94 | |
95 | /*! \brief Base Node class */ |
96 | class GraphNode { |
97 | public: |
98 | GraphNode() {} |
99 | virtual void Save(dmlc::JSONWriter* writer) const {} |
100 | virtual void Load(dmlc::JSONReader* reader) {} |
101 | virtual GraphNodeType Type() const { return kGraphNop; } |
102 | virtual ~GraphNode() {} |
103 | |
104 | public: |
105 | int num_outputs_{1}; |
106 | std::string name_; |
107 | GraphAttrs attrs_; |
108 | }; |
109 | |
110 | /*! \brief Input Node */ |
111 | class GraphInputNode : public GraphNode { |
112 | public: |
113 | GraphInputNode() {} |
114 | GraphInputNode(const std::string& name, const GraphAttrs& attrs) { |
115 | name_ = name; |
116 | attrs_ = attrs; |
117 | } |
118 | |
119 | GraphNodeType Type() const override { return kGraphInputNode; } |
120 | |
121 | void Save(dmlc::JSONWriter* writer) const override { |
122 | const std::string op_name{"null" }; |
123 | writer->BeginObject(); |
124 | writer->WriteObjectKeyValue("op" , op_name); |
125 | writer->WriteObjectKeyValue("name" , this->name_); |
126 | writer->WriteObjectKeyValue("inputs" , std::list<int>()); |
127 | writer->EndObject(); |
128 | } |
129 | static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name, |
130 | const GraphAttrs& attrs) { |
131 | auto ptr = std::make_shared<GraphInputNode>(name, attrs); |
132 | return std::dynamic_pointer_cast<GraphNode>(ptr); |
133 | } |
134 | }; |
135 | |
136 | /*! \brief Op Node */ |
137 | class GraphOpNode : public GraphNode { |
138 | public: |
139 | GraphOpNode() {} |
140 | GraphOpNode(const std::string& name, const GraphAttrs& nd_attrs, const std::string& op_name, |
141 | const std::vector<GraphNodeRef>& inputs, const GraphAttrs& attrs, |
142 | size_t num_outputs = 1) { |
143 | name_ = name; |
144 | attrs_ = nd_attrs; |
145 | op_name_ = op_name; |
146 | inputs_ = inputs; |
147 | op_attrs_ = attrs; |
148 | num_outputs_ = num_outputs; |
149 | op_attrs_["func_name" ] = op_name_; |
150 | op_attrs_["flatten_data" ] = std::string("0" ); |
151 | op_attrs_["num_inputs" ] = std::to_string(inputs_.size()); |
152 | op_attrs_["num_outputs" ] = std::to_string(num_outputs_); |
153 | } |
154 | |
155 | GraphNodeType Type() const override { return kGraphOpNode; } |
156 | |
157 | void Save(dmlc::JSONWriter* writer) const override { |
158 | GraphAttrs attrs = op_attrs_; |
159 | attrs["func_name" ] = this->op_name_; |
160 | attrs["flatten_data" ] = std::string("0" ); |
161 | attrs["num_inputs" ] = std::to_string(this->inputs_.size()); |
162 | attrs["num_outputs" ] = std::to_string(this->num_outputs_); |
163 | writer->BeginObject(); |
164 | writer->WriteObjectKeyValue("op" , op_type_name_); |
165 | writer->WriteObjectKeyValue("name" , name_); |
166 | writer->WriteObjectKeyValue("attrs" , attrs); |
167 | writer->WriteObjectKeyValue("inputs" , this->inputs_); |
168 | writer->EndObject(); |
169 | } |
170 | static std::shared_ptr<GraphNode> make_node_ptr(const std::string& name, |
171 | const GraphAttrs& nd_attrs, |
172 | const std::string& op_name, |
173 | const std::vector<GraphNodeRef>& inputs, |
174 | const GraphAttrs& attrs, size_t num_outputs = 1) { |
175 | auto ptr = std::make_shared<GraphOpNode>(name, nd_attrs, op_name, inputs, attrs, num_outputs); |
176 | return std::dynamic_pointer_cast<GraphNode>(ptr); |
177 | } |
178 | |
179 | public: |
180 | std::string op_name_; |
181 | std::vector<GraphNodeRef> inputs_; |
182 | GraphAttrs op_attrs_; |
183 | |
184 | private: |
185 | const std::string op_type_name_{"tvm_op" }; |
186 | }; |
187 | |
188 | /*! \brief Code generator for the graph executor, produces a module containing the graph JSON, |
189 | * module, and parameters. |
190 | */ |
191 | class GraphExecutorCodegen : public backend::MemoizedExprTranslator<std::vector<GraphNodeRef>> { |
192 | public: |
193 | GraphExecutorCodegen(runtime::Module* mod, const Array<Target>& targets) |
194 | : mod_(mod), config_(transform::PassContext::Current(), targets) {} |
195 | |
196 | StorageInfo GetStorageInfo(const Expr& e) { |
197 | size_t count = memory_plan_->expr_to_storage_info.count(e); |
198 | ICHECK_GT(count, 0) << "Expr is not existing in storage plan" ; |
199 | auto storage_info = memory_plan_->expr_to_storage_info[e]; |
200 | return storage_info; |
201 | } |
202 | |
203 | LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { |
204 | mod_name_ = mod_name; |
205 | VLOG_CONTEXT << "GraphExecutorCodegen" ; |
206 | VLOG(1) << "compiling:" << std::endl << PrettyPrint(func); |
207 | |
208 | // TODO(mbs): Why plan memory and update workspace sizes before lowering? |
209 | memory_plan_ = GraphPlanMemory(func); |
210 | |
211 | backend::FunctionInfo func_info; |
212 | |
213 | if (memory_plan_.defined()) { |
214 | // TODO(@electriclilies, @jroesch): remove UpdateMainWorkspaceSize |
215 | func_info = |
216 | relay::tec::UpdateMainWorkspaceSize(mod, config_, memory_plan_->expr_to_storage_info); |
217 | mod = WithAttr(mod, "main_func_info" , func_info); |
218 | } |
219 | |
220 | IRModule lowered_mod = tec::LowerTE(mod_name_, config_, [this](BaseFunc func) { |
221 | // We need to maintain the constant map for external |
222 | // functions so we pass this processing function which |
223 | // allows us to process each function as we lower it. |
224 | if (func->GetAttr<String>(attr::kCompiler).defined()) { |
225 | UpdateConstants(func, ¶ms_); |
226 | } |
227 | |
228 | // TODO(@areusch, @jroesch): We should refactor this to |
229 | // execute as a further pass, instead writing data to the |
230 | // lowering process directly. |
231 | tec::UpdateFunctionMetadata(func, this->function_metadata_); |
232 | })(mod); |
233 | |
234 | Optional<backend::FunctionInfo> main_func_info = |
235 | lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info" ); |
236 | |
237 | function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info.value()); |
238 | |
239 | Function lowered_main_func = Downcast<Function>(lowered_mod->Lookup("main" )); |
240 | |
241 | // Now that we have lowered all operators to TIR code, we can proceed with compilation. |
242 | // |
243 | // We need to unfortunately re-plan as the previous results have been invalidated by lowering |
244 | // we will fix this in future refactors. |
245 | memory_plan_ = GraphPlanMemory(lowered_main_func); |
246 | |
247 | // The graph planner also can not handle planning calls to global variables to we must remap |
248 | |
249 | // First we convert all the parameters into input nodes. |
250 | for (auto param : lowered_main_func->params) { |
251 | auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); |
252 | var_map_[param.get()] = AddNode(node_ptr, param); |
253 | } |
254 | |
255 | heads_ = VisitExpr(lowered_main_func->body); |
256 | std::ostringstream os; |
257 | |
258 | dmlc::JSONWriter writer(&os); |
259 | GetJSON(&writer); |
260 | LoweredOutput ret; |
261 | ret.graph_json = os.str(); |
262 | |
263 | // Collect any runtime modules generated by external codegen. |
264 | ret.external_mods = |
265 | lowered_mod->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({}); |
266 | |
267 | // Collect any constants extracted by external codegen. |
268 | ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>(); |
269 | Map<String, runtime::NDArray> const_name_to_constant = |
270 | lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant) |
271 | .value_or({}); |
272 | for (const auto& kv : const_name_to_constant) { |
273 | VLOG(1) << "constant '" << kv.first << "' contributed by external codegen" ; |
274 | ICHECK(ret.params.emplace(kv.first, kv.second).second); |
275 | } |
276 | |
277 | // Collect any constants extracted during lowering. |
278 | for (const auto& kv : params_) { |
279 | VLOG(1) << "constant '" << kv.first << "' contributed by TECompiler" ; |
280 | ICHECK(ret.params.emplace(kv.first, kv.second).second); |
281 | } |
282 | |
283 | ret.function_metadata = std::move(function_metadata_); |
284 | |
285 | // This is the point where we separate the functions in the module by target |
286 | ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); |
287 | ret.metadata = |
288 | ExecutorCodegenMetadata({} /* inputs */, {} /* input_tensor_types */, {} /* outputs */, |
289 | {} /* output_tensor_types */, {} /* pools */, {} /* devices */, |
290 | runtime::kTvmExecutorGraph /* executor */, mod_name_ /* mod_name */, |
291 | "packed" /* interface_api */, Bool(false) /* unpacked_api */); |
292 | return ret; |
293 | } |
294 | |
295 | protected: |
296 | /*! |
297 | * \brief Add node to graph |
298 | * |
299 | * \param node |
300 | * \param expr |
301 | * \return std::vector<_NodeRef> |
302 | */ |
303 | std::vector<GraphNodeRef> AddNode(GraphObjectPtr node, Expr expr) { |
304 | auto checked_type = expr->checked_type(); |
305 | |
306 | auto storage_info = GetStorageInfo(expr); |
307 | // storage |
308 | std::vector<int64_t> storage_ids; |
309 | for (auto v : storage_info->storage_ids) { |
310 | storage_ids.push_back(v); |
311 | } |
312 | node->attrs_["storage_id" ] = std::move(storage_ids); |
313 | // type |
314 | std::vector<int64_t> device_types; |
315 | for (const auto& virtual_device : storage_info->virtual_devices) { |
316 | // TODO(mbs): Keeping only the device type. |
317 | ICHECK_GT(virtual_device->device_type(), 0); |
318 | device_types.push_back(virtual_device->device_type()); |
319 | } |
320 | size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); |
321 | if (num_unknown_devices != 0 && num_unknown_devices != device_types.size()) { |
322 | LOG(FATAL) << "The graph contains not annotated nodes for " |
323 | << "heterogeneous execution. All nodes must be " |
324 | << "annotated." ; |
325 | } |
326 | if (num_unknown_devices == 0) { |
327 | node->attrs_["device_index" ] = device_types; |
328 | } |
329 | // storage scope |
330 | std::vector<std::string> storage_scope; |
331 | for (const auto& virtual_device : storage_info->virtual_devices) { |
332 | storage_scope.push_back(std::string(virtual_device->memory_scope)); |
333 | } |
334 | node->attrs_["storage_scope" ] = std::move(storage_scope); |
335 | auto node_id = nodes_.size(); |
336 | nodes_.push_back(node); |
337 | // Tuple return value, flatten as tuple |
338 | if (const auto* tuple_type = checked_type.as<TupleTypeNode>()) { |
339 | std::vector<GraphNodeRef> ret; |
340 | ShapeVector shape; |
341 | std::vector<std::string> dtype; |
342 | for (size_t i = 0; i < tuple_type->fields.size(); ++i) { |
343 | if (const auto* typ = tuple_type->fields[i].as<TensorTypeNode>()) { |
344 | ret.push_back(GraphNodeRef(node_id, i)); |
345 | shape.emplace_back(ShapeToJSON(typ->shape)); |
346 | dtype.emplace_back(DType2String(typ->dtype)); |
347 | } else { |
348 | LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported" ; |
349 | } |
350 | } |
351 | ICHECK_EQ(node->Type(), kGraphOpNode); |
352 | auto op_nd = std::dynamic_pointer_cast<GraphOpNode>(node); |
353 | op_nd->attrs_["shape" ] = shape; |
354 | op_nd->attrs_["dtype" ] = dtype; |
355 | op_nd->num_outputs_ = tuple_type->fields.size(); |
356 | return ret; |
357 | } |
358 | // Normal tensor return type |
359 | if (const auto* tensor_type = checked_type.as<TensorTypeNode>()) { |
360 | ShapeVector shape; |
361 | std::vector<std::string> dtype; |
362 | shape.emplace_back(ShapeToJSON(tensor_type->shape)); |
363 | dtype.emplace_back(DType2String(tensor_type->dtype)); |
364 | node->attrs_["shape" ] = shape; |
365 | node->attrs_["dtype" ] = dtype; |
366 | } else { |
367 | LOG(FATAL) << "type " << checked_type->GetTypeKey() << " not supported" ; |
368 | } |
369 | return {GraphNodeRef(node_id, 0)}; |
370 | } |
371 | |
372 | std::vector<GraphNodeRef> VisitExpr_(const VarNode* op) override { |
373 | Expr expr = GetRef<Expr>(op); |
374 | return var_map_[expr.get()]; |
375 | } |
376 | |
377 | std::vector<GraphNodeRef> VisitExpr_(const ConstantNode* op) override { |
378 | Expr expr = GetRef<Expr>(op); |
379 | size_t index = params_.size(); |
380 | std::string name = "p" + std::to_string(index); |
381 | auto node = GraphInputNode::make_node_ptr(name, GraphAttrs()); |
382 | auto to_return = AddNode(node, expr); |
383 | CHECK_EQ(to_return.size(), 1) << "Expected exactly 1 parameter node created" ; |
384 | param_storage_ids_[name] = GetStorageInfo(expr)->storage_ids[0]; |
385 | params_[name] = op->data; |
386 | return to_return; |
387 | } |
388 | |
389 | std::vector<GraphNodeRef> VisitExpr_(const TupleNode* op) override { |
390 | std::vector<GraphNodeRef> fields; |
391 | for (auto field : op->fields) { |
392 | auto ref_vec = VisitExpr(field); |
393 | for (auto ref : ref_vec) { |
394 | fields.push_back(ref); |
395 | } |
396 | } |
397 | return fields; |
398 | } |
399 | |
400 | bool ShareSameStorage(const Expr& lhs, const Expr& rhs) { |
401 | StorageInfo lit = GetStorageInfo(lhs); |
402 | StorageInfo rit = GetStorageInfo(rhs); |
403 | int64_t lhs_storage_id = lit->storage_ids[0]; |
404 | int64_t rhs_storage_id = rit->storage_ids[0]; |
405 | return lhs_storage_id == rhs_storage_id; |
406 | } |
407 | |
408 | std::vector<GraphNodeRef> GraphAddCallNode(const CallNode* call_node, GraphAttrs attrs) { |
409 | Call call = GetRef<Call>(call_node); |
410 | std::vector<GraphNodeRef> inputs; |
411 | std::string func_name; |
412 | |
413 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); |
414 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
415 | if (device_copy_props.body.defined()) { |
416 | // The graph executor expects to see a normal call to the undefined @__copy function. |
417 | // The source and destination device annotations are no longer needed since they have |
418 | // been captured in the StorageInfos for both input and output. |
419 | // TODO(mbs): device_copy cleanup |
420 | func_name = "__copy" ; |
421 | for (const auto& n : VisitExpr(device_copy_props.body)) { |
422 | inputs.push_back(n); |
423 | } |
424 | } else if (call_lowered_props.lowered_func.defined()) { |
425 | // Extract function and arguments from the call_lowered op |
426 | |
427 | func_name = call_lowered_props.lowered_func->name_hint; |
428 | |
429 | for (const Expr& arg : call_lowered_props.arguments) { |
430 | for (auto n : VisitExpr(arg)) { |
431 | inputs.push_back(n); |
432 | } |
433 | } |
434 | if (call_lowered_props.attrs.metadata.count("relay_attrs" )) { |
435 | if (auto relay_attrs = |
436 | call_lowered_props.attrs.metadata["relay_attrs" ].as<DictAttrsNode>()) { |
437 | for (auto p : relay_attrs->dict) { |
438 | if (p.second.as<StringObj>()) { |
439 | attrs[p.first] = std::string(Downcast<String>(p.second)); |
440 | } |
441 | } |
442 | } |
443 | } |
444 | // TODO(mbs): "reshape" cleanup. |
445 | if (IsReshapeOnly(call_lowered_props) && |
446 | ShareSameStorage(GetRef<Expr>(call_node), call_lowered_props.arguments[0])) { |
447 | auto node = GraphOpNode::make_node_ptr("reshape_nop" , GraphAttrs(), "__nop" , inputs, attrs); |
448 | return AddNode(node, call); |
449 | } |
450 | } else if (!call_node->attrs.defined()) { // Call is an extern function |
451 | const auto* func = call_node->op.as<GlobalVarNode>(); |
452 | ICHECK(func) << "Expected the operator to be a global var, but got " |
453 | << call_node->op->GetTypeKey(); // getting a relay fn here, not sure why. |
454 | func_name = func->name_hint; |
455 | |
456 | for (const Expr& arg : call_node->args) { |
457 | for (auto n : VisitExpr(arg)) { |
458 | inputs.push_back(n); |
459 | } |
460 | } |
461 | } else { |
462 | LOG(FATAL) << "Non-primitive-call nodes should have been transformed away.\n" |
463 | << "The graph executor code generator expects all calls to be call_lowered, " |
464 | << "but found: " << std::endl |
465 | << PrettyPrint(call); |
466 | } |
467 | |
468 | // Compute the operator name, because we used the get unique name when generating the kernel. |
469 | auto op_name = name_supply_->FreshName(func_name); |
470 | auto node = GraphOpNode::make_node_ptr(op_name, GraphAttrs(), func_name, inputs, attrs); |
471 | return AddNode(node, call); |
472 | } |
473 | |
474 | std::vector<GraphNodeRef> VisitExpr_(const CallNode* call_node) override { |
475 | relay::Call call = GetRef<Call>(call_node); |
476 | OnDeviceProps props = GetOnDeviceProps(call_node); |
477 | if (props.body.defined()) { |
478 | // See through "on_device" calls. |
479 | return VisitExpr(props.body); |
480 | } |
481 | return GraphAddCallNode(call_node, GraphAttrs()); |
482 | } |
483 | |
484 | std::vector<GraphNodeRef> VisitExpr_(const LetNode* op) override { |
485 | ICHECK_EQ(var_map_.count(op->var.get()), 0); |
486 | var_map_[op->var.get()] = VisitExpr(op->value); |
487 | return VisitExpr(op->body); |
488 | } |
489 | std::vector<GraphNodeRef> VisitExpr_(const TupleGetItemNode* op) override { |
490 | auto vtuple = VisitExpr(op->tuple); |
491 | return {vtuple[op->index]}; |
492 | } |
493 | |
494 | std::vector<GraphNodeRef> VisitExpr_(const OpNode* op) override { |
495 | LOG(FATAL) << "All OpNodes should have been expanded" ; |
496 | } |
497 | std::vector<GraphNodeRef> VisitExpr_(const GlobalVarNode* op) override { |
498 | LOG(FATAL) << "All GlobalVarNodes should be removed before graph executor's Codegen is called" ; |
499 | } |
500 | std::vector<GraphNodeRef> VisitExpr_(const IfNode* op) override { |
501 | LOG(FATAL) << "Graph executor does not support control flow (found IfNode)" ; |
502 | } |
503 | std::vector<GraphNodeRef> VisitExpr_(const FunctionNode* op) override { |
504 | ICHECK(op->GetAttr<String>(attr::kCompiler).defined()) |
505 | << "Only functions supported by custom codegen" ; |
506 | return {}; |
507 | } |
508 | std::vector<GraphNodeRef> VisitExpr_(const RefCreateNode* op) override { |
509 | LOG(FATAL) << "Graph executor does not support references (found RefCreateNode)" ; |
510 | } |
511 | std::vector<GraphNodeRef> VisitExpr_(const RefReadNode* op) override { |
512 | LOG(FATAL) << "Graph executor does not support references (found RefReadNode)" ; |
513 | } |
514 | std::vector<GraphNodeRef> VisitExpr_(const RefWriteNode* op) override { |
515 | LOG(FATAL) << "Graph executor does not support references (found RefWriteNode)" ; |
516 | } |
517 | std::vector<GraphNodeRef> VisitExpr_(const ConstructorNode* op) override { |
518 | LOG(FATAL) << "Graph executor does not support ADTs (found ConstructorNode)" ; |
519 | } |
520 | std::vector<GraphNodeRef> VisitExpr_(const MatchNode* op) override { |
521 | LOG(FATAL) << "Graph executor does not support matching (found MatchNode)" ; |
522 | } |
523 | /*! |
524 | * \brief Generate Graph JSON |
525 | * |
526 | * \param writer json writer |
527 | */ |
528 | void GetJSON(dmlc::JSONWriter* writer) { |
529 | std::vector<size_t> arg_nodes; |
530 | for (size_t i = 0; i < nodes_.size(); ++i) { |
531 | auto node = nodes_[i]; |
532 | if (node->Type() == kGraphInputNode) { |
533 | arg_nodes.push_back(i); |
534 | } |
535 | } |
536 | size_t num_entry = 0; |
537 | ShapeVector shapes; |
538 | std::vector<size_t> storage_ids; |
539 | std::vector<std::string> storage_scopes; |
540 | std::vector<size_t> device_types; |
541 | std::vector<std::string> dltypes; |
542 | std::vector<size_t> node_row_ptr{0}; |
543 | for (auto node : nodes_) { |
544 | const auto& shape_vec = dmlc::get<ShapeVector>(node->attrs_["shape" ]); |
545 | const auto& storage_id = dmlc::get<std::vector<int64_t>>(node->attrs_["storage_id" ]); |
546 | const auto& storage_scope = |
547 | dmlc::get<std::vector<std::string>>(node->attrs_["storage_scope" ]); |
548 | const auto& dtype_vec = dmlc::get<std::vector<std::string>>(node->attrs_["dtype" ]); |
549 | |
550 | ICHECK_EQ(node->num_outputs_, shape_vec.size()); |
551 | num_entry += node->num_outputs_; |
552 | |
553 | shapes.insert(shapes.end(), shape_vec.begin(), shape_vec.end()); |
554 | dltypes.insert(dltypes.end(), dtype_vec.begin(), dtype_vec.end()); |
555 | storage_ids.insert(storage_ids.end(), storage_id.begin(), storage_id.end()); |
556 | storage_scopes.insert(storage_scopes.end(), storage_scope.begin(), storage_scope.end()); |
557 | if (node->attrs_.count("device_index" )) { |
558 | const auto& dev_types = dmlc::get<std::vector<int64_t>>(node->attrs_["device_index" ]); |
559 | device_types.insert(device_types.end(), dev_types.begin(), dev_types.end()); |
560 | } |
561 | node_row_ptr.push_back(num_entry); |
562 | } |
563 | |
564 | // verification if storage_scope contains any non global memory scope |
565 | // in other case it's better not to write scopes to the JSON at all |
566 | bool global_only_scope = true; |
567 | for (const auto& ss : storage_scopes) { |
568 | if (!(ss.empty() || ss == "global" )) { |
569 | global_only_scope = false; |
570 | } |
571 | } |
572 | if (global_only_scope) { |
573 | storage_scopes.clear(); |
574 | } |
575 | writer->BeginObject(); |
576 | writer->WriteObjectKeyValue("nodes" , nodes_); |
577 | writer->WriteObjectKeyValue("arg_nodes" , arg_nodes); |
578 | writer->WriteObjectKeyValue("heads" , heads_); |
579 | std::unordered_map<std::string, std::vector<dmlc::any>> attrs; |
580 | attrs["shape" ].emplace_back(std::string("list_shape" )); |
581 | attrs["shape" ].emplace_back(shapes); |
582 | attrs["storage_id" ].emplace_back(std::string("list_int" )); |
583 | attrs["storage_id" ].emplace_back(storage_ids); |
584 | if (device_types.size()) { |
585 | attrs["device_index" ].emplace_back(std::string("list_int" )); |
586 | attrs["device_index" ].emplace_back(device_types); |
587 | } |
588 | if (storage_scopes.size()) { |
589 | attrs["storage_scope" ].emplace_back(std::string("list_str" )); |
590 | attrs["storage_scope" ].emplace_back(storage_scopes); |
591 | } |
592 | attrs["dltype" ].emplace_back(std::string("list_str" )); |
593 | attrs["dltype" ].emplace_back(dltypes); |
594 | writer->WriteObjectKeyValue("attrs" , attrs); |
595 | writer->WriteObjectKeyValue("node_row_ptr" , node_row_ptr); |
596 | writer->EndObject(); |
597 | } |
598 | |
599 | protected: |
600 | /*! \brief nodes */ |
601 | std::vector<GraphObjectPtr> nodes_; |
602 | /*! \brief output of graph */ |
603 | std::vector<GraphNodeRef> heads_; |
604 | /*! \brief mod */ |
605 | runtime::Module* mod_; |
606 | /*! \brief variable map */ |
607 | std::unordered_map<const Object*, std::vector<GraphNodeRef>> var_map_; |
608 | /*! \brief Available targets */ |
609 | CompilationConfig config_; |
610 | /*! |
611 | * \brief parameters (i.e. ConstantNodes found in the graph). |
612 | * These are take as inputs to the GraphExecutor. |
613 | * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be |
614 | * used to lookup the parameter. |
615 | */ |
616 | std::unordered_map<std::string, runtime::NDArray> params_; |
617 | std::unordered_map<std::string, int64_t> param_storage_ids_; |
618 | /*! \brief plan memory of device result */ |
619 | StaticMemoryPlan memory_plan_; |
620 | /*! \brief the module name we use to mangle the function names */ |
621 | String mod_name_; |
622 | /*! \brief function metadata */ |
623 | Map<String, FunctionInfo> function_metadata_; |
624 | /*! \brief NameSupply */ |
625 | NameSupply name_supply_ = NameSupply("" ); |
626 | }; |
627 | |
628 | class GraphExecutorCodegenModule : public runtime::ModuleNode { |
629 | public: |
630 | GraphExecutorCodegenModule() {} |
631 | virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) { |
632 | if (name == "init" ) { |
633 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
634 | ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " |
635 | << "runtime::Module mod and Array<Target> targets" ; |
636 | void* mod = args[0]; |
637 | Array<Target> targets = args[1]; |
638 | codegen_ = std::make_shared<GraphExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod), |
639 | std::move(targets)); |
640 | }); |
641 | } else if (name == "codegen" ) { |
642 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
643 | IRModule mod = args[0]; |
644 | Function func = args[1]; |
645 | String mod_name = args[2]; |
646 | this->output_ = this->codegen_->Codegen(mod, func, mod_name); |
647 | }); |
648 | } else if (name == "get_graph_json" ) { |
649 | return PackedFunc( |
650 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.graph_json; }); |
651 | } else if (name == "list_params_name" ) { |
652 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
653 | Array<runtime::String> ret; |
654 | for (const auto& kv : this->output_.params) { |
655 | ret.push_back(kv.first); |
656 | } |
657 | *rv = ret; |
658 | }); |
659 | } else if (name == "get_param_by_name" ) { |
660 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
661 | String key = args[0]; |
662 | auto it = this->output_.params.find(key); |
663 | CHECK(it != this->output_.params.end()) << "no such parameter " << key; |
664 | *rv = (*it).second; |
665 | }); |
666 | } else if (name == "get_irmodule" ) { |
667 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
668 | *rv = this->output_.lowered_funcs; |
669 | }); |
670 | } else if (name == "get_external_modules" ) { |
671 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
672 | *rv = this->output_.external_mods; |
673 | }); |
674 | } else if (name == "get_devices" ) { |
675 | return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = Array<String>(); }); |
676 | } else if (name == "get_executor_codegen_metadata" ) { |
677 | return PackedFunc( |
678 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; }); |
679 | } else if (name == "get_function_metadata" ) { |
680 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
681 | *rv = this->output_.function_metadata; |
682 | }); |
683 | } else { |
684 | return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); |
685 | } |
686 | } |
687 | |
688 | const char* type_key() const final { return "RelayGraphExecutorCodegenModule" ; } |
689 | |
690 | private: |
691 | std::shared_ptr<GraphExecutorCodegen> codegen_; |
692 | LoweredOutput output_; |
693 | }; |
694 | |
695 | runtime::Module CreateGraphCodegenMod() { |
696 | auto ptr = make_object<GraphExecutorCodegenModule>(); |
697 | return runtime::Module(ptr); |
698 | } |
699 | |
700 | TVM_REGISTER_GLOBAL("relay.build_module._GraphExecutorCodegen" ) |
701 | .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateGraphCodegenMod(); }); |
702 | |
703 | } // namespace backend |
704 | } // namespace relay |
705 | } // namespace tvm |
706 | |
707 | namespace dmlc { |
708 | namespace json { |
709 | // JSON utils |
710 | template <typename T> |
711 | inline bool SameType(const dmlc::any& data) { |
712 | return std::type_index(data.type()) == std::type_index(typeid(T)); |
713 | } |
714 | |
715 | template <> |
716 | struct Handler<std::shared_ptr<tvm::relay::backend::GraphNode>> { |
717 | inline static void Write(dmlc::JSONWriter* writer, |
718 | const std::shared_ptr<tvm::relay::backend::GraphNode>& data) { |
719 | data->Save(writer); |
720 | } |
721 | inline static void Read(dmlc::JSONReader* reader, |
722 | std::shared_ptr<tvm::relay::backend::GraphNode>* data) { |
723 | LOG(FATAL) << "Not implemented." ; |
724 | } |
725 | }; |
726 | template <> |
727 | struct Handler<std::unordered_map<std::string, dmlc::any>> { |
728 | inline static void Write(dmlc::JSONWriter* writer, |
729 | const std::unordered_map<std::string, dmlc::any>& data) { |
730 | writer->BeginObject(); |
731 | for (const auto& kv : data) { |
732 | auto k = kv.first; |
733 | const dmlc::any& v = kv.second; |
734 | if (SameType<std::string>(v)) { |
735 | writer->WriteObjectKeyValue(k, dmlc::get<std::string>(v)); |
736 | } else if (SameType<int>(v)) { |
737 | writer->WriteObjectKeyValue(k, dmlc::get<int>(v)); |
738 | } else if (SameType<std::vector<size_t>>(v)) { |
739 | writer->WriteObjectKeyValue(k, dmlc::get<std::vector<size_t>>(v)); |
740 | } else if (SameType<std::vector<std::vector<int64_t>>>(v)) { |
741 | writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::vector<int64_t>>>(v)); |
742 | } else if (SameType<std::vector<std::string>>(v)) { |
743 | writer->WriteObjectKeyValue(k, dmlc::get<std::vector<std::string>>(v)); |
744 | } else if (SameType<std::vector<dmlc::any>>(v)) { |
745 | writer->WriteObjectKeyValue(k, dmlc::get<std::vector<dmlc::any>>(v)); |
746 | } else { |
747 | LOG(FATAL) << "Not supported" ; |
748 | } |
749 | } |
750 | writer->EndObject(); |
751 | } |
752 | inline static void Read(dmlc::JSONReader* reader, |
753 | std::unordered_map<std::string, dmlc::any>* data) { |
754 | LOG(FATAL) << "Not implemented." ; |
755 | } |
756 | }; |
757 | |
758 | template <> |
759 | struct Handler<std::vector<dmlc::any>> { |
760 | inline static void Write(dmlc::JSONWriter* writer, const std::vector<dmlc::any>& data) { |
761 | writer->BeginArray(); |
762 | for (const auto& v : data) { |
763 | if (SameType<std::string>(v)) { |
764 | writer->WriteArrayItem(dmlc::get<std::string>(v)); |
765 | } else if (SameType<int>(v)) { |
766 | writer->WriteArrayItem(dmlc::get<int>(v)); |
767 | } else if (SameType<std::vector<size_t>>(v)) { |
768 | writer->WriteArrayItem(dmlc::get<std::vector<size_t>>(v)); |
769 | } else if (SameType<std::vector<std::vector<int64_t>>>(v)) { |
770 | writer->WriteArrayItem(dmlc::get<std::vector<std::vector<int64_t>>>(v)); |
771 | } else if (SameType<std::vector<std::string>>(v)) { |
772 | writer->WriteArrayItem(dmlc::get<std::vector<std::string>>(v)); |
773 | } else { |
774 | LOG(FATAL) << "Not supported" ; |
775 | } |
776 | } |
777 | writer->EndArray(); |
778 | } |
779 | inline static void Read(dmlc::JSONReader* reader, std::vector<dmlc::any>* data) { |
780 | LOG(FATAL) << "Not implemented." ; |
781 | } |
782 | }; |
783 | } // namespace json |
784 | } // namespace dmlc |
785 | |