1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/aot_executor_codegen.cc |
22 | * \brief AOT executor codegen |
23 | */ |
24 | |
25 | #include <tvm/ir/module.h> |
26 | #include <tvm/relay/attrs/annotation.h> |
27 | #include <tvm/relay/attrs/call.h> |
28 | #include <tvm/relay/executor.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/runtime.h> |
31 | #include <tvm/runtime/device_api.h> |
32 | #include <tvm/runtime/name_transforms.h> |
33 | #include <tvm/runtime/object.h> |
34 | #include <tvm/tir/analysis.h> |
35 | #include <tvm/tir/builtin.h> |
36 | #include <tvm/tir/expr.h> |
37 | #include <tvm/tir/function.h> |
38 | #include <tvm/tir/stmt.h> |
39 | #include <tvm/tir/transform.h> |
40 | #include <tvm/tir/usmp/utils.h> |
41 | |
42 | #include <algorithm> |
43 | #include <list> |
44 | #include <string> |
45 | #include <vector> |
46 | |
47 | #include "../../target/source/codegen_source_base.h" |
48 | #include "../op/annotation/annotation.h" |
49 | #include "../op/call/call.h" |
50 | #include "../op/memory/device_copy.h" |
51 | #include "../transforms/device_aware_visitors.h" |
52 | #include "./name_transforms.h" |
53 | #include "./te_compiler.h" |
54 | #include "./utils.h" |
55 | |
56 | namespace tvm { |
57 | namespace relay { |
58 | namespace backend { |
59 | |
60 | using StorageMap = |
61 | std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>; |
62 | |
63 | /** |
64 | * This is an on demand allocator for AOT. A new temporary |
65 | * (storage allocator identifier) is allocated for each operation. |
66 | */ |
67 | class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { |
68 | public: |
69 | AOTOnDemandAllocator() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {} |
70 | |
71 | // run the visitor on a global function. |
72 | void Run(const Function& func) { VisitExpr(func); } |
73 | |
74 | std::vector<int> GetReturnIds() const { return return_ids_; } |
75 | std::vector<TensorType> GetReturnTtypes() const { return return_ttypes_; } |
76 | |
77 | StorageMap GetStorageMap() const { return storage_device_map_; } |
78 | |
79 | using ExprVisitor::VisitExpr_; |
80 | |
81 | void VisitExpr_(const ConstantNode* op) final { |
82 | CreateStorage(op); |
83 | AssignReturnSid(GetRef<Expr>(op)); |
84 | } |
85 | |
86 | void DeviceAwareVisitExpr_(const CallNode* call_node) final { |
87 | // AOTOnDemandAllocator is run both before and after lowering, so we need to handle the case |
88 | // where the op of the call is a generic function |
89 | |
90 | Expr func; |
91 | Array<Expr> args; |
92 | |
93 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
94 | if (call_lowered_props.lowered_func.defined()) { |
95 | func = call_lowered_props.lowered_func; |
96 | args = call_lowered_props.arguments; |
97 | } else { // Relay functions that have not been lowered and lowered extern functions |
98 | func = call_node->op; |
99 | args = call_node->args; |
100 | if (call_node->op.as<GlobalVarNode>()) { // Lowered extern function |
101 | ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes." ; |
102 | } else { // Relay function which has not been lowered yet |
103 | ICHECK(call_node->op.as<FunctionNode>()) |
104 | << "Expected the call to be to a lowered primfunc, a lowered extern function or a " |
105 | "unlowered Relay function." ; |
106 | } |
107 | } |
108 | VisitExpr(func); |
109 | CreateStorage(call_node); |
110 | for (const Expr& arg : args) { |
111 | VisitExpr(arg); |
112 | } |
113 | AssignReturnSid(GetRef<Expr>(call_node)); |
114 | } |
115 | |
116 | void VisitExpr_(const VarNode* op) final { AssignReturnSid(GetRef<Expr>(op)); } |
117 | |
118 | void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { |
119 | if (function_nesting() > 1) { |
120 | // do not recurse into sub functions. |
121 | return; |
122 | } |
123 | if (func_node->HasNonzeroAttr(attr::kPrimitive)) { |
124 | // No storage needed for primitive functions. |
125 | return; |
126 | } |
127 | for (const auto& param : func_node->params) { |
128 | CreateStorage(param.get()); |
129 | } |
130 | VisitExpr(func_node->body); |
131 | } |
132 | |
133 | void VisitExpr_(const GlobalVarNode* op) final { |
134 | // Do nothing. |
135 | } |
136 | |
137 | void VisitExpr_(const OpNode* op) final { |
138 | // Do nothing. |
139 | } |
140 | |
141 | void VisitExpr_(const TupleNode* op) final { |
142 | std::vector<int64_t> storage_ids; |
143 | std::vector<VirtualDevice> virtual_devices; |
144 | std::vector<int64_t> storage_sizes_in_bytes; |
145 | Expr expr = GetRef<Expr>(op); |
146 | for (Expr field : op->fields) { |
147 | auto sid = GetStorage(field); |
148 | storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); |
149 | virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(), |
150 | sid->virtual_devices.end()); |
151 | storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), |
152 | sid->storage_sizes_in_bytes.begin(), |
153 | sid->storage_sizes_in_bytes.end()); |
154 | } |
155 | storage_device_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes); |
156 | AssignReturnSid(expr); |
157 | } |
158 | |
159 | void VisitExpr_(const TupleGetItemNode* op) final { |
160 | Expr expr = GetRef<Expr>(op); |
161 | auto sids = GetStorage(op->tuple); |
162 | ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size()); |
163 | storage_device_map_[expr] = |
164 | StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]}, |
165 | {sids->storage_sizes_in_bytes[op->index]}); |
166 | AssignReturnSid(expr); |
167 | } |
168 | |
169 | void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported." ; } |
170 | |
171 | void PreVisitLetBinding_(const Var& var, const Expr& value) final { |
172 | VisitExpr(value); |
173 | StorageInfo si = GetStorage(value); |
174 | storage_device_map_[var] = si; |
175 | } |
176 | |
177 | private: |
178 | void AssignReturnSid(Expr e) { |
179 | if (storage_device_map_.find(e) != storage_device_map_.end()) { |
180 | StorageInfo& sinfo = storage_device_map_[e]; |
181 | return_ids_.clear(); |
182 | for (auto sid : sinfo->storage_ids) { |
183 | return_ids_.push_back(sid); |
184 | } |
185 | return_ttypes_.clear(); |
186 | return_ttypes_ = FlattenTupleType(e->checked_type()); |
187 | } |
188 | } |
189 | /*! |
190 | * \brief ceil(size/word_size) to get number of words. |
191 | * \param size The original size. |
192 | * \param word_size The element size. |
193 | */ |
194 | static size_t DivRoundUp(size_t size, size_t word_size) { |
195 | return (size + word_size - 1) / word_size; |
196 | } |
197 | /*! |
198 | * \brief Get the memory requirement. |
199 | * \param prototype The prototype token. |
200 | * \return The required memory size. |
201 | * |
202 | * TODO(mbs): Cf CalculateRelayExprSizeBytes in utils.cc, GetMemorySize is graph_plan_memory.cc |
203 | */ |
204 | size_t GetMemorySizeBytes(const TensorType& ttype) { |
205 | size_t size = 1; |
206 | for (IndexExpr dim : ttype->shape) { |
207 | const int64_t* pval = tir::as_const_int(dim); |
208 | ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << ttype->shape; |
209 | ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; |
210 | size *= static_cast<size_t>(pval[0]); |
211 | } |
212 | size *= DivRoundUp(ttype->dtype.bits() * ttype->dtype.lanes(), 8); |
213 | return size; |
214 | } |
215 | /*! |
216 | * \brief Get the necessary storage for the expression. |
217 | * \param expr The expression. |
218 | * \return The corresponding token. |
219 | */ |
220 | StorageInfo GetStorage(const Expr& expr) { |
221 | // See through "on_device" calls. |
222 | Expr true_expr = IgnoreOnDevice(expr); |
223 | VisitExpr(true_expr); |
224 | auto it = storage_device_map_.find(true_expr); |
225 | ICHECK(it != storage_device_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " " |
226 | << PrettyPrint(true_expr) << " in storage device map" ; |
227 | return it->second; |
228 | } |
229 | |
230 | /*! |
231 | * \brief Create storage for the expression. |
232 | */ |
233 | void CreateStorage(const ExprNode* op) { |
234 | Expr expr = GetRef<Expr>(op); |
235 | return CreateStorage(expr, GetVirtualDevice(expr)); |
236 | } |
237 | |
238 | /*! |
239 | * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device. |
240 | */ |
241 | void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) { |
242 | ICHECK(!virtual_device->IsFullyUnconstrained()) |
243 | << "invalid virtual device for expr:" << std::endl |
244 | << PrettyPrint(expr); |
245 | std::vector<int64_t> storage_ids; |
246 | std::vector<VirtualDevice> virtual_devices; |
247 | std::vector<int64_t> storage_sizes_in_bytes; |
248 | for (const auto& ttype : FlattenTupleType(expr->checked_type())) { |
249 | storage_ids.push_back(next_available_sid_++); |
250 | virtual_devices.push_back(virtual_device); |
251 | storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype)); |
252 | } |
253 | storage_device_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), |
254 | std::move(storage_sizes_in_bytes)); |
255 | } |
256 | |
257 | /*! \brief mapping of expression -> storageInfo */ |
258 | StorageMap storage_device_map_; |
259 | /*! \brief current id of the temporary allocated */ |
260 | int next_available_sid_{0}; |
261 | /*! \brief the set of intermediate tensors that are return variables */ |
262 | std::vector<int> return_ids_; |
263 | /*! \brief the data types of the return values */ |
264 | std::vector<TensorType> return_ttypes_; |
265 | }; |
266 | |
267 | /*! \brief Code generator for AOT executor */ |
268 | class AOTExecutorCodegen : public MixedModeVisitor { |
269 | protected: |
270 | /*! \brief Describes the type of kernel call emitted. */ |
271 | enum CallType { |
272 | /*! |
273 | * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions. |
274 | * |
275 | * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the |
276 | * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those |
277 | * functions are of type TVMBackendPackedCFunc. |
278 | * |
279 | * The following code is emitted at call sites to call a function named `func`: |
280 | * void* func_ptr = TVMBackendGetFuncFromEnv("func"); |
281 | * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes) |
282 | * |
283 | * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` |
284 | * by LowerTVMBuiltin TIR transform. |
285 | * |
286 | * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often, |
287 | * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when |
288 | * `func` is implemented in C). |
289 | * |
290 | * Compatible with both C++ and C runtimes, implemented with the C runtime only. |
291 | */ |
292 | kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor. |
293 | |
294 | /*! |
295 | * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call. |
296 | * |
297 | * When this type is selected, assumes all operators are implemented in functions of type |
298 | * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of |
299 | * downstream compilation that there is a symbol named after the 0th arg to tir::Call of |
300 | * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target. |
301 | * |
302 | * The following code is emitted at call sites to call a function named `func`: |
303 | * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle) |
304 | * |
305 | * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` |
306 | * by LowerTVMBuiltin TIR transform. |
307 | * |
308 | * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is |
309 | * always the device context parameter when not null. At present, the implementation does not |
310 | * support forwarding device context parameters to CPacked. |
311 | * |
312 | * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented |
313 | * in the same scenarios. |
314 | */ |
315 | kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor. |
316 | |
317 | /*! \brief Directly call a function accepting the `data` arrays as args. |
318 | * |
319 | * When this type is selected, assumes all operaotrs are implemented in C functions whose |
320 | * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the |
321 | * `data` parameters (i.e. no DLTensor object is passed along). |
322 | * |
323 | * The following code is emitted at call sites to a function named `func`: |
324 | * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle |
325 | * -or- |
326 | * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle |
327 | * |
328 | * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is |
329 | * always the device context parameter when not null. |
330 | * |
331 | * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented |
332 | * with the C runtime only. |
333 | */ |
334 | kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors. |
335 | }; |
336 | |
337 | /*! |
338 | * \brief Return a vector of variables that represents the sids for the given Relay Expr |
339 | */ |
340 | std::vector<tir::Var> PackSid(Expr expr) { |
341 | std::vector<tir::Var> buffer_vars; |
342 | |
343 | ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) |
344 | << "Storage map did not contain constant expr " << PrettyPrint(expr); |
345 | StorageInfo& sinfo = storage_device_map_[expr]; |
346 | |
347 | // Note that an expression can have multiple sids associated with it |
348 | // e.g., returning multiple values from a function |
349 | for (auto sid : sinfo->storage_ids) { |
350 | // Determine if an sid is an output buffer |
351 | auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); |
352 | if (output_iter != return_sid_.end()) { |
353 | int output_index = std::distance(return_sid_.begin(), output_iter); |
354 | buffer_vars.push_back(GetBufferVarForIO(input_vars_.size() + output_index)); |
355 | continue; |
356 | } |
357 | |
358 | auto sid_value = sids_table_[sid]; |
359 | buffer_vars.push_back(sid_value); |
360 | } |
361 | return buffer_vars; |
362 | } |
363 | |
364 | /*! |
365 | * brief Given an expression return the variable(s) associated with that expression |
366 | */ |
367 | std::vector<te::Var> FindExpr(Expr arg) { |
368 | auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg); |
369 | if (input_iter != input_vars_.end()) { |
370 | // Input variable |
371 | int main_index = std::distance(input_vars_.begin(), input_iter); |
372 | return {GetBufferVarForIO(main_index)}; |
373 | } else { |
374 | // Storage identifier (i.e., intermediate memory) |
375 | return PackSid(arg); |
376 | } |
377 | } |
378 | |
379 | /*! |
380 | * \brief Reverse lookup the device name in devices_ map. |
381 | * \param device_context Value in devices_ to find. |
382 | * \return Key matching device_context in devices_. |
383 | */ |
384 | std::string FindDeviceName(tir::Var device_context) { |
385 | for (std::pair<String, tir::Var> kv : devices_) { |
386 | if (kv.second->name_hint == device_context->name_hint) { |
387 | return kv.first; |
388 | } |
389 | } |
390 | ICHECK(false) << "Did not find a device name associated with " << device_context; |
391 | return "" ; |
392 | } |
393 | |
394 | void PushArgs(const Expr& expr, const std::vector<tir::Var>& sids, Array<PrimExpr>* args) { |
395 | const TupleNode* t = expr.as<TupleNode>(); |
396 | if (t != nullptr) { |
397 | CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't " |
398 | "handle this type of Relay Expr in a CallNode." ; |
399 | } |
400 | |
401 | args->insert(args->end(), sids.begin(), sids.end()); |
402 | } |
403 | |
404 | /* |
405 | * Wraps a call_extern with a tvm_check_return annotation if required otherwise |
406 | * returns the passed Call |
407 | */ |
408 | tir::Call AddCheckReturn(tir::Call existing_call) { |
409 | Array<PrimExpr> args = {tir::make_const(DataType::Int(32, 1), 0, Span()), |
410 | tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; |
411 | return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); |
412 | } |
413 | |
414 | /*! |
415 | * brief Create a function call |
416 | * \param call_lowered_props The lowered function and the arguments to call it with |
417 | * \param result_expr The call we got func and args from (so as to recover the storage |
418 | * ids to hold the result). |
419 | */ |
420 | void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) { |
421 | std::string func_name = call_lowered_props.lowered_func->name_hint; |
422 | tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)}; |
423 | std::vector<tir::Stmt> create_func_call_stmts; |
424 | |
425 | // Pack the inputs |
426 | for (const Expr& arg : call_lowered_props.arguments) { |
427 | if (params_by_expr_.find(arg) != params_by_expr_.end()) { |
428 | auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), |
429 | {tir::StringImm(params_by_expr_[arg])}); |
430 | // NOTE: this cast looks like a no-op, but is required for compilation downstream. |
431 | // Because DataType::Handle has default bits=64, but CodeGenC does not observe this field, |
432 | // adding this cast forces the codegen to insert the cast. In this case, a cast is required |
433 | // because param_handle is actually code-generated as `const void*`, and the `const` piece |
434 | // needs to be removed. |
435 | args.push_back(tvm::tir::Cast(DataType::Handle(32, 1), param_handle)); |
436 | } else { |
437 | auto sids = FindExpr(arg); |
438 | PushArgs(arg, sids, &args); |
439 | } |
440 | } |
441 | |
442 | // Pack the return(s) value. A call node can produce multiple outputs |
443 | auto result_expr_sid = PackSid(result_expr); |
444 | PushArgs(result_expr, result_expr_sid, &args); |
445 | |
446 | GlobalVar global_var = call_lowered_props.lowered_func; |
447 | bool has_c_device_api_context = device_contexts_.count(global_var) != 0; |
448 | tir::Var device_context; |
449 | tir::Stmt func_call; |
450 | |
451 | switch (call_type_) { |
452 | case CallType::kUnpacked: { |
453 | // call_extern calling convention with optional context |
454 | if (has_c_device_api_context) { |
455 | device_context = device_contexts_.Get(global_var).value(); |
456 | args.push_back(device_context); |
457 | } |
458 | func_call = tir::Evaluate(AddCheckReturn( |
459 | tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); |
460 | break; |
461 | } |
462 | case CallType::kCPacked: { |
463 | if (has_c_device_api_context) { |
464 | device_context = device_contexts_.Get(global_var).value(); |
465 | args.push_back(device_context); |
466 | } else { |
467 | // NOTE: LowerTVMBuiltin expects some device_context placeholder. |
468 | args.push_back(tir::make_zero(DataType::Handle())); |
469 | } |
470 | func_call = tir::Evaluate( |
471 | tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)); |
472 | create_func_call_stmts.push_back(func_call); |
473 | break; |
474 | } |
475 | case CallType::kPacked: { |
476 | // call_packed does not accept a device context. |
477 | CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context" ; |
478 | func_call = tir::Evaluate(AddCheckReturn( |
479 | tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); |
480 | create_func_call_stmts.push_back(func_call); |
481 | break; |
482 | } |
483 | default: |
484 | ICHECK(false) << "Unknown CallType: " << call_type_; |
485 | } |
486 | |
487 | ICHECK(func_call.defined()) << "Must define func_call" ; |
488 | |
489 | if (has_c_device_api_context) { |
490 | func_call = tir::SeqStmt(Array<tir::Stmt>({ |
491 | GenerateDeviceHook(device_context, "Open" ), |
492 | func_call, |
493 | GenerateDeviceHook(device_context, "Close" ), |
494 | })); |
495 | } |
496 | |
497 | tir::Stmt body = tir::SeqStmt({func_call}); |
498 | stmts_.push_back(body); |
499 | } |
500 | |
501 | /*! |
502 | * \brief Copy a variable to the output. This function is mainly used in edge cases |
503 | * when we want to return an input or a parameter. |
504 | * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a |
505 | * copy-on-write fashion. |
506 | */ |
507 | void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { |
508 | // Define intermediate DLTensor to load/store the data |
509 | tir::Buffer tmp_read = |
510 | tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read" ); |
511 | tir::Buffer tmp_write = |
512 | tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write" ); |
513 | te::Var loop_idx("i" , DataType::Int(32)); |
514 | auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); |
515 | // Copy the variable from the input to the output |
516 | tir::Stmt copy = tir::For( |
517 | loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, |
518 | tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); |
519 | stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); |
520 | } |
521 | |
522 | /* |
523 | * \brief Collects device context variables for passing to operators |
524 | */ |
525 | void CollectDeviceVariables(const Map<GlobalVar, String>& device_contexts) { |
526 | Map<TargetKind, tir::Var> target_contexts; |
527 | TargetKindAttrMap<Bool> target_attr_map = tvm::TargetKind::GetAttrMap<Bool>("use_device_api" ); |
528 | |
529 | for (const auto& it : device_contexts) { |
530 | const GlobalVar& global_var = it.first; |
531 | const std::string device_context_name = it.second; |
532 | |
533 | Optional<TargetKind> target_kind = tvm::TargetKind::Get(device_context_name); |
534 | if (!target_kind || !target_attr_map.count(target_kind.value())) { |
535 | return; |
536 | } |
537 | if (target_attr_map[target_kind.value()]) { |
538 | std::string context_name = tvm::runtime::SanitizeName(device_context_name); |
539 | tir::Var device_context_var("device_context_" + context_name, DataType::Handle()); |
540 | |
541 | auto pair = target_contexts.find(target_kind.value()); |
542 | if (pair != target_contexts.end()) { |
543 | device_context_var = (*pair).second; |
544 | } else { |
545 | main_signature_.push_back(device_context_var); |
546 | devices_.Set(context_name, device_context_var); |
547 | target_contexts.Set(target_kind.value(), device_context_var); |
548 | } |
549 | |
550 | device_contexts_.Set(global_var, device_context_var); |
551 | } |
552 | } |
553 | } |
554 | |
555 | /** |
556 | * \brief Generates a call to a given hook for all Devices found for C Device API |
557 | * \param Name of hook to generate statements for |
558 | * \return Statement with function calls for each device |
559 | */ |
560 | tir::Stmt GenerateAllDeviceHook(const String& hook) { |
561 | std::vector<tir::Stmt> device_hooks; |
562 | for (const auto& it : devices_) { |
563 | const String& device_name = it.first; |
564 | const tir::Var& context = it.second; |
565 | Array<String> sections = {"Device" , device_name, hook}; |
566 | String device_hook_name = ToCFunctionStyle(PrefixName(sections)); |
567 | |
568 | tir::Evaluate device_hook( |
569 | AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), |
570 | {tvm::tir::StringImm(device_hook_name), context}))); |
571 | device_hooks.push_back(device_hook); |
572 | } |
573 | return tir::SeqStmt(device_hooks); |
574 | } |
575 | |
576 | /** |
577 | * \brief Generates a call to a given hook for a single Device function |
578 | * \param Var Device context to call hook on |
579 | * \param Name of hook to generate statements for |
580 | * \return Statement with function call to Device API |
581 | */ |
582 | tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) { |
583 | const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) { |
584 | return it.second->name_hint == context->name_hint; |
585 | }); |
586 | const String& device_name = (*it).first; |
587 | Array<String> sections = {"Device" , device_name, hook}; |
588 | String device_hook = ToCFunctionStyle(PrefixName(sections)); |
589 | |
590 | return tir::Evaluate( |
591 | AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), |
592 | {tvm::tir::StringImm(device_hook), context}))); |
593 | } |
594 | |
595 | /*! |
596 | * Utility function to string together different arguments |
597 | */ |
598 | template <typename... Args> |
599 | std::string MakeString(Args const&... args) { |
600 | std::ostringstream ss; |
601 | using List = int[]; |
602 | (void)List{0, ((void)(ss << args), 0)...}; |
603 | |
604 | return ss.str(); |
605 | } |
606 | |
607 | void VisitExpr_(const CallNode* call_node) override { |
608 | OnDeviceProps on_device_props = GetOnDeviceProps(call_node); |
609 | if (on_device_props.body.defined()) { |
610 | VisitExpr(on_device_props.body); |
611 | return; |
612 | } |
613 | |
614 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); |
615 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
616 | |
617 | if (device_copy_props.body.defined()) { |
618 | // TODO(mbs): device_copy cleaunp |
619 | // Suspect treating as no-op is better since already built into the StorageInfo? |
620 | LOG(FATAL) << "The AOT executor does not currently support device_copy" ; |
621 | } |
622 | |
623 | // At this point we should only see calls of the form call_lowered(@callee, (args...)), |
624 | // where @callee can be a PrimFunc we've compiled or an external function supplied via |
625 | // some other mechanism. |
626 | ICHECK(call_lowered_props.lowered_func.defined()) |
627 | << "AOT does not support calling Relay functions. Attempting to call:" << std::endl |
628 | << PrettyPrint(GetRef<Call>(call_node)); |
629 | for (const auto& arg : call_lowered_props.arguments) { |
630 | // Evaluate the args |
631 | VisitExpr(arg); |
632 | } |
633 | CreateFuncCall(call_lowered_props, GetRef<Call>(call_node)); |
634 | } |
635 | |
636 | void VisitExpr_(const VarNode* op) override { |
637 | Expr expr = GetRef<Expr>(op); |
638 | StorageInfo& sinfo = storage_device_map_[expr]; |
639 | |
640 | // Let bound vars refer to a value, so these should not be considered "output" vars. |
641 | if (let_bound_vars_.find(GetRef<Var>(op)) != let_bound_vars_.end()) { |
642 | return; |
643 | } |
644 | |
645 | // If the Var node is an output node we need to copy the content of the variable to the output |
646 | // It's safe to check the SID here because Var StorageToken are never reallocated |
647 | auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); |
648 | if (output_iter != return_sid_.end()) { |
649 | int output_index = std::distance(return_sid_.begin(), output_iter); |
650 | if (params_by_expr_.find(expr) != params_by_expr_.end()) { |
651 | auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), |
652 | {tir::StringImm(params_by_expr_[expr])}); |
653 | CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, |
654 | /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); |
655 | } else { |
656 | auto var_expr = FindExpr(expr); |
657 | CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], |
658 | /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); |
659 | } |
660 | } |
661 | } |
662 | |
663 | void VisitExpr_(const ConstantNode* op) override { |
664 | Expr expr = GetRef<Expr>(op); |
665 | ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) |
666 | << "Storage map did not contain constant expr " << PrettyPrint(expr); |
667 | StorageInfo& sinfo = storage_device_map_[expr]; |
668 | std::stringstream ss; |
669 | ss << "constant_" << constant_map_.size(); |
670 | |
671 | tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype)))); |
672 | constant_map_[constant] = op; |
673 | auto sid = sinfo->storage_ids[0]; |
674 | sids_table_[sid] = constant; |
675 | |
676 | // If the Constant node is an output node we need to copy the content of the parameter to the |
677 | // output. A node can only produce a single output |
678 | auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); |
679 | if (output_iter != return_sid_.end()) { |
680 | int output_index = std::distance(return_sid_.begin(), output_iter); |
681 | auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), |
682 | {tir::StringImm(ss.str())}); |
683 | CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), constant, |
684 | /* pack_input */ false, sinfo->storage_sizes_in_bytes[0]); |
685 | } |
686 | } |
687 | |
688 | void VisitExpr_(const TupleNode* op) override { |
689 | for (auto field : op->fields) { |
690 | VisitExpr(field); |
691 | } |
692 | } |
693 | |
694 | void VisitExpr_(const LetNode* op) override { |
695 | auto pre_visit = [this](const LetNode* op) { |
696 | let_bound_vars_.insert(op->var); |
697 | this->VisitExpr(op->value); |
698 | }; |
699 | auto post_visit = [this](const LetNode* op) { |
700 | this->VisitExpr(op->body); |
701 | this->visit_counter_[op] += 1; |
702 | }; |
703 | ExpandANormalForm(op, pre_visit, post_visit); |
704 | } |
705 | |
706 | void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } |
707 | void VisitExpr_(const OpNode* op) override { |
708 | if (GetRef<Op>(op) != CallLoweredOp() && GetRef<Op>(op) != OnDeviceOp()) { |
709 | LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded" ; |
710 | } |
711 | } |
712 | void VisitExpr_(const IfNode* op) override { |
713 | LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called" ; |
714 | } |
715 | void VisitExpr_(const FunctionNode* op) override { |
716 | ICHECK(op->GetAttr<String>(attr::kCompiler).defined()) |
717 | << "FunctionNode only supported by custom codegen" ; |
718 | } |
719 | void VisitExpr_(const RefCreateNode* op) override { |
720 | LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)" ; |
721 | } |
722 | void VisitExpr_(const RefReadNode* op) override { |
723 | LOG(FATAL) << "AOT executor does not support references (found RefReadNode)" ; |
724 | } |
725 | void VisitExpr_(const RefWriteNode* op) override { |
726 | LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)" ; |
727 | } |
728 | void VisitExpr_(const ConstructorNode* op) override { |
729 | LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)" ; |
730 | } |
731 | void VisitExpr_(const MatchNode* op) override { |
732 | LOG(FATAL) << "AOT executor does not support matching (found MatchNode)" ; |
733 | } |
734 | |
735 | // Create the main PrimFunc to execute the graph. Please note that |
736 | // the packed function calls don't pack their arguments. The AOT |
737 | // runner function needs to be legalized by the LegalizePackedCalls pass. |
738 | tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) { |
739 | tir::Stmt body = tir::SeqStmt(stmts_); |
740 | // Allocate the sids |
741 | std::unordered_map<int, bool> allocated; |
742 | |
743 | for (auto kv : storage_device_map_) { |
744 | // Only allocate sids that are needed |
745 | const bool is_input = |
746 | (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end()); |
747 | const bool is_param = (params_by_expr_.find(kv.first) != params_by_expr_.end()); |
748 | if (is_input || is_param) { |
749 | continue; |
750 | } |
751 | |
752 | for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) { |
753 | int size = kv.second->storage_sizes_in_bytes[i]; |
754 | int sid = kv.second->storage_ids[i]; |
755 | |
756 | if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { |
757 | continue; |
758 | } |
759 | |
760 | // Make sure it hasn't already been allocated, this can happen |
761 | // with let-bound var/value pairs. |
762 | if (allocated.find(sid) != allocated.end()) { |
763 | continue; |
764 | } |
765 | |
766 | allocated[sid] = constant_map_.count(sids_table_[sid]); |
767 | |
768 | // TODO(giuseros): we should allocate this once outside the PrimFunc |
769 | // so we don't pay the price of allocation for every inference |
770 | if (!allocated[sid]) { |
771 | PointerType ptype = Downcast<PointerType>(sids_table_[sid]->type_annotation); |
772 | DataType element_type = Downcast<PrimType>(ptype->element_type)->dtype; |
773 | body = tir::Allocate(sids_table_[sid], element_type, {size}, tir::const_true(), body); |
774 | } |
775 | allocated[sid] = true; |
776 | } |
777 | } |
778 | |
779 | for (auto kv : constant_map_) { |
780 | auto buffer_var = kv.first; |
781 | auto dtype = DataType(kv.second->data->dtype); |
782 | |
783 | int ndim = kv.second->data->ndim; |
784 | Array<PrimExpr> extents; |
785 | |
786 | for (int i = 0; i < ndim; i++) { |
787 | int shape = kv.second->data->shape[i]; |
788 | extents.push_back(tir::make_const(DataType::Int(32), shape, Span())); |
789 | } |
790 | body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); |
791 | } |
792 | |
793 | // Define the PrimFunc attributes |
794 | Map<String, ObjectRef> dict_attrs; |
795 | String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main); |
796 | dict_attrs.Set("global_symbol" , run_func_name); |
797 | dict_attrs.Set("runner_function" , Bool(true)); |
798 | dict_attrs.Set(tvm::attr::kTarget, config_->host_target); |
799 | |
800 | tir::Stmt device_activations = GenerateAllDeviceHook("Activate" ); |
801 | tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate" ); |
802 | tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); |
803 | |
804 | // Make the PrimFunc |
805 | return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, |
806 | DictAttrs(dict_attrs)); |
807 | } |
808 | |
809 | /*! |
810 | * \brief Access IO vars using the buffer vars and |
811 | * not the actual var. |
812 | */ |
813 | tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; } |
814 | |
815 | /*! |
816 | * \brief Create tir::Var for input/output while updating the buffer_maps. |
817 | * |
818 | * \param expr The expression to evaluate. |
819 | * \param original_name The name of the tir::Var. |
820 | * \param use_unique_name Whether to generate a new unique name where a name conflicts. |
821 | */ |
822 | void CreateIOVar(const Expr& expr, const std::string& original_name, |
823 | bool use_unique_name = true) { |
824 | CreateIOVar(expr->checked_type(), original_name, use_unique_name); |
825 | } |
826 | |
827 | /*! |
828 | * \brief Create tir::Var for input/output while updating the buffer_maps. |
829 | * |
830 | * \param expr The expression to evaluate. |
831 | * \param original_name The name of the tir::Var. |
832 | * \param use_unique_name Whether to generate a new unique name where a name conflicts. |
833 | */ |
834 | void CreateIOVar(const Type& type, const std::string& original_name, |
835 | bool use_unique_name = true) { |
836 | if (type->IsInstance<TupleTypeNode>()) { |
837 | TupleType tuple_type = Downcast<TupleType>(type); |
838 | for (unsigned i = 0; i < tuple_type->fields.size(); i++) { |
839 | CreateIOVar(tuple_type->fields[i], original_name); |
840 | } |
841 | } else { |
842 | std::string name = original_name; |
843 | if (use_unique_name) { |
844 | name = GetUniqueIOVarName(original_name); |
845 | } |
846 | tir::Var var = tir::Var(name, DataType::Handle()); |
847 | main_signature_.push_back(var); |
848 | auto tensor_type = type.as<TensorTypeNode>(); |
849 | ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey(); |
850 | DataType elem_type = tensor_type->dtype; |
851 | tir::Var buffer_var = |
852 | tir::Var(name + "_buffer_var" , PointerType(PrimType(elem_type), "global" )); |
853 | tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, |
854 | name + "_buffer" , 16, 1, tir::BufferType::kDefault); |
855 | main_buffer_map_.Set(var, buffer); |
856 | io_tensor_types_.Set(var, Downcast<TensorType>(type)); |
857 | } |
858 | } |
859 | |
860 | /*! |
861 | * \brief Create a unique name for I/O Var |
862 | */ |
863 | std::string GetUniqueIOVarName(std::string name) { |
864 | if (io_var_names_.find(name) == io_var_names_.end()) { |
865 | io_var_names_[name] = 1; |
866 | return name; |
867 | } else { |
868 | io_var_names_[name] = io_var_names_[name] + 1; |
869 | return name + std::to_string(io_var_names_[name]); |
870 | } |
871 | } |
872 | |
873 | /*! |
874 | * \brief Calculate workspace sizes for PrimFuncs in the IRModule |
875 | */ |
876 | Map<String, FunctionInfo> CalculateWorkspaceSizes( |
877 | const IRModule& lowered_mod, const Map<String, FunctionInfo>& function_metadata) { |
878 | Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(lowered_mod); |
879 | Map<String, FunctionInfo> updated_function_metadata; |
880 | for (const auto& kv : lowered_mod->functions) { |
881 | GlobalVar global_var = kv.first; |
882 | BaseFunc base_func = kv.second; |
883 | if (base_func->IsInstance<tir::PrimFuncNode>()) { |
884 | tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func); |
885 | Target tgt = pfunc->GetAttr<Target>(tvm::attr::kTarget).value(); |
886 | const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment); |
887 | if (function_metadata.count(global_var->name_hint)) { |
888 | updated_function_metadata.Set(global_var->name_hint, |
889 | function_metadata[global_var->name_hint]); |
890 | updated_function_metadata[global_var->name_hint]->workspace_sizes.Set(tgt, ws); |
891 | } else { |
892 | FunctionInfo finfo{{{tgt, ws}}, {}, {}, {{tgt, pfunc}}, {}}; |
893 | updated_function_metadata.Set(global_var->name_hint, finfo); |
894 | } |
895 | } |
896 | } |
897 | return updated_function_metadata; |
898 | } |
899 | |
900 | /*! |
901 | * \brief Run USMP to plan memory for lowered IRModule. |
902 | */ |
903 | IRModule PlanMemoryWithUSMP(const IRModule& mod) { |
904 | VLOG(1) << "Planning memory with USMP for module:" << std::endl << PrettyPrint(mod); |
905 | Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); |
906 | IRModule lowered_mod = mod->ShallowCopy(); |
907 | lowered_mod = tir::transform::UnifiedStaticMemoryPlanner()(lowered_mod); |
908 | function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_); |
909 | Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos = |
910 | lowered_mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs); |
911 | backend::FunctionInfo main_func_info = |
912 | lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info" ).value(); |
913 | main_func_info->workspace_sizes.clear(); |
914 | if (allocated_pool_infos) { |
915 | for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { |
916 | for (const auto& tgt : allocated_pool_info->pool_info->targets) { |
917 | VLOG(1) << "USMP requires target " << tgt->ToDebugString() << " to have pool size " |
918 | << allocated_pool_info->allocated_size->value; |
919 | size_t size = allocated_pool_info->allocated_size->value; |
920 | if (allocated_pool_info->pool_info->IsInstance<ConstantPoolInfoNode>()) { |
921 | size += main_func_info->constant_sizes.count(tgt) |
922 | ? main_func_info->constant_sizes[tgt]->value |
923 | : 0; |
924 | main_func_info->constant_sizes.Set(tgt, size); |
925 | } else if (allocated_pool_info->pool_info->IsInstance<WorkspacePoolInfoNode>()) { |
926 | size += main_func_info->workspace_sizes.count(tgt) |
927 | ? main_func_info->workspace_sizes[tgt]->value |
928 | : 0; |
929 | main_func_info->workspace_sizes.Set(tgt, size); |
930 | } else { |
931 | LOG(FATAL) << "Unknown pool type: " << allocated_pool_info->pool_info->GetTypeKey(); |
932 | } |
933 | } |
934 | } |
935 | } |
936 | function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info); |
937 | return lowered_mod; |
938 | } |
939 | |
940 | /*! |
941 | * \brief Run StorageRewrite to plan memory for lowered IRModule. |
942 | */ |
943 | IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) { |
944 | Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); |
945 | IRModule lowered_mod = mod->ShallowCopy(); |
946 | function_metadata_ = CalculateWorkspaceSizes(lowered_mod, function_metadata_); |
947 | // Running StorageRewrite just on the main function |
948 | tir::PrimFunc tir_main_func = |
949 | Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
950 | IRModule main_func_mod; |
951 | main_func_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), |
952 | tir_main_func); |
953 | main_func_mod = tir::transform::StorageRewrite()(main_func_mod); |
954 | lowered_mod->Update(lowered_mod->GetGlobalVar(::tvm::runtime::symbol::tvm_module_main), |
955 | main_func_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
956 | tir_main_func = |
957 | Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
958 | // Use the PrimFunc to calculate the workspace required to service the allocates |
959 | Integer main_workspace_size_bytes = |
960 | CalculateWorkspaceBytes(tir_main_func, workspace_byte_alignment); |
961 | backend::FunctionInfo main_func_info = |
962 | lowered_mod->GetAttr<backend::FunctionInfo>("main_func_info" ).value(); |
963 | main_func_info->workspace_sizes.Set(config_->host_target, main_workspace_size_bytes); |
964 | function_metadata_.Set(runtime::symbol::tvm_module_main, main_func_info); |
965 | return lowered_mod; |
966 | } |
967 | |
968 | /*! |
969 | * \brief Gets module workspace alignment from supplied executor or defaults to 16 |
970 | */ |
971 | Integer GetModuleWorkspaceByteAlignment(const IRModule& mod) { |
972 | Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value(); |
973 | return executor_config->GetAttr<Integer>("workspace-byte-alignment" ).value_or(16); |
974 | } |
975 | |
976 | /*! |
977 | * \brief Gets module constant alignment from supplied executor or defaults to 16 |
978 | */ |
979 | Integer GetModuleConstantByteAlignment(const IRModule& mod) { |
980 | Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value(); |
981 | return executor_config->GetAttr<Integer>("constant-byte-alignment" ).value_or(16); |
982 | } |
983 | |
984 | protected: |
985 | /*! \brief mod */ |
986 | runtime::Module* mod_; |
987 | /*! \brief list of input expressions (i.e., variable passed by the user) */ |
988 | std::vector<Var> input_vars_; |
989 | /*! \brief map of device contexts variables */ |
990 | Map<String, tir::Var> devices_; |
991 | /*! \brief map of GlobalVars to C Device API contexts */ |
992 | Map<GlobalVar, tir::Var> device_contexts_; |
993 | /*! \brief input and output variables belonging to the main function signature */ |
994 | Array<tir::Var> main_signature_; |
995 | /*! \brief input and output variables belonging to the main function signature */ |
996 | Map<tir::Var, tir::Buffer> main_buffer_map_; |
997 | /*! \brief maps input and output variables to TensorType which describe them */ |
998 | Map<tir::Var, TensorType> io_tensor_types_; |
999 | /*! \brief All available targets. */ |
1000 | CompilationConfig config_; |
1001 | /*! |
1002 | * \brief The type of kernel call to be emitted. |
1003 | * See CallType for more documentation. |
1004 | */ |
1005 | CallType call_type_; |
1006 | |
1007 | /*! |
1008 | * \brief parameters (i.e. ConstantNodes found in the graph). |
1009 | * These are take as inputs to the GraphRuntime. |
1010 | * Maps param name to a pair of storage_id and NDArray. At runtime, the storage_id can be |
1011 | * used to lookup the parameter. |
1012 | */ |
1013 | std::unordered_map<std::string, runtime::NDArray> params_; |
1014 | /*! \brief mapping between expression and parameters */ |
1015 | Map<Expr, String> params_by_expr_; |
1016 | /*! \brief mapping between parameter names ("p0", "p1", etc..) and storage identifiers*/ |
1017 | std::unordered_map<std::string, int64_t> param_storage_ids_; |
1018 | std::unordered_map<const tir::Var, const ConstantNode*, ObjectPtrHash, ObjectPtrEqual> |
1019 | constant_map_; |
1020 | |
1021 | /*! \brief plan memory of device result */ |
1022 | StorageMap storage_device_map_; |
1023 | /*! \brief mapping sid -> tir::Var */ |
1024 | std::unordered_map<int, tir::Var> sids_table_; |
1025 | /*! \brief lowered funcs */ |
1026 | Map<String, FunctionInfo> function_metadata_; |
1027 | /*! \brief the set of statements that make the program */ |
1028 | std::vector<tir::Stmt> stmts_; |
1029 | /*! \brief the list of return sids (note that the function might return more then one output */ |
1030 | std::vector<int> return_sid_; |
1031 | /*! \brief This is per IO var name counter to aid the generating unique names */ |
1032 | std::unordered_map<std::string, int> io_var_names_; |
1033 | /*! \brief A set of variables that are let bound. */ |
1034 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_vars_; |
1035 | |
1036 | public: |
1037 | AOTExecutorCodegen(runtime::Module* mod, const Array<Target>& targets) |
1038 | : mod_(mod), config_(transform::PassContext::Current(), targets) {} |
1039 | |
1040 | LoweredOutput Codegen(IRModule mod, relay::Function func, String mod_name) { |
1041 | VLOG_CONTEXT << "AOT" ; |
1042 | |
1043 | Runtime runtime_config = mod->GetAttr<Runtime>(tvm::attr::kRuntime).value(); |
1044 | Integer workspace_byte_alignment = GetModuleWorkspaceByteAlignment(mod); |
1045 | |
1046 | Executor executor_config = mod->GetAttr<Executor>(tvm::attr::kExecutor).value(); |
1047 | std::string interface_api = |
1048 | executor_config->GetAttr<String>("interface-api" ).value_or("packed" ); |
1049 | bool unpacked_api = executor_config->GetAttr<Bool>("unpacked-api" ).value_or(Bool(false)); |
1050 | |
1051 | // Validate choice of unpacked_api and use_call_cpacked_ |
1052 | if (runtime_config->name == kTvmRuntimeCrt) { |
1053 | if (unpacked_api == true) { |
1054 | call_type_ = CallType::kUnpacked; |
1055 | } else if (unpacked_api == false && interface_api == "packed" ) { |
1056 | call_type_ = CallType::kCPacked; |
1057 | } else { |
1058 | CHECK(interface_api == "packed" || unpacked_api == true) |
1059 | << "Either need interface_api == \"packed\" (got: " << interface_api |
1060 | << ") or unpacked-api == true (got: " << unpacked_api << ") when targeting c runtime" ; |
1061 | ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api |
1062 | << ", unpacked-api=" << unpacked_api; |
1063 | } |
1064 | } else if (runtime_config->name == kTvmRuntimeCpp) { |
1065 | if (unpacked_api == false && interface_api == "packed" ) { |
1066 | call_type_ = CallType::kCPacked; |
1067 | } else { |
1068 | CHECK(static_cast<bool>(unpacked_api) == false && interface_api == "packed" ) |
1069 | << "Need unpacked-api == false (got: " << unpacked_api |
1070 | << ") and interface-api == \"packed\" (got: " << interface_api |
1071 | << ") when targeting c++ runtime" ; |
1072 | ICHECK(false) << "Unhandled executor option config: interface-api=" << interface_api |
1073 | << ", unpacked-api=" << unpacked_api; |
1074 | } |
1075 | } else { |
1076 | ICHECK(false) << "runtime_config (" << runtime_config->name |
1077 | << ") is not one of the expected values" ; |
1078 | } |
1079 | |
1080 | mod = transform::ToANormalForm()(mod); |
1081 | mod = transform::InferType()(mod); |
1082 | mod = transform::AnnotateUsedMemory()(mod); |
1083 | |
1084 | IRModule lowered_mod = |
1085 | tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) { |
1086 | // We need to maintain the constant map for external |
1087 | // functions so we pass this processing function which |
1088 | // allows us to process each function as we lower it. |
1089 | if (func->GetAttr<String>(attr::kCompiler).defined()) { |
1090 | UpdateConstants(func, ¶ms_); |
1091 | } |
1092 | |
1093 | // TODO(@areusch, @jroesch): We should refactor this to |
1094 | // execute as a further pass, instead writing data to the |
1095 | // lowering process directly. |
1096 | tec::UpdateFunctionMetadata(func, this->function_metadata_, workspace_byte_alignment); |
1097 | })(mod); |
1098 | |
1099 | transform::PassContext pass_ctx = transform::PassContext::Current(); |
1100 | bool enable_remove_reshapes = |
1101 | pass_ctx->GetConfig<Bool>("relay.remove_standalone_reshapes.enable" , Bool(true)).value(); |
1102 | if (enable_remove_reshapes) { |
1103 | lowered_mod = transform::RemoveStandaloneReshapes()(lowered_mod); |
1104 | } |
1105 | auto lowered_main = lowered_mod->Lookup("main" ); |
1106 | auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>()); |
1107 | |
1108 | // Post-lowering storage map for writing main func |
1109 | AOTOnDemandAllocator final_aot_allocator; |
1110 | final_aot_allocator.Run(lowered_main_func); |
1111 | storage_device_map_ = final_aot_allocator.GetStorageMap(); |
1112 | |
1113 | // TODO(@electriclilies, @jroesch, @Mousius): remove UpdateMainWorkspaceSize |
1114 | StaticMemoryPlan memory_plan(storage_device_map_); |
1115 | backend::FunctionInfo func_info = |
1116 | tec::UpdateMainWorkspaceSize(lowered_mod, config_, memory_plan->expr_to_storage_info); |
1117 | lowered_mod = WithAttr(lowered_mod, "main_func_info" , func_info); |
1118 | |
1119 | for (auto input : lowered_main_func->params) { |
1120 | input_vars_.push_back(input); |
1121 | std::string input_name = SanitizeName(input->name_hint()); |
1122 | // We dont want the compiler changing input names in the |
1123 | // event of a sanitization collision. Therefore, enforcing |
1124 | // the var created to use the input_name strictly. |
1125 | CreateIOVar(input, input_name, /*use_unique_name = */ false); |
1126 | } |
1127 | |
1128 | // Define the storage allocator ids |
1129 | for (auto kv : storage_device_map_) { |
1130 | for (auto sid : kv.second->storage_ids) { |
1131 | // The buffer_var is created with storage_scope to be global.workspace to be serviced by |
1132 | // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor |
1133 | // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and |
1134 | // should not be lowered to the stack. For more details please refer to the discussion here: |
1135 | // https://github.com/apache/tvm/issues/9022 |
1136 | te::Var buffer_var(MakeString("sid_" , sid), |
1137 | PointerType(PrimType(DataType::Int(8)), "global.workspace" )); |
1138 | sids_table_[sid] = buffer_var; |
1139 | } |
1140 | } |
1141 | |
1142 | // Retrieve the return sids |
1143 | return_sid_ = final_aot_allocator.GetReturnIds(); |
1144 | // Insert outputs to main func signature |
1145 | // If output tensor names were provided use them |
1146 | if (auto opt = func->GetAttr<Array<String>>("output_tensor_names" )) { |
1147 | Array<String> output_tensor_names = opt.value(); |
1148 | Expr output_expr = lowered_main_func->body; |
1149 | if (output_expr->checked_type()->IsInstance<TupleTypeNode>()) { |
1150 | TupleType output_tuple_type = Downcast<TupleType>(output_expr->checked_type()); |
1151 | for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) { |
1152 | // AoT Executor Codegen does not create these names, |
1153 | // thus should be used as they are provided. |
1154 | CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i], |
1155 | /*use_unique_name = */ false); |
1156 | } |
1157 | } else { |
1158 | // AoT Executor Codegen does not create these names, |
1159 | // thus should be used as they are provided. |
1160 | CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false); |
1161 | } |
1162 | } else { |
1163 | // If output tensor names are not provided we will generate output(x) |
1164 | // where x is a counter to create unique names. |
1165 | CreateIOVar(lowered_main_func->body, "output" ); |
1166 | } |
1167 | |
1168 | CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts" ).value()); |
1169 | VisitExpr(lowered_main_func->body); |
1170 | |
1171 | // Create the runner function. Please note that the function is not legal yet |
1172 | // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need |
1173 | // to run the LegalizePackedCalls pass. |
1174 | LoweredOutput ret; |
1175 | |
1176 | // Collect any constants extracted by external codegen. |
1177 | ret.params = std::unordered_map<std::string, tvm::runtime::NDArray>(); |
1178 | Map<String, runtime::NDArray> const_name_to_constant = |
1179 | lowered_mod->GetAttr<Map<String, runtime::NDArray>>(tvm::attr::kConstNameToConstant) |
1180 | .value_or({}); |
1181 | for (const auto& kv : const_name_to_constant) { |
1182 | ICHECK(ret.params.emplace(kv.first, kv.second).second); |
1183 | } |
1184 | |
1185 | // Collect any constants extracted during lowering. |
1186 | for (const auto& kv : params_) { |
1187 | ICHECK(ret.params.emplace(kv.first, kv.second).second); |
1188 | } |
1189 | |
1190 | // AoT Executor codegen works completely on TIR beyond this point, hence removing relay main |
1191 | // function and replacing it with its TIR version. We should try to make this a Pass. |
1192 | lowered_mod->Remove(lowered_mod->GetGlobalVar("main" )); |
1193 | auto tir_main_func = CreateMainFunc(mod_name, lowered_main_func->params.size()); |
1194 | // Extract additional information around main TIR PrimFunc arguments |
1195 | Array<String> devices = ListDevices(); |
1196 | const auto main_func_params_end_iterator = |
1197 | tir_main_func->params.begin() + tir_main_func->params.size(); |
1198 | const auto outputs_begin_iterator = |
1199 | main_func_params_end_iterator - return_sid_.size() - devices.size(); |
1200 | Array<tir::Var> inputs = Array<tir::Var>(tir_main_func->params.begin(), outputs_begin_iterator); |
1201 | Array<TensorType> input_tensor_types; |
1202 | for (auto i : inputs) { |
1203 | input_tensor_types.push_back(io_tensor_types_[i]); |
1204 | } |
1205 | Array<tir::Var> outputs = |
1206 | Array<tir::Var>(outputs_begin_iterator, main_func_params_end_iterator - devices.size()); |
1207 | |
1208 | lowered_mod->Update(GlobalVar(::tvm::runtime::symbol::tvm_module_main), tir_main_func); |
1209 | // Parallel for loops are not supported in AoT codegen. |
1210 | lowered_mod = tir::transform::ConvertForLoopsToSerial()(lowered_mod); |
1211 | |
1212 | bool enable_usmp = pass_ctx->GetConfig<Bool>(kUSMPEnableOption, Bool(false)).value(); |
1213 | if (enable_usmp) { |
1214 | lowered_mod = PlanMemoryWithUSMP(lowered_mod); |
1215 | } else { |
1216 | lowered_mod = PlanMemoryWithStorageRewrite(lowered_mod); |
1217 | } |
1218 | ret.function_metadata = std::move(function_metadata_); |
1219 | |
1220 | // Legalize AOT if needed. This means that all the packed calls |
1221 | // need to be wrapped in TVMValues (unless unpacked_api is set) |
1222 | if (call_type_ == CallType::kCPacked || call_type_ == CallType::kPacked) { |
1223 | auto pack_calls = tir::transform::LegalizePackedCalls(); |
1224 | lowered_mod = pack_calls(lowered_mod); |
1225 | } |
1226 | |
1227 | // Collect any runtime modules generated by external codegen. |
1228 | ret.external_mods = |
1229 | lowered_mod->GetAttr<Array<tvm::runtime::Module>>(tvm::attr::kExternalMods).value_or({}); |
1230 | |
1231 | // This is the point where we separate the functions in the module by target |
1232 | VLOG(1) << "lowered module:" << std::endl << PrettyPrint(lowered_mod); |
1233 | ret.lowered_funcs = tec::GetPerTargetModules(lowered_mod); |
1234 | VLOG(1) << "per-target modules:" ; |
1235 | for (const auto& kv : ret.lowered_funcs) { |
1236 | VLOG(1) << "target:" << std::endl |
1237 | << kv.first->ToDebugString() << std::endl |
1238 | << "maps to:" << std::endl |
1239 | << PrettyPrint(kv.second); |
1240 | } |
1241 | |
1242 | // Extract USMP metadata to pass onto metadata sources |
1243 | Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_var_info; |
1244 | std::vector<tir::Var> pool_vars; |
1245 | tir_main_func = |
1246 | Downcast<tir::PrimFunc>(lowered_mod->Lookup(::tvm::runtime::symbol::tvm_module_main)); |
1247 | Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos = |
1248 | tir_main_func->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs); |
1249 | if (allocated_pool_infos) { |
1250 | for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { |
1251 | int pool_var_index = allocated_pool_info->pool_var_idx.value()->value; |
1252 | pool_vars.push_back(tir_main_func->params[pool_var_index]); |
1253 | pool_var_info.Set(tir_main_func->params[pool_var_index], allocated_pool_info); |
1254 | } |
1255 | } |
1256 | Map<String, tir::usmp::PoolAllocation> io_pool_allocations = |
1257 | lowered_mod |
1258 | ->GetAttr<Map<String, tir::usmp::PoolAllocation>>(tvm::attr::kIOTensorPoolAllocations) |
1259 | .value_or({}); |
1260 | |
1261 | std::vector<String> output_var_names; |
1262 | if (auto opt = func->GetAttr<Array<String>>("output_tensor_names" )) { |
1263 | Array<String> output_tensor_names = opt.value(); |
1264 | for (size_t i = 0; i < output_tensor_names.size(); ++i) { |
1265 | output_var_names.push_back(output_tensor_names[i]); |
1266 | } |
1267 | } |
1268 | |
1269 | // If output names have not been specified then generate default output names |
1270 | if (output_var_names.size() == 0) { |
1271 | if (return_sid_.size() == 1) { |
1272 | output_var_names.push_back(String("output" )); |
1273 | } else { |
1274 | for (size_t i = 0; i < return_sid_.size(); ++i) { |
1275 | output_var_names.push_back(String("output" + std::to_string(i))); |
1276 | } |
1277 | } |
1278 | } |
1279 | |
1280 | Array<TensorType> output_tensor_types{final_aot_allocator.GetReturnTtypes()}; |
1281 | |
1282 | ret.metadata = ExecutorCodegenMetadata( |
1283 | inputs, input_tensor_types, output_var_names, output_tensor_types, pool_vars, devices, |
1284 | runtime::kTvmExecutorAot, mod_name, interface_api, unpacked_api, |
1285 | GetModuleWorkspaceByteAlignment(mod), GetModuleConstantByteAlignment(mod), pool_var_info, |
1286 | io_pool_allocations); |
1287 | return ret; |
1288 | } |
1289 | |
1290 | /*! |
1291 | * \brief Get list of devices found |
1292 | * \return List of devices |
1293 | */ |
1294 | Array<String> ListDevices() { |
1295 | std::vector<String> device_names(devices_.size()); |
1296 | std::transform(devices_.begin(), devices_.end(), device_names.begin(), |
1297 | [](const auto& it) -> String { return it.first; }); |
1298 | return device_names; |
1299 | } |
1300 | }; // namespace backend |
1301 | |
1302 | class AOTExecutorCodegenModule : public runtime::ModuleNode { |
1303 | public: |
1304 | AOTExecutorCodegenModule() {} |
1305 | virtual PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self) { |
1306 | if (name == "init" ) { |
1307 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
1308 | ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " |
1309 | << "runtime::Module mod and Array<Target> targets" ; |
1310 | void* mod = args[0]; |
1311 | Array<Target> targets = args[1]; |
1312 | init(mod, targets); |
1313 | }); |
1314 | } else if (name == "codegen" ) { |
1315 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
1316 | IRModule mod = args[0]; |
1317 | Function func = args[1]; |
1318 | String mod_name = args[2]; |
1319 | this->output_ = this->codegen_->Codegen(mod, func, mod_name); |
1320 | }); |
1321 | } else if (name == "list_params_name" ) { |
1322 | return PackedFunc( |
1323 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = list_params_name(); }); |
1324 | } else if (name == "get_param_by_name" ) { |
1325 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
1326 | String key = args[0]; |
1327 | *rv = get_param_by_name(key); |
1328 | }); |
1329 | } else if (name == "get_irmodule" ) { |
1330 | return PackedFunc( |
1331 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_irmodule(); }); |
1332 | } else if (name == "get_external_modules" ) { |
1333 | return PackedFunc( |
1334 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = get_external_modules(); }); |
1335 | } else if (name == "get_function_metadata" ) { |
1336 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
1337 | *rv = this->output_.function_metadata; |
1338 | }); |
1339 | } else if (name == "get_devices" ) { |
1340 | return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { |
1341 | *rv = this->codegen_->ListDevices(); |
1342 | }); |
1343 | } else if (name == "get_executor_codegen_metadata" ) { |
1344 | return PackedFunc( |
1345 | [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; }); |
1346 | } else { |
1347 | return PackedFunc([](TVMArgs args, TVMRetValue* rv) {}); |
1348 | } |
1349 | } |
1350 | |
1351 | const char* type_key() const final { return "RelayGraphRuntimeCodegenModule" ; } |
1352 | |
1353 | private: |
1354 | void init(void* mod, const Array<Target>& targets) { |
1355 | codegen_ = |
1356 | std::make_shared<AOTExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod), targets); |
1357 | } |
1358 | |
1359 | Array<runtime::String> list_params_name() { |
1360 | Array<runtime::String> ret; |
1361 | for (const auto& kv : this->output_.params) { |
1362 | ret.push_back(kv.first); |
1363 | } |
1364 | return ret; |
1365 | } |
1366 | |
1367 | runtime::NDArray get_param_by_name(String key) { |
1368 | auto it = this->output_.params.find(key); |
1369 | CHECK(it != this->output_.params.end()) << "no such parameter " << key; |
1370 | return (*it).second; |
1371 | } |
1372 | |
1373 | Array<tvm::runtime::Module> get_external_modules() { return output_.external_mods; } |
1374 | |
1375 | Map<Target, IRModule> get_irmodule() { return this->output_.lowered_funcs; } |
1376 | |
1377 | std::shared_ptr<AOTExecutorCodegen> codegen_; |
1378 | LoweredOutput output_; |
1379 | }; |
1380 | |
1381 | runtime::Module CreateAOTExecutorCodegenMod() { |
1382 | auto ptr = make_object<AOTExecutorCodegenModule>(); |
1383 | return runtime::Module(ptr); |
1384 | } |
1385 | |
1386 | TVM_REGISTER_GLOBAL("relay.build_module._AOTExecutorCodegen" ) |
1387 | .set_body([](TVMArgs args, TVMRetValue* rv) { *rv = CreateAOTExecutorCodegenMod(); }); |
1388 | |
1389 | } // namespace backend |
1390 | } // namespace relay |
1391 | } // namespace tvm |
1392 | |