1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/backend/utils.h |
22 | * \brief Utils function for backend |
23 | */ |
24 | #ifndef TVM_RELAY_BACKEND_UTILS_H_ |
25 | #define TVM_RELAY_BACKEND_UTILS_H_ |
26 | |
27 | #include <dmlc/json.h> |
28 | #include <tvm/driver/driver_api.h> |
29 | #include <tvm/relay/executor.h> |
30 | #include <tvm/relay/expr.h> |
31 | #include <tvm/relay/expr_functor.h> |
32 | #include <tvm/relay/transform.h> |
33 | #include <tvm/relay/type.h> |
34 | #include <tvm/target/codegen.h> |
35 | #include <tvm/target/virtual_device.h> |
36 | #include <tvm/te/operation.h> |
37 | #include <tvm/tir/usmp/utils.h> |
38 | |
39 | #include <iostream> |
40 | #include <sstream> |
41 | #include <string> |
42 | #include <typeinfo> |
43 | #include <unordered_map> |
44 | #include <unordered_set> |
45 | #include <utility> |
46 | #include <vector> |
47 | |
48 | #include "../../runtime/meta_data.h" |
49 | #include "../../target/metadata.h" |
50 | #include "tvm/runtime/ndarray.h" |
51 | |
52 | namespace tvm { |
53 | namespace relay { |
54 | |
55 | namespace tec { |
56 | class TECompiler; |
57 | } |
58 | |
59 | namespace backend { |
60 | using Pass = tvm::transform::Pass; |
61 | |
62 | /*! \brief Describes the type of kernel call emitted. */ |
63 | enum CallType { |
64 | /*! |
65 | * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions. |
66 | * |
67 | * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the |
68 | * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those |
69 | * functions are of type TVMBackendPackedCFunc. |
70 | * |
71 | * The following code is emitted at call sites to call a function named `func`: |
72 | * void* func_ptr = TVMBackendGetFuncFromEnv("func"); |
73 | * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes) |
74 | * |
75 | * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` |
76 | * by LowerTVMBuiltin TIR transform. |
77 | * |
78 | * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often, |
79 | * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when |
80 | * `func` is implemented in C). |
81 | * |
82 | * Compatible with both C++ and C runtimes, implemented with the C runtime only. |
83 | */ |
84 | kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor. |
85 | |
86 | /*! |
87 | * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call. |
88 | * |
89 | * When this type is selected, assumes all operators are implemented in functions of type |
90 | * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of |
91 | * downstream compilation that there is a symbol named after the 0th arg to tir::Call of |
92 | * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target. |
93 | * |
94 | * The following code is emitted at call sites to call a function named `func`: |
95 | * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle) |
96 | * |
97 | * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args` |
98 | * by LowerTVMBuiltin TIR transform. |
99 | * |
100 | * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is |
101 | * always the device context parameter when not null. At present, the implementation does not |
102 | * support forwarding device context parameters to CPacked. |
103 | * |
104 | * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented |
105 | * in the same scenarios. |
106 | */ |
107 | kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor. |
108 | |
109 | /*! \brief Directly call a function accepting the `data` arrays as args. |
110 | * |
111 | * When this type is selected, assumes all operaotrs are implemented in C functions whose |
112 | * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the |
113 | * `data` parameters (i.e. no DLTensor object is passed along). |
114 | * |
115 | * The following code is emitted at call sites to a function named `func`: |
116 | * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle |
117 | * -or- |
118 | * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle |
119 | * |
120 | * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is |
121 | * always the device context parameter when not null. |
122 | * |
123 | * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented |
124 | * with the C runtime only. |
125 | */ |
126 | kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors. |
127 | }; |
128 | |
129 | /*! |
130 | * \brief Structure that can be optionally used by the executor codegen |
131 | */ |
132 | class ExecutorCodegenMetadataNode : public Object { |
133 | public: |
134 | /*! \brief input information for the main function */ |
135 | Array<tir::Var> inputs; |
136 | /*! \brief input tensor type information */ |
137 | Array<TensorType> input_tensor_types; |
138 | /*! \brief output information for the main function */ |
139 | Array<String> outputs; |
140 | /*! \brief output tensor type information */ |
141 | Array<TensorType> output_tensor_types; |
142 | /*! \brief pool information for the main function */ |
143 | Array<tir::Var> pools; |
144 | /*! \brief device contexts information for the main function */ |
145 | Array<String> devices; |
146 | /*! \brief the executor to be used to run the model */ |
147 | String executor = runtime::kTvmExecutorGraph; |
148 | /*! \brief The external API (packed or c) in use */ |
149 | String interface_api; |
150 | /*! \brief The internal API (packed or unpacked) in use */ |
151 | bool unpacked_api; |
152 | /*! \brief Alginment of the workspace in bytes */ |
153 | Integer workspace_alignment; |
154 | /*! \brief Alginment of the constants in bytes */ |
155 | Integer constant_alignment; |
156 | /*! \brief the input var names that correspond to pool_inputs */ |
157 | Optional<Map<tir::Var, tir::usmp::AllocatedPoolInfo>> pool_inputs; |
158 | /*! \brief the I/O tensor to PoolAllocations if any*/ |
159 | Map<String, tir::usmp::PoolAllocation> io_pool_allocations; |
160 | |
161 | String mod_name = "" ; |
162 | |
163 | void VisitAttrs(tvm::AttrVisitor* v) { |
164 | v->Visit("inputs" , &inputs); |
165 | v->Visit("input_tensor_types" , &input_tensor_types); |
166 | v->Visit("outputs" , &outputs); |
167 | v->Visit("output_tensor_types" , &output_tensor_types); |
168 | v->Visit("pools" , &pools); |
169 | v->Visit("devices" , &devices); |
170 | v->Visit("executor" , &executor); |
171 | v->Visit("interface_api" , &interface_api); |
172 | v->Visit("unpacked_api" , &unpacked_api); |
173 | v->Visit("workspace_alignment" , &workspace_alignment); |
174 | v->Visit("constant_alignment" , &constant_alignment); |
175 | v->Visit("pool_inputs" , &pool_inputs); |
176 | v->Visit("io_pool_allocations" , &io_pool_allocations); |
177 | v->Visit("mod_name" , &mod_name); |
178 | } |
179 | |
180 | static constexpr const char* _type_key = "MetadataObj" ; |
181 | TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorCodegenMetadataNode, Object); |
182 | }; |
183 | |
184 | /*! |
185 | * \brief Managed reference to ExecutorCodegenMetadataNode. |
186 | */ |
187 | class ExecutorCodegenMetadata : public ObjectRef { |
188 | public: |
189 | TVM_DLL ExecutorCodegenMetadata(Array<tir::Var> inputs, Array<TensorType> input_tensor_types, |
190 | Array<String> outputs, Array<TensorType> output_tensor_types, |
191 | Array<tir::Var> pools, Array<String> devices, String executor, |
192 | String mod_name, String interface_api = "packed" , |
193 | bool unpacked_api = false, Integer workspace_alignment = 16, |
194 | Integer constant_alignment = 16, |
195 | Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs = |
196 | Map<tir::Var, tir::usmp::AllocatedPoolInfo>(), |
197 | Map<String, tir::usmp::PoolAllocation> io_pool_allocations = {}); |
198 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecutorCodegenMetadata, ObjectRef, |
199 | ExecutorCodegenMetadataNode); |
200 | }; |
201 | |
202 | /*! |
203 | * \brief The static storage information for each Tensor in the result of a Relay expression |
204 | * (as per relay::FlattenTupleType). |
205 | */ |
206 | class StorageInfoNode : public Object { |
207 | public: |
208 | // TODO(mbs): Switch from struct-of-array to array-of-struct repr throughout. |
209 | /*! \brief The set of storage ids where the expression is stored. */ |
210 | std::vector<int64_t> storage_ids; |
211 | /* \brief The virtual devices these expressions are stored within. */ |
212 | std::vector<VirtualDevice> virtual_devices; |
213 | /* \brief The sizes of each storage element, in bytes. */ |
214 | std::vector<int64_t> storage_sizes_in_bytes; |
215 | |
216 | // TODO(@jroesch): expose the fields |
217 | void VisitAttrs(AttrVisitor* v) {} |
218 | |
219 | static constexpr const char* _type_key = "relay.StorageInfo" ; |
220 | TVM_DECLARE_FINAL_OBJECT_INFO(StorageInfoNode, Object); |
221 | }; |
222 | |
223 | /*! \brief The storage information for a single expression. */ |
224 | class StorageInfo : public ObjectRef { |
225 | public: |
226 | StorageInfo(std::vector<int64_t> storage_ids, std::vector<VirtualDevice> virtual_devices, |
227 | std::vector<int64_t> storage_sizes_in_bytes); |
228 | TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode); |
229 | }; |
230 | |
231 | /*! |
232 | * \brief The result of static memory planning. |
233 | */ |
234 | class StaticMemoryPlanNode : public Object { |
235 | public: |
236 | Map<Expr, StorageInfo> expr_to_storage_info; |
237 | |
238 | void VisitAttrs(AttrVisitor* v) { v->Visit("expr_to_storage_info" , &expr_to_storage_info); } |
239 | |
240 | static constexpr const char* _type_key = "relay.StaticMemoryPlan" ; |
241 | TVM_DECLARE_FINAL_OBJECT_INFO(StaticMemoryPlanNode, Object); |
242 | }; |
243 | |
244 | /*! \brief The result of running static memory planning. */ |
245 | class StaticMemoryPlan : public ObjectRef { |
246 | public: |
247 | explicit StaticMemoryPlan(Map<Expr, StorageInfo> expr_to_storage_info); |
248 | TVM_DEFINE_OBJECT_REF_METHODS(StaticMemoryPlan, ObjectRef, StaticMemoryPlanNode); |
249 | }; |
250 | |
251 | struct FunctionInfoNode : public Object { |
252 | Map<Target, Integer> workspace_sizes; |
253 | Map<Target, Integer> io_sizes; |
254 | Map<Target, Integer> constant_sizes; |
255 | Map<Target, tir::PrimFunc> tir_primfuncs; |
256 | Map<Target, Function> relay_primfuncs; |
257 | |
258 | void VisitAttrs(tvm::AttrVisitor* v) { |
259 | v->Visit("workspace_sizes" , &workspace_sizes); |
260 | v->Visit("io_sizes" , &io_sizes); |
261 | v->Visit("constant_sizes" , &constant_sizes); |
262 | v->Visit("tir_primfuncs" , &tir_primfuncs); |
263 | v->Visit("relay_primfuncs" , &relay_primfuncs); |
264 | } |
265 | |
266 | static constexpr const char* _type_key = "relay.backend.FunctionInfo" ; |
267 | TVM_DECLARE_FINAL_OBJECT_INFO(FunctionInfoNode, Object); |
268 | }; |
269 | |
270 | class FunctionInfo : public ObjectRef { |
271 | public: |
272 | FunctionInfo(Map<Target, Integer> workspace_sizes, Map<Target, Integer> io_sizes, |
273 | Map<Target, Integer> constant_sizes, Map<Target, tir::PrimFunc> tir_primfuncs, |
274 | Map<Target, Function> relay_primfuncs); |
275 | |
276 | TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FunctionInfo, ObjectRef, FunctionInfoNode); |
277 | }; |
278 | |
279 | /*! |
280 | * \brief Calculate the bytes of memory needed to hold a tensor of a given shape and data type. |
281 | * \param shape The shape of the tensor |
282 | * \param dtype The data type of the tensor |
283 | */ |
284 | size_t GetMemorySizeBytes(const Array<PrimExpr>& shape, const DataType& dtype); |
285 | |
286 | /*! |
287 | * \brief Calculate the storage required to store the type of relay.Expr |
288 | * |
289 | * \param func The relay expr for which the storage is calculated |
290 | */ |
291 | int64_t CalculateRelayExprSizeBytes(const Type& expr_type); |
292 | |
293 | /*! |
294 | * \brief Executor generator artifacts. Those artifacts are subsequently |
295 | * used by the relay build process. |
296 | */ |
297 | struct LoweredOutput { |
298 | std::string graph_json; |
299 | Map<Target, IRModule> lowered_funcs; |
300 | Array<tvm::runtime::Module> external_mods; |
301 | Map<String, FunctionInfo> function_metadata; |
302 | /*! |
303 | * \brief Map from constant names (allocated by the codegen as constants are encountered) |
304 | * to the constant's value. |
305 | */ |
306 | std::unordered_map<std::string, tvm::runtime::NDArray> params; |
307 | ExecutorCodegenMetadata metadata; |
308 | }; |
309 | |
310 | /*! |
311 | * \brief This class is needed to avoid a GCC 5 bug that prevents maps containing enums from being |
312 | compiled. If i386 GCC version is increased, we can remove it. |
313 | */ |
314 | struct EnumClassHash { |
315 | template <typename T> |
316 | std::size_t operator()(T t) const { |
317 | return static_cast<std::size_t>(t); |
318 | } |
319 | }; |
320 | |
321 | /*! |
322 | * \brief A helper to expand the params by adding the ones used in a given expression. |
323 | */ |
324 | struct ConstantUpdater : public ExprVisitor { |
325 | public: |
326 | ConstantUpdater(const std::string& symbol, |
327 | std::unordered_map<std::string, runtime::NDArray>* params) |
328 | : symbol_(symbol), params_(params) {} |
329 | |
330 | void VisitExpr_(const ConstantNode* cn) final { |
331 | std::string name = symbol_ + "_const_" + std::to_string(const_idx_++); |
332 | VLOG(1) << "binding '" << name << "' to constant of type " << PrettyPrint(cn->checked_type()); |
333 | (*params_)[name] = cn->data; |
334 | } |
335 | |
336 | private: |
337 | int const_idx_{0}; |
338 | std::string symbol_; |
339 | std::unordered_map<std::string, runtime::NDArray>* params_; |
340 | }; |
341 | |
342 | /*! |
343 | * \brief A function to update the params with constants found in an external function. |
344 | * \param func The function from which to get the constant params. |
345 | * \param params The params to update with the constants. |
346 | */ |
347 | inline void UpdateConstants(BaseFunc func, |
348 | std::unordered_map<std::string, runtime::NDArray>* params) { |
349 | VLOG_CONTEXT << "UpdateConstants" ; |
350 | VLOG(1) << "updating constants for:" << std::endl << PrettyPrint(func); |
351 | auto codegen = func->GetAttr<String>(attr::kCompiler); |
352 | ICHECK(codegen.defined()) << "No external codegen is set" ; |
353 | std::string codegen_name = codegen.value(); |
354 | const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol); |
355 | std::string symbol = std::string(name_node.value()); |
356 | std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater" ; |
357 | // Get the constant updater for the external codegen |
358 | auto pf = tvm::runtime::Registry::Get(const_update_name); |
359 | // If the backend hasn't registered a constant updater, use a default one |
360 | if (pf == nullptr) { |
361 | ConstantUpdater const_visit(symbol, params); |
362 | const_visit(func); |
363 | } else { |
364 | Map<String, tvm::runtime::NDArray> constants = (*pf)(func, symbol); |
365 | for (const auto& it : constants) { |
366 | std::string const_name(it.first); |
367 | // Constant names should begin this the compiler name (to avoid conflicts) |
368 | ICHECK(const_name.find(codegen_name) == 0) |
369 | << "External constant names must start with compiler name" ; |
370 | (*params)[const_name] = it.second; |
371 | } |
372 | } |
373 | for (const auto& pair : *params) { |
374 | VLOG(1) << "Constants: " << pair.first << " = " << PrettyPrint(pair.second); |
375 | } |
376 | } |
377 | |
378 | /*! |
379 | * \brief A simple wrapper around ExprFunctor for a single argument case. |
380 | * The result of visit is memoized. |
381 | */ |
382 | template <typename OutputType> |
383 | class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor<OutputType(const Expr&)> { |
384 | using BaseFunctor = ::tvm::relay::ExprFunctor<OutputType(const Expr&)>; |
385 | |
386 | public: |
387 | /*! \brief virtual destructor */ |
388 | virtual ~MemoizedExprTranslator() {} |
389 | |
390 | /*! |
391 | * \brief The memoized call. |
392 | * \param n The expression node. |
393 | * \return The result of the call |
394 | */ |
395 | virtual OutputType VisitExpr(const Expr& n) { |
396 | ICHECK(n.defined()); |
397 | auto it = memo_.find(n); |
398 | if (it != memo_.end()) { |
399 | return it->second; |
400 | } |
401 | auto res = BaseFunctor::VisitExpr(n); |
402 | memo_[n] = res; |
403 | return res; |
404 | } |
405 | |
406 | protected: |
407 | /*! \brief Internal map used for memoization. */ |
408 | std::unordered_map<Expr, OutputType, ObjectPtrHash, ObjectPtrEqual> memo_; |
409 | }; |
410 | |
411 | /*! |
412 | * \brief Get the Packed Func |
413 | * |
414 | * \param func_name |
415 | * \return const PackedFunc* |
416 | */ |
417 | inline const PackedFunc* GetPackedFunc(const std::string& func_name) { |
418 | return tvm::runtime::Registry::Get(func_name); |
419 | } |
420 | |
421 | /*! |
422 | * \brief Get a typed packed function. |
423 | * |
424 | * \param func_name |
425 | * \return const PackedFunc* |
426 | */ |
427 | template <typename R, typename... Args> |
428 | inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) { |
429 | auto* pf = GetPackedFunc(func_name); |
430 | ICHECK(pf != nullptr) << "can not find packed function" ; |
431 | return runtime::TypedPackedFunc<R(Args...)>(*pf); |
432 | } |
433 | |
434 | /*! |
435 | * \brief Extract shape from an IndexExpr array to std::vector<int64_t> |
436 | * |
437 | * \param shape The shape in Array |
438 | * \return The converted shape in std::vector<int64_t> |
439 | */ |
440 | inline std::vector<int64_t> GetIntShape(const Array<IndexExpr>& shape) { |
441 | std::vector<int64_t> ret; |
442 | for (const auto& dim : shape) { |
443 | const int64_t* pval = tir::as_const_int(dim); |
444 | ret.push_back(pval ? *pval : -1); |
445 | } |
446 | return ret; |
447 | } |
448 | |
449 | /*! |
450 | * \brief Convert type to string |
451 | * |
452 | * \param typ |
453 | * \return std::string string format of type |
454 | */ |
455 | inline std::string DType2String(const tvm::DataType dtype) { |
456 | std::ostringstream os; |
457 | if (dtype.is_float()) { |
458 | os << "float" ; |
459 | } else if (dtype.is_int()) { |
460 | os << "int" ; |
461 | } else if (dtype.is_uint()) { |
462 | os << "uint" ; |
463 | } else if (dtype.is_bfloat16()) { |
464 | os << "bfloat" ; |
465 | } else if ((*GetPackedFunc("runtime._datatype_get_type_registered" ))(dtype.code())) { |
466 | os << "custom[" |
467 | << (*GetPackedFunc("runtime._datatype_get_type_name" ))(dtype.code()).operator std::string() |
468 | << "]" ; |
469 | } else { |
470 | LOG(FATAL) << "Unknown type with code " << static_cast<unsigned>(dtype.code()); |
471 | } |
472 | os << dtype.bits(); |
473 | return os.str(); |
474 | } |
475 | |
476 | /*! |
477 | * \brief Bind params to function by using name |
478 | * \param func Relay function |
479 | * \param params params dict |
480 | * \return relay::Function |
481 | */ |
482 | relay::Function BindParamsByName(relay::Function func, |
483 | const std::unordered_map<std::string, runtime::NDArray>& params); |
484 | |
485 | /*! |
486 | * \brief Bind params to the main function in Relay module, using BindParamsByName |
487 | * \param mod Relay module |
488 | * \param params params dict |
489 | */ |
490 | void BindParamsInModule(IRModule mod, |
491 | const std::unordered_map<std::string, runtime::NDArray>& params); |
492 | |
493 | void BindParamsInModule(IRModule mod, Map<String, runtime::NDArray> params); |
494 | |
495 | /*! |
496 | * \brief Extract the shape from a Relay tensor type. |
497 | * \param type The provided type. |
498 | * \return The extracted shape in a list. |
499 | */ |
500 | inline std::vector<int> GetShape(const Type& type) { |
501 | const auto* ttype = type.as<TensorTypeNode>(); |
502 | ICHECK(ttype) << "Expect TensorTypeNode" ; |
503 | std::vector<int> shape; |
504 | for (size_t i = 0; i < ttype->shape.size(); ++i) { |
505 | auto* val = ttype->shape[i].as<IntImmNode>(); |
506 | ICHECK(val); |
507 | shape.push_back(val->value); |
508 | } |
509 | return shape; |
510 | } |
511 | |
512 | /*! |
513 | * \brief Check if a call has the provided name. |
514 | * \param call A Relay call node. |
515 | * \param op_name The name of the expected call. |
516 | * \return true if the call's name is equivalent to the given name. Otherwise, |
517 | * false. |
518 | */ |
519 | inline bool IsOp(const CallNode* call, const std::string& op_name) { |
520 | const auto* op_node = call->op.as<OpNode>(); |
521 | ICHECK(op_node) << "Expects a single op." ; |
522 | Op op = GetRef<Op>(op_node); |
523 | return op == Op::Get(op_name); |
524 | } |
525 | |
526 | /*! |
527 | * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d)) |
528 | * \param call A Relay call node. Typically nn.relu when called the first time. |
529 | * \param depth The number of calls before the root op, counting from current_call. |
530 | * \param expected_op_names The names of ops in this fused call. Example: {"nn.conv2d", "add", |
531 | * "nn.relu"} |
532 | * \return A CallNode corresponding to the root op, whose name is expected_op_names[0] |
533 | */ |
534 | inline const CallNode* GetRootCall(const CallNode* current_call, int depth, |
535 | const std::vector<std::string>& expected_op_names) { |
536 | ICHECK(current_call && depth >= 0 && static_cast<size_t>(depth) < expected_op_names.size() && |
537 | IsOp(current_call, expected_op_names[depth])); |
538 | |
539 | if (depth == 0) { |
540 | return current_call; |
541 | } |
542 | |
543 | ICHECK_GT(current_call->args.size(), 0); |
544 | size_t valid_node_idx = 0; |
545 | while (valid_node_idx < current_call->args.size() && |
546 | current_call->args[valid_node_idx].as<VarNode>()) { |
547 | valid_node_idx++; |
548 | } |
549 | while (valid_node_idx < current_call->args.size() && |
550 | !(IsOp(current_call->args[valid_node_idx].as<CallNode>(), expected_op_names[depth - 1]))) { |
551 | valid_node_idx++; |
552 | } |
553 | const auto* next_call = current_call->args[valid_node_idx].as<CallNode>(); |
554 | return GetRootCall(next_call, depth - 1, expected_op_names); |
555 | } |
556 | |
557 | /*! |
558 | * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d)) |
559 | * Unlike the previous definition, it does not verify operator names of intermediate nodes. Instead, |
560 | * it recursively visit child nodes until it finds a call node with the given op_name. |
561 | * \param call A Relay call node. |
562 | * \param op_name The name of an op to look for, such as ""nn.conv2d". |
563 | * \return A CallNode corresponding to the root op with the given op_name |
564 | */ |
565 | inline const CallNode* GetRootCall(const CallNode* current_call, const std::string& op_name) { |
566 | if (current_call == nullptr) return nullptr; |
567 | if (IsOp(current_call, op_name)) return current_call; |
568 | |
569 | ICHECK_GT(current_call->args.size(), 0); |
570 | |
571 | const auto* next_call = current_call->args[0].as<CallNode>(); |
572 | return GetRootCall(next_call, op_name); |
573 | } |
574 | |
575 | /*! |
576 | * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in |
577 | * relu(add(conv2d)) |
578 | * \param call A Relay call node. Typically nn.relu when called the first time. |
579 | * \param max_depth The maximum number of calls before the root op, counting from current_call. |
580 | * \param op_name The name of expected "root" op in this fused call. |
581 | * \return A CallNode corresponding to the root op |
582 | */ |
583 | inline const CallNode* GetRootCall(const CallNode* current_call, int max_depth, |
584 | const std::string& op_name) { |
585 | ICHECK(current_call && max_depth >= 0); |
586 | |
587 | if (max_depth == 0) { |
588 | ICHECK(current_call && IsOp(current_call, op_name)); |
589 | return current_call; |
590 | } |
591 | if (IsOp(current_call, op_name)) { |
592 | return current_call; |
593 | } |
594 | |
595 | ICHECK_GT(current_call->args.size(), 0); |
596 | |
597 | size_t valid_node_idx = 0; |
598 | while (valid_node_idx < current_call->args.size() && |
599 | current_call->args[valid_node_idx].as<VarNode>()) { |
600 | valid_node_idx++; |
601 | } |
602 | |
603 | const auto* next_call = current_call->args[valid_node_idx].as<CallNode>(); |
604 | return GetRootCall(next_call, max_depth - 1, op_name); |
605 | } |
606 | |
607 | /*! |
608 | * \brief Get the external symbol of the Relay function name. |
609 | * |
610 | * \param func The provided function. |
611 | * \return An external symbol. |
612 | */ |
613 | inline std::string GetExtSymbol(const Function& func) { |
614 | const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol); |
615 | ICHECK(name_node.defined()) << "Fail to retrieve external symbol." ; |
616 | return std::string(name_node.value()); |
617 | } |
618 | |
619 | /*! |
620 | * \brief Return whether the auto scheduler is enabled in the pass context. |
621 | */ |
622 | inline bool IsAutoSchedulerEnabled() { |
623 | return transform::PassContext::Current() |
624 | ->GetConfig<Bool>("relay.backend.use_auto_scheduler" , Bool(false)) |
625 | .value(); |
626 | } |
627 | |
628 | /*! |
629 | * \brief Return whether the meta schedule is enabled in the pass context. |
630 | */ |
631 | inline bool IsMetaScheduleEnabled() { |
632 | return transform::PassContext::Current() |
633 | ->GetConfig<Bool>("relay.backend.use_meta_schedule" , Bool(false)) |
634 | .value(); |
635 | } |
636 | |
637 | /*! \brief Consider MetaSchedule's dispatch option. */ |
638 | inline int UseMetaScheduleDispatch() { |
639 | return transform::PassContext::Current() |
640 | ->GetConfig<Integer>("relay.backend.use_meta_schedule_dispatch" , Integer(0)) |
641 | .value() |
642 | ->value; |
643 | } |
644 | /*! |
645 | * \brief Method in TECompiler to convert TE compute to scheduleable TIR |
646 | * \param args The arguments of the TE compute |
647 | * \param constants The constants used in AllocateConst |
648 | * \return NullOpt if conversion fails; Otherwise the converted TIR |
649 | * \note This method could be further used as a task filtering mechanism in task extraction |
650 | */ |
651 | using FTECompilerTIRConverter = runtime::TypedPackedFunc< // |
652 | Optional<tir::PrimFunc>( // |
653 | const Array<te::Tensor>& args, // |
654 | const Array<runtime::NDArray>& constants)>; |
655 | |
656 | /*! \brief Return a task filter for AutoTIR according to `relay.backend.tir_converter` */ |
657 | inline FTECompilerTIRConverter GetTIRConverter() { |
658 | String name = transform::PassContext::Current() |
659 | ->GetConfig<String>("relay.backend.tir_converter" , "default" ) |
660 | .value(); |
661 | const PackedFunc* f = runtime::Registry::Get("relay.backend.tir_converter." + name); |
662 | ICHECK(f != nullptr) << "IndexError: Cannot find TIR converter: " << name; |
663 | return FTECompilerTIRConverter(*f); |
664 | } |
665 | |
666 | /*! \brief Converts a PrimFunc to IRModule. */ |
667 | inline IRModule PrimFuncToIRModule(tir::PrimFunc f) { |
668 | f = WithAttrs(f, Map<String, ObjectRef>{ |
669 | {tvm::attr::kGlobalSymbol, String("main" )}, |
670 | {tvm::tir::attr::kNoAlias, Bool(1)}, |
671 | }); |
672 | return IRModule({{GlobalVar("main" ), f}}); |
673 | } |
674 | |
675 | /*! |
676 | * \brief Get the sequence of Relay optimization passes based on backend type. |
677 | * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight |
678 | * difference. This function unifies the shared optimization pass prefix between vm and graph |
679 | * runtime, and returns the pass prefix given the backend type. |
680 | * |
681 | * \param is_homogeneous True if all primitives are to be executed on the same device and target. |
682 | * \param is_vm True if passes are to be used for the vm executor. |
683 | * \return An array of passes. |
684 | */ |
685 | Array<Pass> GetPassPrefix(bool is_homogeneous, bool is_vm); |
686 | |
687 | /*! \brief Target hash function */ |
688 | struct TargetStrHash { |
689 | /*! |
690 | * \brief Calculate the hash code of a Target based on the string value of the Target KIND. |
691 | Note that this hash should NOT be used in new usecases, equality of targets based on their |
692 | value is not well-defined. |
693 | This will be removed when maps from Targets to IRModules are removed from the codebase. |
694 | * \param target The Target to hash |
695 | * \return String hash of the target |
696 | */ |
697 | size_t operator()(const Target& target) const { |
698 | std::string s(target->kind->name); |
699 | return String::HashBytes(s.c_str(), s.size()); |
700 | } |
701 | }; |
702 | |
703 | /*! \brief Target equality function based on the string value of Target |
704 | Note that this equality function should NOT be used in new usecases, equality of targets based on |
705 | their value is not well-defined. This will be removed when maps from Targets to IRModules are |
706 | removed from the codebase.*/ |
707 | struct TargetStrEqual { |
708 | /*! |
709 | * \brief Check if the two Targets are equal |
710 | * \param target One Target |
711 | * \param other_target The other Target |
712 | * \return String equality of the targets |
713 | */ |
714 | const bool operator()(const Target& target, const Target& other_target) const { |
715 | TargetStrHash target_hash = TargetStrHash(); |
716 | return target_hash(target) == target_hash(other_target); |
717 | } |
718 | }; |
719 | |
720 | /*! |
721 | * \brief Convert a Map<Target, IRModule> to std::unordered_map<Target, IRmodule, TargetStrHash, |
722 | * TargetStrEqual> Target equality is currently based on pointer equality, which is a problem since |
723 | * we have a lot of Map<Target, IRModule> in the codebase. This function converts the map to a |
724 | * version that is keyed based on string value of the Target instead. Note that once we remove |
725 | * Map<Target, IRModule>, this function will be removed. |
726 | * \param input_map The map to convert |
727 | * \return The converted map |
728 | */ |
729 | std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> |
730 | TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map); |
731 | |
732 | /*! |
733 | * \brief Convert a std::unordered_map<Target, IRmodule, TargetStrHash, TargetStrEqual> to |
734 | * Map<Target, IRModule> This function is a helper that undoes TargetModuleMapToTargetStr. Note that |
735 | * once we remove Map<Target, IRModule>, this function will be removed. |
736 | * \param input_map The map to convert |
737 | * \return The converted map |
738 | */ |
739 | Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap( |
740 | std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map); |
741 | |
742 | /*! |
743 | * \brief Call "weight update callback" to communicate op weights seen during Relay module |
744 | * lowering back to the auto scheduler. |
745 | * Op weights refer to the number of times each distinct op/workload appears in a given module. |
746 | * It is called "use_count" in TECompiler. |
747 | * \param IRModule after lowering by LowerTEPass. |
748 | */ |
749 | void UpdateAutoSchedulerOpWeights(const IRModule& module); |
750 | |
751 | /*! |
752 | * \brief Extract shape from expr to vector<int64_t> |
753 | * |
754 | * \param shape |
755 | * \return std::vector<int64_t> |
756 | */ |
757 | std::vector<int64_t> ShapeToJSON(tvm::Array<IndexExpr> shape); |
758 | |
759 | } // namespace backend |
760 | } // namespace relay |
761 | } // namespace tvm |
762 | |
763 | #endif // TVM_RELAY_BACKEND_UTILS_H_ |
764 | |