1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/backend/utils.h
22 * \brief Utils function for backend
23 */
24#ifndef TVM_RELAY_BACKEND_UTILS_H_
25#define TVM_RELAY_BACKEND_UTILS_H_
26
27#include <dmlc/json.h>
28#include <tvm/driver/driver_api.h>
29#include <tvm/relay/executor.h>
30#include <tvm/relay/expr.h>
31#include <tvm/relay/expr_functor.h>
32#include <tvm/relay/transform.h>
33#include <tvm/relay/type.h>
34#include <tvm/target/codegen.h>
35#include <tvm/target/virtual_device.h>
36#include <tvm/te/operation.h>
37#include <tvm/tir/usmp/utils.h>
38
39#include <iostream>
40#include <sstream>
41#include <string>
42#include <typeinfo>
43#include <unordered_map>
44#include <unordered_set>
45#include <utility>
46#include <vector>
47
48#include "../../runtime/meta_data.h"
49#include "../../target/metadata.h"
50#include "tvm/runtime/ndarray.h"
51
52namespace tvm {
53namespace relay {
54
55namespace tec {
56class TECompiler;
57}
58
59namespace backend {
60using Pass = tvm::transform::Pass;
61
62/*! \brief Describes the type of kernel call emitted. */
63enum CallType {
64 /*!
65 * \brief Emit PackedFunc calls bound just-in-time using TVMBackend* functions.
66 *
67 * When this type is selected, assumes all operators must be called via TVMFuncCall. Given the
68 * implementation of TVMFuncCall in the C++ runtime, this in practice implies that those
69 * functions are of type TVMBackendPackedCFunc.
70 *
71 * The following code is emitted at call sites to call a function named `func`:
72 * void* func_ptr = TVMBackendGetFuncFromEnv("func");
73 * TVMFuncCall(func_ptr, values, tcodes, num_args, ret_values, ret_tcodes)
74 *
75 * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args`
76 * by LowerTVMBuiltin TIR transform.
77 *
78 * If `resource_handle` is passed to `func`, it is determined by TVMFuncCall (often,
79 * `resource_handle` is registered with the C++ runtime to provide a `this` equivalent when
80 * `func` is implemented in C).
81 *
82 * Compatible with both C++ and C runtimes, implemented with the C runtime only.
83 */
84 kPacked, // Emit tir.call_packed and wrap all arguments in DLTensor.
85
86 /*!
87 * \brief Directly call a TVMBackendPackedCFunc named according to the tir::Call.
88 *
89 * When this type is selected, assumes all operators are implemented in functions of type
90 * `TVMBackendPackedCFunc` and should be called directly. That is, presumes at the time of
91 * downstream compilation that there is a symbol named after the 0th arg to tir::Call of
92 * type `TVMBackendPackedCFunc`. This situation should occur when target_host == target.
93 *
94 * The following code is emitted at call sites to call a function named `func`:
95 * func(values, tcodes, num_args, ret_values, ret_tcodes, resource_handle)
96 *
97 * The arguments given to the tir::Call node are encoded into `values`, `tcodes`, and `num_args`
98 * by LowerTVMBuiltin TIR transform.
99 *
100 * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is
101 * always the device context parameter when not null. At present, the implementation does not
102 * support forwarding device context parameters to CPacked.
103 *
104 * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented
105 * in the same scenarios.
106 */
107 kCPacked, // Emit tir.call_cpacked and wrap all arguments in DLTensor.
108
109 /*! \brief Directly call a function accepting the `data` arrays as args.
110 *
111 * When this type is selected, assumes all operaotrs are implemented in C functions whose
112 * arguments are 1-to-1 with those in the tir::Call. DLTensor arguments are encoded as just the
113 * `data` parameters (i.e. no DLTensor object is passed along).
114 *
115 * The following code is emitted at call sites to a function named `func`:
116 * func(void* arg0, void* arg1, ..., void* argN) // no resource_handle
117 * -or-
118 * func(void* arg0, void* arg1, ..., void* argN, void* resource_handle) // with resource_handle
119 *
120 * `resource_handle` is encoded as the final argument to the tir::Call node. In practice, it is
121 * always the device context parameter when not null.
122 *
123 * Compatible with the C runtime and C++ runtime (so long as target_host == target). Implemented
124 * with the C runtime only.
125 */
126 kUnpacked, // Emit tir.call_extern passing only the `data` part of DLTensors.
127};
128
129/*!
130 * \brief Structure that can be optionally used by the executor codegen
131 */
132class ExecutorCodegenMetadataNode : public Object {
133 public:
134 /*! \brief input information for the main function */
135 Array<tir::Var> inputs;
136 /*! \brief input tensor type information */
137 Array<TensorType> input_tensor_types;
138 /*! \brief output information for the main function */
139 Array<String> outputs;
140 /*! \brief output tensor type information */
141 Array<TensorType> output_tensor_types;
142 /*! \brief pool information for the main function */
143 Array<tir::Var> pools;
144 /*! \brief device contexts information for the main function */
145 Array<String> devices;
146 /*! \brief the executor to be used to run the model */
147 String executor = runtime::kTvmExecutorGraph;
148 /*! \brief The external API (packed or c) in use */
149 String interface_api;
150 /*! \brief The internal API (packed or unpacked) in use */
151 bool unpacked_api;
152 /*! \brief Alginment of the workspace in bytes */
153 Integer workspace_alignment;
154 /*! \brief Alginment of the constants in bytes */
155 Integer constant_alignment;
156 /*! \brief the input var names that correspond to pool_inputs */
157 Optional<Map<tir::Var, tir::usmp::AllocatedPoolInfo>> pool_inputs;
158 /*! \brief the I/O tensor to PoolAllocations if any*/
159 Map<String, tir::usmp::PoolAllocation> io_pool_allocations;
160
161 String mod_name = "";
162
163 void VisitAttrs(tvm::AttrVisitor* v) {
164 v->Visit("inputs", &inputs);
165 v->Visit("input_tensor_types", &input_tensor_types);
166 v->Visit("outputs", &outputs);
167 v->Visit("output_tensor_types", &output_tensor_types);
168 v->Visit("pools", &pools);
169 v->Visit("devices", &devices);
170 v->Visit("executor", &executor);
171 v->Visit("interface_api", &interface_api);
172 v->Visit("unpacked_api", &unpacked_api);
173 v->Visit("workspace_alignment", &workspace_alignment);
174 v->Visit("constant_alignment", &constant_alignment);
175 v->Visit("pool_inputs", &pool_inputs);
176 v->Visit("io_pool_allocations", &io_pool_allocations);
177 v->Visit("mod_name", &mod_name);
178 }
179
180 static constexpr const char* _type_key = "MetadataObj";
181 TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorCodegenMetadataNode, Object);
182};
183
184/*!
185 * \brief Managed reference to ExecutorCodegenMetadataNode.
186 */
187class ExecutorCodegenMetadata : public ObjectRef {
188 public:
189 TVM_DLL ExecutorCodegenMetadata(Array<tir::Var> inputs, Array<TensorType> input_tensor_types,
190 Array<String> outputs, Array<TensorType> output_tensor_types,
191 Array<tir::Var> pools, Array<String> devices, String executor,
192 String mod_name, String interface_api = "packed",
193 bool unpacked_api = false, Integer workspace_alignment = 16,
194 Integer constant_alignment = 16,
195 Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs =
196 Map<tir::Var, tir::usmp::AllocatedPoolInfo>(),
197 Map<String, tir::usmp::PoolAllocation> io_pool_allocations = {});
198 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecutorCodegenMetadata, ObjectRef,
199 ExecutorCodegenMetadataNode);
200};
201
202/*!
203 * \brief The static storage information for each Tensor in the result of a Relay expression
204 * (as per relay::FlattenTupleType).
205 */
206class StorageInfoNode : public Object {
207 public:
208 // TODO(mbs): Switch from struct-of-array to array-of-struct repr throughout.
209 /*! \brief The set of storage ids where the expression is stored. */
210 std::vector<int64_t> storage_ids;
211 /* \brief The virtual devices these expressions are stored within. */
212 std::vector<VirtualDevice> virtual_devices;
213 /* \brief The sizes of each storage element, in bytes. */
214 std::vector<int64_t> storage_sizes_in_bytes;
215
216 // TODO(@jroesch): expose the fields
217 void VisitAttrs(AttrVisitor* v) {}
218
219 static constexpr const char* _type_key = "relay.StorageInfo";
220 TVM_DECLARE_FINAL_OBJECT_INFO(StorageInfoNode, Object);
221};
222
223/*! \brief The storage information for a single expression. */
224class StorageInfo : public ObjectRef {
225 public:
226 StorageInfo(std::vector<int64_t> storage_ids, std::vector<VirtualDevice> virtual_devices,
227 std::vector<int64_t> storage_sizes_in_bytes);
228 TVM_DEFINE_OBJECT_REF_METHODS(StorageInfo, ObjectRef, StorageInfoNode);
229};
230
231/*!
232 * \brief The result of static memory planning.
233 */
234class StaticMemoryPlanNode : public Object {
235 public:
236 Map<Expr, StorageInfo> expr_to_storage_info;
237
238 void VisitAttrs(AttrVisitor* v) { v->Visit("expr_to_storage_info", &expr_to_storage_info); }
239
240 static constexpr const char* _type_key = "relay.StaticMemoryPlan";
241 TVM_DECLARE_FINAL_OBJECT_INFO(StaticMemoryPlanNode, Object);
242};
243
244/*! \brief The result of running static memory planning. */
245class StaticMemoryPlan : public ObjectRef {
246 public:
247 explicit StaticMemoryPlan(Map<Expr, StorageInfo> expr_to_storage_info);
248 TVM_DEFINE_OBJECT_REF_METHODS(StaticMemoryPlan, ObjectRef, StaticMemoryPlanNode);
249};
250
251struct FunctionInfoNode : public Object {
252 Map<Target, Integer> workspace_sizes;
253 Map<Target, Integer> io_sizes;
254 Map<Target, Integer> constant_sizes;
255 Map<Target, tir::PrimFunc> tir_primfuncs;
256 Map<Target, Function> relay_primfuncs;
257
258 void VisitAttrs(tvm::AttrVisitor* v) {
259 v->Visit("workspace_sizes", &workspace_sizes);
260 v->Visit("io_sizes", &io_sizes);
261 v->Visit("constant_sizes", &constant_sizes);
262 v->Visit("tir_primfuncs", &tir_primfuncs);
263 v->Visit("relay_primfuncs", &relay_primfuncs);
264 }
265
266 static constexpr const char* _type_key = "relay.backend.FunctionInfo";
267 TVM_DECLARE_FINAL_OBJECT_INFO(FunctionInfoNode, Object);
268};
269
270class FunctionInfo : public ObjectRef {
271 public:
272 FunctionInfo(Map<Target, Integer> workspace_sizes, Map<Target, Integer> io_sizes,
273 Map<Target, Integer> constant_sizes, Map<Target, tir::PrimFunc> tir_primfuncs,
274 Map<Target, Function> relay_primfuncs);
275
276 TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(FunctionInfo, ObjectRef, FunctionInfoNode);
277};
278
279/*!
280 * \brief Calculate the bytes of memory needed to hold a tensor of a given shape and data type.
281 * \param shape The shape of the tensor
282 * \param dtype The data type of the tensor
283 */
284size_t GetMemorySizeBytes(const Array<PrimExpr>& shape, const DataType& dtype);
285
286/*!
287 * \brief Calculate the storage required to store the type of relay.Expr
288 *
289 * \param func The relay expr for which the storage is calculated
290 */
291int64_t CalculateRelayExprSizeBytes(const Type& expr_type);
292
293/*!
294 * \brief Executor generator artifacts. Those artifacts are subsequently
295 * used by the relay build process.
296 */
297struct LoweredOutput {
298 std::string graph_json;
299 Map<Target, IRModule> lowered_funcs;
300 Array<tvm::runtime::Module> external_mods;
301 Map<String, FunctionInfo> function_metadata;
302 /*!
303 * \brief Map from constant names (allocated by the codegen as constants are encountered)
304 * to the constant's value.
305 */
306 std::unordered_map<std::string, tvm::runtime::NDArray> params;
307 ExecutorCodegenMetadata metadata;
308};
309
310/*!
311 * \brief This class is needed to avoid a GCC 5 bug that prevents maps containing enums from being
312 compiled. If i386 GCC version is increased, we can remove it.
313 */
314struct EnumClassHash {
315 template <typename T>
316 std::size_t operator()(T t) const {
317 return static_cast<std::size_t>(t);
318 }
319};
320
321/*!
322 * \brief A helper to expand the params by adding the ones used in a given expression.
323 */
324struct ConstantUpdater : public ExprVisitor {
325 public:
326 ConstantUpdater(const std::string& symbol,
327 std::unordered_map<std::string, runtime::NDArray>* params)
328 : symbol_(symbol), params_(params) {}
329
330 void VisitExpr_(const ConstantNode* cn) final {
331 std::string name = symbol_ + "_const_" + std::to_string(const_idx_++);
332 VLOG(1) << "binding '" << name << "' to constant of type " << PrettyPrint(cn->checked_type());
333 (*params_)[name] = cn->data;
334 }
335
336 private:
337 int const_idx_{0};
338 std::string symbol_;
339 std::unordered_map<std::string, runtime::NDArray>* params_;
340};
341
342/*!
343 * \brief A function to update the params with constants found in an external function.
344 * \param func The function from which to get the constant params.
345 * \param params The params to update with the constants.
346 */
347inline void UpdateConstants(BaseFunc func,
348 std::unordered_map<std::string, runtime::NDArray>* params) {
349 VLOG_CONTEXT << "UpdateConstants";
350 VLOG(1) << "updating constants for:" << std::endl << PrettyPrint(func);
351 auto codegen = func->GetAttr<String>(attr::kCompiler);
352 ICHECK(codegen.defined()) << "No external codegen is set";
353 std::string codegen_name = codegen.value();
354 const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
355 std::string symbol = std::string(name_node.value());
356 std::string const_update_name = "relay.ext." + codegen_name + ".constant_updater";
357 // Get the constant updater for the external codegen
358 auto pf = tvm::runtime::Registry::Get(const_update_name);
359 // If the backend hasn't registered a constant updater, use a default one
360 if (pf == nullptr) {
361 ConstantUpdater const_visit(symbol, params);
362 const_visit(func);
363 } else {
364 Map<String, tvm::runtime::NDArray> constants = (*pf)(func, symbol);
365 for (const auto& it : constants) {
366 std::string const_name(it.first);
367 // Constant names should begin this the compiler name (to avoid conflicts)
368 ICHECK(const_name.find(codegen_name) == 0)
369 << "External constant names must start with compiler name";
370 (*params)[const_name] = it.second;
371 }
372 }
373 for (const auto& pair : *params) {
374 VLOG(1) << "Constants: " << pair.first << " = " << PrettyPrint(pair.second);
375 }
376}
377
378/*!
379 * \brief A simple wrapper around ExprFunctor for a single argument case.
380 * The result of visit is memoized.
381 */
382template <typename OutputType>
383class MemoizedExprTranslator : public ::tvm::relay::ExprFunctor<OutputType(const Expr&)> {
384 using BaseFunctor = ::tvm::relay::ExprFunctor<OutputType(const Expr&)>;
385
386 public:
387 /*! \brief virtual destructor */
388 virtual ~MemoizedExprTranslator() {}
389
390 /*!
391 * \brief The memoized call.
392 * \param n The expression node.
393 * \return The result of the call
394 */
395 virtual OutputType VisitExpr(const Expr& n) {
396 ICHECK(n.defined());
397 auto it = memo_.find(n);
398 if (it != memo_.end()) {
399 return it->second;
400 }
401 auto res = BaseFunctor::VisitExpr(n);
402 memo_[n] = res;
403 return res;
404 }
405
406 protected:
407 /*! \brief Internal map used for memoization. */
408 std::unordered_map<Expr, OutputType, ObjectPtrHash, ObjectPtrEqual> memo_;
409};
410
411/*!
412 * \brief Get the Packed Func
413 *
414 * \param func_name
415 * \return const PackedFunc*
416 */
417inline const PackedFunc* GetPackedFunc(const std::string& func_name) {
418 return tvm::runtime::Registry::Get(func_name);
419}
420
421/*!
422 * \brief Get a typed packed function.
423 *
424 * \param func_name
425 * \return const PackedFunc*
426 */
427template <typename R, typename... Args>
428inline const runtime::TypedPackedFunc<R(Args...)> GetTypedPackedFunc(const std::string& func_name) {
429 auto* pf = GetPackedFunc(func_name);
430 ICHECK(pf != nullptr) << "can not find packed function";
431 return runtime::TypedPackedFunc<R(Args...)>(*pf);
432}
433
434/*!
435 * \brief Extract shape from an IndexExpr array to std::vector<int64_t>
436 *
437 * \param shape The shape in Array
438 * \return The converted shape in std::vector<int64_t>
439 */
440inline std::vector<int64_t> GetIntShape(const Array<IndexExpr>& shape) {
441 std::vector<int64_t> ret;
442 for (const auto& dim : shape) {
443 const int64_t* pval = tir::as_const_int(dim);
444 ret.push_back(pval ? *pval : -1);
445 }
446 return ret;
447}
448
449/*!
450 * \brief Convert type to string
451 *
452 * \param typ
453 * \return std::string string format of type
454 */
455inline std::string DType2String(const tvm::DataType dtype) {
456 std::ostringstream os;
457 if (dtype.is_float()) {
458 os << "float";
459 } else if (dtype.is_int()) {
460 os << "int";
461 } else if (dtype.is_uint()) {
462 os << "uint";
463 } else if (dtype.is_bfloat16()) {
464 os << "bfloat";
465 } else if ((*GetPackedFunc("runtime._datatype_get_type_registered"))(dtype.code())) {
466 os << "custom["
467 << (*GetPackedFunc("runtime._datatype_get_type_name"))(dtype.code()).operator std::string()
468 << "]";
469 } else {
470 LOG(FATAL) << "Unknown type with code " << static_cast<unsigned>(dtype.code());
471 }
472 os << dtype.bits();
473 return os.str();
474}
475
476/*!
477 * \brief Bind params to function by using name
478 * \param func Relay function
479 * \param params params dict
480 * \return relay::Function
481 */
482relay::Function BindParamsByName(relay::Function func,
483 const std::unordered_map<std::string, runtime::NDArray>& params);
484
485/*!
486 * \brief Bind params to the main function in Relay module, using BindParamsByName
487 * \param mod Relay module
488 * \param params params dict
489 */
490void BindParamsInModule(IRModule mod,
491 const std::unordered_map<std::string, runtime::NDArray>& params);
492
493void BindParamsInModule(IRModule mod, Map<String, runtime::NDArray> params);
494
495/*!
496 * \brief Extract the shape from a Relay tensor type.
497 * \param type The provided type.
498 * \return The extracted shape in a list.
499 */
500inline std::vector<int> GetShape(const Type& type) {
501 const auto* ttype = type.as<TensorTypeNode>();
502 ICHECK(ttype) << "Expect TensorTypeNode";
503 std::vector<int> shape;
504 for (size_t i = 0; i < ttype->shape.size(); ++i) {
505 auto* val = ttype->shape[i].as<IntImmNode>();
506 ICHECK(val);
507 shape.push_back(val->value);
508 }
509 return shape;
510}
511
512/*!
513 * \brief Check if a call has the provided name.
514 * \param call A Relay call node.
515 * \param op_name The name of the expected call.
516 * \return true if the call's name is equivalent to the given name. Otherwise,
517 * false.
518 */
519inline bool IsOp(const CallNode* call, const std::string& op_name) {
520 const auto* op_node = call->op.as<OpNode>();
521 ICHECK(op_node) << "Expects a single op.";
522 Op op = GetRef<Op>(op_node);
523 return op == Op::Get(op_name);
524}
525
526/*!
527 * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d))
528 * \param call A Relay call node. Typically nn.relu when called the first time.
529 * \param depth The number of calls before the root op, counting from current_call.
530 * \param expected_op_names The names of ops in this fused call. Example: {"nn.conv2d", "add",
531 * "nn.relu"}
532 * \return A CallNode corresponding to the root op, whose name is expected_op_names[0]
533 */
534inline const CallNode* GetRootCall(const CallNode* current_call, int depth,
535 const std::vector<std::string>& expected_op_names) {
536 ICHECK(current_call && depth >= 0 && static_cast<size_t>(depth) < expected_op_names.size() &&
537 IsOp(current_call, expected_op_names[depth]));
538
539 if (depth == 0) {
540 return current_call;
541 }
542
543 ICHECK_GT(current_call->args.size(), 0);
544 size_t valid_node_idx = 0;
545 while (valid_node_idx < current_call->args.size() &&
546 current_call->args[valid_node_idx].as<VarNode>()) {
547 valid_node_idx++;
548 }
549 while (valid_node_idx < current_call->args.size() &&
550 !(IsOp(current_call->args[valid_node_idx].as<CallNode>(), expected_op_names[depth - 1]))) {
551 valid_node_idx++;
552 }
553 const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
554 return GetRootCall(next_call, depth - 1, expected_op_names);
555}
556
557/*!
558 * \brief Retrieve the "root" op nested inside a fused call, such as conv2d in relu(add(conv2d))
559 * Unlike the previous definition, it does not verify operator names of intermediate nodes. Instead,
560 * it recursively visit child nodes until it finds a call node with the given op_name.
561 * \param call A Relay call node.
562 * \param op_name The name of an op to look for, such as ""nn.conv2d".
563 * \return A CallNode corresponding to the root op with the given op_name
564 */
565inline const CallNode* GetRootCall(const CallNode* current_call, const std::string& op_name) {
566 if (current_call == nullptr) return nullptr;
567 if (IsOp(current_call, op_name)) return current_call;
568
569 ICHECK_GT(current_call->args.size(), 0);
570
571 const auto* next_call = current_call->args[0].as<CallNode>();
572 return GetRootCall(next_call, op_name);
573}
574
575/*!
576 * \brief Retrieve the expected "root" op nested inside a fused call, such as conv2d in
577 * relu(add(conv2d))
578 * \param call A Relay call node. Typically nn.relu when called the first time.
579 * \param max_depth The maximum number of calls before the root op, counting from current_call.
580 * \param op_name The name of expected "root" op in this fused call.
581 * \return A CallNode corresponding to the root op
582 */
583inline const CallNode* GetRootCall(const CallNode* current_call, int max_depth,
584 const std::string& op_name) {
585 ICHECK(current_call && max_depth >= 0);
586
587 if (max_depth == 0) {
588 ICHECK(current_call && IsOp(current_call, op_name));
589 return current_call;
590 }
591 if (IsOp(current_call, op_name)) {
592 return current_call;
593 }
594
595 ICHECK_GT(current_call->args.size(), 0);
596
597 size_t valid_node_idx = 0;
598 while (valid_node_idx < current_call->args.size() &&
599 current_call->args[valid_node_idx].as<VarNode>()) {
600 valid_node_idx++;
601 }
602
603 const auto* next_call = current_call->args[valid_node_idx].as<CallNode>();
604 return GetRootCall(next_call, max_depth - 1, op_name);
605}
606
607/*!
608 * \brief Get the external symbol of the Relay function name.
609 *
610 * \param func The provided function.
611 * \return An external symbol.
612 */
613inline std::string GetExtSymbol(const Function& func) {
614 const auto name_node = func->GetAttr<String>(tvm::attr::kGlobalSymbol);
615 ICHECK(name_node.defined()) << "Fail to retrieve external symbol.";
616 return std::string(name_node.value());
617}
618
619/*!
620 * \brief Return whether the auto scheduler is enabled in the pass context.
621 */
622inline bool IsAutoSchedulerEnabled() {
623 return transform::PassContext::Current()
624 ->GetConfig<Bool>("relay.backend.use_auto_scheduler", Bool(false))
625 .value();
626}
627
628/*!
629 * \brief Return whether the meta schedule is enabled in the pass context.
630 */
631inline bool IsMetaScheduleEnabled() {
632 return transform::PassContext::Current()
633 ->GetConfig<Bool>("relay.backend.use_meta_schedule", Bool(false))
634 .value();
635}
636
637/*! \brief Consider MetaSchedule's dispatch option. */
638inline int UseMetaScheduleDispatch() {
639 return transform::PassContext::Current()
640 ->GetConfig<Integer>("relay.backend.use_meta_schedule_dispatch", Integer(0))
641 .value()
642 ->value;
643}
644/*!
645 * \brief Method in TECompiler to convert TE compute to scheduleable TIR
646 * \param args The arguments of the TE compute
647 * \param constants The constants used in AllocateConst
648 * \return NullOpt if conversion fails; Otherwise the converted TIR
649 * \note This method could be further used as a task filtering mechanism in task extraction
650 */
651using FTECompilerTIRConverter = runtime::TypedPackedFunc< //
652 Optional<tir::PrimFunc>( //
653 const Array<te::Tensor>& args, //
654 const Array<runtime::NDArray>& constants)>;
655
656/*! \brief Return a task filter for AutoTIR according to `relay.backend.tir_converter` */
657inline FTECompilerTIRConverter GetTIRConverter() {
658 String name = transform::PassContext::Current()
659 ->GetConfig<String>("relay.backend.tir_converter", "default")
660 .value();
661 const PackedFunc* f = runtime::Registry::Get("relay.backend.tir_converter." + name);
662 ICHECK(f != nullptr) << "IndexError: Cannot find TIR converter: " << name;
663 return FTECompilerTIRConverter(*f);
664}
665
666/*! \brief Converts a PrimFunc to IRModule. */
667inline IRModule PrimFuncToIRModule(tir::PrimFunc f) {
668 f = WithAttrs(f, Map<String, ObjectRef>{
669 {tvm::attr::kGlobalSymbol, String("main")},
670 {tvm::tir::attr::kNoAlias, Bool(1)},
671 });
672 return IRModule({{GlobalVar("main"), f}});
673}
674
675/*!
676 * \brief Get the sequence of Relay optimization passes based on backend type.
677 * The prefix of the Relay passes almost overlaps between the vm and graph backend, with some slight
678 * difference. This function unifies the shared optimization pass prefix between vm and graph
679 * runtime, and returns the pass prefix given the backend type.
680 *
681 * \param is_homogeneous True if all primitives are to be executed on the same device and target.
682 * \param is_vm True if passes are to be used for the vm executor.
683 * \return An array of passes.
684 */
685Array<Pass> GetPassPrefix(bool is_homogeneous, bool is_vm);
686
687/*! \brief Target hash function */
688struct TargetStrHash {
689 /*!
690 * \brief Calculate the hash code of a Target based on the string value of the Target KIND.
691 Note that this hash should NOT be used in new usecases, equality of targets based on their
692 value is not well-defined.
693 This will be removed when maps from Targets to IRModules are removed from the codebase.
694 * \param target The Target to hash
695 * \return String hash of the target
696 */
697 size_t operator()(const Target& target) const {
698 std::string s(target->kind->name);
699 return String::HashBytes(s.c_str(), s.size());
700 }
701};
702
703/*! \brief Target equality function based on the string value of Target
704Note that this equality function should NOT be used in new usecases, equality of targets based on
705their value is not well-defined. This will be removed when maps from Targets to IRModules are
706removed from the codebase.*/
707struct TargetStrEqual {
708 /*!
709 * \brief Check if the two Targets are equal
710 * \param target One Target
711 * \param other_target The other Target
712 * \return String equality of the targets
713 */
714 const bool operator()(const Target& target, const Target& other_target) const {
715 TargetStrHash target_hash = TargetStrHash();
716 return target_hash(target) == target_hash(other_target);
717 }
718};
719
720/*!
721 * \brief Convert a Map<Target, IRModule> to std::unordered_map<Target, IRmodule, TargetStrHash,
722 * TargetStrEqual> Target equality is currently based on pointer equality, which is a problem since
723 * we have a lot of Map<Target, IRModule> in the codebase. This function converts the map to a
724 * version that is keyed based on string value of the Target instead. Note that once we remove
725 * Map<Target, IRModule>, this function will be removed.
726 * \param input_map The map to convert
727 * \return The converted map
728 */
729std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
730TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map);
731
732/*!
733 * \brief Convert a std::unordered_map<Target, IRmodule, TargetStrHash, TargetStrEqual> to
734 * Map<Target, IRModule> This function is a helper that undoes TargetModuleMapToTargetStr. Note that
735 * once we remove Map<Target, IRModule>, this function will be removed.
736 * \param input_map The map to convert
737 * \return The converted map
738 */
739Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
740 std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map);
741
742/*!
743 * \brief Call "weight update callback" to communicate op weights seen during Relay module
744 * lowering back to the auto scheduler.
745 * Op weights refer to the number of times each distinct op/workload appears in a given module.
746 * It is called "use_count" in TECompiler.
747 * \param IRModule after lowering by LowerTEPass.
748 */
749void UpdateAutoSchedulerOpWeights(const IRModule& module);
750
751/*!
752 * \brief Extract shape from expr to vector<int64_t>
753 *
754 * \param shape
755 * \return std::vector<int64_t>
756 */
757std::vector<int64_t> ShapeToJSON(tvm::Array<IndexExpr> shape);
758
759} // namespace backend
760} // namespace relay
761} // namespace tvm
762
763#endif // TVM_RELAY_BACKEND_UTILS_H_
764