1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/vm/compiler.cc |
22 | * \brief A compiler from relay::Module to the VM byte code. |
23 | */ |
24 | |
25 | #include "compiler.h" |
26 | |
27 | #include <tvm/driver/driver_api.h> |
28 | #include <tvm/relay/analysis.h> |
29 | #include <tvm/relay/attrs/device_copy.h> |
30 | #include <tvm/relay/attrs/memory.h> |
31 | #include <tvm/relay/error.h> |
32 | #include <tvm/relay/expr_functor.h> |
33 | #include <tvm/relay/interpreter.h> |
34 | #include <tvm/relay/parser.h> |
35 | #include <tvm/relay/qnn/transform.h> |
36 | #include <tvm/relay/runtime.h> |
37 | #include <tvm/relay/transform.h> |
38 | #include <tvm/runtime/logging.h> |
39 | #include <tvm/runtime/vm/vm.h> |
40 | #include <tvm/te/operation.h> |
41 | |
42 | #include <iostream> |
43 | #include <map> |
44 | #include <memory> |
45 | #include <string> |
46 | #include <tuple> |
47 | #include <vector> |
48 | |
49 | #include "../../../driver/internal_driver_api.h" |
50 | #include "../../../target/metadata_module.h" |
51 | #include "../../../target/source/codegen_source_base.h" |
52 | #include "../../op/annotation/annotation.h" |
53 | #include "../../op/memory/device_copy.h" |
54 | #include "../../op/op_common.h" |
55 | #include "../../transforms/device_aware_visitors.h" |
56 | #include "../../transforms/pass_utils.h" |
57 | #include "../utils.h" |
58 | #include "./compiler.h" |
59 | |
60 | namespace tvm { |
61 | namespace relay { |
62 | |
63 | namespace transform { |
64 | |
65 | Pass LambdaLift(); |
66 | Pass LabelOps(); |
67 | |
68 | Pass MemoryPlan() { |
69 | auto f = tvm::runtime::Registry::Get("relay.transform.MemoryPlan" ); |
70 | ICHECK(f != nullptr) << "unable to load the memory planning pass" ; |
71 | return (*f)(); |
72 | } |
73 | |
74 | Pass LiftConstants() { |
75 | auto f = tvm::runtime::Registry::Get("relay.transform.LiftConstants" ); |
76 | ICHECK(f != nullptr) << "unable to load the constant lifting pass" ; |
77 | return (*f)(); |
78 | } |
79 | |
80 | } // namespace transform |
81 | |
82 | namespace vm { |
83 | |
84 | using namespace tvm::runtime; |
85 | using namespace tvm::runtime::vm; |
86 | using namespace relay::transform; |
87 | |
88 | /*! \brief The host device is always stored at device index 0. */ |
89 | constexpr Index kHostDeviceIndex = 0; |
90 | |
91 | // (@jroesch): VM passes, eventually declare as passes. |
92 | bool IsClosure(const Function& func); |
93 | |
94 | // Represent a runtime object that's going to be matched by pattern match expressions |
95 | struct MatchValue { |
96 | virtual ~MatchValue() {} |
97 | }; |
98 | using MatchValuePtr = std::shared_ptr<MatchValue>; |
99 | |
100 | // A runtime object that resides in a register |
101 | struct RegisterValue : MatchValue { |
102 | // The register num |
103 | RegName register_num; |
104 | |
105 | explicit RegisterValue(RegName reg) : register_num(reg) {} |
106 | |
107 | ~RegisterValue() {} |
108 | }; |
109 | |
110 | // The value is a field of another runtime object |
111 | struct AccessField : MatchValue { |
112 | MatchValuePtr parent; |
113 | // Field index |
114 | size_t index; |
115 | // Runtime register num after compiling the access field path |
116 | RegName reg{-1}; |
117 | |
118 | AccessField(MatchValuePtr parent, size_t index) : parent(parent), index(index) {} |
119 | |
120 | ~AccessField() {} |
121 | }; |
122 | |
123 | /*! |
124 | * \brief Condition in a decision tree |
125 | */ |
126 | struct ConditionNode { |
127 | virtual ~ConditionNode() {} |
128 | }; |
129 | |
130 | using ConditionObjectPtr = std::shared_ptr<ConditionNode>; |
131 | |
132 | /*! |
133 | * \brief A var binding condition |
134 | */ |
135 | struct VarBinding : ConditionNode { |
136 | Var var; |
137 | MatchValuePtr val; |
138 | |
139 | VarBinding(Var var, MatchValuePtr val) : var(var), val(val) {} |
140 | |
141 | ~VarBinding() {} |
142 | }; |
143 | |
144 | /*! |
145 | * \brief Compare the tag of the object |
146 | */ |
147 | struct TagCompare : ConditionNode { |
148 | /*! \brief The object to be examined */ |
149 | MatchValuePtr obj; |
150 | |
151 | /*! \brief The expected tag */ |
152 | int target_tag; |
153 | |
154 | TagCompare(MatchValuePtr obj, size_t target) : obj(obj), target_tag(target) {} |
155 | |
156 | ~TagCompare() {} |
157 | }; |
158 | |
159 | using TreeObjectPtr = typename relay::TreeNode<ConditionObjectPtr>::pointer; |
160 | using TreeLeafNode = relay::TreeLeafNode<ConditionObjectPtr>; |
161 | using TreeLeafFatalNode = relay::TreeLeafFatalNode<ConditionObjectPtr>; |
162 | using TreeBranchNode = relay::TreeBranchNode<ConditionObjectPtr>; |
163 | |
164 | TreeObjectPtr BuildDecisionTreeFromPattern(MatchValuePtr data, Pattern pattern, |
165 | TreeObjectPtr then_branch, TreeObjectPtr else_branch) { |
166 | if (pattern.as<PatternWildcardNode>()) { |
167 | // We ignore wildcard binding since it's not producing new vars |
168 | return then_branch; |
169 | } else if (const auto* pvn = pattern.as<PatternVarNode>()) { |
170 | auto cond = std::make_shared<VarBinding>(pvn->var, data); |
171 | return TreeBranchNode::Make(cond, then_branch, else_branch); |
172 | } else if (const auto* pcn = pattern.as<PatternConstructorNode>()) { |
173 | auto tag = pcn->constructor->tag; |
174 | |
175 | size_t field_index = 0; |
176 | for (auto& p : pcn->patterns) { |
177 | auto d = std::make_shared<AccessField>(data, field_index); |
178 | then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); |
179 | field_index++; |
180 | } |
181 | auto cond = std::make_shared<TagCompare>(data, tag); |
182 | return TreeBranchNode::Make(cond, then_branch, else_branch); |
183 | } else { |
184 | const auto* pt = pattern.as<PatternTupleNode>(); |
185 | ICHECK(pt) << "unhandled case: " << AsText(pattern, false); |
186 | size_t field_index = 0; |
187 | for (auto& p : pt->patterns) { |
188 | auto d = std::make_shared<AccessField>(data, field_index++); |
189 | then_branch = BuildDecisionTreeFromPattern(d, p, then_branch, else_branch); |
190 | } |
191 | return then_branch; |
192 | } |
193 | } |
194 | |
195 | TreeObjectPtr BuildDecisionTreeFromClause(MatchValuePtr data, Clause clause, |
196 | TreeObjectPtr else_branch) { |
197 | return BuildDecisionTreeFromPattern(data, clause->lhs, TreeLeafNode::Make(clause->rhs), |
198 | else_branch); |
199 | } |
200 | |
201 | TreeObjectPtr BuildDecisionTreeFromClauses(MatchValuePtr data, tvm::Array<Clause> clauses) { |
202 | // When nothing matches, the VM throws fatal error |
203 | TreeObjectPtr else_branch = TreeLeafFatalNode::Make(); |
204 | // Start from the last clause |
205 | for (auto it = clauses.rbegin(); it != clauses.rend(); ++it) { |
206 | else_branch = BuildDecisionTreeFromClause(data, *it, else_branch); |
207 | } |
208 | return else_branch; |
209 | } |
210 | |
211 | std::vector<int64_t> ToAllocTensorShape(NDArray shape) { |
212 | std::vector<int64_t> raw_shape; |
213 | if (shape->ndim == 0) { |
214 | return raw_shape; |
215 | } |
216 | ICHECK_EQ(shape->ndim, 1u); |
217 | ICHECK_EQ(shape->dtype.code, 0U) << "The dtype of constant shape must be int32 or int64, but got " |
218 | << DLDataType2String(shape->dtype); |
219 | ICHECK(shape->dtype.bits == 64 || shape->dtype.bits == 32) |
220 | << "The dtype of constant shape must be int32 or int64, but got" |
221 | << DLDataType2String(shape->dtype); |
222 | |
223 | if (shape->dtype.bits == 64) { |
224 | int64_t* int_ptr = reinterpret_cast<int64_t*>(shape->data); |
225 | for (auto i = 0; i < shape->shape[0]; i++) { |
226 | raw_shape.push_back(int_ptr[i]); |
227 | } |
228 | } else { // int32 |
229 | int32_t* int_ptr = reinterpret_cast<int32_t*>(shape->data); |
230 | for (auto i = 0; i < shape->shape[0]; i++) { |
231 | raw_shape.push_back(static_cast<int64_t>(int_ptr[i])); |
232 | } |
233 | } |
234 | return raw_shape; |
235 | } |
236 | |
237 | class VMFunctionCompiler : DeviceAwareExprFunctor<void(const Expr& n)> { |
238 | public: |
239 | VMFunctionCompiler(VMCompilerContext* context, VirtualDevice host_virtual_device) |
240 | : DeviceAwareExprFunctor(context->module), |
241 | last_register_(0), |
242 | registers_num_(0), |
243 | context_(context), |
244 | host_virtual_device_(std::move(host_virtual_device)) {} |
245 | |
246 | VMFunction Compile(const GlobalVar& var, const Function& func) { |
247 | VLOG(1) << "Compiling:" << std::endl << PrettyPrint(func); |
248 | std::vector<Index> param_device_indexes; |
249 | if (IsClosure(func)) { |
250 | // After lifting we'll have functions of the form: |
251 | // fn(closure args) { fn(lifted function args) { body } } |
252 | // But we want the closure's function to be: |
253 | // fn(closure args, lifter function args) { body } |
254 | // Do that flattening on-the-fly here. |
255 | Function inner_func = Downcast<Function>(func->body); |
256 | std::vector<Var> params; |
257 | params.reserve(func->params.size() + inner_func->params.size()); |
258 | param_device_indexes.reserve(func->params.size() + inner_func->params.size()); |
259 | for (size_t i = 0; i < func->params.size(); ++i) { |
260 | params.emplace_back(func->params[i]); |
261 | param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device())); |
262 | } |
263 | for (size_t i = 0; i < inner_func->params.size(); ++i) { |
264 | params.emplace_back(inner_func->params[i]); |
265 | |
266 | param_device_indexes.push_back(GetDeviceIndex(inner_func->params[i]->virtual_device())); |
267 | } |
268 | std::vector<TypeVar> type_params; |
269 | type_params.reserve(func->type_params.size() + inner_func->type_params.size()); |
270 | for (const auto& tyvar : func->type_params) { |
271 | type_params.push_back(tyvar); |
272 | } |
273 | for (const auto& tyvar : inner_func->type_params) { |
274 | type_params.push_back(tyvar); |
275 | } |
276 | Function flattened_func = Function(params, inner_func->body, inner_func->ret_type, |
277 | type_params, func->attrs, func->span); |
278 | flattened_func->virtual_device_ = inner_func->virtual_device(); |
279 | VisitExpr(flattened_func); |
280 | } else { |
281 | param_device_indexes.reserve(func->params.size()); |
282 | for (size_t i = 0; i < func->params.size(); ++i) { |
283 | param_device_indexes.push_back(GetDeviceIndex(func->params[i]->virtual_device())); |
284 | } |
285 | VisitExpr(func); |
286 | } |
287 | return VMFunction(var->name_hint, params_, instructions_, registers_num_, |
288 | std::move(param_device_indexes)); |
289 | } |
290 | |
291 | /*! \brief Attrs objects for each op. */ |
292 | std::map<Index, Map<String, ObjectRef>> op_attrs; |
293 | |
294 | /*! \brief Attrs objects for each callsite. */ |
295 | std::map<Index, Map<String, ObjectRef>> callsite_attrs; |
296 | |
297 | protected: |
298 | size_t NewRegister() { return registers_num_++; } |
299 | |
300 | inline void Emit(const Instruction& instr) { |
301 | size_t instruction_index = instructions_.size(); |
302 | VLOG(2) << "instruction[" << instruction_index << "] = " << instr; |
303 | ICHECK((int)instr.op < 100) << "Invalid opcode " << (int)instr.op; |
304 | switch (instr.op) { |
305 | case Opcode::AllocADT: |
306 | case Opcode::AllocTensor: |
307 | case Opcode::AllocTensorReg: |
308 | case Opcode::GetField: |
309 | case Opcode::GetTag: |
310 | case Opcode::LoadConst: |
311 | case Opcode::LoadConsti: |
312 | case Opcode::Invoke: |
313 | case Opcode::AllocClosure: |
314 | case Opcode::AllocStorage: |
315 | case Opcode::ShapeOf: |
316 | case Opcode::ReshapeTensor: |
317 | case Opcode::Move: |
318 | case Opcode::InvokeClosure: |
319 | case Opcode::DeviceCopy: |
320 | last_register_ = instr.dst; |
321 | break; |
322 | case Opcode::InvokePacked: |
323 | case Opcode::If: |
324 | case Opcode::Ret: |
325 | case Opcode::Goto: |
326 | case Opcode::Fatal: |
327 | case Opcode::KillRegister: |
328 | break; |
329 | } |
330 | instructions_.push_back(instr); |
331 | } |
332 | |
333 | /*! |
334 | * \brief Returns the "device index" to represent \p virtual_device for primitives |
335 | * in emitted code. Note that the host device is always at index 0. |
336 | */ |
337 | Index GetDeviceIndex(const VirtualDevice& virtual_device) { |
338 | ICHECK(!virtual_device->IsFullyUnconstrained()); |
339 | auto itr = std::find(context_->virtual_devices_.begin(), context_->virtual_devices_.end(), |
340 | virtual_device); |
341 | if (itr != context_->virtual_devices_.end()) { |
342 | return std::distance(context_->virtual_devices_.begin(), itr); |
343 | } |
344 | |
345 | ICHECK_GT(context_->virtual_devices_.size(), 0); |
346 | ICHECK_NE(virtual_device, host_virtual_device_); // the host scope is always at index 0 |
347 | |
348 | if (virtual_device->device_type() == context_->virtual_devices_.front()->device_type()) { |
349 | // It's ok if we see distinct scopes which share the host device type. This is because |
350 | // we allow the VirtualDevice for the host to be different from the VirtualDevice for |
351 | // primitive operations which both happen to be on the same device (typically CPU). |
352 | return 0; |
353 | } |
354 | |
355 | // However, otherwise we allow at most one VirtualDevice per device type. |
356 | // TODO(mbs): This will eventually need to account for memory scopes somehow so device_copy |
357 | // instructions can do the right thing. |
358 | itr = std::find_if(context_->virtual_devices_.begin() + 1, context_->virtual_devices_.end(), |
359 | [&virtual_device](const VirtualDevice& existing_virtual_device) { |
360 | return existing_virtual_device->device_type() == |
361 | virtual_device->device_type(); |
362 | }); |
363 | CHECK(itr == context_->virtual_devices_.end()) |
364 | << "The VM does not currently support using more than one device with the same device type " |
365 | "for primitives, however the program is using the distinct scopes " |
366 | << virtual_device << " and " << *itr << " of device type " << virtual_device->device_type(); |
367 | |
368 | ICHECK(virtual_device != host_virtual_device_); |
369 | Index index = context_->virtual_devices_.size(); |
370 | VLOG(2) << "virtual_device[" << index << "] = " << virtual_device; |
371 | context_->virtual_devices_.push_back(virtual_device); |
372 | |
373 | return index; |
374 | } |
375 | |
376 | using DeviceAwareExprFunctor<void(const Expr&)>::VisitExpr_; |
377 | |
378 | void VisitExpr_(const ConstantNode* const_node) final { |
379 | // Check the shape is valid |
380 | NDArray data = const_node->data; |
381 | size_t const_index = context_->constants.size(); |
382 | auto con = GetRef<Constant>(const_node); |
383 | Index device_index = GetDeviceIndex(GetVirtualDevice(con)); |
384 | VLOG(2) << "constant[" << const_index << "] on device[" << device_index << "]" ; |
385 | context_->const_device_indexes.push_back(device_index); |
386 | context_->constants.push_back(const_node->data); |
387 | Emit(Instruction::LoadConst(const_index, NewRegister())); |
388 | } |
389 | |
390 | void VisitExpr_(const VarNode* var_node) final { |
391 | auto var = GetRef<Var>(var_node); |
392 | auto reg_it = this->var_register_map_.find(var); |
393 | ICHECK(reg_it != this->var_register_map_.end()); |
394 | last_register_ = reg_it->second; |
395 | } |
396 | |
397 | void VisitExpr_(const TupleNode* tuple_node) final { |
398 | auto tuple = GetRef<Tuple>(tuple_node); |
399 | std::vector<Index> fields_registers; |
400 | |
401 | for (auto& field : tuple->fields) { |
402 | this->VisitExpr(field); |
403 | fields_registers.push_back(last_register_); |
404 | } |
405 | |
406 | // TODO(@jroesch): use correct tag |
407 | Emit(Instruction::AllocADT(0, tuple->fields.size(), fields_registers, NewRegister())); |
408 | } |
409 | |
410 | void VisitExpr_(const MatchNode* match_node) final { |
411 | auto match = GetRef<Match>(match_node); |
412 | |
413 | this->VisitExpr(match->data); |
414 | CompileMatch(match); |
415 | } |
416 | |
417 | void PreVisitLetBinding_(const Var& var, const Expr& value) final { |
418 | ICHECK(!value.as<FunctionNode>()) |
419 | << "unexpected function:" << std::endl |
420 | << PrettyPrint(value) << std::endl |
421 | << "bound to var '" << var->name_hint() << "'. Did you set opt_level = 2?" ; |
422 | VisitExpr(value); |
423 | var_register_map_.emplace(var, this->last_register_); |
424 | } |
425 | |
426 | void VisitExpr_(const TupleGetItemNode* get_node) final { |
427 | auto get = GetRef<TupleGetItem>(get_node); |
428 | this->VisitExpr(get->tuple); |
429 | auto tuple_register = last_register_; |
430 | Emit(Instruction::GetField(tuple_register, get->index, NewRegister())); |
431 | } |
432 | |
433 | void VisitExpr_(const GlobalVarNode* gvar) final { |
434 | auto var = GetRef<GlobalVar>(gvar); |
435 | auto func = context_->module->Lookup(var); |
436 | auto it = context_->global_map.find(var); |
437 | ICHECK(it != context_->global_map.end()) << PrettyPrint(var); |
438 | // Allocate closure with zero free vars |
439 | Emit(Instruction::AllocClosure(it->second, 0, {}, NewRegister())); |
440 | } |
441 | |
442 | void VisitExpr_(const IfNode* if_node) final { |
443 | this->VisitExpr(if_node->cond); |
444 | |
445 | size_t test_register = last_register_; |
446 | |
447 | this->Emit(Instruction::LoadConsti(1, NewRegister())); |
448 | auto after_cond = instructions_.size(); |
449 | auto target_register = last_register_; |
450 | this->Emit(Instruction::If(test_register, target_register, 0, 0)); |
451 | this->VisitExpr(if_node->true_branch); |
452 | |
453 | // It saves the result of If-Else expression. |
454 | auto merge_register = NewRegister(); |
455 | Emit(Instruction::Move(last_register_, merge_register)); |
456 | Emit(Instruction::Goto(0)); |
457 | |
458 | // Finally store how many instructions there are in the |
459 | // true branch. |
460 | auto after_true = this->instructions_.size(); |
461 | |
462 | this->VisitExpr(if_node->false_branch); |
463 | |
464 | size_t false_register = last_register_; |
465 | |
466 | // In else-branch, override the then-branch register |
467 | Emit(Instruction::Move(false_register, merge_register)); |
468 | // Compute the total number of instructions |
469 | // after generating false. |
470 | auto after_false = this->instructions_.size(); |
471 | |
472 | // Now we will compute the jump targets in order |
473 | // to properly patch the instruction with the |
474 | // the requiste targets. |
475 | |
476 | // After we emit the true body, and false body, |
477 | // we patch up the if instruction, and goto. |
478 | auto true_offset = 1; |
479 | auto false_offset = after_true - after_cond; |
480 | instructions_[after_cond].if_op.true_offset = true_offset; |
481 | instructions_[after_cond].if_op.false_offset = false_offset; |
482 | |
483 | // Patch the Goto. |
484 | this->instructions_[after_true - 1].pc_offset = (after_false - after_true) + 1; |
485 | |
486 | this->last_register_ = merge_register; |
487 | } |
488 | |
489 | void EmitInvokeTVMOp(const Expr& func, const Expr& inputs, const Expr& outputs, |
490 | const DictAttrs& attrs) { |
491 | std::vector<Index> argument_registers; |
492 | |
493 | const auto* global_var_node = func.as<GlobalVarNode>(); |
494 | ICHECK(global_var_node) << "Expecting function in invoke_tvm_op to be a global" ; |
495 | |
496 | auto input_tuple = inputs.as<TupleNode>(); |
497 | ICHECK(input_tuple) << "internal error: invoke_tvm_op inputs must be a tuple," |
498 | << "please file a bug in the memory manifestation pass" ; |
499 | |
500 | auto output_tuple = outputs.as<TupleNode>(); |
501 | ICHECK(output_tuple) << "internal error: invoke_tvm_op outputs must be a tuple," |
502 | << "please file a bug in the memory manifestation pass" ; |
503 | |
504 | for (auto input : input_tuple->fields) { |
505 | VisitExpr(input); |
506 | argument_registers.push_back(last_register_); |
507 | } |
508 | |
509 | for (auto output : output_tuple->fields) { |
510 | ICHECK(output->IsInstance<VarNode>()) << "output should be var, found:" << std::endl |
511 | << PrettyPrint(output); |
512 | auto reg = var_register_map_.find(Downcast<Var>(output)); |
513 | ICHECK(reg != var_register_map_.end()) |
514 | << "internal error: all variables should be in the register mapping" ; |
515 | argument_registers.push_back(reg->second); |
516 | } |
517 | |
518 | Index op_index; |
519 | auto itr = context_->primitive_map.find(global_var_node->name_hint); |
520 | if (itr == context_->primitive_map.end()) { |
521 | op_index = context_->primitive_map.size(); |
522 | context_->primitive_map.emplace(global_var_node->name_hint, op_index); |
523 | } else { |
524 | op_index = itr->second; |
525 | } |
526 | |
527 | if (attrs.defined() && attrs->dict.defined()) { |
528 | // Capture the dictionary of attributes from the original primitive function so that they |
529 | // can contribute to the hash of the compiled primitive. This way we can distinguish |
530 | // primitives with the same body expression but different attributes which may arbitrarily |
531 | // influence code generation. |
532 | op_attrs[op_index] = attrs->dict; |
533 | } |
534 | |
535 | Emit(Instruction::InvokePacked(op_index, argument_registers.size(), output_tuple->fields.size(), |
536 | argument_registers)); |
537 | } |
538 | |
539 | void DeviceAwareVisitExpr_(const CallNode* call_node) final { |
540 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); |
541 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
542 | ICHECK(!call_lowered_props.lowered_func.defined()); |
543 | if (device_copy_props.body.defined()) { |
544 | // TODO(mbs): device_copy cleanup. |
545 | VisitExpr(device_copy_props.body); |
546 | RegName src_reg = last_register_; |
547 | Index src_index = GetDeviceIndex(device_copy_props.src_virtual_device); |
548 | Index dst_index = GetDeviceIndex(device_copy_props.dst_virtual_device); |
549 | // Since scopes distinguish by targets (including any target hosts) but at runtime we |
550 | // deal only with devices, the copy may be unnecessary. |
551 | if (src_index != dst_index) { |
552 | Emit(Instruction::DeviceCopy(src_reg, src_index, dst_index, NewRegister())); |
553 | } |
554 | return; |
555 | } |
556 | |
557 | // Now we handle the case in which we are using an opaque operator used to define a |
558 | // sub-dialect, such as memory allocation operations. |
559 | if (call_node->op.as<OpNode>()) { |
560 | OpMatch<void> matcher; |
561 | matcher |
562 | .Match("vm.invoke_tvm_op" , |
563 | [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) { |
564 | ICHECK_EQ(args.size(), 3); |
565 | EmitInvokeTVMOp(args[0], args[1], args[2], Downcast<DictAttrs>(attrs)); |
566 | }) |
567 | .Match("memory.alloc_tensor" , |
568 | [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) { |
569 | ICHECK_EQ(args.size(), 3); |
570 | |
571 | // Get the attributes. |
572 | auto alloc_attrs = attrs.as<AllocTensorAttrs>(); |
573 | ICHECK(alloc_attrs != nullptr) << "must be the alloc tensor attrs" ; |
574 | auto dtype = alloc_attrs->dtype; |
575 | |
576 | // The storage will be passed dynamically. |
577 | this->VisitExpr(args[0]); |
578 | auto storage_register = last_register_; |
579 | |
580 | // The storage will be passed dynamically. |
581 | this->VisitExpr(args[1]); |
582 | auto offset_register = last_register_; |
583 | |
584 | // If the shape is constant then we will emit a static tensor allocation |
585 | // instruction. It may be wrapped by an on_device, but it will be on the host |
586 | // which is assumed by the alloc_tensor instruction anyway. |
587 | auto const_shape = AsIgnoringOnDevice<ConstantNode>(args[2]); |
588 | |
589 | if (const_shape) { |
590 | NDArray shape = const_shape->data; |
591 | // TODO(@jroesch): we need to get an RFC done to standarize shape dtype |
592 | std::vector<int64_t> raw_shape = ToAllocTensorShape(shape); |
593 | // Add context field. |
594 | Emit(Instruction::AllocTensor(storage_register, offset_register, raw_shape, |
595 | dtype, NewRegister())); |
596 | } else { |
597 | this->VisitExpr(args[2]); |
598 | auto shape_register = last_register_; |
599 | Emit(Instruction::AllocTensorReg(storage_register, offset_register, |
600 | shape_register, dtype, NewRegister())); |
601 | } |
602 | }) |
603 | .Match("memory.alloc_storage" , |
604 | [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) { |
605 | ICHECK_EQ(args.size(), 2); |
606 | // Compute the size of the allocation. |
607 | this->VisitExpr(args[0]); |
608 | auto size_register = last_register_; |
609 | |
610 | ICHECK(args[1].as<ConstantNode>()); // Always a literal. |
611 | NDArray alignment_arr = args[1].as<ConstantNode>()->data; |
612 | ICHECK_EQ(alignment_arr->dtype.code, 0U) |
613 | << "The dtype of constant shape must be int32 or int64, but got " |
614 | << DLDataType2String(alignment_arr->dtype); |
615 | ICHECK_EQ(alignment_arr->dtype.bits, 64U); |
616 | Index alignment = reinterpret_cast<int64_t*>(alignment_arr->data)[0]; |
617 | |
618 | // Get the dtype hint from the attributes. |
619 | auto alloc_attrs = attrs.as<AllocStorageAttrs>(); |
620 | ICHECK(alloc_attrs != nullptr) << "must be the AllocStorage attrs" ; |
621 | auto dtype = alloc_attrs->dtype; |
622 | |
623 | Emit(Instruction::AllocStorage(size_register, alignment, dtype, |
624 | GetDeviceIndex(alloc_attrs->virtual_device), |
625 | NewRegister())); |
626 | }) |
627 | .Match("vm.shape_of" , |
628 | [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) { |
629 | ICHECK_EQ(args.size(), 1U); |
630 | // Get the attributes. |
631 | const auto* shape_of_attrs = attrs.as<ShapeOfAttrs>(); |
632 | ICHECK(shape_of_attrs) << "Must be the shape_of attrs" ; |
633 | ICHECK_EQ(shape_of_attrs->dtype.bits(), 64) |
634 | << "The dtype of shape of must be int64, but got" |
635 | << DLDataType2String(shape_of_attrs->dtype); |
636 | this->VisitExpr(args[0]); |
637 | Emit(Instruction::ShapeOf(last_register_, NewRegister())); |
638 | }) |
639 | .Match("vm.reshape_tensor" , |
640 | [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) { |
641 | ICHECK_EQ(args.size(), 2u); |
642 | this->VisitExpr(args[0]); |
643 | auto tensor_reg = last_register_; |
644 | this->VisitExpr(args[1]); |
645 | auto shape_reg = last_register_; |
646 | Emit(Instruction::ReshapeTensor(tensor_reg, shape_reg, NewRegister())); |
647 | }) |
648 | .Match("memory.kill" , |
649 | [this](const Array<Expr>& args, const Attrs& attrs, const Array<Type>& type_arg) { |
650 | ICHECK_EQ(args.size(), 1u); |
651 | this->VisitExpr(args[0]); |
652 | Emit(Instruction::KillRegister(this->last_register_)); |
653 | }); |
654 | matcher(GetRef<Call>(call_node)); |
655 | return; |
656 | } |
657 | |
658 | // In the case it's not one of these specialized operators we will generate code |
659 | // for one of the "standard" cases. |
660 | std::vector<Index> args_registers; |
661 | |
662 | // Evaluate the call arguments. |
663 | for (auto arg : call_node->args) { |
664 | VisitExpr(arg); |
665 | args_registers.push_back(last_register_); |
666 | } |
667 | |
668 | if (const auto* global_var_node = call_node->op.as<GlobalVarNode>()) { |
669 | // In the case we are invoking a global we need to find its |
670 | // global ID, and then check whether it is closure invocation |
671 | // or whether it is a standard global, and emit the correct |
672 | // calling convention. |
673 | auto global = GetRef<GlobalVar>(global_var_node); |
674 | auto it = context_->global_map.find(global); |
675 | ICHECK(it != context_->global_map.end()) << PrettyPrint(global); |
676 | VLOG(2) << "VisitExpr_: generating invoke for " << global->name_hint |
677 | << " with func_index=" << it->second; |
678 | |
679 | // TODO(tvm-team): |
680 | // Think about mixed call into global that is not a relay::Function |
681 | // perhaps establish as an invariance(all functions in mod must be relay::Function) |
682 | auto func = Downcast<Function>(context_->module->Lookup(global)); |
683 | |
684 | if (IsClosure(func)) { |
685 | auto arity = func->params.size(); |
686 | Emit(Instruction::AllocClosure(it->second, arity, args_registers, NewRegister())); |
687 | } else { |
688 | Emit(Instruction::Invoke(it->second, args_registers, NewRegister())); |
689 | } |
690 | } else if (const auto* constructor_node = call_node->op.as<ConstructorNode>()) { |
691 | // In the constructor case, we simply need to find its tag |
692 | // and emit a call to allocate the data structure. |
693 | auto constructor = GetRef<Constructor>(constructor_node); |
694 | Emit(Instruction::AllocADT(constructor->tag, call_node->args.size(), args_registers, |
695 | NewRegister())); |
696 | } else if (const auto* var_node = call_node->op.as<VarNode>()) { |
697 | // If we are calling a variable, it must be the case that it is a closure so we |
698 | // emit invoke closure here. |
699 | VisitExpr(GetRef<Var>(var_node)); |
700 | Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister())); |
701 | } else if (auto inner_call_node = call_node->op.as<CallNode>()) { |
702 | VisitExpr(GetRef<Call>(inner_call_node)); |
703 | Emit(Instruction::InvokeClosure(last_register_, args_registers, NewRegister())); |
704 | } else { |
705 | // Finally if there are any other cases this is a bug. |
706 | LOG(FATAL) << "internal error: unreachable code," |
707 | << "should be transformed away by previous passes:" << std::endl |
708 | << PrettyPrint(GetRef<Expr>(call_node)); |
709 | } |
710 | } |
711 | |
712 | void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { |
713 | if (function_nesting() > 1) { |
714 | ICHECK(func_node->HasNonzeroAttr(attr::kPrimitive)) |
715 | << "local functions should have been removed by lambda lifting:" << std::endl |
716 | << "Program: " << AsText(GetRef<Function>(func_node), false) << std::endl |
717 | << "AST: " << GetRef<Function>(func_node); |
718 | return; |
719 | } |
720 | |
721 | // We're processing a top-level function which has possibly been rejigged to capture |
722 | // both closure and function arguments. Those functions retain their 'Closure' attribute, |
723 | // but we can just process them like any other function here. |
724 | |
725 | // Assign a register num to each parameter. |
726 | size_t i = 0; |
727 | for (auto param : func_node->params) { |
728 | auto arg_register = NewRegister(); |
729 | ICHECK_EQ(i, arg_register); |
730 | var_register_map_.insert({param, arg_register}); |
731 | params_.push_back(param->name_hint()); |
732 | ++i; |
733 | } |
734 | |
735 | VisitExpr(func_node->body); |
736 | |
737 | instructions_.push_back(Instruction::Ret(last_register_)); |
738 | } |
739 | |
740 | /*! |
741 | * \brief Compile a match value |
742 | * Generate byte code that compute the value specificed in val |
743 | * |
744 | * \return The register number assigned for the final value |
745 | */ |
746 | RegName CompileMatchValue(MatchValuePtr val) { |
747 | if (std::dynamic_pointer_cast<RegisterValue>(val)) { |
748 | auto r = std::dynamic_pointer_cast<RegisterValue>(val); |
749 | return r->register_num; |
750 | } else { |
751 | auto path = std::dynamic_pointer_cast<AccessField>(val); |
752 | auto p = CompileMatchValue(path->parent); |
753 | Emit(Instruction::GetField(p, path->index, NewRegister())); |
754 | path->reg = last_register_; |
755 | return path->reg; |
756 | } |
757 | } |
758 | |
759 | void CompileTreeNode(TreeObjectPtr tree) { |
760 | if (auto node = std::dynamic_pointer_cast<TreeLeafNode>(tree)) { |
761 | VisitExpr(node->body); |
762 | } else if (std::dynamic_pointer_cast<TreeLeafFatalNode>(tree)) { |
763 | Emit(Instruction::Fatal()); |
764 | } else if (auto node = std::dynamic_pointer_cast<TreeBranchNode>(tree)) { |
765 | if (auto cond = std::dynamic_pointer_cast<TagCompare>(node->cond)) { |
766 | // For Tag compariton, generate branches |
767 | auto r = CompileMatchValue(cond->obj); |
768 | Emit(Instruction::GetTag(r, NewRegister())); |
769 | auto operand1 = last_register_; |
770 | Emit(Instruction::LoadConsti(cond->target_tag, NewRegister())); |
771 | auto operand2 = last_register_; |
772 | |
773 | Emit(Instruction::If(operand1, operand2, 1, 0)); |
774 | auto cond_offset = instructions_.size() - 1; |
775 | CompileTreeNode(node->then_branch); |
776 | auto if_reg = last_register_; |
777 | Emit(Instruction::Goto(1)); |
778 | auto goto_offset = instructions_.size() - 1; |
779 | CompileTreeNode(node->else_branch); |
780 | auto else_reg = last_register_; |
781 | Emit(Instruction::Move(else_reg, if_reg)); |
782 | last_register_ = if_reg; |
783 | auto else_offset = instructions_.size() - 1; |
784 | // Fixing offsets |
785 | instructions_[cond_offset].if_op.false_offset = goto_offset - cond_offset + 1; |
786 | instructions_[goto_offset].pc_offset = else_offset - goto_offset + 1; |
787 | } else { |
788 | // For other non-branch conditions, move to then_branch directly |
789 | auto var_bind = std::dynamic_pointer_cast<VarBinding>(node->cond); |
790 | var_register_map_[var_bind->var] = CompileMatchValue(var_bind->val); |
791 | CompileTreeNode(node->then_branch); |
792 | } |
793 | } |
794 | } |
795 | |
796 | /*! |
797 | * \brief Compile a pattern match expression |
798 | * It first converts the pattern match expression into a decision tree, the condition |
799 | * could be object comparison or variable binding. If any of the condition fails in a clause, |
800 | * the decision tree switches to check the conditions of next clause and so on. If no clause |
801 | * matches the value, a fatal node is inserted. |
802 | * |
803 | * After the decision tree is built, we convert it into bytecodes using If/Goto. |
804 | */ |
805 | void CompileMatch(Match match) { |
806 | auto data = std::make_shared<RegisterValue>(last_register_); |
807 | auto decision_tree = BuildDecisionTreeFromClauses(data, match->clauses); |
808 | CompileTreeNode(decision_tree); |
809 | } |
810 | |
811 | protected: |
812 | /*! \brief Store the expression a variable points to. */ |
813 | std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> expr_map_; |
814 | /*! \brief Instructions in the VMFunction. */ |
815 | std::vector<Instruction> instructions_; |
816 | /*! \brief Parameter names of the function. */ |
817 | std::vector<std::string> params_; |
818 | /*! \brief Map from var to register number. */ |
819 | std::unordered_map<Var, RegName, ObjectPtrHash, ObjectPtrEqual> var_register_map_; |
820 | /*! \brief Last used register number. */ |
821 | size_t last_register_; |
822 | /*! \brief Total number of virtual registers allocated. */ |
823 | size_t registers_num_; |
824 | /*! \brief Global shared meta data */ |
825 | VMCompilerContext* context_; |
826 | /*! \brief VirtualDevice for data and computation which must reside on a CPU. */ |
827 | VirtualDevice host_virtual_device_; |
828 | }; |
829 | |
830 | PackedFunc VMCompiler::GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) { |
831 | if (name == "lower" ) { |
832 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
833 | ICHECK_EQ(args.num_args, 2); |
834 | this->Lower(args[0], args[1]); |
835 | }); |
836 | } else if (name == "codegen" ) { |
837 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
838 | ICHECK_EQ(args.num_args, 0); |
839 | this->Codegen(); |
840 | }); |
841 | } else if (name == "get_executable" ) { |
842 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
843 | ICHECK_EQ(args.num_args, 0); |
844 | *rv = this->GetExecutable(); |
845 | }); |
846 | } else if (name == "set_params" ) { |
847 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
848 | Map<String, Constant> params = args[0]; |
849 | for (const auto& kv : params) { |
850 | this->SetParam(kv.first, kv.second->data); |
851 | } |
852 | }); |
853 | } else if (name == "get_params" ) { |
854 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
855 | Map<String, Constant> ret; |
856 | for (const auto& kv : params_) { |
857 | ret.Set(kv.first, Constant(kv.second)); |
858 | } |
859 | *rv = ret; |
860 | }); |
861 | } else if (name == "optimize" ) { |
862 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
863 | ICHECK_EQ(args.num_args, 2); |
864 | *rv = this->OptimizeModule(args[0], args[1]); |
865 | }); |
866 | } else { |
867 | LOG(FATAL) << "Unknown packed function: " << name; |
868 | } |
869 | } |
870 | |
871 | void VMCompiler::SetParam(const std::string& name, runtime::NDArray data_in) { |
872 | params_[name] = data_in; |
873 | } |
874 | |
875 | void VMCompiler::Lower(IRModule mod, const Array<Target>& raw_targets) { |
876 | VLOG_CONTEXT << "VM Lower" ; |
877 | Setup(raw_targets); |
878 | LowerImpl(std::move(mod)); |
879 | } |
880 | |
881 | IRModule VMCompiler::OptimizeModule(IRModule mod, const Array<Target>& raw_targets) { |
882 | VLOG_CONTEXT << "VM Optimize" ; |
883 | Setup(raw_targets); |
884 | return OptimizeModuleImpl(std::move(mod)); |
885 | } |
886 | |
887 | runtime::Module VMCompiler::GetExecutable() const { |
888 | if (exec_ == nullptr) { |
889 | LOG(WARNING) << "No executable to return. Did you forget to call VMCompiler::Lower?" ; |
890 | } |
891 | if (exec_->imports().empty()) { |
892 | LOG(WARNING) << "Executable is empty. Did you forget to call VMCompiler::Codegen?" ; |
893 | } |
894 | return runtime::Module(exec_); |
895 | } |
896 | |
897 | void VMCompiler::Setup(const Array<Target>& raw_targets) { |
898 | ICHECK(exec_ == nullptr) << "Can't reuse VMComplier object for multiple modules" ; |
899 | exec_ = make_object<Executable>(); |
900 | ICHECK(!config_.defined()); |
901 | config_ = CompilationConfig(PassContext::Current(), raw_targets); |
902 | VLOG(1) << "Using compilation config:" << std::endl << config_; |
903 | |
904 | // The first device is always for the host. |
905 | CHECK(context_.virtual_devices_.empty()); |
906 | VLOG(1) << "virtual_device[0] = " << config_->host_virtual_device << " (host)" ; |
907 | context_.virtual_devices_.push_back(config_->host_virtual_device); |
908 | } |
909 | |
910 | void VMCompiler::LowerImpl(IRModule mod) { |
911 | // Run the optimizations necessary to target the VM. |
912 | context_.module = OptimizeModuleImpl(std::move(mod)); |
913 | |
914 | // Build the map from global variables bound to Functions to a global index in the |
915 | // VMFunction table. |
916 | size_t num_functions = PopulateGlobalMap(); |
917 | |
918 | // Next we get ready by allocating space for |
919 | // the global state. |
920 | exec_->functions.resize(num_functions); |
921 | |
922 | for (const auto& pair : context_.module->functions) { |
923 | auto gvar = pair.first; |
924 | if (auto* n = pair.second.as<FunctionNode>()) { |
925 | if (n->HasNonzeroAttr(attr::kExtern)) { |
926 | // Already compiled during lowering. |
927 | continue; |
928 | } |
929 | auto func = GetRef<Function>(n); |
930 | VMFunctionCompiler func_compiler(&context_, config_->host_virtual_device); |
931 | auto vm_func = func_compiler.Compile(gvar, func); |
932 | |
933 | size_t func_index = context_.global_map.at(gvar); |
934 | ICHECK(func_index < exec_->functions.size()); |
935 | exec_->functions[func_index] = vm_func; |
936 | |
937 | // update structural hashes for tvm ops |
938 | for (auto p : func_compiler.op_attrs) { |
939 | exec_->op_attrs.insert(p); |
940 | } |
941 | } |
942 | } |
943 | |
944 | // Populate virtual devices and the host device index. |
945 | for (const auto& virtual_device : context_.virtual_devices_) { |
946 | ICHECK(!virtual_device->IsFullyUnconstrained()); |
947 | ICHECK_GT(virtual_device->device_type(), 0); |
948 | // TODO(mbs): We forget the memory scope. |
949 | exec_->virtual_devices.push_back(Device{/*device_type=*/virtual_device->device_type(), |
950 | /*device_id=*/virtual_device->virtual_device_id}); |
951 | } |
952 | exec_->host_device_index = kHostDeviceIndex; |
953 | |
954 | // populate constants |
955 | for (const auto& data : context_.constants) { |
956 | exec_->constants.push_back(data); |
957 | } |
958 | |
959 | for (auto index : context_.const_device_indexes) { |
960 | exec_->const_device_indexes.push_back(index); |
961 | } |
962 | |
963 | // update global function map |
964 | for (const auto& gv : context_.global_map) { |
965 | exec_->global_map.insert({gv.first->name_hint, gv.second}); |
966 | } |
967 | |
968 | // update primitive function map |
969 | for (const auto& pair : context_.primitive_map) { |
970 | exec_->primitive_map.insert(pair); |
971 | } |
972 | |
973 | VLOG(1) << "Compiled to:" << std::endl |
974 | << "-------------------------------------------------" << std::endl |
975 | << exec_->GetVirtualDevices() // |
976 | << exec_->GetConstants() // |
977 | << exec_->GetPrimitives() // |
978 | << exec_->GetBytecode() // |
979 | << "-------------------------------------------------" ; |
980 | |
981 | if (backend::IsAutoSchedulerEnabled()) { |
982 | backend::UpdateAutoSchedulerOpWeights(context_.module); |
983 | } |
984 | } |
985 | |
986 | transform::Sequential VMCompiler::MemoryOpt(const CompilationConfig& config) { |
987 | Array<Pass> pass_seqs; |
988 | // Remove unused functions |
989 | Array<runtime::String> entry_functions{"main" }; |
990 | pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); |
991 | // Manifest the allocations. |
992 | pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); |
993 | |
994 | // Compute away possibly introduced constant computation. |
995 | pass_seqs.push_back(transform::FoldConstant()); |
996 | |
997 | // Fuse & lower any new shape functions and device_copies. |
998 | pass_seqs.push_back(FuseAndLowerOperators(config)); |
999 | |
1000 | // Manifest the allocations needed for the shape functions. |
1001 | pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); |
1002 | |
1003 | // Fuse & lower any new allocations. |
1004 | pass_seqs.push_back(FuseAndLowerOperators(config)); |
1005 | |
1006 | // TODO(mbrookhart, jroesch, masahi): this pass is very slow, and is |
1007 | // incomplete to provide memory resuse optimizations. Disable it until we can |
1008 | // rewrite it in C++ and complete it. |
1009 | // // Perform memory planning in order to coalesce/reduce allocations. |
1010 | // pass_seqs.push_back(transform::MemoryPlan()); |
1011 | |
1012 | // Compute away constant computation introduced by coalescing allocations. |
1013 | pass_seqs.push_back(transform::FoldConstant()); |
1014 | |
1015 | // Fuse & lower yet again |
1016 | pass_seqs.push_back(FuseAndLowerOperators(config)); |
1017 | |
1018 | // Create allocations for math introduced by dynamic region math. |
1019 | pass_seqs.push_back(transform::ManifestAlloc(config->host_virtual_device)); |
1020 | |
1021 | // Compute away possibly introduced constant computation. |
1022 | pass_seqs.push_back(transform::FoldConstant()); |
1023 | |
1024 | // Insert kills to free memory. |
1025 | pass_seqs.push_back(transform::ManifestLifetimes()); |
1026 | |
1027 | // Lift constants to the top-level of the block to simplify VM code generation. |
1028 | // TODO(@icemelon9, @jroesch): Remove this pass for now because some |
1029 | // instructions need to access to constant |
1030 | // pass_seqs.push_back(transform::LiftConstants()); |
1031 | |
1032 | return transform::Sequential(std::move(pass_seqs)); |
1033 | } |
1034 | |
1035 | transform::Sequential VMCompiler::FuseAndLowerOperators(const CompilationConfig& config) { |
1036 | Array<Pass> pass_seqs; |
1037 | // Hoist operators to "primitive" Functions. |
1038 | pass_seqs.push_back(FuseOps()); |
1039 | // Give each "primitive" Function a hash. |
1040 | pass_seqs.push_back(LabelOps()); |
1041 | // Lower "primitive" Functions to PrimFuncs and rewrite calls. |
1042 | pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod" , config, [this](const BaseFunc& func) { |
1043 | if (func->GetAttr<String>(attr::kCompiler).defined()) { |
1044 | backend::UpdateConstants(func, ¶ms_); |
1045 | } |
1046 | })); |
1047 | // Since lowered functions are bound in the IRModule, we can now eliminate any unused |
1048 | // let-bound functions. |
1049 | pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); |
1050 | return transform::Sequential(std::move(pass_seqs)); |
1051 | } |
1052 | |
1053 | IRModule VMCompiler::OptimizeModuleImpl(IRModule mod) { |
1054 | backend::BindParamsInModule(mod, params_); |
1055 | Array<Pass> pass_seqs = relay::backend::GetPassPrefix( |
1056 | /*is_homogeneous=*/config_->optional_homogeneous_target.defined(), /*is_vm=*/true); |
1057 | |
1058 | // Always plan devices so the remaining passes don't need to distinguish homogeneous vs |
1059 | // heterogeneous execution. |
1060 | pass_seqs.push_back(transform::PlanDevices(config_)); |
1061 | |
1062 | pass_seqs.push_back(transform::FuseOps()); |
1063 | |
1064 | // Do layout rewrite for auto-scheduler. |
1065 | transform::PassContext pass_ctx = PassContext::Current(); |
1066 | if (backend::IsAutoSchedulerEnabled() && config_->optional_homogeneous_target.defined()) { |
1067 | Pass major_pass = transform::AutoSchedulerLayoutRewrite(); |
1068 | bool enable_layout_rewrite_targets = |
1069 | config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU || |
1070 | config_->optional_homogeneous_target->GetAttr<String>("device" , "" ) == "mali" ; |
1071 | if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) { |
1072 | With<Target> tctx(config_->optional_homogeneous_target); |
1073 | pass_seqs.push_back(major_pass); |
1074 | // Defuse ops to fold constants, then fuse them again |
1075 | pass_seqs.push_back(transform::DefuseOps()); |
1076 | pass_seqs.push_back(transform::FoldConstant()); |
1077 | pass_seqs.push_back(transform::FuseOps()); |
1078 | } |
1079 | } |
1080 | if (backend::IsMetaScheduleEnabled() && config_->optional_homogeneous_target.defined()) { |
1081 | Pass major_pass = transform::MetaScheduleLayoutRewrite(); |
1082 | bool enable_layout_rewrite_targets = |
1083 | config_->optional_homogeneous_target->GetTargetDeviceType() == kDLCPU || |
1084 | config_->optional_homogeneous_target->GetAttr<String>("device" , "" ) == "mali" ; |
1085 | if (enable_layout_rewrite_targets && pass_ctx.PassEnabled(major_pass->Info())) { |
1086 | With<Target> tctx(config_->optional_homogeneous_target); |
1087 | pass_seqs.push_back(major_pass); |
1088 | // Defuse ops to fold constants, then fuse them again |
1089 | pass_seqs.push_back(transform::DefuseOps()); |
1090 | pass_seqs.push_back(transform::FoldConstant()); |
1091 | pass_seqs.push_back(transform::FuseOps()); |
1092 | } |
1093 | } |
1094 | |
1095 | pass_seqs.push_back(transform::ToANormalForm()); |
1096 | pass_seqs.push_back(transform::InferType()); |
1097 | pass_seqs.push_back(transform::LambdaLift()); |
1098 | |
1099 | // Eliminate dead-code before we lower. We don't track the purity of PrimFuncs, thus after |
1100 | // lowering all calls to lowered functions will be kept. |
1101 | pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); |
1102 | pass_seqs.push_back(transform::LabelOps()); |
1103 | |
1104 | // Lower all functions annotated as "primitive" by FuseOps. |
1105 | pass_seqs.push_back(tec::LowerTE(/*module_name=*/"vm_mod" , config_, [this](const BaseFunc& func) { |
1106 | if (func->GetAttr<String>(attr::kCompiler).defined()) { |
1107 | backend::UpdateConstants(func, ¶ms_); |
1108 | } |
1109 | })); |
1110 | |
1111 | // Since lowered functions are bound in the IRModule, we can now eliminate any unused |
1112 | // let-bound functions. |
1113 | pass_seqs.push_back(DeadCodeElimination(/*inline_once=*/false)); |
1114 | |
1115 | // At this point it's possible to run PlanDevices again to pick up any additional constraints |
1116 | // introduced during lowering. However we'll not do this until more testing has been done. |
1117 | |
1118 | // Inline the functions that are lifted to the module scope. We perform this |
1119 | // pass after all other optimization passes but before the memory allocation |
1120 | // pass. This is because memory allocation pass will insert `invoke_tvm_op` |
1121 | // and we use these ops to invoke the symbols in the module generated by |
1122 | // external codegen. |
1123 | pass_seqs.push_back(transform::Inline()); |
1124 | |
1125 | pass_seqs.push_back(MemoryOpt(config_)); |
1126 | pass_seqs.push_back(transform::InferType()); |
1127 | |
1128 | transform::Sequential seq(pass_seqs); |
1129 | tvm::With<relay::transform::PassContext> ctx(pass_ctx); |
1130 | if (config_->optional_homogeneous_target.defined()) { |
1131 | With<Target> tctx(config_->optional_homogeneous_target); |
1132 | return seq(std::move(mod)); |
1133 | } else { |
1134 | return seq(std::move(mod)); |
1135 | } |
1136 | } |
1137 | |
1138 | size_t VMCompiler::PopulateGlobalMap() { |
1139 | // Allocate a VMFunction index for every Relay Function we could call. |
1140 | // Excludes PrimFuncs and externs, which are managed by the primitive_map_. |
1141 | for (const auto& kv : context_.module->functions) { |
1142 | if (const auto* function_node = kv.second.as<FunctionNode>()) { |
1143 | if (!function_node->HasNonzeroAttr(attr::kExtern)) { |
1144 | context_.global_map.emplace(kv.first, context_.global_map.size()); |
1145 | } |
1146 | } |
1147 | } |
1148 | return context_.global_map.size(); |
1149 | } |
1150 | |
1151 | void VMCompiler::Codegen() { |
1152 | VLOG_CONTEXT << "VM Codegen" ; |
1153 | if (!context_.module.defined()) { |
1154 | LOG(WARNING) << "No compiled module to codegen from. Did you forget to call VMCompiler::Lower?" ; |
1155 | return; |
1156 | } |
1157 | |
1158 | // At this point context_.module will contain only: |
1159 | // - non-external Relay functions, which we've compiled into VMFunctions. |
1160 | // - external Relay functions, which will have definitions within some external runtime module |
1161 | // in the "external_mods" attribute |
1162 | // - PrimFuncs annotated with their targets. |
1163 | // Only the PrimFuncs will appear in per_target_modules, and there may legitimately be none. |
1164 | Map<Target, IRModule> per_tvm_target_modules = tec::GetPerTargetModules(context_.module); |
1165 | for (const auto& kv : per_tvm_target_modules) { |
1166 | ICHECK(kv.first->GetTargetDeviceType() != kDLExtDev); |
1167 | } |
1168 | |
1169 | // Retrieve all external runtime modules accumulated by external codegen (both function-at-a-time |
1170 | // and IRModule-at-a-time). |
1171 | Array<runtime::Module> external_mods = |
1172 | context_.module->GetAttr<Array<runtime::Module>>(tvm::attr::kExternalMods).value_or({}); |
1173 | |
1174 | // Retrieve any constant bindings accumulated by external codegen (by IRModule-at-a-time passes). |
1175 | Map<String, runtime::NDArray> const_name_to_constant = |
1176 | context_.module->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant) |
1177 | .value_or({}); |
1178 | |
1179 | VLOG(0) << "have " << per_tvm_target_modules.size() << " targets to build, " |
1180 | << external_mods.size() << " external runtime modules, " << const_name_to_constant.size() |
1181 | << " external constants, and " << params_.size() << " local constants" ; |
1182 | |
1183 | // Any constant bindings must be merged into the overall 'params' map we've directly accumulated |
1184 | // via the TECompiler callback. |
1185 | for (const auto& kv : const_name_to_constant) { |
1186 | ICHECK_EQ(params_.count(kv.first), 0); |
1187 | params_.emplace(kv.first, kv.second); |
1188 | } |
1189 | |
1190 | runtime::Module lib; |
1191 | if (per_tvm_target_modules.empty()) { |
1192 | // There is no function handled by TVM. We create a virtual main module |
1193 | // to make sure a DSO module will be also available. |
1194 | LOG(INFO) << "All lowered functions have been build by BYOC -- generating an empty TVM module" ; |
1195 | lib = codegen::CSourceModuleCreate(";" , "" , Array<String>{}); |
1196 | } else { |
1197 | lib = tvm::TIRToRuntime(per_tvm_target_modules, config_->host_target); |
1198 | } |
1199 | |
1200 | lib = |
1201 | codegen::CreateMetadataModule(params_, lib, external_mods, config_->host_target, |
1202 | Runtime::Create("cpp" ), Executor::Create("graph" ), // DNS HACK |
1203 | relay::backend::ExecutorCodegenMetadata()); |
1204 | exec_->SetLib(lib); |
1205 | } |
1206 | |
1207 | runtime::Module CreateVMCompiler() { |
1208 | auto exec = make_object<VMCompiler>(); |
1209 | return runtime::Module(std::move(exec)); |
1210 | } |
1211 | |
1212 | TVM_REGISTER_GLOBAL("relay._vm._VMCompiler" ).set_body_typed(CreateVMCompiler); |
1213 | |
1214 | } // namespace vm |
1215 | } // namespace relay |
1216 | } // namespace tvm |
1217 | |