1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/aot_executor_codegen.cc
22 * \brief AOT executor codegen
23 */
24
25#include <tvm/ir/module.h>
26#include <tvm/relay/attrs/annotation.h>
27#include <tvm/relay/attrs/call.h>
28#include <tvm/relay/executor.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/runtime.h>
31#include <tvm/runtime/device_api.h>
32#include <tvm/runtime/name_transforms.h>
33#include <tvm/runtime/object.h>
34#include <tvm/tir/analysis.h>
35#include <tvm/tir/builtin.h>
36#include <tvm/tir/expr.h>
37#include <tvm/tir/function.h>
38#include <tvm/tir/stmt.h>
39#include <tvm/tir/transform.h>
40#include <tvm/tir/usmp/utils.h>
41
42#include <algorithm>
43#include <list>
44#include <string>
45#include <vector>
46
47#include "../../target/source/codegen_source_base.h"
48#include "../op/annotation/annotation.h"
49#include "../op/call/call.h"
50#include "../op/memory/device_copy.h"
51#include "../transforms/device_aware_visitors.h"
52#include "./name_transforms.h"
53#include "./te_compiler.h"
54#include "./utils.h"
55
56namespace tvm {
57namespace relay {
58namespace backend {
59
60using StorageMap =
61 std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;
62
63/**
64 * This is an on demand allocator for AOT. A new temporary
65 * (storage allocator identifier) is allocated for each operation.
66 */
67class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor {
68 public:
69 AOTOnDemandAllocator() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
70
71 // run the visitor on a global function.
72 void Run(const Function& func) { VisitExpr(func); }
73
74 std::vector<int> GetReturnIds() const { return return_ids_; }
75 std::vector<TensorType> GetReturnTtypes() const { return return_ttypes_; }
76
77 StorageMap GetStorageMap() const { return storage_device_map_; }
78
79 using ExprVisitor::VisitExpr_;
80
81 void VisitExpr_(const ConstantNode* op) final {
82 CreateStorage(op);
83 AssignReturnSid(GetRef<Expr>(op));
84 }
85
86 void DeviceAwareVisitExpr_(const CallNode* call_node) final {
87 // AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case
88 // where the op of the call is a generic function
89
90 Expr func;
91 Array<Expr> args;
92
93 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
94 if (call_lowered_props.lowered_func.defined()) {
95 func = call_lowered_props.lowered_func;
96 args = call_lowered_props.arguments;
97 } else { // Relay functions that have not been lowered and lowered extern functions
98 func = call_node->op;
99 args = call_node->args;
100 if (call_node->op.as<GlobalVarNode>()) { // Lowered extern function
101 ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes.";
102 } else { // Relay function which has not been lowered yet
103 ICHECK(call_node->op.as<FunctionNode>())
104 << "Expected the call to be to a lowered primfunc, a lowered extern function or a "
105 "unlowered Relay function.";
106 }
107 }
108 VisitExpr(func);
109 CreateStorage(call_node);
110 for (const Expr& arg : args) {
111 VisitExpr(arg);
112 }
113 AssignReturnSid(GetRef<Expr>(call_node));
114 }
115
116 void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef<Expr>(op)); }
117
118 void DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
119 if (function_nesting() > 1) {
120 // do not recurse into sub functions.
121 return;
122 }
123 if (func_node->HasNonzeroAttr(attr::kPrimitive)) {
124 // No storage needed for primitive functions.
125 return;
126 }
127 for (const auto& param : func_node->params) {
128 CreateStorage(param.get());
129 }
130 VisitExpr(func_node->body);
131 }
132
133 void VisitExpr_(const GlobalVarNode* op) final {
134 // Do nothing.
135 }
136
137 void VisitExpr_(const OpNode* op) final {
138 // Do nothing.
139 }
140
141 void VisitExpr_(const TupleNode* op) final {
142 std::vector<int64_t> storage_ids;
143 std::vector<VirtualDevice> virtual_devices;
144 std::vector<int64_t> storage_sizes_in_bytes;
145 Expr expr = GetRef<Expr>(op);
146 for (Expr field : op->fields) {
147 auto sid = GetStorage(field);
148 storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end());
149 virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(),
150 sid->virtual_devices.end());
151 storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(),
152 sid->storage_sizes_in_bytes.begin(),
153 sid->storage_sizes_in_bytes.end());
154 }
155 storage_device_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes);
156 AssignReturnSid(expr);
157 }
158
159 void VisitExpr_(const TupleGetItemNode* op) final {
160 Expr expr = GetRef<Expr>(op);
161 auto sids = GetStorage(op->tuple);
162 ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size());
163 storage_device_map_[expr] =
164 StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]},
165 {sids->storage_sizes_in_bytes[op->index]});
166 AssignReturnSid(expr);
167 }
168
169 void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }
170
171 void PreVisitLetBinding_(const Var& var, const Expr& value) final {
172 VisitExpr(value);
173 StorageInfo si = GetStorage(value);
174 storage_device_map_[var] = si;
175 }
176
177 private:
178 void AssignReturnSid(Expr e) {
179 if (storage_device_map_.find(e) != storage_device_map_.end()) {
180 StorageInfo& sinfo = storage_device_map_[e];
181 return_ids_.clear();
182 for (auto sid : sinfo->storage_ids) {
183 return_ids_.push_back(sid);
184 }
185 return_ttypes_.clear();
186 return_ttypes_ = FlattenTupleType(e->checked_type());
187 }
188 }
189 /*!
190 * \brief ceil(size/word_size) to get number of words.
191 * \param size The original size.
192 * \param word_size The element size.
193 */
194 static size_t DivRoundUp(size_t size, size_t word_size) {
195 return (size + word_size - 1) / word_size;
196 }
197 /*!
198 * \brief Get the memory requirement.
199 * \param prototype The prototype token.
200 * \return The required memory size.
201 *
202 * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc, GetMemorySize is graph_plan_memory.cc
203 */
204 size_t GetMemorySizeBytes(const TensorType& ttype) {
205 size_t size = 1;
206 for (IndexExpr dim : ttype->shape) {
207 const int64_t* pval = tir::as_const_int(dim);
208 ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape;
209 ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval;
210 size *= static_cast<size_t>(pval[0]);
211 }
212 size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8);
213 return size;
214 }
215 /*!
216 * \brief Get the necessary storage for the expression.
217 * \param expr The expression.
218 * \return The corresponding token.
219 */
220 StorageInfo GetStorage(const Expr& expr) {
221 // See through "on_device" calls.
222 Expr true_expr = IgnoreOnDevice(expr);
223 VisitExpr(true_expr);
224 auto it = storage_device_map_.find(true_expr);
225 ICHECK(it != storage_device_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " "
226 << PrettyPrint(true_expr) << " in storage device map";
227 return it->second;
228 }
229
230 /*!
231 * \brief Create storage for the expression.
232 */
233 void CreateStorage(const ExprNode* op) {
234 Expr expr = GetRef<Expr>(op);
235 return CreateStorage(expr, GetVirtualDevice(expr));
236 }
237
238 /*!
239 * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device.
240 */
241 void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) {
242 ICHECK(!virtual_device->IsFullyUnconstrained())
243 << "invalid virtual device for expr:" << std::endl
244 << PrettyPrint(expr);
245 std::vector<int64_t> storage_ids;
246 std::vector<VirtualDevice> virtual_devices;
247 std::vector<int64_t> storage_sizes_in_bytes;
248 for (const auto& ttype : FlattenTupleType(expr->checked_type())) {
249 storage_ids.push_back(next_available_sid_++);
250 virtual_devices.push_back(virtual_device);
251 storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype));
252 }
253 storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices),
254 std::move(storage_sizes_in_bytes));
255 }
256
257 /*! \brief mapping of expression -> storageInfo */
258 StorageMap storage_device_map_;
259 /*! \brief current id of the temporary allocated */
260 int next_available_sid_{0};
261 /*! \brief the set of intermediate tensors that are return variables */
262 std::vector<int> return_ids_;
263 /*! \brief the data types of the return values */
264 std::vector<TensorType> return_ttypes_;
265};
266
267/*! \brief Code generator for AOT executor */
268class AOTExecutorCodegen : public MixedModeVisitor {
269 protected:
270 /*! \brief Describes the type of kernel call emitted. */
271 enum CallType {
272 /*!
273 * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions.
274 *
275 * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the
276 * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those
277 * functions are of type TVMBackendPackedCFunc.
278 *
279 * The following code is emitted at call sites to call a function named `func`:
280 * void* func_ptr = TVMBackendGetFuncFromEnv("func");
281 * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes)
282 *
283 * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args`
284 * by LowerTVMBuiltin TIR transform.
285 *
286 * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often,
287 * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when
288 * `func` is implemented in C).
289 *
290 * Compatible with both C++ and C runtimes, implemented with the C runtime only.
291 */
292 kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor.
293
294 /*!
295 * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call.
296 *
297 * When this type is selected, assumes all operators are implemented in functions of type
298 * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of
299 * downstream compilation that there is a symbol named after the 0th arg to tir::Call of
300 * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target.
301 *
302 * The following code is emitted at call sites to call a function named `func`:
303 * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle)
304 *
305 * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args`
306 * by LowerTVMBuiltin TIR transform.
307 *
308 * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is
309 * always the device context parameter when not null. At present, the implementation does not
310 * support forwarding device context parameters to CPacked.
311 *
312 * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented
313 * in the same scenarios.
314 */
315 kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor.
316
317 /*! \brief Directly call a function accepting the `data` arrays as args.
318 *
319 * When this type is selected, assumes all operaotrs are implemented in C functions whose
320 * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the
321 * `data` parameters (i.e. no DLTensor object is passed along).
322 *
323 * The following code is emitted at call sites to a function named `func`:
324 * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle
325 * -or-
326 * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle
327 *
328 * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is
329 * always the device context parameter when not null.
330 *
331 * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented
332 * with the C runtime only.
333 */
334 kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors.
335 };
336
337 /*!
338 * \brief Return a vector of variables that represents the sids for the given Relay Expr
339 */
340 std::vector<tir::Var> PackSid(Expr expr) {
341 std::vector<tir::Var> buffer_vars;
342
343 ICHECK(storage_device_map_.find(expr) != storage_device_map_.end())
344 << "Storage map did not contain constant expr " << PrettyPrint(expr);
345 StorageInfo& sinfo = storage_device_map_[expr];
346
347 // Note that an expression can have multiple sids associated with it
348 // e.g., returning multiple values from a function
349 for (auto sid : sinfo->storage_ids) {
350 // Determine if an sid is an output buffer
351 auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid);
352 if (output_iter != return_sid_.end()) {
353 int output_index = std::distance(return_sid_.begin(), output_iter);
354 buffer_vars.push_back(GetBufferVarForIO(input_vars_.size() + output_index));
355 continue;
356 }
357
358 auto sid_value = sids_table_[sid];
359 buffer_vars.push_back(sid_value);
360 }
361 return buffer_vars;
362 }
363
364 /*!
365 * brief Given an expression return the variable(s) associated with that expression
366 */
367 std::vector<te::Var> FindExpr(Expr arg) {
368 auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg);
369 if (input_iter != input_vars_.end()) {
370 // Input variable
371 int main_index = std::distance(input_vars_.begin(), input_iter);
372 return {GetBufferVarForIO(main_index)};
373 } else {
374 // Storage identifier (i.e., intermediate memory)
375 return PackSid(arg);
376 }
377 }
378
379 /*!
380 * \brief Reverse lookup the device name in devices_ map.
381 * \param device_context Value in devices_ to find.
382 * \return Key matching device_context in devices_.
383 */
384 std::string FindDeviceName(tir::Var device_context) {
385 for (std::pair<String, tir::Var> kv : devices_) {
386 if (kv.second->name_hint == device_context->name_hint) {
387 return kv.first;
388 }
389 }
390 ICHECK(false) << "Did not find a device name associated with " << device_context;
391 return "";
392 }
393
394 void PushArgs(const Expr& expr, const std::vector<tir::Var>& sids, Array<PrimExpr>* args) {
395 const TupleNode* t = expr.as<TupleNode>();
396 if (t != nullptr) {
397 CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't "
398 "handle this type of Relay Expr in a CallNode.";
399 }
400
401 args->insert(args->end(), sids.begin(), sids.end());
402 }
403
404 /*
405 * Wraps a call_extern with a tvm_check_return annotation if required otherwise
406 * returns the passed Call
407 */
408 tir::Call AddCheckReturn(tir::Call existing_call) {
409 Array<PrimExpr> args = {tir::make_const(DataType::Int(32, 1), 0, Span()),
410 tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call};
411 return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args);
412 }
413
414 /*!
415 * brief Create a function call
416 * \param call_lowered_props The lowered function and the arguments to call it with
417 * \param result_expr The call we got func and args from (so as to recover the storage
418 * ids to hold the result).
419 */
420 void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) {
421 std::string func_name = call_lowered_props.lowered_func->name_hint;
422 tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
423 std::vector<tir::Stmt> create_func_call_stmts;
424
425 // Pack the inputs
426 for (const Expr& arg : call_lowered_props.arguments) {
427 if (params_by_expr_.find(arg) != params_by_expr_.end()) {
428 auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
429 {tir::StringImm(params_by_expr_[arg])});
430 // NOTE: this cast looks like a no-op, but is required for compilation downstream.
431 // Because DataType::Handle has default bits=64, but CodeGenC does not observe this field,
432 // adding this cast forces the codegen to insert the cast. In this case, a cast is required
433 // because param_handle is actually code-generated as `const void*`, and the `const` piece
434 // needs to be removed.
435 args.push_back(tvm::tir::Cast(DataType::Handle(32, 1), param_handle));
436 } else {
437 auto sids = FindExpr(arg);
438 PushArgs(arg, sids, &args);
439 }
440 }
441
442 // Pack the return(s) value. A call node can produce multiple outputs
443 auto result_expr_sid = PackSid(result_expr);
444 PushArgs(result_expr, result_expr_sid, &args);
445
446 GlobalVar global_var = call_lowered_props.lowered_func;
447 bool has_c_device_api_context = device_contexts_.count(global_var) != 0;
448 tir::Var device_context;
449 tir::Stmt func_call;
450
451 switch (call_type_) {
452 case CallType::kUnpacked: {
453 // call_extern calling convention with optional context
454 if (has_c_device_api_context) {
455 device_context = device_contexts_.Get(global_var).value();
456 args.push_back(device_context);
457 }
458 func_call = tir::Evaluate(AddCheckReturn(
459 tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args)));
460 break;
461 }
462 case CallType::kCPacked: {
463 if (has_c_device_api_context) {
464 device_context = device_contexts_.Get(global_var).value();
465 args.push_back(device_context);
466 } else {
467 // NOTE: LowerTVMBuiltin expects some device_context placeholder.
468 args.push_back(tir::make_zero(DataType::Handle()));
469 }
470 func_call = tir::Evaluate(
471 tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args));
472 create_func_call_stmts.push_back(func_call);
473 break;
474 }
475 case CallType::kPacked: {
476 // call_packed does not accept a device context.
477 CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context";
478 func_call = tir::Evaluate(AddCheckReturn(
479 tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args)));
480 create_func_call_stmts.push_back(func_call);
481 break;
482 }
483 default:
484 ICHECK(false) << "Unknown CallType: " << call_type_;
485 }
486
487 ICHECK(func_call.defined()) << "Must define func_call";
488
489 if (has_c_device_api_context) {
490 func_call = tir::SeqStmt(Array<tir::Stmt>({
491 GenerateDeviceHook(device_context, "Open"),
492 func_call,
493 GenerateDeviceHook(device_context, "Close"),
494 }));
495 }
496
497 tir::Stmt body = tir::SeqStmt({func_call});
498 stmts_.push_back(body);
499 }
500
501 /*!
502 * \brief Copy a variable to the output. This function is mainly used in edge cases
503 * when we want to return an input or a parameter.
504 * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a
505 * copy-on-write fashion.
506 */
507 void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) {
508 // Define intermediate DLTensor to load/store the data
509 tir::Buffer tmp_read =
510 tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read");
511 tir::Buffer tmp_write =
512 tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write");
513 te::Var loop_idx("i", DataType::Int(32));
514 auto retval_i = tir::BufferLoad(tmp_read, {loop_idx});
515 // Copy the variable from the input to the output
516 tir::Stmt copy = tir::For(
517 loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial,
518 tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx}));
519 stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy));
520 }
521
522 /*
523 * \brief Collects device context variables for passing to operators
524 */
525 void CollectDeviceVariables(const Map<GlobalVar, String>& device_contexts) {
526 Map<TargetKind, tir::Var> target_contexts;
527 TargetKindAttrMap<Bool> target_attr_map = tvm::TargetKind::GetAttrMap<Bool>("use_device_api");
528
529 for (const auto& it : device_contexts) {
530 const GlobalVar& global_var = it.first;
531 const std::string device_context_name = it.second;
532
533 Optional<TargetKind> target_kind = tvm::TargetKind::Get(device_context_name);
534 if (!target_kind || !target_attr_map.count(target_kind.value())) {
535 return;
536 }
537 if (target_attr_map[target_kind.value()]) {
538 std::string context_name = tvm::runtime::SanitizeName(device_context_name);
539 tir::Var device_context_var("device_context_" + context_name, DataType::Handle());
540
541 auto pair = target_contexts.find(target_kind.value());
542 if (pair != target_contexts.end()) {
543 device_context_var = (*pair).second;
544 } else {
545 main_signature_.push_back(device_context_var);
546 devices_.Set(context_name, device_context_var);
547 target_contexts.Set(target_kind.value(), device_context_var);
548 }
549
550 device_contexts_.Set(global_var, device_context_var);
551 }
552 }
553 }
554
555 /**
556 * \brief Generates a call to a given hook for all Devices found for C Device API
557 * \param Name of hook to generate statements for
558 * \return Statement with function calls for each device
559 */
560 tir::Stmt GenerateAllDeviceHook(const String& hook) {
561 std::vector<tir::Stmt> device_hooks;
562 for (const auto& it : devices_) {
563 const String& device_name = it.first;
564 const tir::Var& context = it.second;
565 Array<String> sections = {"Device", device_name, hook};
566 String device_hook_name = ToCFunctionStyle(PrefixName(sections));
567
568 tir::Evaluate device_hook(
569 AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
570 {tvm::tir::StringImm(device_hook_name), context})));
571 device_hooks.push_back(device_hook);
572 }
573 return tir::SeqStmt(device_hooks);
574 }
575
576 /**
577 * \brief Generates a call to a given hook for a single Device function
578 * \param Var Device context to call hook on
579 * \param Name of hook to generate statements for
580 * \return Statement with function call to Device API
581 */
582 tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) {
583 const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) {
584 return it.second->name_hint == context->name_hint;
585 });
586 const String& device_name = (*it).first;
587 Array<String> sections = {"Device", device_name, hook};
588 String device_hook = ToCFunctionStyle(PrefixName(sections));
589
590 return tir::Evaluate(
591 AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(),
592 {tvm::tir::StringImm(device_hook), context})));
593 }
594
595 /*!
596 * Utility function to string together different arguments
597 */
598 template <typename... Args>
599 std::string MakeString(Args const&... args) {
600 std::ostringstream ss;
601 using List = int[];
602 (void)List{0, ((void)(ss << args), 0)...};
603
604 return ss.str();
605 }
606
607 void VisitExpr_(const CallNode* call_node) override {
608 OnDeviceProps on_device_props = GetOnDeviceProps(call_node);
609 if (on_device_props.body.defined()) {
610 VisitExpr(on_device_props.body);
611 return;
612 }
613
614 DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node);
615 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
616
617 if (device_copy_props.body.defined()) {
618 // TODO(mbs): device_copy cleaunp
619 // Suspect treating as no-op is better since already built into the StorageInfo?
620 LOG(FATAL) << "The AOT executor does not currently support device_copy";
621 }
622
623 // At this point we should only see calls of the form call_lowered(@callee, (args...)),
624 // where @callee can be a PrimFunc we've compiled or an external function supplied via
625 // some other mechanism.
626 ICHECK(call_lowered_props.lowered_func.defined())
627 << "AOT does not support calling Relay functions. Attempting to call:" << std::endl
628 << PrettyPrint(GetRef<Call>(call_node));
629 for (const auto& arg : call_lowered_props.arguments) {
630 // Evaluate the args
631 VisitExpr(arg);
632 }
633 CreateFuncCall(call_lowered_props, GetRef<Call>(call_node));
634 }
635
636 void VisitExpr_(const VarNode* op) override {
637 Expr expr = GetRef<Expr>(op);
638 StorageInfo& sinfo = storage_device_map_[expr];
639
640 // Let bound vars refer to a value, so these should not be considered "output" vars.
641 if (let_bound_vars_.find(GetRef<Var>(op)) != let_bound_vars_.end()) {
642 return;
643 }
644
645 // If the Var node is an output node we need to copy the content of the variable to the output
646 // It's safe to check the SID here because Var StorageToken are never reallocated
647 auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]);
648 if (output_iter != return_sid_.end()) {
649 int output_index = std::distance(return_sid_.begin(), output_iter);
650 if (params_by_expr_.find(expr) != params_by_expr_.end()) {
651 auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
652 {tir::StringImm(params_by_expr_[expr])});
653 CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle,
654 /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]);
655 } else {
656 auto var_expr = FindExpr(expr);
657 CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0],
658 /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]);
659 }
660 }
661 }
662
663 void VisitExpr_(const ConstantNode* op) override {
664 Expr expr = GetRef<Expr>(op);
665 ICHECK(storage_device_map_.find(expr) != storage_device_map_.end())
666 << "Storage map did not contain constant expr " << PrettyPrint(expr);
667 StorageInfo& sinfo = storage_device_map_[expr];
668 std::stringstream ss;
669 ss << "constant_" << constant_map_.size();
670
671 tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype))));
672 constant_map_[constant] = op;
673 auto sid = sinfo->storage_ids[0];
674 sids_table_[sid] = constant;
675
676 // If the Constant node is an output node we need to copy the content of the parameter to the
677 // output. A node can only produce a single output
678 auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid);
679 if (output_iter != return_sid_.end()) {
680 int output_index = std::distance(return_sid_.begin(), output_iter);
681 auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(),
682 {tir::StringImm(ss.str())});
683 CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), constant,
684 /* pack_input */ false, sinfo->storage_sizes_in_bytes[0]);
685 }
686 }
687
688 void VisitExpr_(const TupleNode* op) override {
689 for (auto field : op->fields) {
690 VisitExpr(field);
691 }
692 }
693
694 void VisitExpr_(const LetNode* op) override {
695 auto pre_visit = [this](const LetNode* op) {
696 let_bound_vars_.insert(op->var);
697 this->VisitExpr(op->value);
698 };
699 auto post_visit = [this](const LetNode* op) {
700 this->VisitExpr(op->body);
701 this->visit_counter_[op] += 1;
702 };
703 ExpandANormalForm(op, pre_visit, post_visit);
704 }
705
706 void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); }
707 void VisitExpr_(const OpNode* op) override {
708 if (GetRef<Op>(op) != CallLoweredOp() && GetRef<Op>(op) != OnDeviceOp()) {
709 LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded";
710 }
711 }
712 void VisitExpr_(const IfNode* op) override {
713 LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called";
714 }
715 void VisitExpr_(const FunctionNode* op) override {
716 ICHECK(op->GetAttr<String>(attr::kCompiler).defined())
717 << "FunctionNode only supported by custom codegen";
718 }
719 void VisitExpr_(const RefCreateNode* op) override {
720 LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)";
721 }
722 void VisitExpr_(const RefReadNode* op) override {
723 LOG(FATAL) << "AOT executor does not support references (found RefReadNode)";
724 }
725 void VisitExpr_(const RefWriteNode* op) override {
726 LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)";
727 }
728 void VisitExpr_(const ConstructorNode* op) override {
729 LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)";
730 }
731 void VisitExpr_(const MatchNode* op) override {
732 LOG(FATAL) << "AOT executor does not support matching (found MatchNode)";
733 }
734
735 // Create the main PrimFunc to execute the graph. Please note that
736 // the packed function calls don't pack their arguments. The AOT
737 // runner function needs to be legalized by the LegalizePackedCalls pass.
738 tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) {
739 tir::Stmt body = tir::SeqStmt(stmts_);
740 // Allocate the sids
741 std::unordered_map<int, bool> allocated;
742
743 for (auto kv : storage_device_map_) {
744 // Only allocate sids that are needed
745 const bool is_input =
746 (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end());
747 const bool is_param = (params_by_expr_.find(kv.first) != params_by_expr_.end());
748 if (is_input || is_param) {
749 continue;
750 }
751
752 for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) {
753 int size = kv.second->storage_sizes_in_bytes[i];
754 int sid = kv.second->storage_ids[i];
755
756 if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) {
757 continue;
758 }
759
760 // Make sure it hasn't already been allocated, this can happen
761 // with let-bound var/value pairs.
762 if (allocated.find(sid) != allocated.end()) {
763 continue;
764 }
765
766 allocated[sid] = constant_map_.count(sids_table_[sid]);
767
768 // TODO(giuseros): we should allocate this once outside the PrimFunc
769 // so we don't pay the price of allocation for every inference
770 if (!allocated[sid]) {
771 PointerType ptype = Downcast<PointerType>(sids_table_[sid]->type_annotation);
772 DataType element_type = Downcast<PrimType>(ptype->element_type)->dtype;
773 body = tir::Allocate(sids_table_[sid], element_type, {size}, tir::const_true(), body);
774 }
775 allocated[sid] = true;
776 }
777 }
778
779 for (auto kv : constant_map_) {
780 auto buffer_var = kv.first;
781 auto dtype = DataType(kv.second->data->dtype);
782
783 int ndim = kv.second->data->ndim;
784 Array<PrimExpr> extents;
785
786 for (int i = 0; i < ndim; i++) {
787 int shape = kv.second->data->shape[i];
788 extents.push_back(tir::make_const(DataType::Int(32), shape, Span()));
789 }
790 body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body);
791 }
792
793 // Define the PrimFunc attributes
794 Map<String, ObjectRef> dict_attrs;
795 String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main);
796 dict_attrs.Set("global_symbol", run_func_name);
797 dict_attrs.Set("runner_function", Bool(true));
798 dict_attrs.Set(tvm::attr::kTarget, config_->host_target);
799
800 tir::Stmt device_activations = GenerateAllDeviceHook("Activate");
801 tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate");
802 tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations});
803
804 // Make the PrimFunc
805 return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_,
806 DictAttrs(dict_attrs));
807 }
808
809 /*!
810 * \brief Access IO vars using the buffer vars and
811 * not the actual var.
812 */
813 tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; }
814
815 /*!
816 * \brief Create tir::Var for input/output while updating the buffer_maps.
817 *
818 * \param expr The expression to evaluate.
819 * \param original_name The name of the tir::Var.
820 * \param use_unique_name Whether to generate a new unique name where a name conflicts.
821 */
822 void CreateIOVar(const Expr& expr, const std::string& original_name,
823 bool use_unique_name = true) {
824 CreateIOVar(expr->checked_type(), original_name, use_unique_name);
825 }
826
827 /*!
828 * \brief Create tir::Var for input/output while updating the buffer_maps.
829 *
830 * \param expr The expression to evaluate.
831 * \param original_name The name of the tir::Var.
832 * \param use_unique_name Whether to generate a new unique name where a name conflicts.
833 */
834 void CreateIOVar(const Type& type, const std::string& original_name,
835 bool use_unique_name = true) {
836 if (type->IsInstance<TupleTypeNode>()) {
837 TupleType tuple_type = Downcast<TupleType>(type);
838 for (unsigned i = 0; i < tuple_type->fields.size(); i++) {
839 CreateIOVar(tuple_type->fields[i], original_name);
840 }
841 } else {
842 std::string name = original_name;
843 if (use_unique_name) {
844 name = GetUniqueIOVarName(original_name);
845 }
846 tir::Var var = tir::Var(name, DataType::Handle());
847 main_signature_.push_back(var);
848 auto tensor_type = type.as<TensorTypeNode>();
849 ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey();
850 DataType elem_type = tensor_type->dtype;
851 tir::Var buffer_var =
852 tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global"));
853 tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0,
854 name + "_buffer", 16, 1, tir::BufferType::kDefault);
855 main_buffer_map_.Set(var, buffer);
856 io_tensor_types_.Set(var, Downcast<TensorType>(type));
857 }
858 }
859
860 /*!
861 * \brief Create a unique name for I/O Var
862 */
863 std::string GetUniqueIOVarName(std::string name) {
864 if (io_var_names_.find(name) == io_var_names_.end()) {
865 io_var_names_[name] = 1;
866 return name;
867 } else {
868 io_var_names_[name] = io_var_names_[name] + 1;
869 return name + std::to_string(io_var_names_[name]);
870 }
871 }
872
873 /*!
874 * \brief Calculate workspace sizes for PrimFuncs in the IRModule
875 */
876 Map<String, FunctionInfo> CalculateWorkspaceSizes(
877 const IRModule& lowered_mod, const Map<String, FunctionInfo>& function_metadata) {
878 Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(lowered_mod);
879 Map<String, FunctionInfo> updated_function_metadata;
880 for (const auto& kv : lowered_mod->functions) {
881 GlobalVar global_var = kv.first;
882 BaseFunc base_func = kv.second;
883 if (base_func->IsInstance<tir::PrimFuncNode>()) {
884 tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func);
885 Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value();
886 const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
887 if (function_metadata.count(global_var->name_hint)) {
888 updated_function_metadata.Set(global_var->name_hint,
889 function_metadata[global_var->name_hint]);
890 updated_function_metadata[global_var->name_hint]->workspace_sizes.Set(tgt, ws);
891 } else {
892 FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}};
893 updated_function_metadata.Set(global_var->name_hint, finfo);
894 }
895 }
896 }
897 return updated_function_metadata;
898 }
899
900 /*!
901 * \brief Run USMP to plan memory for lowered IRModule.
902 */
903 IRModule PlanMemoryWithUSMP(const IRModule& mod) {
904 VLOG(1) << "Planning memory with USMP for module:" << std::endl << PrettyPrint(mod);
905 Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod);
906 IRModule lowered_mod = mod->ShallowCopy();
907 lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod);
908 function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_);
909 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
910 lowered_mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
911 backend::FunctionInfo main_func_info =
912 lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info").value();
913 main_func_info->workspace_sizes.clear();
914 if (allocated_pool_infos) {
915 for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) {
916 for (const auto& tgt : allocated_pool_info->pool_info->targets) {
917 VLOG(1) << "USMP requires target " << tgt->ToDebugString() << " to have pool size "
918 << allocated_pool_info->allocated_size->value;
919 size_t size = allocated_pool_info->allocated_size->value;
920 if (allocated_pool_info->pool_info->IsInstance<ConstantPoolInfoNode>()) {
921 size += main_func_info->constant_sizes.count(tgt)
922 ? main_func_info->constant_sizes[tgt]->value
923 : 0;
924 main_func_info->constant_sizes.Set(tgt, size);
925 } else if (allocated_pool_info->pool_info->IsInstance<WorkspacePoolInfoNode>()) {
926 size += main_func_info->workspace_sizes.count(tgt)
927 ? main_func_info->workspace_sizes[tgt]->value
928 : 0;
929 main_func_info->workspace_sizes.Set(tgt, size);
930 } else {
931 LOG(FATAL) << "Unknown pool type: " << allocated_pool_info->pool_info->GetTypeKey();
932 }
933 }
934 }
935 }
936 function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info);
937 return lowered_mod;
938 }
939
940 /*!
941 * \brief Run StorageRewrite to plan memory for lowered IRModule.
942 */
943 IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) {
944 Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod);
945 IRModule lowered_mod = mod->ShallowCopy();
946 function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_);
947 // Running StorageRewrite just on the main function
948 tir::PrimFunc tir_main_func =
949 Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
950 IRModule main_func_mod;
951 main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main),
952 tir_main_func);
953 main_func_mod = tir::transform::StorageRewrite()(main_func_mod);
954 lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main),
955 main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
956 tir_main_func =
957 Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
958 // Use the PrimFunc to calculate the workspace required to service the allocates
959 Integer main_workspace_size_bytes =
960 CalculateWorkspaceBytes(tir_main_func, workspace_byte_alignment);
961 backend::FunctionInfo main_func_info =
962 lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info").value();
963 main_func_info->workspace_sizes.Set(config_->host_target, main_workspace_size_bytes);
964 function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info);
965 return lowered_mod;
966 }
967
968 /*!
969 * \brief Gets module workspace alignment from supplied executor or defaults to 16
970 */
971 Integer GetModuleWorkspaceByteAlignment(const IRModule& mod) {
972 Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
973 return executor_config->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
974 }
975
976 /*!
977 * \brief Gets module constant alignment from supplied executor or defaults to 16
978 */
979 Integer GetModuleConstantByteAlignment(const IRModule& mod) {
980 Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
981 return executor_config->GetAttr<Integer>("constant-byte-alignment").value_or(16);
982 }
983
984 protected:
985 /*! \brief mod */
986 runtime::Module* mod_;
987 /*! \brief list of input expressions (i.e., variable passed by the user) */
988 std::vector<Var> input_vars_;
989 /*! \brief map of device contexts variables */
990 Map<String, tir::Var> devices_;
991 /*! \brief map of GlobalVars to C Device API contexts */
992 Map<GlobalVar, tir::Var> device_contexts_;
993 /*! \brief input and output variables belonging to the main function signature */
994 Array<tir::Var> main_signature_;
995 /*! \brief input and output variables belonging to the main function signature */
996 Map<tir::Var, tir::Buffer> main_buffer_map_;
997 /*! \brief maps input and output variables to TensorType which describe them */
998 Map<tir::Var, TensorType> io_tensor_types_;
999 /*! \brief All available targets. */
1000 CompilationConfig config_;
1001 /*!
1002 * \brief The type of kernel call to be emitted.
1003 * See CallType for more documentation.
1004 */
1005 CallType call_type_;
1006
1007 /*!
1008 * \brief parameters (i.e. ConstantNodes found in the graph).
1009 * These are take as inputs to the GraphRuntime.
1010 * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be
1011 * used to lookup the parameter.
1012 */
1013 std::unordered_map<std::string, runtime::NDArray> params_;
1014 /*! \brief mapping between expression and parameters */
1015 Map<Expr, String> params_by_expr_;
1016 /*! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/
1017 std::unordered_map<std::string, int64_t> param_storage_ids_;
1018 std::unordered_map<const tir::Var, const ConstantNode*, ObjectPtrHash, ObjectPtrEqual>
1019 constant_map_;
1020
1021 /*! \brief plan memory of device result */
1022 StorageMap storage_device_map_;
1023 /*! \brief mapping sid -> tir::Var */
1024 std::unordered_map<int, tir::Var> sids_table_;
1025 /*! \brief lowered funcs */
1026 Map<String, FunctionInfo> function_metadata_;
1027 /*! \brief the set of statements that make the program */
1028 std::vector<tir::Stmt> stmts_;
1029 /*! \brief the list of return sids (note that the function might return more then one output */
1030 std::vector<int> return_sid_;
1031 /*! \brief This is per IO var name counter to aid the generating unique names */
1032 std::unordered_map<std::string, int> io_var_names_;
1033 /*! \brief A set of variables that are let bound. */
1034 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_vars_;
1035
1036 public:
1037 AOTExecutorCodegen(runtime::Module* mod, const Array<Target>& targets)
1038 : mod_(mod), config_(transform::PassContext::Current(), targets) {}
1039
1040 LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) {
1041 VLOG_CONTEXT << "AOT";
1042
1043 Runtime runtime_config = mod->GetAttr<Runtime>(tvm::attr::kRuntime).value();
1044 Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod);
1045
1046 Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value();
1047 std::string interface_api =
1048 executor_config->GetAttr<String>("interface-api").value_or("packed");
1049 bool unpacked_api = executor_config->GetAttr<Bool>("unpacked-api").value_or(Bool(false));
1050
1051 // Validate choice of unpacked_api and use_call_cpacked_
1052 if (runtime_config->name == kTvmRuntimeCrt) {
1053 if (unpacked_api == true) {
1054 call_type_ = CallType::kUnpacked;
1055 } else if (unpacked_api == false && interface_api == "packed") {
1056 call_type_ = CallType::kCPacked;
1057 } else {
1058 CHECK(interface_api == "packed" || unpacked_api == true)
1059 << "Either need interface_api == \"packed\" (got: " << interface_api
1060 << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime";
1061 ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api
1062 << ", unpacked-api=" << unpacked_api;
1063 }
1064 } else if (runtime_config->name == kTvmRuntimeCpp) {
1065 if (unpacked_api == false && interface_api == "packed") {
1066 call_type_ = CallType::kCPacked;
1067 } else {
1068 CHECK(static_cast<bool>(unpacked_api) == false && interface_api == "packed")
1069 << "Need unpacked-api == false (got: " << unpacked_api
1070 << ") and interface-api == \"packed\" (got: " << interface_api
1071 << ") when targeting c++ runtime";
1072 ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api
1073 << ", unpacked-api=" << unpacked_api;
1074 }
1075 } else {
1076 ICHECK(false) << "runtime_config (" << runtime_config->name
1077 << ") is not one of the expected values";
1078 }
1079
1080 mod = transform::ToANormalForm()(mod);
1081 mod = transform::InferType()(mod);
1082 mod = transform::AnnotateUsedMemory()(mod);
1083
1084 IRModule lowered_mod =
1085 tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) {
1086 // We need to maintain the constant map for external
1087 // functions so we pass this processing function which
1088 // allows us to process each function as we lower it.
1089 if (func->GetAttr<String>(attr::kCompiler).defined()) {
1090 UpdateConstants(func, &params_);
1091 }
1092
1093 // TODO(@areusch, @jroesch): We should refactor this to
1094 // execute as a further pass, instead writing data to the
1095 // lowering process directly.
1096 tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment);
1097 })(mod);
1098
1099 transform::PassContext pass_ctx = transform::PassContext::Current();
1100 bool enable_remove_reshapes =
1101 pass_ctx->GetConfig<Bool>("relay.remove_standalone_reshapes.enable", Bool(true)).value();
1102 if (enable_remove_reshapes) {
1103 lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod);
1104 }
1105 auto lowered_main = lowered_mod->Lookup("main");
1106 auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());
1107
1108 // Post-lowering storage map for writing main func
1109 AOTOnDemandAllocator final_aot_allocator;
1110 final_aot_allocator.Run(lowered_main_func);
1111 storage_device_map_ = final_aot_allocator.GetStorageMap();
1112
1113 // TODO(@electriclilies, @jroesch, @Mousius): remove UpdateMainWorkspaceSize
1114 StaticMemoryPlan memory_plan(storage_device_map_);
1115 backend::FunctionInfo func_info =
1116 tec::UpdateMainWorkspaceSize(lowered_mod, config_, memory_plan->expr_to_storage_info);
1117 lowered_mod = WithAttr(lowered_mod, "main_func_info", func_info);
1118
1119 for (auto input : lowered_main_func->params) {
1120 input_vars_.push_back(input);
1121 std::string input_name = SanitizeName(input->name_hint());
1122 // We dont want the compiler changing input names in the
1123 // event of a sanitization collision. Therefore, enforcing
1124 // the var created to use the input_name strictly.
1125 CreateIOVar(input, input_name, /*use_unique_name = */ false);
1126 }
1127
1128 // Define the storage allocator ids
1129 for (auto kv : storage_device_map_) {
1130 for (auto sid : kv.second->storage_ids) {
1131 // The buffer_var is created with storage_scope to be global.workspace to be serviced by
1132 // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor
1133 // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and
1134 // should not be lowered to the stack. For more details please refer to the discussion here:
1135 // https://github.com/apache/tvm/issues/9022
1136 te::Var buffer_var(MakeString("sid_", sid),
1137 PointerType(PrimType(DataType::Int(8)), "global.workspace"));
1138 sids_table_[sid] = buffer_var;
1139 }
1140 }
1141
1142 // Retrieve the return sids
1143 return_sid_ = final_aot_allocator.GetReturnIds();
1144 // Insert outputs to main func signature
1145 // If output tensor names were provided use them
1146 if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
1147 Array<String> output_tensor_names = opt.value();
1148 Expr output_expr = lowered_main_func->body;
1149 if (output_expr->checked_type()->IsInstance<TupleTypeNode>()) {
1150 TupleType output_tuple_type = Downcast<TupleType>(output_expr->checked_type());
1151 for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) {
1152 // AoT Executor Codegen does not create these names,
1153 // thus should be used as they are provided.
1154 CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i],
1155 /*use_unique_name = */ false);
1156 }
1157 } else {
1158 // AoT Executor Codegen does not create these names,
1159 // thus should be used as they are provided.
1160 CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false);
1161 }
1162 } else {
1163 // If output tensor names are not provided we will generate output(x)
1164 // where x is a counter to create unique names.
1165 CreateIOVar(lowered_main_func->body, "output");
1166 }
1167
1168 CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts").value());
1169 VisitExpr(lowered_main_func->body);
1170
1171 // Create the runner function. Please note that the function is not legal yet
1172 // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
1173 // to run the LegalizePackedCalls pass.
1174 LoweredOutput ret;
1175
1176 // Collect any constants extracted by external codegen.
1177 ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>();
1178 Map<String, runtime::NDArray> const_name_to_constant =
1179 lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant)
1180 .value_or({});
1181 for (const auto& kv : const_name_to_constant) {
1182 ICHECK(ret.params.emplace(kv.first, kv.second).second);
1183 }
1184
1185 // Collect any constants extracted during lowering.
1186 for (const auto& kv : params_) {
1187 ICHECK(ret.params.emplace(kv.first, kv.second).second);
1188 }
1189
1190 // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main
1191 // function and replacing it with its TIR version. We should try to make this a Pass.
1192 lowered_mod->Remove(lowered_mod->GetGlobalVar("main"));
1193 auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
1194 // Extract additional information around main TIR PrimFunc arguments
1195 Array<String> devices = ListDevices();
1196 const auto main_func_params_end_iterator =
1197 tir_main_func->params.begin() + tir_main_func->params.size();
1198 const auto outputs_begin_iterator =
1199 main_func_params_end_iterator - return_sid_.size() - devices.size();
1200 Array<tir::Var> inputs = Array<tir::Var>(tir_main_func->params.begin(), outputs_begin_iterator);
1201 Array<TensorType> input_tensor_types;
1202 for (auto i : inputs) {
1203 input_tensor_types.push_back(io_tensor_types_[i]);
1204 }
1205 Array<tir::Var> outputs =
1206 Array<tir::Var>(outputs_begin_iterator, main_func_params_end_iterator - devices.size());
1207
1208 lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), tir_main_func);
1209 // Parallel for loops are not supported in AoT codegen.
1210 lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod);
1211
1212 bool enable_usmp = pass_ctx->GetConfig<Bool>(kUSMPEnableOption, Bool(false)).value();
1213 if (enable_usmp) {
1214 lowered_mod = PlanMemoryWithUSMP(lowered_mod);
1215 } else {
1216 lowered_mod = PlanMemoryWithStorageRewrite(lowered_mod);
1217 }
1218 ret.function_metadata = std::move(function_metadata_);
1219
1220 // Legalize AOT if needed. This means that all the packed calls
1221 // need to be wrapped in TVMValues (unless unpacked_api is set)
1222 if (call_type_ == CallType::kCPacked || call_type_ == CallType::kPacked) {
1223 auto pack_calls = tir::transform::LegalizePackedCalls();
1224 lowered_mod = pack_calls(lowered_mod);
1225 }
1226
1227 // Collect any runtime modules generated by external codegen.
1228 ret.external_mods =
1229 lowered_mod->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods).value_or({});
1230
1231 // This is the point where we separate the functions in the module by target
1232 VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod);
1233 ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod);
1234 VLOG(1) << "per-target modules:";
1235 for (const auto& kv : ret.lowered_funcs) {
1236 VLOG(1) << "target:" << std::endl
1237 << kv.first->ToDebugString() << std::endl
1238 << "maps to:" << std::endl
1239 << PrettyPrint(kv.second);
1240 }
1241
1242 // Extract USMP metadata to pass onto metadata sources
1243 Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info;
1244 std::vector<tir::Var> pool_vars;
1245 tir_main_func =
1246 Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main));
1247 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
1248 tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
1249 if (allocated_pool_infos) {
1250 for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) {
1251 int pool_var_index = allocated_pool_info->pool_var_idx.value()->value;
1252 pool_vars.push_back(tir_main_func->params[pool_var_index]);
1253 pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info);
1254 }
1255 }
1256 Map<String, tir::usmp::PoolAllocation> io_pool_allocations =
1257 lowered_mod
1258 ->GetAttr<Map<String, tir::usmp::PoolAllocation>>(tvm::attr::kIOTensorPoolAllocations)
1259 .value_or({});
1260
1261 std::vector<String> output_var_names;
1262 if (auto opt = func->GetAttr<Array<String>>("output_tensor_names")) {
1263 Array<String> output_tensor_names = opt.value();
1264 for (size_t i = 0; i < output_tensor_names.size(); ++i) {
1265 output_var_names.push_back(output_tensor_names[i]);
1266 }
1267 }
1268
1269 // If output names have not been specified then generate default output names
1270 if (output_var_names.size() == 0) {
1271 if (return_sid_.size() == 1) {
1272 output_var_names.push_back(String("output"));
1273 } else {
1274 for (size_t i = 0; i < return_sid_.size(); ++i) {
1275 output_var_names.push_back(String("output" + std::to_string(i)));
1276 }
1277 }
1278 }
1279
1280 Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes()};
1281
1282 ret.metadata = ExecutorCodegenMetadata(
1283 inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices,
1284 runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api,
1285 GetModuleWorkspaceByteAlignment(mod), GetModuleConstantByteAlignment(mod), pool_var_info,
1286 io_pool_allocations);
1287 return ret;
1288 }
1289
1290 /*!
1291 * \brief Get list of devices found
1292 * \return List of devices
1293 */
1294 Array<String> ListDevices() {
1295 std::vector<String> device_names(devices_.size());
1296 std::transform(devices_.begin(), devices_.end(), device_names.begin(),
1297 [](const auto& it) -> String { return it.first; });
1298 return device_names;
1299 }
1300}; // namespace backend
1301
1302class AOTExecutorCodegenModule : public runtime::ModuleNode {
1303 public:
1304 AOTExecutorCodegenModule() {}
1305 virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) {
1306 if (name == "init") {
1307 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
1308 ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: "
1309 << "runtime::Module mod and Array<Target> targets";
1310 void* mod = args[0];
1311 Array<Target> targets = args[1];
1312 init(mod, targets);
1313 });
1314 } else if (name == "codegen") {
1315 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
1316 IRModule mod = args[0];
1317 Function func = args[1];
1318 String mod_name = args[2];
1319 this->output_ = this->codegen_->Codegen(mod, func, mod_name);
1320 });
1321 } else if (name == "list_params_name") {
1322 return PackedFunc(
1323 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = list_params_name(); });
1324 } else if (name == "get_param_by_name") {
1325 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
1326 String key = args[0];
1327 *rv = get_param_by_name(key);
1328 });
1329 } else if (name == "get_irmodule") {
1330 return PackedFunc(
1331 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); });
1332 } else if (name == "get_external_modules") {
1333 return PackedFunc(
1334 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); });
1335 } else if (name == "get_function_metadata") {
1336 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
1337 *rv = this->output_.function_metadata;
1338 });
1339 } else if (name == "get_devices") {
1340 return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) {
1341 *rv = this->codegen_->ListDevices();
1342 });
1343 } else if (name == "get_executor_codegen_metadata") {
1344 return PackedFunc(
1345 [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; });
1346 } else {
1347 return PackedFunc([](TVMArgs args, TVMRetValue* rv) {});
1348 }
1349 }
1350
1351 const char* type_key() const final { return "RelayGraphRuntimeCodegenModule"; }
1352
1353 private:
1354 void init(void* mod, const Array<Target>& targets) {
1355 codegen_ =
1356 std::make_shared<AOTExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod), targets);
1357 }
1358
1359 Array<runtime::String> list_params_name() {
1360 Array<runtime::String> ret;
1361 for (const auto& kv : this->output_.params) {
1362 ret.push_back(kv.first);
1363 }
1364 return ret;
1365 }
1366
1367 runtime::NDArray get_param_by_name(String key) {
1368 auto it = this->output_.params.find(key);
1369 CHECK(it != this->output_.params.end()) << "no such parameter " << key;
1370 return (*it).second;
1371 }
1372
1373 Array<tvm::runtime::Module> get_external_modules() { return output_.external_mods; }
1374
1375 Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; }
1376
1377 std::shared_ptr<AOTExecutorCodegen> codegen_;
1378 LoweredOutput output_;
1379};
1380
1381runtime::Module CreateAOTExecutorCodegenMod() {
1382 auto ptr = make_object<AOTExecutorCodegenModule>();
1383 return runtime::Module(ptr);
1384}
1385
1386TVM_REGISTER_GLOBAL("relay.build_module._AOTExecutorCodegen")
1387 .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTExecutorCodegenMod(); });
1388
1389} // namespace backend
1390} // namespace relay
1391} // namespace tvm
1392