1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/aot/aot_lower_main.cc |
22 | * \brief Lower the Relay main func into an AOT TIR main func. |
23 | */ |
24 | #include "./aot_lower_main.h" |
25 | |
26 | #include <tvm/runtime/name_transforms.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include "../../op/call/call.h" |
31 | #include "../../op/memory/device_copy.h" |
32 | #include "../../op/memory/memory.h" |
33 | #include "../../transforms/device_aware_visitors.h" |
34 | #include "../name_transforms.h" |
35 | #include "../utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | namespace backend { |
40 | namespace aot { |
41 | |
42 | /*! |
43 | * \brief Looks at the expressions in a given function and produces an Expr to |
44 | * StorageInfo map by assigning one or more StorageInfos to the expressions that |
45 | * require storage. |
46 | * |
47 | * This pass is leveraged by AOTMainLowerer to perform an initial naive allocation |
48 | * for tensors in the Relay main function. The resulting storage map is then lowered |
49 | * into TIR allocations by AOTMainLowerer where the allocation can be subsequently |
50 | * optimized by later passes (e.g. USMP). |
51 | */ |
52 | class ExprAllocator : public transform::DeviceAwareExprVisitor { |
53 | public: |
54 | ExprAllocator() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {} |
55 | |
56 | // run the visitor on a global function. |
57 | void Run(const Function& func) { VisitExpr(func); } |
58 | |
59 | std::vector<int> GetReturnSIDs() const { return return_sids_; } |
60 | |
61 | StorageMap GetStorageMap() const { return expr_storage_map_; } |
62 | |
63 | using ExprVisitor::VisitExpr_; |
64 | |
65 | void DeviceAwareVisitExpr_(const CallNode* call_node) final { |
66 | Array<Expr> args; |
67 | |
68 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
69 | if (call_lowered_props.lowered_func.defined()) { |
70 | args = call_lowered_props.arguments; |
71 | } else { // Relay functions that have not been lowered and lowered extern functions |
72 | args = call_node->args; |
73 | if (call_node->op.as<GlobalVarNode>()) { // Lowered extern function |
74 | ICHECK(!(call_node->attrs.defined())) << "Extern functions should have null attributes." ; |
75 | } else { // Relay function which has not been lowered yet |
76 | ICHECK(call_node->op.as<FunctionNode>()) |
77 | << "Expected the call to be to a lowered primfunc, a lowered extern function or a " |
78 | "unlowered Relay function." ; |
79 | } |
80 | } |
81 | CreateStorage(call_node); |
82 | for (const Expr& arg : args) { |
83 | VisitExpr(arg); |
84 | } |
85 | AssignReturnSID(GetRef<Expr>(call_node)); |
86 | } |
87 | |
88 | void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { |
89 | if (function_nesting() > 1) { |
90 | // Do not recurse into sub functions. |
91 | return; |
92 | } |
93 | for (const auto& param : func_node->params) { |
94 | CreateStorage(param.get()); |
95 | } |
96 | VisitExpr(func_node->body); |
97 | } |
98 | |
99 | void PreVisitLetBinding_(const Var& var, const Expr& value) final { |
100 | VisitExpr(value); |
101 | StorageInfo si = GetStorage(value); |
102 | expr_storage_map_[var] = si; |
103 | } |
104 | |
105 | void VisitExpr_(const ConstantNode* op) final { |
106 | CreateStorage(op); |
107 | AssignReturnSID(GetRef<Expr>(op)); |
108 | } |
109 | |
110 | void VisitExpr_(const VarNode* op) final { AssignReturnSID(GetRef<Expr>(op)); } |
111 | |
112 | void VisitExpr_(const TupleNode* op) final { |
113 | std::vector<int64_t> storage_ids; |
114 | std::vector<VirtualDevice> virtual_devices; |
115 | std::vector<int64_t> storage_sizes_in_bytes; |
116 | Expr expr = GetRef<Expr>(op); |
117 | for (Expr field : op->fields) { |
118 | auto sid = GetStorage(field); |
119 | storage_ids.insert(storage_ids.end(), sid->storage_ids.begin(), sid->storage_ids.end()); |
120 | virtual_devices.insert(virtual_devices.end(), sid->virtual_devices.begin(), |
121 | sid->virtual_devices.end()); |
122 | storage_sizes_in_bytes.insert(storage_sizes_in_bytes.end(), |
123 | sid->storage_sizes_in_bytes.begin(), |
124 | sid->storage_sizes_in_bytes.end()); |
125 | } |
126 | expr_storage_map_[expr] = StorageInfo(storage_ids, virtual_devices, storage_sizes_in_bytes); |
127 | AssignReturnSID(expr); |
128 | } |
129 | |
130 | void VisitExpr_(const TupleGetItemNode* op) final { |
131 | Expr expr = GetRef<Expr>(op); |
132 | auto sids = GetStorage(op->tuple); |
133 | ICHECK_LT(static_cast<size_t>(op->index), sids->storage_ids.size()); |
134 | expr_storage_map_[expr] = |
135 | StorageInfo({sids->storage_ids[op->index]}, {sids->virtual_devices[op->index]}, |
136 | {sids->storage_sizes_in_bytes[op->index]}); |
137 | AssignReturnSID(expr); |
138 | } |
139 | |
140 | void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "'If' is not supported." ; } |
141 | |
142 | private: |
143 | /*! |
144 | * \brief Assign the expression's storage IDs as the return storage IDs. |
145 | * \note This is called when visiting every expression on the understanding |
146 | * that the returned expression will be visited last. |
147 | */ |
148 | void AssignReturnSID(const Expr& e) { |
149 | if (expr_storage_map_.find(e) != expr_storage_map_.end()) { |
150 | StorageInfo& sinfo = expr_storage_map_[e]; |
151 | return_sids_.clear(); |
152 | for (auto sid : sinfo->storage_ids) { |
153 | return_sids_.push_back(sid); |
154 | } |
155 | } |
156 | } |
157 | |
158 | /*! |
159 | * \brief Get the necessary storage for the expression. |
160 | * \param expr The expression. |
161 | * \return The corresponding token. |
162 | */ |
163 | StorageInfo GetStorage(const Expr& expr) { |
164 | // See through "on_device" calls. |
165 | Expr true_expr = IgnoreOnDevice(expr); |
166 | VisitExpr(true_expr); |
167 | auto it = expr_storage_map_.find(true_expr); |
168 | ICHECK(it != expr_storage_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " " |
169 | << PrettyPrint(true_expr) << " in storage device map" ; |
170 | return it->second; |
171 | } |
172 | |
173 | /*! |
174 | * \brief Create storage for the expression. |
175 | */ |
176 | void CreateStorage(const ExprNode* op) { |
177 | Expr expr = GetRef<Expr>(op); |
178 | return CreateStorage(expr, GetVirtualDevice(expr)); |
179 | } |
180 | |
181 | /*! |
182 | * \brief Create storage to hold the result of evaluating \p expr in \p virtual_device. |
183 | */ |
184 | void CreateStorage(const Expr& expr, const VirtualDevice& virtual_device) { |
185 | ICHECK(!virtual_device->IsFullyUnconstrained()) |
186 | << "invalid virtual device for expr:" << std::endl |
187 | << PrettyPrint(expr); |
188 | std::vector<int64_t> storage_ids; |
189 | std::vector<VirtualDevice> virtual_devices; |
190 | std::vector<int64_t> storage_sizes_in_bytes; |
191 | for (const auto& ttype : FlattenTupleType(expr->checked_type())) { |
192 | storage_ids.push_back(next_available_sid_++); |
193 | virtual_devices.push_back(virtual_device); |
194 | storage_sizes_in_bytes.push_back(GetMemorySizeBytes(ttype->shape, ttype->dtype)); |
195 | } |
196 | expr_storage_map_[expr] = StorageInfo(std::move(storage_ids), std::move(virtual_devices), |
197 | std::move(storage_sizes_in_bytes)); |
198 | } |
199 | |
200 | /*! \brief Map between Exprs and StorageInfos */ |
201 | StorageMap expr_storage_map_; |
202 | /*! \brief The next available storage ID to be used */ |
203 | int next_available_sid_{0}; |
204 | /*! \brief The storage IDs that correspond to return values */ |
205 | std::vector<int> return_sids_; |
206 | }; |
207 | |
208 | std::tuple<StorageMap, std::vector<int>> CreateStorage(const Function& func) { |
209 | ExprAllocator expr_allocator; |
210 | expr_allocator.Run(func); |
211 | return std::make_tuple(expr_allocator.GetStorageMap(), expr_allocator.GetReturnSIDs()); |
212 | } |
213 | |
214 | class AOTMainLowerer : public MixedModeVisitor { |
215 | public: |
216 | AOTMainLowerer(tvm::CompilationConfig config, CallType call_type) |
217 | : config_(config), call_type_(call_type) {} |
218 | |
219 | IRModule Lower(IRModule mod, String mod_name) { |
220 | VLOG_CONTEXT << "AOT" ; |
221 | IRModule lowered_mod = GetRef<IRModule>(mod.CopyOnWrite()); |
222 | |
223 | auto lowered_main = lowered_mod->Lookup("main" ); |
224 | auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>()); |
225 | |
226 | // Assign StorageInfo to all the Relay exprs and get the return SIDs |
227 | std::tie(expr_storage_map_, return_sid_) = CreateStorage(lowered_main_func); |
228 | |
229 | for (auto input : lowered_main_func->params) { |
230 | input_vars_.push_back(input); |
231 | std::string input_name = tvm::runtime::SanitizeName(input->name_hint()); |
232 | // We don't want the compiler changing input names in the |
233 | // event of a sanitization collision. Therefore, enforcing |
234 | // the var created to use the input_name strictly. |
235 | CreateIOVar(input, input_name, /*use_unique_name = */ false); |
236 | } |
237 | |
238 | // Define the storage allocator ids |
239 | for (auto kv : expr_storage_map_) { |
240 | for (auto sid : kv.second->storage_ids) { |
241 | // The buffer_var is created with storage_scope to be global.workspace to be serviced by |
242 | // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor |
243 | // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and |
244 | // should not be lowered to the stack. For more details please refer to the discussion here: |
245 | // https://github.com/apache/tvm/issues/9022 |
246 | tir::Var buffer_var(MakeString("sid_" , sid), |
247 | PointerType(PrimType(DataType::Int(8)), "global.workspace" )); |
248 | sids_table_[sid] = buffer_var; |
249 | } |
250 | } |
251 | |
252 | // Create output vars for the TIR main func |
253 | // If output tensor names were provided use them |
254 | if (auto opt = lowered_main->GetAttr<Array<String>>("output_tensor_names" )) { |
255 | Array<String> output_tensor_names = opt.value(); |
256 | Expr output_expr = lowered_main_func->body; |
257 | if (output_expr->checked_type()->IsInstance<TupleTypeNode>()) { |
258 | TupleType output_tuple_type = Downcast<TupleType>(output_expr->checked_type()); |
259 | for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) { |
260 | // AoT Executor Codegen does not create these names, |
261 | // thus should be used as they are provided. |
262 | CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i], |
263 | /*use_unique_name = */ false); |
264 | } |
265 | } else { |
266 | // AoT Executor Codegen does not create these names, |
267 | // thus should be used as they are provided. |
268 | CreateIOVar(lowered_main_func->body, output_tensor_names[0], /*use_unique_name = */ false); |
269 | } |
270 | } else { |
271 | // If output tensor names are not provided we will generate output(x) |
272 | // where x is a counter to create unique names. |
273 | if (lowered_main_func->body->checked_type()->IsInstance<TupleTypeNode>()) { |
274 | CreateIOVar(lowered_main_func->body, "output" ); |
275 | } else { |
276 | CreateIOVar(lowered_main_func->body, "output" , /*use_unique_name = */ false); |
277 | } |
278 | } |
279 | |
280 | CollectDeviceVariables(lowered_mod->GetAttr<Map<GlobalVar, String>>("device_contexts" ) |
281 | .value_or(Map<GlobalVar, String>())); |
282 | VisitExpr(lowered_main_func->body); |
283 | |
284 | // Remove the Relay main and replace it with the lowered TIR version |
285 | lowered_mod->Remove(lowered_mod->GetGlobalVar("main" )); |
286 | auto tir_main_func = CreateMainFunc(mod_name); |
287 | lowered_mod->Update(GlobalVar(runtime::symbol::tvm_module_main), tir_main_func); |
288 | lowered_mod = tir::transform::RemoveNoOp()(lowered_mod); |
289 | return lowered_mod; |
290 | } |
291 | |
292 | void VisitExpr_(const CallNode* call_node) override { |
293 | OnDeviceProps on_device_props = GetOnDeviceProps(call_node); |
294 | if (on_device_props.body.defined()) { |
295 | VisitExpr(on_device_props.body); |
296 | return; |
297 | } |
298 | |
299 | DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); |
300 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
301 | |
302 | if (device_copy_props.body.defined()) { |
303 | // TODO(mbs): device_copy cleaunp |
304 | // Suspect treating as no-op is better since already built into the StorageInfo? |
305 | LOG(FATAL) << "The AOT executor does not currently support device_copy" ; |
306 | } |
307 | |
308 | // At this point we should only see calls of the form call_lowered(@callee, (args...)), |
309 | // where @callee can be a PrimFunc we've compiled or an external function supplied via |
310 | // some other mechanism. |
311 | ICHECK(call_lowered_props.lowered_func.defined()) |
312 | << "AOT does not support calling Relay functions. Attempting to call:" << std::endl |
313 | << PrettyPrint(GetRef<Call>(call_node)); |
314 | for (const auto& arg : call_lowered_props.arguments) { |
315 | // Evaluate the args |
316 | VisitExpr(arg); |
317 | } |
318 | CreateFuncCall(call_lowered_props, GetRef<Call>(call_node)); |
319 | } |
320 | |
321 | void VisitExpr_(const VarNode* op) override { |
322 | Expr expr = GetRef<Expr>(op); |
323 | StorageInfo& sinfo = expr_storage_map_[expr]; |
324 | |
325 | // Let bound vars refer to a value, so these should not be considered "output" vars. |
326 | if (let_bound_vars_.find(GetRef<Var>(op)) != let_bound_vars_.end()) { |
327 | return; |
328 | } |
329 | |
330 | // If the Var node is an output node we need to copy the content of the variable to the output |
331 | // It's safe to check the SID here because Var StorageToken are never reallocated |
332 | auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); |
333 | if (output_iter != return_sid_.end()) { |
334 | int output_index = std::distance(return_sid_.begin(), output_iter); |
335 | auto var_expr = FindExpr(expr); |
336 | CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], |
337 | /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); |
338 | } |
339 | } |
340 | |
341 | void VisitExpr_(const ConstantNode* op) override { |
342 | Expr expr = GetRef<Expr>(op); |
343 | ICHECK(expr_storage_map_.find(expr) != expr_storage_map_.end()) |
344 | << "Storage map did not contain constant expr " << PrettyPrint(expr); |
345 | StorageInfo& sinfo = expr_storage_map_[expr]; |
346 | std::stringstream ss; |
347 | ss << "constant_" << constant_map_.size(); |
348 | |
349 | tir::Var constant(ss.str(), PointerType(PrimType(DataType(op->data->dtype)))); |
350 | constant_map_[constant] = op; |
351 | auto sid = sinfo->storage_ids[0]; |
352 | sids_table_[sid] = constant; |
353 | |
354 | // If the Constant node is an output node we need to copy the content of the parameter to the |
355 | // output. A node can only produce a single output |
356 | auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); |
357 | if (output_iter != return_sid_.end()) { |
358 | int output_index = std::distance(return_sid_.begin(), output_iter); |
359 | auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), |
360 | {tir::StringImm(ss.str())}); |
361 | CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), constant, |
362 | /* pack_input */ false, sinfo->storage_sizes_in_bytes[0]); |
363 | } |
364 | } |
365 | |
366 | void VisitExpr_(const TupleNode* op) override { |
367 | for (auto field : op->fields) { |
368 | VisitExpr(field); |
369 | } |
370 | } |
371 | |
372 | void VisitExpr_(const LetNode* op) override { |
373 | auto pre_visit = [this](const LetNode* op) { |
374 | let_bound_vars_.insert(op->var); |
375 | this->VisitExpr(op->value); |
376 | }; |
377 | auto post_visit = [this](const LetNode* op) { |
378 | this->VisitExpr(op->body); |
379 | this->visit_counter_[op] += 1; |
380 | }; |
381 | ExpandANormalForm(op, pre_visit, post_visit); |
382 | } |
383 | |
384 | void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } |
385 | void VisitExpr_(const OpNode* op) override { |
386 | if (GetRef<Op>(op) != CallLoweredOp() && GetRef<Op>(op) != OnDeviceOp()) { |
387 | LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded" ; |
388 | } |
389 | } |
390 | void VisitExpr_(const IfNode* op) override { |
391 | LOG(FATAL) << "All GlobalVarNodes should be removed before AOT executor's Codegen is called" ; |
392 | } |
393 | void VisitExpr_(const FunctionNode* op) override { |
394 | ICHECK(op->GetAttr<String>(attr::kCompiler).defined()) |
395 | << "FunctionNode only supported by custom codegen" ; |
396 | } |
397 | void VisitExpr_(const RefCreateNode* op) override { |
398 | LOG(FATAL) << "AOT executor does not support references (found RefCreateNode)" ; |
399 | } |
400 | void VisitExpr_(const RefReadNode* op) override { |
401 | LOG(FATAL) << "AOT executor does not support references (found RefReadNode)" ; |
402 | } |
403 | void VisitExpr_(const RefWriteNode* op) override { |
404 | LOG(FATAL) << "AOT executor does not support references (found RefWriteNode)" ; |
405 | } |
406 | void VisitExpr_(const ConstructorNode* op) override { |
407 | LOG(FATAL) << "AOT executor does not support ADTs (found ConstructorNode)" ; |
408 | } |
409 | void VisitExpr_(const MatchNode* op) override { |
410 | LOG(FATAL) << "AOT executor does not support matching (found MatchNode)" ; |
411 | } |
412 | |
413 | private: |
414 | /*! |
415 | * \brief Create the main PrimFunc to execute the graph. |
416 | * \note The packed function calls don't pack their arguments. The AOT |
417 | * runner function needs to be legalized by the LegalizePackedCalls pass. |
418 | */ |
419 | tir::PrimFunc CreateMainFunc(String mod_name) { |
420 | tir::Stmt body = tir::SeqStmt(stmts_); |
421 | // Allocate the sids |
422 | std::unordered_map<int, bool> allocated; |
423 | std::vector<std::pair<int64_t, int64_t>> sids_to_allocate; |
424 | |
425 | for (auto kv : expr_storage_map_) { |
426 | // Only allocate sids that are needed |
427 | const bool is_input = |
428 | (std::find(input_vars_.begin(), input_vars_.end(), kv.first) != input_vars_.end()); |
429 | if (is_input) { |
430 | continue; |
431 | } |
432 | |
433 | for (unsigned int i = 0; i < kv.second->storage_ids.size(); i++) { |
434 | sids_to_allocate.push_back( |
435 | std::make_pair(kv.second->storage_ids[i], kv.second->storage_sizes_in_bytes[i])); |
436 | } |
437 | } |
438 | |
439 | // Sort the SID allocation to make output deterministic |
440 | std::sort(sids_to_allocate.begin(), sids_to_allocate.end()); |
441 | |
442 | for (auto p : sids_to_allocate) { |
443 | int sid = p.first; |
444 | int size = p.second; |
445 | |
446 | if (std::find(return_sid_.begin(), return_sid_.end(), sid) != return_sid_.end()) { |
447 | continue; |
448 | } |
449 | |
450 | // Make sure it hasn't already been allocated, this can happen |
451 | // with let-bound var/value pairs. |
452 | if (allocated.find(sid) != allocated.end()) { |
453 | continue; |
454 | } |
455 | |
456 | allocated[sid] = constant_map_.count(sids_table_[sid]); |
457 | |
458 | // TODO(giuseros): we should allocate this once outside the PrimFunc |
459 | // so we don't pay the price of allocation for every inference |
460 | if (!allocated[sid]) { |
461 | PointerType ptype = Downcast<PointerType>(sids_table_[sid]->type_annotation); |
462 | DataType element_type = Downcast<PrimType>(ptype->element_type)->dtype; |
463 | body = tir::Allocate(sids_table_[sid], element_type, {size}, tir::const_true(), body); |
464 | } |
465 | allocated[sid] = true; |
466 | } |
467 | |
468 | for (auto kv : constant_map_) { |
469 | auto buffer_var = kv.first; |
470 | auto dtype = DataType(kv.second->data->dtype); |
471 | |
472 | int ndim = kv.second->data->ndim; |
473 | Array<PrimExpr> extents; |
474 | |
475 | for (int i = 0; i < ndim; i++) { |
476 | int shape = kv.second->data->shape[i]; |
477 | extents.push_back(tir::make_const(DataType::Int(32), shape, Span())); |
478 | } |
479 | body = tir::AllocateConst(buffer_var, dtype, extents, kv.second->data, body); |
480 | } |
481 | |
482 | // Define the PrimFunc attributes |
483 | Map<String, ObjectRef> dict_attrs; |
484 | String run_func_name = runtime::get_name_mangled(mod_name, runtime::symbol::tvm_module_main); |
485 | dict_attrs.Set("global_symbol" , run_func_name); |
486 | dict_attrs.Set("runner_function" , Bool(true)); |
487 | dict_attrs.Set(tvm::attr::kTarget, config_->host_target); |
488 | Array<tir::Var> input_vars = |
489 | Array<tir::Var>(main_signature_.begin(), main_signature_.begin() + input_vars_.size()); |
490 | dict_attrs.Set("input_vars" , input_vars); |
491 | Array<tir::Var> output_vars = |
492 | Array<tir::Var>(main_signature_.begin() + input_vars_.size(), |
493 | main_signature_.begin() + input_vars_.size() + return_sid_.size()); |
494 | dict_attrs.Set("output_vars" , output_vars); |
495 | Array<String> device_names; |
496 | for (const auto& it : devices_) { |
497 | device_names.push_back(it.first); |
498 | } |
499 | dict_attrs.Set("devices" , device_names); |
500 | |
501 | tir::Stmt device_activations = GenerateAllDeviceHook("Activate" ); |
502 | tir::Stmt device_deactivations = GenerateAllDeviceHook("Deactivate" ); |
503 | tir::Stmt final_body = tir::SeqStmt({device_activations, body, device_deactivations}); |
504 | |
505 | // Make the PrimFunc |
506 | return tir::PrimFunc(main_signature_, final_body, VoidType(), main_buffer_map_, |
507 | DictAttrs(dict_attrs)); |
508 | } |
509 | |
510 | /*! |
511 | * \brief Collects device context variables for passing to operators |
512 | */ |
513 | void CollectDeviceVariables(const Map<GlobalVar, String>& device_contexts) { |
514 | Map<TargetKind, tir::Var> target_contexts; |
515 | TargetKindAttrMap<Bool> target_attr_map = tvm::TargetKind::GetAttrMap<Bool>("use_device_api" ); |
516 | |
517 | for (const auto& it : device_contexts) { |
518 | const GlobalVar& global_var = it.first; |
519 | const std::string device_context_name = it.second; |
520 | |
521 | Optional<TargetKind> target_kind = tvm::TargetKind::Get(device_context_name); |
522 | if (!target_kind || !target_attr_map.count(target_kind.value())) { |
523 | return; |
524 | } |
525 | if (target_attr_map[target_kind.value()]) { |
526 | std::string context_name = tvm::runtime::SanitizeName(device_context_name); |
527 | tir::Var device_context_var("device_context_" + context_name, DataType::Handle()); |
528 | |
529 | auto pair = target_contexts.find(target_kind.value()); |
530 | if (pair != target_contexts.end()) { |
531 | device_context_var = (*pair).second; |
532 | } else { |
533 | main_signature_.push_back(device_context_var); |
534 | devices_.Set(context_name, device_context_var); |
535 | target_contexts.Set(target_kind.value(), device_context_var); |
536 | } |
537 | |
538 | device_contexts_.Set(global_var, device_context_var); |
539 | } |
540 | } |
541 | } |
542 | |
543 | /*! |
544 | * \brief Return a vector of variables that represents the sids for the given Relay Expr |
545 | */ |
546 | std::vector<tir::Var> PackSid(Expr expr) { |
547 | std::vector<tir::Var> buffer_vars; |
548 | |
549 | ICHECK(expr_storage_map_.find(expr) != expr_storage_map_.end()) |
550 | << "Storage map did not contain constant expr " << PrettyPrint(expr); |
551 | StorageInfo& sinfo = expr_storage_map_[expr]; |
552 | |
553 | // Note that an expression can have multiple sids associated with it |
554 | // e.g., returning multiple values from a function |
555 | for (auto sid : sinfo->storage_ids) { |
556 | // Determine if an sid is an output buffer |
557 | auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sid); |
558 | if (output_iter != return_sid_.end()) { |
559 | int output_index = std::distance(return_sid_.begin(), output_iter); |
560 | buffer_vars.push_back(GetBufferVarForIO(input_vars_.size() + output_index)); |
561 | continue; |
562 | } |
563 | |
564 | auto sid_value = sids_table_[sid]; |
565 | buffer_vars.push_back(sid_value); |
566 | } |
567 | return buffer_vars; |
568 | } |
569 | |
570 | /*! |
571 | * \brief Given an expression return the variable(s) associated with that expression |
572 | */ |
573 | std::vector<tir::Var> FindExpr(Expr arg) { |
574 | auto input_iter = std::find(input_vars_.begin(), input_vars_.end(), arg); |
575 | if (input_iter != input_vars_.end()) { |
576 | // Input variable |
577 | int main_index = std::distance(input_vars_.begin(), input_iter); |
578 | return {GetBufferVarForIO(main_index)}; |
579 | } else { |
580 | // Storage identifier (i.e., intermediate memory) |
581 | return PackSid(arg); |
582 | } |
583 | } |
584 | |
585 | void PushArgs(const Expr& expr, const std::vector<tir::Var>& sids, Array<PrimExpr>* args) { |
586 | const TupleNode* t = expr.as<TupleNode>(); |
587 | if (t != nullptr) { |
588 | CHECK_EQ(sids.size(), t->fields.size()) << "Relay tuple does not map 1:1 into TIR; AOT can't " |
589 | "handle this type of Relay Expr in a CallNode." ; |
590 | } |
591 | |
592 | args->insert(args->end(), sids.begin(), sids.end()); |
593 | } |
594 | |
595 | /*! |
596 | * \brief Wraps a call_extern with a tvm_check_return annotation if required otherwise |
597 | * returns the passed Call |
598 | */ |
599 | tir::Call AddCheckReturn(tir::Call existing_call) { |
600 | Array<PrimExpr> args = {tir::make_const(DataType::Int(32, 1), 0, Span()), |
601 | tir::make_const(DataType::Int(32, 1), -1, Span()), existing_call}; |
602 | return tir::Call(DataType::Int(32), tir::builtin::tvm_check_return(), args); |
603 | } |
604 | |
605 | /*! |
606 | * \brief Create a function call |
607 | * \param call_lowered_props The lowered function and the arguments to call it with |
608 | * \param result_expr The call we got func and args from (so as to recover the storage |
609 | * ids to hold the result). |
610 | */ |
611 | void CreateFuncCall(CallLoweredProps call_lowered_props, const Expr& result_expr) { |
612 | std::string func_name = call_lowered_props.lowered_func->name_hint; |
613 | tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)}; |
614 | std::vector<tir::Stmt> create_func_call_stmts; |
615 | |
616 | // Pack the inputs |
617 | for (const Expr& arg : call_lowered_props.arguments) { |
618 | auto sids = FindExpr(arg); |
619 | PushArgs(arg, sids, &args); |
620 | } |
621 | |
622 | // Pack the return(s) value. A call node can produce multiple outputs |
623 | auto result_expr_sid = PackSid(result_expr); |
624 | PushArgs(result_expr, result_expr_sid, &args); |
625 | |
626 | GlobalVar global_var = call_lowered_props.lowered_func; |
627 | bool has_c_device_api_context = device_contexts_.count(global_var) != 0; |
628 | tir::Var device_context; |
629 | tir::Stmt func_call; |
630 | |
631 | switch (call_type_) { |
632 | case CallType::kUnpacked: { |
633 | // call_extern calling convention with optional context |
634 | if (has_c_device_api_context) { |
635 | device_context = device_contexts_.Get(global_var).value(); |
636 | args.push_back(device_context); |
637 | } |
638 | func_call = tir::Evaluate(AddCheckReturn( |
639 | tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), args))); |
640 | break; |
641 | } |
642 | case CallType::kCPacked: { |
643 | if (has_c_device_api_context) { |
644 | device_context = device_contexts_.Get(global_var).value(); |
645 | args.push_back(device_context); |
646 | } else { |
647 | // NOTE: LowerTVMBuiltin expects some device_context placeholder. |
648 | args.push_back(tir::make_zero(DataType::Handle())); |
649 | } |
650 | func_call = tir::Evaluate( |
651 | tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_cpacked(), args)); |
652 | create_func_call_stmts.push_back(func_call); |
653 | break; |
654 | } |
655 | case CallType::kPacked: { |
656 | // call_packed does not accept a device context. |
657 | CHECK(!has_c_device_api_context) << "CallType::kPacked does not accept a device context" ; |
658 | func_call = tir::Evaluate(AddCheckReturn( |
659 | tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::tvm_call_packed(), args))); |
660 | create_func_call_stmts.push_back(func_call); |
661 | break; |
662 | } |
663 | default: |
664 | ICHECK(false) << "Unknown CallType: " << call_type_; |
665 | } |
666 | |
667 | ICHECK(func_call.defined()) << "Must define func_call" ; |
668 | |
669 | if (has_c_device_api_context) { |
670 | func_call = tir::SeqStmt(Array<tir::Stmt>({ |
671 | GenerateDeviceHook(device_context, "Open" ), |
672 | func_call, |
673 | GenerateDeviceHook(device_context, "Close" ), |
674 | })); |
675 | } |
676 | |
677 | tir::Stmt body = tir::SeqStmt({func_call}); |
678 | stmts_.push_back(body); |
679 | } |
680 | |
681 | /*! |
682 | * \brief Copy a variable to the output. This function is mainly used in edge cases |
683 | * when we want to return an input or a parameter. |
684 | * TODO(giuseros): we should try to avoid unnecessary copy to the output, e.g., in a |
685 | * copy-on-write fashion. |
686 | */ |
687 | void CopyToOutput(PrimExpr out, PrimExpr in, bool pack_input, size_t size) { |
688 | // Define intermediate DLTensor to load/store the data |
689 | tir::Buffer tmp_read = |
690 | tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_read" ); |
691 | tir::Buffer tmp_write = |
692 | tir::decl_buffer({IntImm(DataType::UInt(64), size)}, DataType::UInt(8), "tmp_write" ); |
693 | te::Var loop_idx("i" , DataType::Int(32)); |
694 | auto retval_i = tir::BufferLoad(tmp_read, {loop_idx}); |
695 | // Copy the variable from the input to the output |
696 | tir::Stmt copy = tir::For( |
697 | loop_idx, 0, tir::make_const(DataType::Int(32, 1), size, Span()), tir::ForKind::kSerial, |
698 | tir::BufferStore(tmp_write, tir::Let(tmp_read->data, in, retval_i), {loop_idx})); |
699 | stmts_.push_back(tir::LetStmt(tmp_write->data, out, copy)); |
700 | } |
701 | |
702 | /*! |
703 | * \brief Generates a call to a given hook for all Devices found for C Device API |
704 | * \param hook Name of hook to generate statements for |
705 | * \return Statement with function calls for each device |
706 | */ |
707 | tir::Stmt GenerateAllDeviceHook(const String& hook) { |
708 | std::vector<tir::Stmt> device_hooks; |
709 | for (const auto& it : devices_) { |
710 | const String& device_name = it.first; |
711 | const tir::Var& context = it.second; |
712 | Array<String> sections = {"Device" , device_name, hook}; |
713 | String device_hook_name = ToCFunctionStyle(PrefixName(sections)); |
714 | |
715 | tir::Evaluate device_hook( |
716 | AddCheckReturn(tvm::tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), |
717 | {tvm::tir::StringImm(device_hook_name), context}))); |
718 | device_hooks.push_back(device_hook); |
719 | } |
720 | return tir::SeqStmt(device_hooks); |
721 | } |
722 | |
723 | /*! |
724 | * \brief Generates a call to a given hook for a single Device function |
725 | * \param context Device context to call hook on |
726 | * \param hook Name of hook to generate statements for |
727 | * \return Statement with function call to Device API |
728 | */ |
729 | tir::Stmt GenerateDeviceHook(const tir::Var& context, const String& hook) { |
730 | const auto& it = std::find_if(std::begin(devices_), std::end(devices_), [&](const auto& it) { |
731 | return it.second->name_hint == context->name_hint; |
732 | }); |
733 | const String& device_name = (*it).first; |
734 | Array<String> sections = {"Device" , device_name, hook}; |
735 | String device_hook = ToCFunctionStyle(PrefixName(sections)); |
736 | |
737 | return tir::Evaluate( |
738 | AddCheckReturn(tir::Call(DataType::Int(32), tvm::tir::builtin::call_extern(), |
739 | {tvm::tir::StringImm(device_hook), context}))); |
740 | } |
741 | |
742 | /*! |
743 | * \brief Utility function to string together different arguments |
744 | */ |
745 | template <typename... Args> |
746 | std::string MakeString(Args const&... args) { |
747 | std::ostringstream ss; |
748 | using List = int[]; |
749 | (void)List{0, ((void)(ss << args), 0)...}; |
750 | |
751 | return ss.str(); |
752 | } |
753 | |
754 | /*! |
755 | * \brief Access IO vars using the buffer vars and |
756 | * not the actual var. |
757 | */ |
758 | tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; } |
759 | |
760 | /*! |
761 | * \brief Create tir::Var for input/output while updating the buffer_maps. |
762 | * \param expr The expression to evaluate. |
763 | * \param original_name The name of the tir::Var. |
764 | * \param use_unique_name Whether to generate a new unique name where a name conflicts. |
765 | */ |
766 | void CreateIOVar(const Expr& expr, const std::string& original_name, |
767 | bool use_unique_name = true) { |
768 | CreateIOVar(expr->checked_type(), original_name, use_unique_name); |
769 | } |
770 | |
771 | /*! |
772 | * \brief Create tir::Var for input/output while updating the buffer_maps. |
773 | * \param expr The expression to evaluate. |
774 | * \param original_name The name of the tir::Var. |
775 | * \param use_unique_name Whether to generate a new unique name where a name conflicts. |
776 | */ |
777 | void CreateIOVar(const Type& type, const std::string& original_name, |
778 | bool use_unique_name = true) { |
779 | if (type->IsInstance<TupleTypeNode>()) { |
780 | TupleType tuple_type = Downcast<TupleType>(type); |
781 | for (unsigned i = 0; i < tuple_type->fields.size(); i++) { |
782 | CreateIOVar(tuple_type->fields[i], original_name); |
783 | } |
784 | } else { |
785 | std::string name = original_name; |
786 | if (use_unique_name) { |
787 | name = GetUniqueIOVarName(original_name); |
788 | } |
789 | tir::Var var = tir::Var(name, DataType::Handle()); |
790 | main_signature_.push_back(var); |
791 | auto tensor_type = type.as<TensorTypeNode>(); |
792 | ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey(); |
793 | DataType elem_type = tensor_type->dtype; |
794 | tir::Var buffer_var = |
795 | tir::Var(name + "_buffer_var" , PointerType(PrimType(elem_type), "global" )); |
796 | tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, |
797 | name + "_buffer" , 16, 1, tir::BufferType::kDefault); |
798 | main_buffer_map_.Set(var, buffer); |
799 | } |
800 | } |
801 | |
802 | /*! |
803 | * \brief Create a unique name for I/O Var |
804 | */ |
805 | std::string GetUniqueIOVarName(std::string name) { |
806 | if (io_var_names_.find(name) == io_var_names_.end()) { |
807 | io_var_names_[name] = 1; |
808 | return name + std::to_string(io_var_names_[name] - 1); |
809 | } else { |
810 | io_var_names_[name] = io_var_names_[name] + 1; |
811 | return name + std::to_string(io_var_names_[name] - 1); |
812 | } |
813 | } |
814 | |
815 | /*! \brief list of input expressions (i.e., variable passed by the user) */ |
816 | std::vector<Var> input_vars_; |
817 | /*! \brief map of device contexts variables */ |
818 | Map<String, tir::Var> devices_; |
819 | /*! \brief map of GlobalVars to C Device API contexts */ |
820 | Map<GlobalVar, tir::Var> device_contexts_; |
821 | /*! \brief input and output variables belonging to the main function signature */ |
822 | Array<tir::Var> main_signature_; |
823 | /*! \brief input and output variables belonging to the main function signature */ |
824 | Map<tir::Var, tir::Buffer> main_buffer_map_; |
825 | /*! \brief All available targets. */ |
826 | CompilationConfig config_; |
827 | /*! |
828 | * \brief The type of kernel call to be emitted. |
829 | * See CallType for more documentation. |
830 | */ |
831 | CallType call_type_; |
832 | std::unordered_map<const tir::Var, const ConstantNode*, ObjectPtrHash, ObjectPtrEqual> |
833 | constant_map_; |
834 | /*! \brief plan memory of device result */ |
835 | StorageMap expr_storage_map_; |
836 | /*! \brief mapping sid -> tir::Var */ |
837 | std::unordered_map<int, tir::Var> sids_table_; |
838 | /*! \brief the set of statements that make the program */ |
839 | std::vector<tir::Stmt> stmts_; |
840 | /*! \brief the list of return sids (note that the function might return more then one output */ |
841 | std::vector<int> return_sid_; |
842 | /*! \brief This is per IO var name counter to aid the generating unique names */ |
843 | std::unordered_map<std::string, int> io_var_names_; |
844 | /*! \brief A set of variables that are let bound. */ |
845 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_vars_; |
846 | }; |
847 | |
848 | Pass AOTLowerMain(String mod_name, tvm::CompilationConfig config, CallType call_type) { |
849 | runtime::TypedPackedFunc<IRModule(IRModule, transform::PassContext)> pass_func = |
850 | [=](IRModule module, transform::PassContext ctx) { |
851 | return AOTMainLowerer(config, call_type).Lower(module, mod_name); |
852 | }; |
853 | |
854 | return tvm::transform::CreateModulePass(pass_func, 0, "AOTLowerMain" , {"InferType" }); |
855 | } |
856 | |
857 | TVM_REGISTER_GLOBAL("relay.backend.aot.AOTLowerMain" ) |
858 | .set_body_typed([](const String& mod_name, const tvm::CompilationConfig& config, |
859 | int call_type) { |
860 | return AOTLowerMain(mod_name, config, static_cast<CallType>(call_type)); |
861 | }); |
862 | |
863 | } // namespace aot |
864 | } // namespace backend |
865 | } // namespace relay |
866 | } // namespace tvm |
867 | |