1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/vm/compiler.cc
22 * \brief A compiler from relay::Module to the VM byte code.
23 */
24
25#include "compiler.h"
26
27#include <tvm/driver/driver_api.h>
28#include <tvm/relay/analysis.h>
29#include <tvm/relay/attrs/device_copy.h>
30#include <tvm/relay/attrs/memory.h>
31#include <tvm/relay/error.h>
32#include <tvm/relay/expr_functor.h>
33#include <tvm/relay/interpreter.h>
34#include <tvm/relay/parser.h>
35#include <tvm/relay/qnn/transform.h>
36#include <tvm/relay/runtime.h>
37#include <tvm/relay/transform.h>
38#include <tvm/runtime/logging.h>
39#include <tvm/runtime/vm/vm.h>
40#include <tvm/te/operation.h>
41
42#include <iostream>
43#include <map>
44#include <memory>
45#include <string>
46#include <tuple>
47#include <vector>
48
49#include "../../../driver/internal_driver_api.h"
50#include "../../../target/metadata_module.h"
51#include "../../../target/source/codegen_source_base.h"
52#include "../../op/annotation/annotation.h"
53#include "../../op/memory/device_copy.h"
54#include "../../op/op_common.h"
55#include "../../transforms/device_aware_visitors.h"
56#include "../../transforms/pass_utils.h"
57#include "../utils.h"
58#include "./compiler.h"
59
60namespace tvm {
61namespace relay {
62
63namespace transform {
64
65Pass LambdaLift();
66Pass LabelOps();
67
68Pass MemoryPlan() {
69 auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan");
70 ICHECK(f != nullptr) << "unable to load the memory planning pass";
71 return (*f)();
72}
73
74Pass LiftConstants() {
75 auto f = tvm::runtime::Registry::Get("relay.transform.LiftConstants");
76 ICHECK(f != nullptr) << "unable to load the constant lifting pass";
77 return (*f)();
78}
79
80} // namespace transform
81
82namespace vm {
83
84using namespace tvm::runtime;
85using namespace tvm::runtime::vm;
86using namespace relay::transform;
87
88/*! \brief The host device is always stored at device index 0. */
89constexpr Index kHostDeviceIndex = 0;
90
91// (@jroesch): VM passes, eventually declare as passes.
92bool IsClosure(const Function& func);
93
94// Represent a runtime object that's going to be matched by pattern match expressions
95struct MatchValue {
96 virtual ~MatchValue() {}
97};
98using MatchValuePtr = std::shared_ptr<MatchValue>;
99
100// A runtime object that resides in a register
101struct RegisterValue : MatchValue {
102 // The register num
103 RegName register_num;
104
105 explicit RegisterValue(RegName reg) : register_num(reg) {}
106
107 ~RegisterValue() {}
108};
109
110// The value is a field of another runtime object
111struct AccessField : MatchValue {
112 MatchValuePtr parent;
113 // Field index
114 size_t index;
115 // Runtime register num after compiling the access field path
116 RegName reg{-1};
117
118 AccessField(MatchValuePtr parent, size_t index) : parent(parent), index(index) {}
119
120 ~AccessField() {}
121};
122
123/*!
124 * \brief Condition in a decision tree
125 */
126struct ConditionNode {
127 virtual ~ConditionNode() {}
128};
129
130using ConditionObjectPtr = std::shared_ptr<ConditionNode>;
131
132/*!
133 * \brief A var binding condition
134 */
135struct VarBinding : ConditionNode {
136 Var var;
137 MatchValuePtr val;
138
139 VarBinding(Var var, MatchValuePtr val) : var(var), val(val) {}
140
141 ~VarBinding() {}
142};
143
144/*!
145 * \brief Compare the tag of the object
146 */
147struct TagCompare : ConditionNode {
148 /*! \brief The object to be examined */
149 MatchValuePtr obj;
150
151 /*! \brief The expected tag */
152 int target_tag;
153
154 TagCompare(MatchValuePtr obj, size_t target) : obj(obj), target_tag(target) {}
155
156 ~TagCompare() {}
157};
158
159using TreeObjectPtr = typename relay::TreeNode<ConditionObjectPtr>::pointer;
160using TreeLeafNode = relay::TreeLeafNode<ConditionObjectPtr>;
161using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionObjectPtr>;
162using TreeBranchNode = relay::TreeBranchNode<ConditionObjectPtr>;
163
164TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, Pattern pattern,
165 TreeObjectPtr then_branch, TreeObjectPtr else_branch) {
166 if (pattern.as<PatternWildcardNode>()) {
167 // We ignore wildcard binding since it's not producing new vars
168 return then_branch;
169 } else if (const auto* pvn = pattern.as<PatternVarNode>()) {
170 auto cond = std::make_shared<VarBinding>(pvn->var, data);
171 return TreeBranchNode::Make(cond, then_branch, else_branch);
172 } else if (const auto* pcn = pattern.as<PatternConstructorNode>()) {
173 auto tag = pcn->constructor->tag;
174
175 size_t field_index = 0;
176 for (auto& p : pcn->patterns) {
177 auto d = std::make_shared<AccessField>(data, field_index);
178 then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
179 field_index++;
180 }
181 auto cond = std::make_shared<TagCompare>(data, tag);
182 return TreeBranchNode::Make(cond, then_branch, else_branch);
183 } else {
184 const auto* pt = pattern.as<PatternTupleNode>();
185 ICHECK(pt) << "unhandled case: " << AsText(pattern, false);
186 size_t field_index = 0;
187 for (auto& p : pt->patterns) {
188 auto d = std::make_shared<AccessField>(data, field_index++);
189 then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch);
190 }
191 return then_branch;
192 }
193}
194
195TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, Clause clause,
196 TreeObjectPtr else_branch) {
197 return BuildDecisionTreeFromPattern(data, clause->lhs, TreeLeafNode::Make(clause->rhs),
198 else_branch);
199}
200
201TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) {
202 // When nothing matches, the VM throws fatal error
203 TreeObjectPtr else_branch = TreeLeafFatalNode::Make();
204 // Start from the last clause
205 for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) {
206 else_branch = BuildDecisionTreeFromClause(data, *it, else_branch);
207 }
208 return else_branch;
209}
210
211std::vector<int64_t> ToAllocTensorShape(NDArray shape) {
212 std::vector<int64_t> raw_shape;
213 if (shape->ndim == 0) {
214 return raw_shape;
215 }
216 ICHECK_EQ(shape->ndim, 1u);
217 ICHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got "
218 << DLDataType2String(shape->dtype);
219 ICHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32)
220 << "The dtype of constant shape must be int32 or int64, but got"
221 << DLDataType2String(shape->dtype);
222
223 if (shape->dtype.bits == 64) {
224 int64_t* int_ptr = reinterpret_cast<int64_t*>(shape->data);
225 for (auto i = 0; i < shape->shape[0]; i++) {
226 raw_shape.push_back(int_ptr[i]);
227 }
228 } else { // int32
229 int32_t* int_ptr = reinterpret_cast<int32_t*>(shape->data);
230 for (auto i = 0; i < shape->shape[0]; i++) {
231 raw_shape.push_back(static_cast<int64_t>(int_ptr[i]));
232 }
233 }
234 return raw_shape;
235}
236
237class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> {
238 public:
239 VMFunctionCompiler(VMCompilerContext* context, VirtualDevice host_virtual_device)
240 : DeviceAwareExprFunctor(context->module),
241 last_register_(0),
242 registers_num_(0),
243 context_(context),
244 host_virtual_device_(std::move(host_virtual_device)) {}
245
246 VMFunction Compile(const GlobalVar& var, const Function& func) {
247 VLOG(1) << "Compiling:" << std::endl << PrettyPrint(func);
248 std::vector<Index> param_device_indexes;
249 if (IsClosure(func)) {
250 // After lifting we'll have functions of the form:
251 // fn(closure args) { fn(lifted function args) { body } }
252 // But we want the closure's function to be:
253 // fn(closure args, lifter function args) { body }
254 // Do that flattening on-the-fly here.
255 Function inner_func = Downcast<Function>(func->body);
256 std::vector<Var> params;
257 params.reserve(func->params.size() + inner_func->params.size());
258 param_device_indexes.reserve(func->params.size() + inner_func->params.size());
259 for (size_t i = 0; i < func->params.size(); ++i) {
260 params.emplace_back(func->params[i]);
261 param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
262 }
263 for (size_t i = 0; i < inner_func->params.size(); ++i) {
264 params.emplace_back(inner_func->params[i]);
265
266 param_device_indexes.push_back(GetDeviceIndex(inner_func->params[i]->virtual_device()));
267 }
268 std::vector<TypeVar> type_params;
269 type_params.reserve(func->type_params.size() + inner_func->type_params.size());
270 for (const auto& tyvar : func->type_params) {
271 type_params.push_back(tyvar);
272 }
273 for (const auto& tyvar : inner_func->type_params) {
274 type_params.push_back(tyvar);
275 }
276 Function flattened_func = Function(params, inner_func->body, inner_func->ret_type,
277 type_params, func->attrs, func->span);
278 flattened_func->virtual_device_ = inner_func->virtual_device();
279 VisitExpr(flattened_func);
280 } else {
281 param_device_indexes.reserve(func->params.size());
282 for (size_t i = 0; i < func->params.size(); ++i) {
283 param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device()));
284 }
285 VisitExpr(func);
286 }
287 return VMFunction(var->name_hint, params_, instructions_, registers_num_,
288 std::move(param_device_indexes));
289 }
290
291 /*! \brief Attrs objects for each op. */
292 std::map<Index, Map<String, ObjectRef>> op_attrs;
293
294 /*! \brief Attrs objects for each callsite. */
295 std::map<Index, Map<String, ObjectRef>> callsite_attrs;
296
297 protected:
298 size_t NewRegister() { return registers_num_++; }
299
300 inline void Emit(const Instruction& instr) {
301 size_t instruction_index = instructions_.size();
302 VLOG(2) << "instruction[" << instruction_index << "] = " << instr;
303 ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op;
304 switch (instr.op) {
305 case Opcode::AllocADT:
306 case Opcode::AllocTensor:
307 case Opcode::AllocTensorReg:
308 case Opcode::GetField:
309 case Opcode::GetTag:
310 case Opcode::LoadConst:
311 case Opcode::LoadConsti:
312 case Opcode::Invoke:
313 case Opcode::AllocClosure:
314 case Opcode::AllocStorage:
315 case Opcode::ShapeOf:
316 case Opcode::ReshapeTensor:
317 case Opcode::Move:
318 case Opcode::InvokeClosure:
319 case Opcode::DeviceCopy:
320 last_register_ = instr.dst;
321 break;
322 case Opcode::InvokePacked:
323 case Opcode::If:
324 case Opcode::Ret:
325 case Opcode::Goto:
326 case Opcode::Fatal:
327 case Opcode::KillRegister:
328 break;
329 }
330 instructions_.push_back(instr);
331 }
332
333 /*!
334 * \brief Returns the "device index" to represent \p virtual_device for primitives
335 * in emitted code. Note that the host device is always at index 0.
336 */
337 Index GetDeviceIndex(const VirtualDevice& virtual_device) {
338 ICHECK(!virtual_device->IsFullyUnconstrained());
339 auto itr = std::find(context_->virtual_devices_.begin(), context_->virtual_devices_.end(),
340 virtual_device);
341 if (itr != context_->virtual_devices_.end()) {
342 return std::distance(context_->virtual_devices_.begin(), itr);
343 }
344
345 ICHECK_GT(context_->virtual_devices_.size(), 0);
346 ICHECK_NE(virtual_device, host_virtual_device_); // the host scope is always at index 0
347
348 if (virtual_device->device_type() == context_->virtual_devices_.front()->device_type()) {
349 // It's ok if we see distinct scopes which share the host device type. This is because
350 // we allow the VirtualDevice for the host to be different from the VirtualDevice for
351 // primitive operations which both happen to be on the same device (typically CPU).
352 return 0;
353 }
354
355 // However, otherwise we allow at most one VirtualDevice per device type.
356 // TODO(mbs): This will eventually need to account for memory scopes somehow so device_copy
357 // instructions can do the right thing.
358 itr = std::find_if(context_->virtual_devices_.begin() + 1, context_->virtual_devices_.end(),
359 [&virtual_device](const VirtualDevice& existing_virtual_device) {
360 return existing_virtual_device->device_type() ==
361 virtual_device->device_type();
362 });
363 CHECK(itr == context_->virtual_devices_.end())
364 << "The VM does not currently support using more than one device with the same device type "
365 "for primitives, however the program is using the distinct scopes "
366 << virtual_device << " and " << *itr << " of device type " << virtual_device->device_type();
367
368 ICHECK(virtual_device != host_virtual_device_);
369 Index index = context_->virtual_devices_.size();
370 VLOG(2) << "virtual_device[" << index << "] = " << virtual_device;
371 context_->virtual_devices_.push_back(virtual_device);
372
373 return index;
374 }
375
376 using DeviceAwareExprFunctor<void(const Expr&)>::VisitExpr_;
377
378 void VisitExpr_(const ConstantNode* const_node) final {
379 // Check the shape is valid
380 NDArray data = const_node->data;
381 size_t const_index = context_->constants.size();
382 auto con = GetRef<Constant>(const_node);
383 Index device_index = GetDeviceIndex(GetVirtualDevice(con));
384 VLOG(2) << "constant[" << const_index << "] on device[" << device_index << "]";
385 context_->const_device_indexes.push_back(device_index);
386 context_->constants.push_back(const_node->data);
387 Emit(Instruction::LoadConst(const_index, NewRegister()));
388 }
389
390 void VisitExpr_(const VarNode* var_node) final {
391 auto var = GetRef<Var>(var_node);
392 auto reg_it = this->var_register_map_.find(var);
393 ICHECK(reg_it != this->var_register_map_.end());
394 last_register_ = reg_it->second;
395 }
396
397 void VisitExpr_(const TupleNode* tuple_node) final {
398 auto tuple = GetRef<Tuple>(tuple_node);
399 std::vector<Index> fields_registers;
400
401 for (auto& field : tuple->fields) {
402 this->VisitExpr(field);
403 fields_registers.push_back(last_register_);
404 }
405
406 // TODO(@jroesch): use correct tag
407 Emit(Instruction::AllocADT(0, tuple->fields.size(), fields_registers, NewRegister()));
408 }
409
410 void VisitExpr_(const MatchNode* match_node) final {
411 auto match = GetRef<Match>(match_node);
412
413 this->VisitExpr(match->data);
414 CompileMatch(match);
415 }
416
417 void PreVisitLetBinding_(const Var& var, const Expr& value) final {
418 ICHECK(!value.as<FunctionNode>())
419 << "unexpected function:" << std::endl
420 << PrettyPrint(value) << std::endl
421 << "bound to var '" << var->name_hint() << "'. Did you set opt_level = 2?";
422 VisitExpr(value);
423 var_register_map_.emplace(var, this->last_register_);
424 }
425
426 void VisitExpr_(const TupleGetItemNode* get_node) final {
427 auto get = GetRef<TupleGetItem>(get_node);
428 this->VisitExpr(get->tuple);
429 auto tuple_register = last_register_;
430 Emit(Instruction::GetField(tuple_register, get->index, NewRegister()));
431 }
432
433 void VisitExpr_(const GlobalVarNode* gvar) final {
434 auto var = GetRef<GlobalVar>(gvar);
435 auto func = context_->module->Lookup(var);
436 auto it = context_->global_map.find(var);
437 ICHECK(it != context_->global_map.end()) << PrettyPrint(var);
438 // Allocate closure with zero free vars
439 Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister()));
440 }
441
442 void VisitExpr_(const IfNode* if_node) final {
443 this->VisitExpr(if_node->cond);
444
445 size_t test_register = last_register_;
446
447 this->Emit(Instruction::LoadConsti(1, NewRegister()));
448 auto after_cond = instructions_.size();
449 auto target_register = last_register_;
450 this->Emit(Instruction::If(test_register, target_register, 0, 0));
451 this->VisitExpr(if_node->true_branch);
452
453 // It saves the result of If-Else expression.
454 auto merge_register = NewRegister();
455 Emit(Instruction::Move(last_register_, merge_register));
456 Emit(Instruction::Goto(0));
457
458 // Finally store how many instructions there are in the
459 // true branch.
460 auto after_true = this->instructions_.size();
461
462 this->VisitExpr(if_node->false_branch);
463
464 size_t false_register = last_register_;
465
466 // In else-branch, override the then-branch register
467 Emit(Instruction::Move(false_register, merge_register));
468 // Compute the total number of instructions
469 // after generating false.
470 auto after_false = this->instructions_.size();
471
472 // Now we will compute the jump targets in order
473 // to properly patch the instruction with the
474 // the requiste targets.
475
476 // After we emit the true body, and false body,
477 // we patch up the if instruction, and goto.
478 auto true_offset = 1;
479 auto false_offset = after_true - after_cond;
480 instructions_[after_cond].if_op.true_offset = true_offset;
481 instructions_[after_cond].if_op.false_offset = false_offset;
482
483 // Patch the Goto.
484 this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1;
485
486 this->last_register_ = merge_register;
487 }
488
489 void EmitInvokeTVMOp(const Expr& func, const Expr& inputs, const Expr& outputs,
490 const DictAttrs& attrs) {
491 std::vector<Index> argument_registers;
492
493 const auto* global_var_node = func.as<GlobalVarNode>();
494 ICHECK(global_var_node) << "Expecting function in invoke_tvm_op to be a global";
495
496 auto input_tuple = inputs.as<TupleNode>();
497 ICHECK(input_tuple) << "internal error: invoke_tvm_op inputs must be a tuple,"
498 << "please file a bug in the memory manifestation pass";
499
500 auto output_tuple = outputs.as<TupleNode>();
501 ICHECK(output_tuple) << "internal error: invoke_tvm_op outputs must be a tuple,"
502 << "please file a bug in the memory manifestation pass";
503
504 for (auto input : input_tuple->fields) {
505 VisitExpr(input);
506 argument_registers.push_back(last_register_);
507 }
508
509 for (auto output : output_tuple->fields) {
510 ICHECK(output->IsInstance<VarNode>()) << "output should be var, found:" << std::endl
511 << PrettyPrint(output);
512 auto reg = var_register_map_.find(Downcast<Var>(output));
513 ICHECK(reg != var_register_map_.end())
514 << "internal error: all variables should be in the register mapping";
515 argument_registers.push_back(reg->second);
516 }
517
518 Index op_index;
519 auto itr = context_->primitive_map.find(global_var_node->name_hint);
520 if (itr == context_->primitive_map.end()) {
521 op_index = context_->primitive_map.size();
522 context_->primitive_map.emplace(global_var_node->name_hint, op_index);
523 } else {
524 op_index = itr->second;
525 }
526
527 if (attrs.defined() && attrs->dict.defined()) {
528 // Capture the dictionary of attributes from the original primitive function so that they
529 // can contribute to the hash of the compiled primitive. This way we can distinguish
530 // primitives with the same body expression but different attributes which may arbitrarily
531 // influence code generation.
532 op_attrs[op_index] = attrs->dict;
533 }
534
535 Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(),
536 argument_registers));
537 }
538
539 void DeviceAwareVisitExpr_(const CallNode* call_node) final {
540 DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
541 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
542 ICHECK(!call_lowered_props.lowered_func.defined());
543 if (device_copy_props.body.defined()) {
544 // TODO(mbs): device_copy cleanup.
545 VisitExpr(device_copy_props.body);
546 RegName src_reg = last_register_;
547 Index src_index = GetDeviceIndex(device_copy_props.src_virtual_device);
548 Index dst_index = GetDeviceIndex(device_copy_props.dst_virtual_device);
549 // Since scopes distinguish by targets (including any target hosts) but at runtime we
550 // deal only with devices, the copy may be unnecessary.
551 if (src_index != dst_index) {
552 Emit(Instruction::DeviceCopy(src_reg, src_index, dst_index, NewRegister()));
553 }
554 return;
555 }
556
557 // Now we handle the case in which we are using an opaque operator used to define a
558 // sub-dialect, such as memory allocation operations.
559 if (call_node->op.as<OpNode>()) {
560 OpMatch<void> matcher;
561 matcher
562 .Match("vm.invoke_tvm_op",
563 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
564 ICHECK_EQ(args.size(), 3);
565 EmitInvokeTVMOp(args[0], args[1], args[2], Downcast<DictAttrs>(attrs));
566 })
567 .Match("memory.alloc_tensor",
568 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
569 ICHECK_EQ(args.size(), 3);
570
571 // Get the attributes.
572 auto alloc_attrs = attrs.as<AllocTensorAttrs>();
573 ICHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs";
574 auto dtype = alloc_attrs->dtype;
575
576 // The storage will be passed dynamically.
577 this->VisitExpr(args[0]);
578 auto storage_register = last_register_;
579
580 // The storage will be passed dynamically.
581 this->VisitExpr(args[1]);
582 auto offset_register = last_register_;
583
584 // If the shape is constant then we will emit a static tensor allocation
585 // instruction. It may be wrapped by an on_device, but it will be on the host
586 // which is assumed by the alloc_tensor instruction anyway.
587 auto const_shape = AsIgnoringOnDevice<ConstantNode>(args[2]);
588
589 if (const_shape) {
590 NDArray shape = const_shape->data;
591 // TODO(@jroesch): we need to get an RFC done to standarize shape dtype
592 std::vector<int64_t> raw_shape = ToAllocTensorShape(shape);
593 // Add context field.
594 Emit(Instruction::AllocTensor(storage_register, offset_register, raw_shape,
595 dtype, NewRegister()));
596 } else {
597 this->VisitExpr(args[2]);
598 auto shape_register = last_register_;
599 Emit(Instruction::AllocTensorReg(storage_register, offset_register,
600 shape_register, dtype, NewRegister()));
601 }
602 })
603 .Match("memory.alloc_storage",
604 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
605 ICHECK_EQ(args.size(), 2);
606 // Compute the size of the allocation.
607 this->VisitExpr(args[0]);
608 auto size_register = last_register_;
609
610 ICHECK(args[1].as<ConstantNode>()); // Always a literal.
611 NDArray alignment_arr = args[1].as<ConstantNode>()->data;
612 ICHECK_EQ(alignment_arr->dtype.code, 0U)
613 << "The dtype of constant shape must be int32 or int64, but got "
614 << DLDataType2String(alignment_arr->dtype);
615 ICHECK_EQ(alignment_arr->dtype.bits, 64U);
616 Index alignment = reinterpret_cast<int64_t*>(alignment_arr->data)[0];
617
618 // Get the dtype hint from the attributes.
619 auto alloc_attrs = attrs.as<AllocStorageAttrs>();
620 ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs";
621 auto dtype = alloc_attrs->dtype;
622
623 Emit(Instruction::AllocStorage(size_register, alignment, dtype,
624 GetDeviceIndex(alloc_attrs->virtual_device),
625 NewRegister()));
626 })
627 .Match("vm.shape_of",
628 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
629 ICHECK_EQ(args.size(), 1U);
630 // Get the attributes.
631 const auto* shape_of_attrs = attrs.as<ShapeOfAttrs>();
632 ICHECK(shape_of_attrs) << "Must be the shape_of attrs";
633 ICHECK_EQ(shape_of_attrs->dtype.bits(), 64)
634 << "The dtype of shape of must be int64, but got"
635 << DLDataType2String(shape_of_attrs->dtype);
636 this->VisitExpr(args[0]);
637 Emit(Instruction::ShapeOf(last_register_, NewRegister()));
638 })
639 .Match("vm.reshape_tensor",
640 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
641 ICHECK_EQ(args.size(), 2u);
642 this->VisitExpr(args[0]);
643 auto tensor_reg = last_register_;
644 this->VisitExpr(args[1]);
645 auto shape_reg = last_register_;
646 Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister()));
647 })
648 .Match("memory.kill",
649 [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) {
650 ICHECK_EQ(args.size(), 1u);
651 this->VisitExpr(args[0]);
652 Emit(Instruction::KillRegister(this->last_register_));
653 });
654 matcher(GetRef<Call>(call_node));
655 return;
656 }
657
658 // In the case it's not one of these specialized operators we will generate code
659 // for one of the "standard" cases.
660 std::vector<Index> args_registers;
661
662 // Evaluate the call arguments.
663 for (auto arg : call_node->args) {
664 VisitExpr(arg);
665 args_registers.push_back(last_register_);
666 }
667
668 if (const auto* global_var_node = call_node->op.as<GlobalVarNode>()) {
669 // In the case we are invoking a global we need to find its
670 // global ID, and then check whether it is closure invocation
671 // or whether it is a standard global, and emit the correct
672 // calling convention.
673 auto global = GetRef<GlobalVar>(global_var_node);
674 auto it = context_->global_map.find(global);
675 ICHECK(it != context_->global_map.end()) << PrettyPrint(global);
676 VLOG(2) << "VisitExpr_: generating invoke for " << global->name_hint
677 << " with func_index=" << it->second;
678
679 // TODO(tvm-team):
680 // Think about mixed call into global that is not a relay::Function
681 // perhaps establish as an invariance(all functions in mod must be relay::Function)
682 auto func = Downcast<Function>(context_->module->Lookup(global));
683
684 if (IsClosure(func)) {
685 auto arity = func->params.size();
686 Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister()));
687 } else {
688 Emit(Instruction::Invoke(it->second, args_registers, NewRegister()));
689 }
690 } else if (const auto* constructor_node = call_node->op.as<ConstructorNode>()) {
691 // In the constructor case, we simply need to find its tag
692 // and emit a call to allocate the data structure.
693 auto constructor = GetRef<Constructor>(constructor_node);
694 Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers,
695 NewRegister()));
696 } else if (const auto* var_node = call_node->op.as<VarNode>()) {
697 // If we are calling a variable, it must be the case that it is a closure so we
698 // emit invoke closure here.
699 VisitExpr(GetRef<Var>(var_node));
700 Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
701 } else if (auto inner_call_node = call_node->op.as<CallNode>()) {
702 VisitExpr(GetRef<Call>(inner_call_node));
703 Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister()));
704 } else {
705 // Finally if there are any other cases this is a bug.
706 LOG(FATAL) << "internal error: unreachable code,"
707 << "should be transformed away by previous passes:" << std::endl
708 << PrettyPrint(GetRef<Expr>(call_node));
709 }
710 }
711
712 void DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
713 if (function_nesting() > 1) {
714 ICHECK(func_node->HasNonzeroAttr(attr::kPrimitive))
715 << "local functions should have been removed by lambda lifting:" << std::endl
716 << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl
717 << "AST: " << GetRef<Function>(func_node);
718 return;
719 }
720
721 // We're processing a top-level function which has possibly been rejigged to capture
722 // both closure and function arguments. Those functions retain their 'Closure' attribute,
723 // but we can just process them like any other function here.
724
725 // Assign a register num to each parameter.
726 size_t i = 0;
727 for (auto param : func_node->params) {
728 auto arg_register = NewRegister();
729 ICHECK_EQ(i, arg_register);
730 var_register_map_.insert({param, arg_register});
731 params_.push_back(param->name_hint());
732 ++i;
733 }
734
735 VisitExpr(func_node->body);
736
737 instructions_.push_back(Instruction::Ret(last_register_));
738 }
739
740 /*!
741 * \brief Compile a match value
742 * Generate byte code that compute the value specificed in val
743 *
744 * \return The register number assigned for the final value
745 */
746 RegName CompileMatchValue(MatchValuePtr val) {
747 if (std::dynamic_pointer_cast<RegisterValue>(val)) {
748 auto r = std::dynamic_pointer_cast<RegisterValue>(val);
749 return r->register_num;
750 } else {
751 auto path = std::dynamic_pointer_cast<AccessField>(val);
752 auto p = CompileMatchValue(path->parent);
753 Emit(Instruction::GetField(p, path->index, NewRegister()));
754 path->reg = last_register_;
755 return path->reg;
756 }
757 }
758
759 void CompileTreeNode(TreeObjectPtr tree) {
760 if (auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree)) {
761 VisitExpr(node->body);
762 } else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) {
763 Emit(Instruction::Fatal());
764 } else if (auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree)) {
765 if (auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond)) {
766 // For Tag compariton, generate branches
767 auto r = CompileMatchValue(cond->obj);
768 Emit(Instruction::GetTag(r, NewRegister()));
769 auto operand1 = last_register_;
770 Emit(Instruction::LoadConsti(cond->target_tag, NewRegister()));
771 auto operand2 = last_register_;
772
773 Emit(Instruction::If(operand1, operand2, 1, 0));
774 auto cond_offset = instructions_.size() - 1;
775 CompileTreeNode(node->then_branch);
776 auto if_reg = last_register_;
777 Emit(Instruction::Goto(1));
778 auto goto_offset = instructions_.size() - 1;
779 CompileTreeNode(node->else_branch);
780 auto else_reg = last_register_;
781 Emit(Instruction::Move(else_reg, if_reg));
782 last_register_ = if_reg;
783 auto else_offset = instructions_.size() - 1;
784 // Fixing offsets
785 instructions_[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1;
786 instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1;
787 } else {
788 // For other non-branch conditions, move to then_branch directly
789 auto var_bind = std::dynamic_pointer_cast<VarBinding>(node->cond);
790 var_register_map_[var_bind->var] = CompileMatchValue(var_bind->val);
791 CompileTreeNode(node->then_branch);
792 }
793 }
794 }
795
796 /*!
797 * \brief Compile a pattern match expression
798 * It first converts the pattern match expression into a decision tree, the condition
799 * could be object comparison or variable binding. If any of the condition fails in a clause,
800 * the decision tree switches to check the conditions of next clause and so on. If no clause
801 * matches the value, a fatal node is inserted.
802 *
803 * After the decision tree is built, we convert it into bytecodes using If/Goto.
804 */
805 void CompileMatch(Match match) {
806 auto data = std::make_shared<RegisterValue>(last_register_);
807 auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses);
808 CompileTreeNode(decision_tree);
809 }
810
811 protected:
812 /*! \brief Store the expression a variable points to. */
813 std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> expr_map_;
814 /*! \brief Instructions in the VMFunction. */
815 std::vector<Instruction> instructions_;
816 /*! \brief Parameter names of the function. */
817 std::vector<std::string> params_;
818 /*! \brief Map from var to register number. */
819 std::unordered_map<Var, RegName, ObjectPtrHash, ObjectPtrEqual> var_register_map_;
820 /*! \brief Last used register number. */
821 size_t last_register_;
822 /*! \brief Total number of virtual registers allocated. */
823 size_t registers_num_;
824 /*! \brief Global shared meta data */
825 VMCompilerContext* context_;
826 /*! \brief VirtualDevice for data and computation which must reside on a CPU. */
827 VirtualDevice host_virtual_device_;
828};
829
830PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
831 if (name == "lower") {
832 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
833 ICHECK_EQ(args.num_args, 2);
834 this->Lower(args[0], args[1]);
835 });
836 } else if (name == "codegen") {
837 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
838 ICHECK_EQ(args.num_args, 0);
839 this->Codegen();
840 });
841 } else if (name == "get_executable") {
842 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
843 ICHECK_EQ(args.num_args, 0);
844 *rv = this->GetExecutable();
845 });
846 } else if (name == "set_params") {
847 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
848 Map<String, Constant> params = args[0];
849 for (const auto& kv : params) {
850 this->SetParam(kv.first, kv.second->data);
851 }
852 });
853 } else if (name == "get_params") {
854 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
855 Map<String, Constant> ret;
856 for (const auto& kv : params_) {
857 ret.Set(kv.first, Constant(kv.second));
858 }
859 *rv = ret;
860 });
861 } else if (name == "optimize") {
862 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
863 ICHECK_EQ(args.num_args, 2);
864 *rv = this->OptimizeModule(args[0], args[1]);
865 });
866 } else {
867 LOG(FATAL) << "Unknown packed function: " << name;
868 }
869}
870
871void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) {
872 params_[name] = data_in;
873}
874
875void VMCompiler::Lower(IRModule mod, const Array<Target>& raw_targets) {
876 VLOG_CONTEXT << "VM Lower";
877 Setup(raw_targets);
878 LowerImpl(std::move(mod));
879}
880
881IRModule VMCompiler::OptimizeModule(IRModule mod, const Array<Target>& raw_targets) {
882 VLOG_CONTEXT << "VM Optimize";
883 Setup(raw_targets);
884 return OptimizeModuleImpl(std::move(mod));
885}
886
887runtime::Module VMCompiler::GetExecutable() const {
888 if (exec_ == nullptr) {
889 LOG(WARNING) << "No executable to return. Did you forget to call VMCompiler::Lower?";
890 }
891 if (exec_->imports().empty()) {
892 LOG(WARNING) << "Executable is empty. Did you forget to call VMCompiler::Codegen?";
893 }
894 return runtime::Module(exec_);
895}
896
897void VMCompiler::Setup(const Array<Target>& raw_targets) {
898 ICHECK(exec_ == nullptr) << "Can't reuse VMComplier object for multiple modules";
899 exec_ = make_object<Executable>();
900 ICHECK(!config_.defined());
901 config_ = CompilationConfig(PassContext::Current(), raw_targets);
902 VLOG(1) << "Using compilation config:" << std::endl << config_;
903
904 // The first device is always for the host.
905 CHECK(context_.virtual_devices_.empty());
906 VLOG(1) << "virtual_device[0] = " << config_->host_virtual_device << " (host)";
907 context_.virtual_devices_.push_back(config_->host_virtual_device);
908}
909
910void VMCompiler::LowerImpl(IRModule mod) {
911 // Run the optimizations necessary to target the VM.
912 context_.module = OptimizeModuleImpl(std::move(mod));
913
914 // Build the map from global variables bound to Functions to a global index in the
915 // VMFunction table.
916 size_t num_functions = PopulateGlobalMap();
917
918 // Next we get ready by allocating space for
919 // the global state.
920 exec_->functions.resize(num_functions);
921
922 for (const auto& pair : context_.module->functions) {
923 auto gvar = pair.first;
924 if (auto* n = pair.second.as<FunctionNode>()) {
925 if (n->HasNonzeroAttr(attr::kExtern)) {
926 // Already compiled during lowering.
927 continue;
928 }
929 auto func = GetRef<Function>(n);
930 VMFunctionCompiler func_compiler(&context_, config_->host_virtual_device);
931 auto vm_func = func_compiler.Compile(gvar, func);
932
933 size_t func_index = context_.global_map.at(gvar);
934 ICHECK(func_index < exec_->functions.size());
935 exec_->functions[func_index] = vm_func;
936
937 // update structural hashes for tvm ops
938 for (auto p : func_compiler.op_attrs) {
939 exec_->op_attrs.insert(p);
940 }
941 }
942 }
943
944 // Populate virtual devices and the host device index.
945 for (const auto& virtual_device : context_.virtual_devices_) {
946 ICHECK(!virtual_device->IsFullyUnconstrained());
947 ICHECK_GT(virtual_device->device_type(), 0);
948 // TODO(mbs): We forget the memory scope.
949 exec_->virtual_devices.push_back(Device{/*device_type=*/virtual_device->device_type(),
950 /*device_id=*/virtual_device->virtual_device_id});
951 }
952 exec_->host_device_index = kHostDeviceIndex;
953
954 // populate constants
955 for (const auto& data : context_.constants) {
956 exec_->constants.push_back(data);
957 }
958
959 for (auto index : context_.const_device_indexes) {
960 exec_->const_device_indexes.push_back(index);
961 }
962
963 // update global function map
964 for (const auto& gv : context_.global_map) {
965 exec_->global_map.insert({gv.first->name_hint, gv.second});
966 }
967
968 // update primitive function map
969 for (const auto& pair : context_.primitive_map) {
970 exec_->primitive_map.insert(pair);
971 }
972
973 VLOG(1) << "Compiled to:" << std::endl
974 << "-------------------------------------------------" << std::endl
975 << exec_->GetVirtualDevices() //
976 << exec_->GetConstants() //
977 << exec_->GetPrimitives() //
978 << exec_->GetBytecode() //
979 << "-------------------------------------------------";
980
981 if (backend::IsAutoSchedulerEnabled()) {
982 backend::UpdateAutoSchedulerOpWeights(context_.module);
983 }
984}
985
986transform::Sequential VMCompiler::MemoryOpt(const CompilationConfig& config) {
987 Array<Pass> pass_seqs;
988 // Remove unused functions
989 Array<runtime::String> entry_functions{"main"};
990 pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
991 // Manifest the allocations.
992 pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device));
993
994 // Compute away possibly introduced constant computation.
995 pass_seqs.push_back(transform::FoldConstant());
996
997 // Fuse & lower any new shape functions and device_copies.
998 pass_seqs.push_back(FuseAndLowerOperators(config));
999
1000 // Manifest the allocations needed for the shape functions.
1001 pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device));
1002
1003 // Fuse & lower any new allocations.
1004 pass_seqs.push_back(FuseAndLowerOperators(config));
1005
1006 // TODO(mbrookhart, jroesch, masahi): this pass is very slow, and is
1007 // incomplete to provide memory resuse optimizations. Disable it until we can
1008 // rewrite it in C++ and complete it.
1009 // // Perform memory planning in order to coalesce/reduce allocations.
1010 // pass_seqs.push_back(transform::MemoryPlan());
1011
1012 // Compute away constant computation introduced by coalescing allocations.
1013 pass_seqs.push_back(transform::FoldConstant());
1014
1015 // Fuse & lower yet again
1016 pass_seqs.push_back(FuseAndLowerOperators(config));
1017
1018 // Create allocations for math introduced by dynamic region math.
1019 pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device));
1020
1021 // Compute away possibly introduced constant computation.
1022 pass_seqs.push_back(transform::FoldConstant());
1023
1024 // Insert kills to free memory.
1025 pass_seqs.push_back(transform::ManifestLifetimes());
1026
1027 // Lift constants to the top-level of the block to simplify VM code generation.
1028 // TODO(@icemelon9, @jroesch): Remove this pass for now because some
1029 // instructions need to access to constant
1030 // pass_seqs.push_back(transform::LiftConstants());
1031
1032 return transform::Sequential(std::move(pass_seqs));
1033}
1034
1035transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& config) {
1036 Array<Pass> pass_seqs;
1037 // Hoist operators to "primitive" Functions.
1038 pass_seqs.push_back(FuseOps());
1039 // Give each "primitive" Function a hash.
1040 pass_seqs.push_back(LabelOps());
1041 // Lower "primitive" Functions to PrimFuncs and rewrite calls.
1042 pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config, [this](const BaseFunc& func) {
1043 if (func->GetAttr<String>(attr::kCompiler).defined()) {
1044 backend::UpdateConstants(func, &params_);
1045 }
1046 }));
1047 // Since lowered functions are bound in the IRModule, we can now eliminate any unused
1048 // let-bound functions.
1049 pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
1050 return transform::Sequential(std::move(pass_seqs));
1051}
1052
1053IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) {
1054 backend::BindParamsInModule(mod, params_);
1055 Array<Pass> pass_seqs = relay::backend::GetPassPrefix(
1056 /*is_homogeneous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true);
1057
1058 // Always plan devices so the remaining passes don't need to distinguish homogeneous vs
1059 // heterogeneous execution.
1060 pass_seqs.push_back(transform::PlanDevices(config_));
1061
1062 pass_seqs.push_back(transform::FuseOps());
1063
1064 // Do layout rewrite for auto-scheduler.
1065 transform::PassContext pass_ctx = PassContext::Current();
1066 if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) {
1067 Pass major_pass = transform::AutoSchedulerLayoutRewrite();
1068 bool enable_layout_rewrite_targets =
1069 config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU ||
1070 config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
1071 if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
1072 With<Target> tctx(config_->optional_homogeneous_target);
1073 pass_seqs.push_back(major_pass);
1074 // Defuse ops to fold constants, then fuse them again
1075 pass_seqs.push_back(transform::DefuseOps());
1076 pass_seqs.push_back(transform::FoldConstant());
1077 pass_seqs.push_back(transform::FuseOps());
1078 }
1079 }
1080 if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) {
1081 Pass major_pass = transform::MetaScheduleLayoutRewrite();
1082 bool enable_layout_rewrite_targets =
1083 config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU ||
1084 config_->optional_homogeneous_target->GetAttr<String>("device", "") == "mali";
1085 if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) {
1086 With<Target> tctx(config_->optional_homogeneous_target);
1087 pass_seqs.push_back(major_pass);
1088 // Defuse ops to fold constants, then fuse them again
1089 pass_seqs.push_back(transform::DefuseOps());
1090 pass_seqs.push_back(transform::FoldConstant());
1091 pass_seqs.push_back(transform::FuseOps());
1092 }
1093 }
1094
1095 pass_seqs.push_back(transform::ToANormalForm());
1096 pass_seqs.push_back(transform::InferType());
1097 pass_seqs.push_back(transform::LambdaLift());
1098
1099 // Eliminate dead-code before we lower. We don't track the purity of PrimFuncs, thus after
1100 // lowering all calls to lowered functions will be kept.
1101 pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
1102 pass_seqs.push_back(transform::LabelOps());
1103
1104 // Lower all functions annotated as "primitive" by FuseOps.
1105 pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod", config_, [this](const BaseFunc& func) {
1106 if (func->GetAttr<String>(attr::kCompiler).defined()) {
1107 backend::UpdateConstants(func, &params_);
1108 }
1109 }));
1110
1111 // Since lowered functions are bound in the IRModule, we can now eliminate any unused
1112 // let-bound functions.
1113 pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false));
1114
1115 // At this point it's possible to run PlanDevices again to pick up any additional constraints
1116 // introduced during lowering. However we'll not do this until more testing has been done.
1117
1118 // Inline the functions that are lifted to the module scope. We perform this
1119 // pass after all other optimization passes but before the memory allocation
1120 // pass. This is because memory allocation pass will insert `invoke_tvm_op`
1121 // and we use these ops to invoke the symbols in the module generated by
1122 // external codegen.
1123 pass_seqs.push_back(transform::Inline());
1124
1125 pass_seqs.push_back(MemoryOpt(config_));
1126 pass_seqs.push_back(transform::InferType());
1127
1128 transform::Sequential seq(pass_seqs);
1129 tvm::With<relay::transform::PassContext> ctx(pass_ctx);
1130 if (config_->optional_homogeneous_target.defined()) {
1131 With<Target> tctx(config_->optional_homogeneous_target);
1132 return seq(std::move(mod));
1133 } else {
1134 return seq(std::move(mod));
1135 }
1136}
1137
1138size_t VMCompiler::PopulateGlobalMap() {
1139 // Allocate a VMFunction index for every Relay Function we could call.
1140 // Excludes PrimFuncs and externs, which are managed by the primitive_map_.
1141 for (const auto& kv : context_.module->functions) {
1142 if (const auto* function_node = kv.second.as<FunctionNode>()) {
1143 if (!function_node->HasNonzeroAttr(attr::kExtern)) {
1144 context_.global_map.emplace(kv.first, context_.global_map.size());
1145 }
1146 }
1147 }
1148 return context_.global_map.size();
1149}
1150
1151void VMCompiler::Codegen() {
1152 VLOG_CONTEXT << "VM Codegen";
1153 if (!context_.module.defined()) {
1154 LOG(WARNING) << "No compiled module to codegen from. Did you forget to call VMCompiler::Lower?";
1155 return;
1156 }
1157
1158 // At this point context_.module will contain only:
1159 // - non-external Relay functions, which we've compiled into VMFunctions.
1160 // - external Relay functions, which will have definitions within some external runtime module
1161 // in the "external_mods" attribute
1162 // - PrimFuncs annotated with their targets.
1163 // Only the PrimFuncs will appear in per_target_modules, and there may legitimately be none.
1164 Map<Target, IRModule> per_tvm_target_modules = tec::GetPerTargetModules(context_.module);
1165 for (const auto& kv : per_tvm_target_modules) {
1166 ICHECK(kv.first->GetTargetDeviceType() != kDLExtDev);
1167 }
1168
1169 // Retrieve all external runtime modules accumulated by external codegen (both function-at-a-time
1170 // and IRModule-at-a-time).
1171 Array<runtime::Module> external_mods =
1172 context_.module->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({});
1173
1174 // Retrieve any constant bindings accumulated by external codegen (by IRModule-at-a-time passes).
1175 Map<String, runtime::NDArray> const_name_to_constant =
1176 context_.module->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
1177 .value_or({});
1178
1179 VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build, "
1180 << external_mods.size() << " external runtime modules, " << const_name_to_constant.size()
1181 << " external constants, and " << params_.size() << " local constants";
1182
1183 // Any constant bindings must be merged into the overall 'params' map we've directly accumulated
1184 // via the TECompiler callback.
1185 for (const auto& kv : const_name_to_constant) {
1186 ICHECK_EQ(params_.count(kv.first), 0);
1187 params_.emplace(kv.first, kv.second);
1188 }
1189
1190 runtime::Module lib;
1191 if (per_tvm_target_modules.empty()) {
1192 // There is no function handled by TVM. We create a virtual main module
1193 // to make sure a DSO module will be also available.
1194 LOG(INFO) << "All lowered functions have been build by BYOC -- generating an empty TVM module";
1195 lib = codegen::CSourceModuleCreate(";", "", Array<String>{});
1196 } else {
1197 lib = tvm::TIRToRuntime(per_tvm_target_modules, config_->host_target);
1198 }
1199
1200 lib =
1201 codegen::CreateMetadataModule(params_, lib, external_mods, config_->host_target,
1202 Runtime::Create("cpp"), Executor::Create("graph"), // DNS HACK
1203 relay::backend::ExecutorCodegenMetadata());
1204 exec_->SetLib(lib);
1205}
1206
1207runtime::Module CreateVMCompiler() {
1208 auto exec = make_object<VMCompiler>();
1209 return runtime::Module(std::move(exec));
1210}
1211
1212TVM_REGISTER_GLOBAL("relay._vm._VMCompiler").set_body_typed(CreateVMCompiler);
1213
1214} // namespace vm
1215} // namespace relay
1216} // namespace tvm
1217