1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/aot/aot_lower_main.cc
22 * \brief Lower the Relay main func into an AOT TIR main func.
23 */
24#include "./aot_lower_main.h"
25
26#include <tvm/runtime/name_transforms.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/transform.h>
29
30#include "../../op/call/call.h"
31#include "../../op/memory/device_copy.h"
32#include "../../op/memory/memory.h"
33#include "../../transforms/device_aware_visitors.h"
34#include "../name_transforms.h"
35#include "../utils.h"
36
37namespace tvm {
38namespace relay {
39namespace backend {
40namespace aot {
41
42/*!
43 * \brief Looks at the expressions in a given function and produces an Expr to
44 * StorageInfo map by assigning one or more StorageInfos to the expressions that
45 * require storage.
46 *
47 * This pass is leveraged by AOTMainLowerer to perform an initial naive allocation
48 * for tensors in the Relay main function. The resulting storage map is then lowered
49 * into TIR allocations by AOTMainLowerer where the allocation can be subsequently
50 * optimized by later passes (e.g. USMP).
51 */
52class ExprAllocator : public transform::DeviceAwareExprVisitor {
53 public:
54 ExprAllocator() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
55
56 // run the visitor on a global function.
57 void Run(const Function& func) { VisitExpr(func); }
58
59 std::vector<int> GetReturnSIDs() const { return return_sids_; }
60
61 StorageMap GetStorageMap() const { return expr_storage_map_; }
62
63 using ExprVisitor::VisitExpr_;
64
65 void DeviceAwareVisitExpr_(const CallNode* call_node) final {
66 Array<Expr> args;
67
68 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
69 if (call_lowered_props.lowered_func.defined()) {
70 args = call_lowered_props.arguments;
71 } else { // Relay functions that have not been lowered and lowered extern functions
72 args = call_node->args;
73 if (call_node->op.as<GlobalVarNode>()) { // Lowered extern function
74 ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
75 } else { // Relay function which has not been lowered yet
76 ICHECK(call_node->op.as<FunctionNode>())
77 << "Expected the call to be to a lowered primfunc, a lowered extern function or a "
78 "unlowered Relay function.";
79 }
80 }
81 CreateStorage(call_node);
82 for (const Expr& arg : args) {
83 VisitExpr(arg);
84 }
85 AssignReturnSID(GetRef<Expr>(call_node));
86 }
87
88 void DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
89 if (function_nesting() > 1) {
90 // Do not recurse into sub functions.
91 return;
92 }
93 for (const auto& param : func_node->params) {
94 CreateStorage(param.get());
95 }
96 VisitExpr(func_node->body);
97 }
98
99 void PreVisitLetBinding_(const Var& var, const Expr& value) final {
100 VisitExpr(value);
101 StorageInfo si = GetStorage(value);
102 expr_storage_map_[var] = si;
103 }
104
105 void VisitExpr_(const ConstantNode* op) final {
106 CreateStorage(op);
107 AssignReturnSID(GetRef<Expr>(op));
108 }
109
110 void VisitExpr_(const VarNode* op) final { AssignReturnSID(GetRef<Expr>(op)); }
111
112 void VisitExpr_(const TupleNode* op) final {
113 std::vector<int64_t> storage_ids;
114 std::vector<VirtualDevice> virtual_devices;
115 std::vector<int64_t> storage_sizes_in_bytes;
116 Expr expr = GetRef<Expr>(op);
117 for (Expr field : op->fields) {
118 auto sid = GetStorage(field);
119 storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end());
120 virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(),
121 sid->virtual_devices.end());
122 storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(),
123 sid->storage_sizes_in_bytes.begin(),
124 sid->storage_sizes_in_bytes.end());
125 }
126 expr_storage_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes);
127 AssignReturnSID(expr);
128 }
129
130 void VisitExpr_(const TupleGetItemNode* op) final {
131 Expr expr = GetRef<Expr>(op);
132 auto sids = GetStorage(op->tuple);
133 ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size());
134 expr_storage_map_[expr] =
135 StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]},
136 {sids->storage_sizes_in_bytes[op->index]});
137 AssignReturnSID(expr);
138 }
139
140 void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "'If' is not supported."; }
141
142 private:
143 /*!
144 * \brief Assign the expression's storage IDs as the return storage IDs.
145 * \note This is called when visiting every expression on the understanding
146 * that the returned expression will be visited last.
147 */
148 void AssignReturnSID(const Expr& e) {
149 if (expr_storage_map_.find(e) != expr_storage_map_.end()) {
150 StorageInfo& sinfo = expr_storage_map_[e];
151 return_sids_.clear();
152 for (auto sid : sinfo->storage_ids) {
153 return_sids_.push_back(sid);
154 }
155 }
156 }
157
158 /*!
159 * \brief Get the necessary storage for the expression.
160 * \param expr The expression.
161 * \return The corresponding token.
162 */
163 StorageInfo GetStorage(const Expr& expr) {
164 // See through "on_device" calls.
165 Expr true_expr = IgnoreOnDevice(expr);
166 VisitExpr(true_expr);
167 auto it = expr_storage_map_.find(true_expr);
168 ICHECK(it != expr_storage_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " "
169 << PrettyPrint(true_expr) << " in storage device map";
170 return it->second;
171 }
172
173 /*!
174 * \brief Create storage for the expression.
175 */
176 void CreateStorage(const ExprNode* op) {
177 Expr expr = GetRef<Expr>(op);
178 return CreateStorage(expr, GetVirtualDevice(expr));
179 }
180
181 /*!
182 * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device.
183 */
184 void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) {
185 ICHECK(!virtual_device->IsFullyUnconstrained())
186 << "invalid virtual device for expr:" << std::endl
187 << PrettyPrint(expr);
188 std::vector<int64_t> storage_ids;
189 std::vector<VirtualDevice> virtual_devices;
190 std::vector<int64_t> storage_sizes_in_bytes;
191 for (const auto& ttype : FlattenTupleType(expr->checked_type())) {
192 storage_ids.push_back(next_available_sid_++);
193 virtual_devices.push_back(virtual_device);
194 storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype->shape, ttype->dtype));
195 }
196 expr_storage_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices),
197 std::move(storage_sizes_in_bytes));
198 }
199
200 /*! \brief Map between Exprs and StorageInfos */
201 StorageMap expr_storage_map_;
202 /*! \brief The next available storage ID to be used */
203 int next_available_sid_{0};
204 /*! \brief The storage IDs that correspond to return values */
205 std::vector<int> return_sids_;
206};
207
208std::tuple<StorageMap, std::vector<int>> CreateStorage(const Function& func) {
209 ExprAllocator expr_allocator;
210 expr_allocator.Run(func);
211 return std::make_tuple(expr_allocator.GetStorageMap(), expr_allocator.GetReturnSIDs());
212}
213
214class AOTMainLowerer : public MixedModeVisitor {
215 public:
216 AOTMainLowerer(tvm::CompilationConfig config, CallType call_type)
217 : config_(config), call_type_(call_type) {}
218
219 IRModule Lower(IRModule mod, String mod_name) {
220 VLOG_CONTEXT << "AOT";
221 IRModule lowered_mod = GetRef<IRModule>(mod.CopyOnWrite());
222
223 auto lowered_main = lowered_mod->Lookup("main");
224 auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
225
226 // Assign StorageInfo to all the Relay exprs and get the return SIDs
227 std::tie(expr_storage_map_, return_sid_) = CreateStorage(lowered_main_func);
228
229 for (auto input : lowered_main_func->params) {
230 input_vars_.push_back(input);
231 std::string input_name = tvm::runtime::SanitizeName(input->name_hint());
232 // We don't want the compiler changing input names in the
233 // event of a sanitization collision. Therefore, enforcing
234 // the var created to use the input_name strictly.
235 CreateIOVar(input, input_name, /*use_unique_name = */ false);
236 }
237
238 // Define the storage allocator ids
239 for (auto kv : expr_storage_map_) {
240 for (auto sid : kv.second->storage_ids) {
241 // The buffer_var is created with storage_scope to be global.workspace to be serviced by
242 // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor
243 // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and
244 // should not be lowered to the stack. For more details please refer to the discussion here:
245 // https://github.com/apache/tvm/issues/9022
246 tir::Var buffer_var(MakeString("sid_", sid),
247 PointerType(PrimType(DataType::Int(8)), "global.workspace"));
248 sids_table_[sid] = buffer_var;
249 }
250 }
251
252 // Create output vars for the TIR main func
253 // If output tensor names were provided use them
254 if (auto opt = lowered_main->GetAttr<Array<String>>("output_tensor_names")) {
255 Array<String> output_tensor_names = opt.value();
256 Expr output_expr = lowered_main_func->body;
257 if (output_expr->checked_type()->IsInstance<TupleTypeNode>()) {
258 TupleType output_tuple_type = Downcast<TupleType>(output_expr->checked_type());
259 for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) {
260 // AoT Executor Codegen does not create these names,
261 // thus should be used as they are provided.
262 CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i],
263 /*use_unique_name = */ false);
264 }
265 } else {
266 // AoT Executor Codegen does not create these names,
267 // thus should be used as they are provided.
268 CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false);
269 }
270 } else {
271 // If output tensor names are not provided we will generate output(x)
272 // where x is a counter to create unique names.
273 if (lowered_main_func->body->checked_type()->IsInstance<TupleTypeNode>()) {
274 CreateIOVar(lowered_main_func->body, "output");
275 } else {
276 CreateIOVar(lowered_main_func->body, "output", /*use_unique_name = */ false);
277 }
278 }
279
280 CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts")
281 .value_or(Map<GlobalVar, String>()));
282 VisitExpr(lowered_main_func->body);
283
284 // Remove the Relay main and replace it with the lowered TIR version
285 lowered_mod->Remove(lowered_mod->GetGlobalVar("main"));
286 auto tir_main_func = CreateMainFunc(mod_name);
287 lowered_mod->Update(GlobalVar(runtime::symbol::tvm_module_main), tir_main_func);
288 lowered_mod = tir::transform::RemoveNoOp()(lowered_mod);
289 return lowered_mod;
290 }
291
292 void VisitExpr_(const CallNode* call_node) override {
293 OnDeviceProps on_device_props = GetOnDeviceProps(call_node);
294 if (on_device_props.body.defined()) {
295 VisitExpr(on_device_props.body);
296 return;
297 }
298
299 DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
300 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
301
302 if (device_copy_props.body.defined()) {
303 // TODO(mbs): device_copy cleaunp
304 // Suspect treating as no-op is better since already built into the StorageInfo?
305 LOG(FATAL) << "The AOT executor does not currently support device_copy";
306 }
307
308 // At this point we should only see calls of the form call_lowered(@callee, (args...)),
309 // where @callee can be a PrimFunc we've compiled or an external function supplied via
310 // some other mechanism.
311 ICHECK(call_lowered_props.lowered_func.defined())
312 << "AOT does not support calling Relay functions. Attempting to call:" << std::endl
313 << PrettyPrint(GetRef<Call>(call_node));
314 for (const auto& arg : call_lowered_props.arguments) {
315 // Evaluate the args
316 VisitExpr(arg);
317 }
318 CreateFuncCall(call_lowered_props, GetRef<Call>(call_node));
319 }
320
321 void VisitExpr_(const VarNode* op) override {
322 Expr expr = GetRef<Expr>(op);
323 StorageInfo& sinfo = expr_storage_map_[expr];
324
325 // Let bound vars refer to a value, so these should not be considered "output" vars.
326 if (let_bound_vars_.find(GetRef<Var>(op)) != let_bound_vars_.end()) {
327 return;
328 }
329
330 // If the Var node is an output node we need to copy the content of the variable to the output
331 // It's safe to check the SID here because Var StorageToken are never reallocated
332 auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
333 if (output_iter != return_sid_.end()) {
334 int output_index = std::distance(return_sid_.begin(), output_iter);
335 auto var_expr = FindExpr(expr);
336 CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0],
337 /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]);
338 }
339 }
340
341 void VisitExpr_(const ConstantNode* op) override {
342 Expr expr = GetRef<Expr>(op);
343 ICHECK(expr_storage_map_.find(expr) != expr_storage_map_.end())
344 << "Storage map did not contain constant expr " << PrettyPrint(expr);
345 StorageInfo& sinfo = expr_storage_map_[expr];
346 std::stringstream ss;
347 ss << "constant_" << constant_map_.size();
348
349 tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype))));
350 constant_map_[constant] = op;
351 auto sid = sinfo->storage_ids[0];
352 sids_table_[sid] = constant;
353
354 // If the Constant node is an output node we need to copy the content of the parameter to the
355 // output. A node can only produce a single output
356 auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid);
357 if (output_iter != return_sid_.end()) {
358 int output_index = std::distance(return_sid_.begin(), output_iter);
359 auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
360 {tir::StringImm(ss.str())});
361 CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), constant,
362 /* pack_input */ false, sinfo->storage_sizes_in_bytes[0]);
363 }
364 }
365
366 void VisitExpr_(const TupleNode* op) override {
367 for (auto field : op->fields) {
368 VisitExpr(field);
369 }
370 }
371
372 void VisitExpr_(const LetNode* op) override {
373 auto pre_visit = [this](const LetNode* op) {
374 let_bound_vars_.insert(op->var);
375 this->VisitExpr(op->value);
376 };
377 auto post_visit = [this](const LetNode* op) {
378 this->VisitExpr(op->body);
379 this->visit_counter_[op] += 1;
380 };
381 ExpandANormalForm(op, pre_visit, post_visit);
382 }
383
384 void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); }
385 void VisitExpr_(const OpNode* op) override {
386 if (GetRef<Op>(op) != CallLoweredOp() && GetRef<Op>(op) != OnDeviceOp()) {
387 LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded";
388 }
389 }
390 void VisitExpr_(const IfNode* op) override {
391 LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called";
392 }
393 void VisitExpr_(const FunctionNode* op) override {
394 ICHECK(op->GetAttr<String>(attr::kCompiler).defined())
395 << "FunctionNode only supported by custom codegen";
396 }
397 void VisitExpr_(const RefCreateNode* op) override {
398 LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)";
399 }
400 void VisitExpr_(const RefReadNode* op) override {
401 LOG(FATAL) << "AOT executor does not support references (found RefReadNode)";
402 }
403 void VisitExpr_(const RefWriteNode* op) override {
404 LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)";
405 }
406 void VisitExpr_(const ConstructorNode* op) override {
407 LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)";
408 }
409 void VisitExpr_(const MatchNode* op) override {
410 LOG(FATAL) << "AOT executor does not support matching (found MatchNode)";
411 }
412
413 private:
414 /*!
415 * \brief Create the main PrimFunc to execute the graph.
416 * \note The packed function calls don't pack their arguments. The AOT
417 * runner function needs to be legalized by the LegalizePackedCalls pass.
418 */
419 tir::PrimFunc CreateMainFunc(String mod_name) {
420 tir::Stmt body = tir::SeqStmt(stmts_);
421 // Allocate the sids
422 std::unordered_map<int, bool> allocated;
423 std::vector<std::pair<int64_t, int64_t>> sids_to_allocate;
424
425 for (auto kv : expr_storage_map_) {
426 // Only allocate sids that are needed
427 const bool is_input =
428 (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end());
429 if (is_input) {
430 continue;
431 }
432
433 for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) {
434 sids_to_allocate.push_back(
435 std::make_pair(kv.second->storage_ids[i], kv.second->storage_sizes_in_bytes[i]));
436 }
437 }
438
439 // Sort the SID allocation to make output deterministic
440 std::sort(sids_to_allocate.begin(), sids_to_allocate.end());
441
442 for (auto p : sids_to_allocate) {
443 int sid = p.first;
444 int size = p.second;
445
446 if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) {
447 continue;
448 }
449
450 // Make sure it hasn't already been allocated, this can happen
451 // with let-bound var/value pairs.
452 if (allocated.find(sid) != allocated.end()) {
453 continue;
454 }
455
456 allocated[sid] = constant_map_.count(sids_table_[sid]);
457
458 // TODO(giuseros): we should allocate this once outside the PrimFunc
459 // so we don't pay the price of allocation for every inference
460 if (!allocated[sid]) {
461 PointerType ptype = Downcast<PointerType>(sids_table_[sid]->type_annotation);
462 DataType element_type = Downcast<PrimType>(ptype->element_type)->dtype;
463 body = tir::Allocate(sids_table_[sid], element_type, {size}, tir::const_true(), body);
464 }
465 allocated[sid] = true;
466 }
467
468 for (auto kv : constant_map_) {
469 auto buffer_var = kv.first;
470 auto dtype = DataType(kv.second->data->dtype);
471
472 int ndim = kv.second->data->ndim;
473 Array<PrimExpr> extents;
474
475 for (int i = 0; i < ndim; i++) {
476 int shape = kv.second->data->shape[i];
477 extents.push_back(tir::make_const(DataType::Int(32), shape, Span()));
478 }
479 body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body);
480 }
481
482 // Define the PrimFunc attributes
483 Map<String, ObjectRef> dict_attrs;
484 String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main);
485 dict_attrs.Set("global_symbol", run_func_name);
486 dict_attrs.Set("runner_function", Bool(true));
487 dict_attrs.Set(tvm::attr::kTarget, config_->host_target);
488 Array<tir::Var> input_vars =
489 Array<tir::Var>(main_signature_.begin(), main_signature_.begin() + input_vars_.size());
490 dict_attrs.Set("input_vars", input_vars);
491 Array<tir::Var> output_vars =
492 Array<tir::Var>(main_signature_.begin() + input_vars_.size(),
493 main_signature_.begin() + input_vars_.size() + return_sid_.size());
494 dict_attrs.Set("output_vars", output_vars);
495 Array<String> device_names;
496 for (const auto& it : devices_) {
497 device_names.push_back(it.first);
498 }
499 dict_attrs.Set("devices", device_names);
500
501 tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
502 tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
503 tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});
504
505 // Make the PrimFunc
506 return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_,
507 DictAttrs(dict_attrs));
508 }
509
510 /*!
511 * \brief Collects device context variables for passing to operators
512 */
513 void CollectDeviceVariables(const Map<GlobalVar, String>& device_contexts) {
514 Map<TargetKind, tir::Var> target_contexts;
515 TargetKindAttrMap<Bool> target_attr_map = tvm::TargetKind::GetAttrMap<Bool>("use_device_api");
516
517 for (const auto& it : device_contexts) {
518 const GlobalVar& global_var = it.first;
519 const std::string device_context_name = it.second;
520
521 Optional<TargetKind> target_kind = tvm::TargetKind::Get(device_context_name);
522 if (!target_kind || !target_attr_map.count(target_kind.value())) {
523 return;
524 }
525 if (target_attr_map[target_kind.value()]) {
526 std::string context_name = tvm::runtime::SanitizeName(device_context_name);
527 tir::Var device_context_var("device_context_" + context_name, DataType::Handle());
528
529 auto pair = target_contexts.find(target_kind.value());
530 if (pair != target_contexts.end()) {
531 device_context_var = (*pair).second;
532 } else {
533 main_signature_.push_back(device_context_var);
534 devices_.Set(context_name, device_context_var);
535 target_contexts.Set(target_kind.value(), device_context_var);
536 }
537
538 device_contexts_.Set(global_var, device_context_var);
539 }
540 }
541 }
542
543 /*!
544 * \brief Return a vector of variables that represents the sids for the given Relay Expr
545 */
546 std::vector<tir::Var> PackSid(Expr expr) {
547 std::vector<tir::Var> buffer_vars;
548
549 ICHECK(expr_storage_map_.find(expr) != expr_storage_map_.end())
550 << "Storage map did not contain constant expr " << PrettyPrint(expr);
551 StorageInfo& sinfo = expr_storage_map_[expr];
552
553 // Note that an expression can have multiple sids associated with it
554 // e.g., returning multiple values from a function
555 for (auto sid : sinfo->storage_ids) {
556 // Determine if an sid is an output buffer
557 auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid);
558 if (output_iter != return_sid_.end()) {
559 int output_index = std::distance(return_sid_.begin(), output_iter);
560 buffer_vars.push_back(GetBufferVarForIO(input_vars_.size() + output_index));
561 continue;
562 }
563
564 auto sid_value = sids_table_[sid];
565 buffer_vars.push_back(sid_value);
566 }
567 return buffer_vars;
568 }
569
570 /*!
571 * \brief Given an expression return the variable(s) associated with that expression
572 */
573 std::vector<tir::Var> FindExpr(Expr arg) {
574 auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg);
575 if (input_iter != input_vars_.end()) {
576 // Input variable
577 int main_index = std::distance(input_vars_.begin(), input_iter);
578 return {GetBufferVarForIO(main_index)};
579 } else {
580 // Storage identifier (i.e., intermediate memory)
581 return PackSid(arg);
582 }
583 }
584
585 void PushArgs(const Expr& expr, const std::vector<tir::Var>& sids, Array<PrimExpr>* args) {
586 const TupleNode* t = expr.as<TupleNode>();
587 if (t != nullptr) {
588 CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't "
589 "handle this type of Relay Expr in a CallNode.";
590 }
591
592 args->insert(args->end(), sids.begin(), sids.end());
593 }
594
595 /*!
596 * \brief Wraps a call_extern with a tvm_check_return annotation if required otherwise
597 * returns the passed Call
598 */
599 tir::Call AddCheckReturn(tir::Call existing_call) {
600 Array<PrimExpr> args = {tir::make_const(DataType::Int(32, 1), 0, Span()),
601 tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call};
602 return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args);
603 }
604
605 /*!
606 * \brief Create a function call
607 * \param call_lowered_props The lowered function and the arguments to call it with
608 * \param result_expr The call we got func and args from (so as to recover the storage
609 * ids to hold the result).
610 */
611 void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) {
612 std::string func_name = call_lowered_props.lowered_func->name_hint;
613 tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
614 std::vector<tir::Stmt> create_func_call_stmts;
615
616 // Pack the inputs
617 for (const Expr& arg : call_lowered_props.arguments) {
618 auto sids = FindExpr(arg);
619 PushArgs(arg, sids, &args);
620 }
621
622 // Pack the return(s) value. A call node can produce multiple outputs
623 auto result_expr_sid = PackSid(result_expr);
624 PushArgs(result_expr, result_expr_sid, &args);
625
626 GlobalVar global_var = call_lowered_props.lowered_func;
627 bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
628 tir::Var device_context;
629 tir::Stmt func_call;
630
631 switch (call_type_) {
632 case CallType::kUnpacked: {
633 // call_extern calling convention with optional context
634 if (has_c_device_api_context) {
635 device_context = device_contexts_.Get(global_var).value();
636 args.push_back(device_context);
637 }
638 func_call = tir::Evaluate(AddCheckReturn(
639 tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args)));
640 break;
641 }
642 case CallType::kCPacked: {
643 if (has_c_device_api_context) {
644 device_context = device_contexts_.Get(global_var).value();
645 args.push_back(device_context);
646 } else {
647 // NOTE: LowerTVMBuiltin expects some device_context placeholder.
648 args.push_back(tir::make_zero(DataType::Handle()));
649 }
650 func_call = tir::Evaluate(
651 tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args));
652 create_func_call_stmts.push_back(func_call);
653 break;
654 }
655 case CallType::kPacked: {
656 // call_packed does not accept a device context.
657 CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context";
658 func_call = tir::Evaluate(AddCheckReturn(
659 tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args)));
660 create_func_call_stmts.push_back(func_call);
661 break;
662 }
663 default:
664 ICHECK(false) << "Unknown CallType: " << call_type_;
665 }
666
667 ICHECK(func_call.defined()) << "Must define func_call";
668
669 if (has_c_device_api_context) {
670 func_call = tir::SeqStmt(Array<tir::Stmt>({
671 GenerateDeviceHook(device_context, "Open"),
672 func_call,
673 GenerateDeviceHook(device_context, "Close"),
674 }));
675 }
676
677 tir::Stmt body = tir::SeqStmt({func_call});
678 stmts_.push_back(body);
679 }
680
681 /*!
682 * \brief Copy a variable to the output. This function is mainly used in edge cases
683 * when we want to return an input or a parameter.
684 * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a
685 * copy-on-write fashion.
686 */
687 void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
688 // Define intermediate DLTensor to load/store the data
689 tir::Buffer tmp_read =
690 tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read");
691 tir::Buffer tmp_write =
692 tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write");
693 te::Var loop_idx("i", DataType::Int(32));
694 auto retval_i = tir::BufferLoad(tmp_read, {loop_idx});
695 // Copy the variable from the input to the output
696 tir::Stmt copy = tir::For(
697 loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial,
698 tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx}));
699 stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy));
700 }
701
702 /*!
703 * \brief Generates a call to a given hook for all Devices found for C Device API
704 * \param hook Name of hook to generate statements for
705 * \return Statement with function calls for each device
706 */
707 tir::Stmt GenerateAllDeviceHook(const String& hook) {
708 std::vector<tir::Stmt> device_hooks;
709 for (const auto& it : devices_) {
710 const String& device_name = it.first;
711 const tir::Var& context = it.second;
712 Array<String> sections = {"Device", device_name, hook};
713 String device_hook_name = ToCFunctionStyle(PrefixName(sections));
714
715 tir::Evaluate device_hook(
716 AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
717 {tvm::tir::StringImm(device_hook_name), context})));
718 device_hooks.push_back(device_hook);
719 }
720 return tir::SeqStmt(device_hooks);
721 }
722
723 /*!
724 * \brief Generates a call to a given hook for a single Device function
725 * \param context Device context to call hook on
726 * \param hook Name of hook to generate statements for
727 * \return Statement with function call to Device API
728 */
729 tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) {
730 const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) {
731 return it.second->name_hint == context->name_hint;
732 });
733 const String& device_name = (*it).first;
734 Array<String> sections = {"Device", device_name, hook};
735 String device_hook = ToCFunctionStyle(PrefixName(sections));
736
737 return tir::Evaluate(
738 AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
739 {tvm::tir::StringImm(device_hook), context})));
740 }
741
742 /*!
743 * \brief Utility function to string together different arguments
744 */
745 template <typename... Args>
746 std::string MakeString(Args const&... args) {
747 std::ostringstream ss;
748 using List = int[];
749 (void)List{0, ((void)(ss << args), 0)...};
750
751 return ss.str();
752 }
753
754 /*!
755 * \brief Access IO vars using the buffer vars and
756 * not the actual var.
757 */
758 tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; }
759
760 /*!
761 * \brief Create tir::Var for input/output while updating the buffer_maps.
762 * \param expr The expression to evaluate.
763 * \param original_name The name of the tir::Var.
764 * \param use_unique_name Whether to generate a new unique name where a name conflicts.
765 */
766 void CreateIOVar(const Expr& expr, const std::string& original_name,
767 bool use_unique_name = true) {
768 CreateIOVar(expr->checked_type(), original_name, use_unique_name);
769 }
770
771 /*!
772 * \brief Create tir::Var for input/output while updating the buffer_maps.
773 * \param expr The expression to evaluate.
774 * \param original_name The name of the tir::Var.
775 * \param use_unique_name Whether to generate a new unique name where a name conflicts.
776 */
777 void CreateIOVar(const Type& type, const std::string& original_name,
778 bool use_unique_name = true) {
779 if (type->IsInstance<TupleTypeNode>()) {
780 TupleType tuple_type = Downcast<TupleType>(type);
781 for (unsigned i = 0; i < tuple_type->fields.size(); i++) {
782 CreateIOVar(tuple_type->fields[i], original_name);
783 }
784 } else {
785 std::string name = original_name;
786 if (use_unique_name) {
787 name = GetUniqueIOVarName(original_name);
788 }
789 tir::Var var = tir::Var(name, DataType::Handle());
790 main_signature_.push_back(var);
791 auto tensor_type = type.as<TensorTypeNode>();
792 ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey();
793 DataType elem_type = tensor_type->dtype;
794 tir::Var buffer_var =
795 tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global"));
796 tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0,
797 name + "_buffer", 16, 1, tir::BufferType::kDefault);
798 main_buffer_map_.Set(var, buffer);
799 }
800 }
801
802 /*!
803 * \brief Create a unique name for I/O Var
804 */
805 std::string GetUniqueIOVarName(std::string name) {
806 if (io_var_names_.find(name) == io_var_names_.end()) {
807 io_var_names_[name] = 1;
808 return name + std::to_string(io_var_names_[name] - 1);
809 } else {
810 io_var_names_[name] = io_var_names_[name] + 1;
811 return name + std::to_string(io_var_names_[name] - 1);
812 }
813 }
814
815 /*! \brief list of input expressions (i.e., variable passed by the user) */
816 std::vector<Var> input_vars_;
817 /*! \brief map of device contexts variables */
818 Map<String, tir::Var> devices_;
819 /*! \brief map of GlobalVars to C Device API contexts */
820 Map<GlobalVar, tir::Var> device_contexts_;
821 /*! \brief input and output variables belonging to the main function signature */
822 Array<tir::Var> main_signature_;
823 /*! \brief input and output variables belonging to the main function signature */
824 Map<tir::Var, tir::Buffer> main_buffer_map_;
825 /*! \brief All available targets. */
826 CompilationConfig config_;
827 /*!
828 * \brief The type of kernel call to be emitted.
829 * See CallType for more documentation.
830 */
831 CallType call_type_;
832 std::unordered_map<const tir::Var, const ConstantNode*, ObjectPtrHash, ObjectPtrEqual>
833 constant_map_;
834 /*! \brief plan memory of device result */
835 StorageMap expr_storage_map_;
836 /*! \brief mapping sid -> tir::Var */
837 std::unordered_map<int, tir::Var> sids_table_;
838 /*! \brief the set of statements that make the program */
839 std::vector<tir::Stmt> stmts_;
840 /*! \brief the list of return sids (note that the function might return more then one output */
841 std::vector<int> return_sid_;
842 /*! \brief This is per IO var name counter to aid the generating unique names */
843 std::unordered_map<std::string, int> io_var_names_;
844 /*! \brief A set of variables that are let bound. */
845 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_vars_;
846};
847
848Pass AOTLowerMain(String mod_name, tvm::CompilationConfig config, CallType call_type) {
849 runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func =
850 [=](IRModule module, transform::PassContext ctx) {
851 return AOTMainLowerer(config, call_type).Lower(module, mod_name);
852 };
853
854 return tvm::transform::CreateModulePass(pass_func, 0, "AOTLowerMain", {"InferType"});
855}
856
857TVM_REGISTER_GLOBAL("relay.backend.aot.AOTLowerMain")
858 .set_body_typed([](const String& mod_name, const tvm::CompilationConfig& config,
859 int call_type) {
860 return AOTLowerMain(mod_name, config, static_cast<CallType>(call_type));
861 });
862
863} // namespace aot
864} // namespace backend
865} // namespace relay
866} // namespace tvm
867