1
2/*
3 * Licensed to the Apache Software Foundation (ASF) under one
4 * or more contributor license agreements. See the NOTICE file
5 * distributed with this work for additional information
6 * regarding copyright ownership. The ASF licenses this file
7 * to you under the Apache License, Version 2.0 (the
8 * "License"); you may not use this file except in compliance
9 * with the License. You may obtain a copy of the License at
10 *
11 * http://www.apache.org/licenses/LICENSE-2.0
12 *
13 * Unless required by applicable law or agreed to in writing,
14 * software distributed under the License is distributed on an
15 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
16 * KIND, either express or implied. See the License for the
17 * specific language governing permissions and limitations
18 * under the License.
19 */
20
21/*!
22 * \file relay/backend/util.cc
23 * \brief Relay backend utilities.
24 */
25
26#include "utils.h"
27
28#include <tvm/relay/parser.h>
29#include <tvm/relay/qnn/transform.h>
30#include <tvm/runtime/ndarray.h>
31#include <tvm/tir/stmt_functor.h>
32
33#include "../../te/operation/create_primfunc.h"
34
35namespace tvm {
36namespace relay {
37namespace backend {
38
39TVM_REGISTER_NODE_TYPE(StorageInfoNode);
40
41TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
42 .set_dispatch<StorageInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
43 const auto* node = ref.as<StorageInfoNode>();
44 p->stream << "StorageInfoNode("
45 << "storage_ids=[";
46 for (auto id : node->storage_ids) {
47 p->stream << id << ",";
48 }
49 p->stream << "], virtual_devices=[";
50 for (const auto& virtual_device : node->virtual_devices) {
51 p->stream << virtual_device << ",";
52 }
53 p->stream << "], storage_size_in_bytes=[";
54 for (auto bytes : node->storage_sizes_in_bytes) {
55 p->stream << bytes << ",";
56 }
57 p->stream << "])";
58 });
59
60StorageInfo::StorageInfo(std::vector<int64_t> storage_ids,
61 std::vector<VirtualDevice> virtual_devices,
62 std::vector<int64_t> storage_sizes_in_bytes) {
63 ICHECK_EQ(storage_ids.size(), virtual_devices.size());
64 ICHECK_EQ(storage_ids.size(), storage_sizes_in_bytes.size());
65 auto node = make_object<StorageInfoNode>();
66 node->storage_ids = std::move(storage_ids);
67 node->virtual_devices = std::move(virtual_devices);
68 node->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes);
69 data_ = std::move(node);
70}
71
72// This is the legacy interface for devices as DLDeviceTypes (represented by integers)
73TVM_REGISTER_GLOBAL("relay.ir.StorageInfo")
74 .set_body_typed([](const Array<Integer>& sids, const Array<Integer>& device_types,
75 const Array<Integer>& sizes_in_bytes) {
76 std::vector<int64_t> sids_v;
77 sids_v.reserve(sids.size());
78 for (auto s : sids) {
79 sids_v.push_back(s.IntValue());
80 }
81 std::vector<VirtualDevice> virtual_devices_v;
82 virtual_devices_v.reserve(device_types.size());
83 for (const auto& device_type : device_types) {
84 virtual_devices_v.emplace_back(VirtualDevice::ForDeviceType(device_type));
85 }
86 std::vector<int64_t> size_in_bytes_v;
87 size_in_bytes_v.reserve(sizes_in_bytes.size());
88 for (auto s : sizes_in_bytes) {
89 size_in_bytes_v.push_back(s.IntValue());
90 }
91 return StorageInfo(std::move(sids_v), std::move(virtual_devices_v),
92 std::move(size_in_bytes_v));
93 });
94
95TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds").set_body_typed([](StorageInfo si) {
96 Array<tvm::Integer> ids;
97 for (auto id : si->storage_ids) {
98 ids.push_back(id);
99 }
100 return ids;
101});
102
103// This is the legacy interface for devices as DLDeviceTypes (represented by integers)
104TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes").set_body_typed([](StorageInfo si) {
105 Array<tvm::Integer> device_types;
106 for (const auto& virtual_device : si->virtual_devices) {
107 device_types.push_back(virtual_device->device_type());
108 }
109 return device_types;
110});
111
112TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageSizes").set_body_typed([](StorageInfo si) {
113 Array<tvm::Integer> storage_sizes_in_bytes;
114 for (auto id : si->storage_sizes_in_bytes) {
115 storage_sizes_in_bytes.push_back(id);
116 }
117 return storage_sizes_in_bytes;
118});
119
120TVM_REGISTER_GLOBAL("relay.ir.StorageInfoVirtualDevices").set_body_typed([](StorageInfo si) {
121 Array<VirtualDevice> virtual_devices;
122 for (auto id : si->virtual_devices) {
123 virtual_devices.push_back(id);
124 }
125 return virtual_devices;
126});
127
128TVM_REGISTER_NODE_TYPE(StaticMemoryPlanNode);
129
130StaticMemoryPlan::StaticMemoryPlan(Map<Expr, StorageInfo> expr_to_storage_info) {
131 auto n = make_object<StaticMemoryPlanNode>();
132 n->expr_to_storage_info = std::move(expr_to_storage_info);
133 data_ = std::move(n);
134}
135
136TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan")
137 .set_body_typed([](const Map<Expr, StorageInfo>& expr_to_storage_info) {
138 return StaticMemoryPlan(expr_to_storage_info);
139 });
140
141size_t DivRoundUp(size_t size, size_t word_size) { return (size + word_size - 1) / word_size; }
142
143size_t GetMemorySizeBytes(const Array<PrimExpr>& shape, const DataType& dtype) {
144 size_t size = 1;
145 for (IndexExpr dim : shape) {
146 const int64_t* pval = tir::as_const_int(dim);
147 ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << shape;
148 ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval;
149 size *= static_cast<size_t>(pval[0]);
150 }
151 size *= DivRoundUp(dtype.bits() * dtype.lanes(), 8);
152 return size;
153}
154
155int64_t CalculateRelayExprSizeBytes(const Type& expr_type) {
156 if (expr_type->IsInstance<TupleTypeNode>()) {
157 auto tuple_type = Downcast<TupleType>(expr_type);
158 int64_t size = 0;
159 for (const auto& field : tuple_type->fields) {
160 size += CalculateRelayExprSizeBytes(field);
161 }
162 return size;
163 }
164 auto tensor_type = expr_type.as<TensorTypeNode>();
165 ICHECK(tensor_type);
166 auto shape = tensor_type->shape;
167 return GetMemorySizeBytes(tensor_type->shape, tensor_type->dtype);
168}
169
170TVM_REGISTER_NODE_TYPE(FunctionInfoNode);
171
172FunctionInfo::FunctionInfo(Map<Target, Integer> workspace_sizes, Map<Target, Integer> io_sizes,
173 Map<Target, Integer> constant_sizes,
174 Map<Target, tir::PrimFunc> tir_primfuncs,
175 Map<Target, Function> relay_primfuncs) {
176 ObjectPtr<FunctionInfoNode> n = make_object<FunctionInfoNode>();
177 n->workspace_sizes = std::move(workspace_sizes);
178 n->io_sizes = std::move(io_sizes);
179 n->constant_sizes = std::move(constant_sizes);
180 n->tir_primfuncs = std::move(tir_primfuncs);
181 n->relay_primfuncs = std::move(relay_primfuncs);
182 data_ = std::move(n);
183}
184
185TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
186 .set_dispatch<FunctionInfoNode>([](const ObjectRef& ref, ReprPrinter* p) {
187 auto* node = static_cast<const FunctionInfoNode*>(ref.get());
188 p->stream << "FunctionInfoNode(\n"
189 << "workspace_sizes=" << node->workspace_sizes << ",\n io_sizes=" << node->io_sizes
190 << ",\n constant_sizes=" << node->constant_sizes
191 << ",\n tir_primfuncs=" << node->tir_primfuncs
192 << ",\n relay_primfuncs=" << node->relay_primfuncs << ")";
193 });
194
195ExecutorCodegenMetadata::ExecutorCodegenMetadata(
196 Array<tir::Var> inputs, Array<TensorType> input_tensor_types, Array<String> outputs,
197 Array<TensorType> output_tensor_types, Array<tir::Var> pools, Array<String> devices,
198 String executor, String mod_name, String interface_api, bool unpacked_api,
199 Integer workspace_alignment, Integer constant_alignment,
200 Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs,
201 Map<String, tir::usmp::PoolAllocation> io_pool_allocations) {
202 auto n = make_object<ExecutorCodegenMetadataNode>();
203 n->inputs = inputs;
204 n->input_tensor_types = input_tensor_types;
205 n->outputs = outputs;
206 n->output_tensor_types = output_tensor_types;
207 n->pools = pools;
208 n->devices = devices;
209 n->executor = executor;
210 n->interface_api = interface_api;
211 n->unpacked_api = unpacked_api;
212 n->mod_name = mod_name;
213 n->workspace_alignment = workspace_alignment;
214 n->constant_alignment = constant_alignment;
215 n->pool_inputs = pool_inputs;
216 n->io_pool_allocations = io_pool_allocations;
217 data_ = std::move(n);
218}
219
220TVM_REGISTER_NODE_TYPE(ExecutorCodegenMetadataNode);
221
222Array<Pass> GetPassPrefix(bool is_homogeneous, bool is_vm) {
223 Array<Pass> pass_seqs;
224 // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton
225 // by most passes there's little utility in including this now. Plus we'd need to only do
226 // this if there's no existing spans to work from.
227 // pass_seqs.push_back(parser::AnnotateSpans());
228 Array<runtime::String> entry_functions{"main"};
229 pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions));
230 pass_seqs.push_back(transform::ToBasicBlockNormalForm());
231 // Run all dialect legalization passes.
232 pass_seqs.push_back(relay::qnn::transform::Legalize());
233
234 // Legalize pass is restricted to homogeneous execution for now.
235 if (is_homogeneous) {
236 pass_seqs.push_back(transform::Legalize());
237 }
238
239 pass_seqs.push_back(transform::SimplifyInference());
240
241 if (is_vm) {
242 // eta expand to support constructors in argument position
243 pass_seqs.push_back(transform::EtaExpand(
244 /* expand_constructor */ true, /* expand_global_var */ false));
245 }
246
247 PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) {
248 Expr expr = args[0];
249 if (auto* call_node = expr.as<CallNode>()) {
250 auto op_node = call_node->op.as<OpNode>();
251 if (op_node->name == "cast") {
252 auto attrs = call_node->attrs.as<CastAttrs>();
253 if (attrs->dtype == DataType::Int(32)) {
254 *rv = true;
255 }
256 }
257 }
258 *rv = false;
259 });
260 pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
261 pass_seqs.push_back(transform::CombineParallelConv2D(3));
262 pass_seqs.push_back(transform::CombineParallelDense(3));
263 pass_seqs.push_back(transform::CombineParallelBatchMatmul(3));
264 pass_seqs.push_back(transform::FoldConstant());
265 pass_seqs.push_back(transform::FoldScaleAxis());
266 pass_seqs.push_back(transform::SimplifyExpr());
267 pass_seqs.push_back(transform::CanonicalizeCast());
268 pass_seqs.push_back(transform::CanonicalizeOps());
269 pass_seqs.push_back(transform::FlattenAtrousConv());
270
271 // Alter layout transformation is currently only applied to homogeneous execution.
272 if (is_homogeneous) {
273 if (!is_vm) {
274 pass_seqs.push_back(transform::InferType());
275 }
276 pass_seqs.push_back(transform::AlterOpLayout());
277 }
278
279 // Fast math optimizations.
280 pass_seqs.push_back(transform::FastMath());
281 pass_seqs.push_back(transform::FoldConstant());
282
283 return pass_seqs;
284}
285
286std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual>
287TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map) {
288 std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> std_map;
289 for (auto kv : input_map) {
290 std_map[kv.first] = kv.second;
291 }
292 return std_map;
293}
294
295Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap(
296 std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map) {
297 Map<Target, IRModule> tvm_map;
298 for (auto kv : input_map) {
299 tvm_map.Set(kv.first, kv.second);
300 }
301 return tvm_map;
302}
303
304void UpdateAutoSchedulerOpWeights(const IRModule& module) {
305 const auto* te_compiler_update_weights =
306 runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights");
307
308 ICHECK(te_compiler_update_weights != nullptr)
309 << "auto_scheduler.relay_integration.te_compiler_update_weights";
310
311 Map<String, Integer> weight_map =
312 module->GetAttr<Map<String, Integer>>("op_weights", Map<String, Integer>()).value();
313
314 (*te_compiler_update_weights)(weight_map);
315}
316
317std::vector<int64_t> ShapeToJSON(tvm::Array<IndexExpr> shape) {
318 std::vector<int64_t> ret;
319 for (IndexExpr dim : shape) {
320 const int64_t* pval = tir::as_const_int(dim);
321 ret.push_back(*pval);
322 }
323 return ret;
324}
325
326relay::Function BindParamsByName(relay::Function func,
327 const std::unordered_map<std::string, runtime::NDArray>& params) {
328 std::unordered_map<std::string, relay::Var> name_dict;
329 std::unordered_set<relay::Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
330 for (auto arg : func->params) {
331 const auto& name = arg->name_hint();
332 if (name_dict.count(name)) {
333 repeat_var.insert(name_dict[name]);
334 } else {
335 name_dict[name] = arg;
336 }
337 }
338
339 std::unordered_map<relay::Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
340 for (auto& kv : params) {
341 if (name_dict.count(kv.first) == 0) {
342 continue;
343 }
344 auto arg = name_dict.at(kv.first);
345 if (repeat_var.count(arg)) {
346 LOG(FATAL) << "Multiple args in the function have name " << kv.first;
347 }
348 bind_dict[arg] = Constant(kv.second);
349 }
350 Expr bound_expr = relay::Bind(func, bind_dict);
351 Function ret = Downcast<Function>(bound_expr);
352 ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function."
353 << "\n";
354 return ret;
355}
356
357void BindParamsInModule(IRModule mod,
358 const std::unordered_map<std::string, runtime::NDArray>& params) {
359 if (!params.empty()) {
360 BaseFunc base_func = mod->Lookup("main");
361 ICHECK(base_func->IsInstance<FunctionNode>());
362 auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params);
363 auto gvar = mod->GetGlobalVar("main");
364 mod->Add(gvar, f);
365 }
366}
367
368void BindParamsInModule(IRModule mod, Map<String, runtime::NDArray> params) {
369 std::unordered_map<std::string, runtime::NDArray> params_tmp;
370 for (const auto& kv : params) {
371 params_tmp[kv.first] = kv.second;
372 }
373 BindParamsInModule(mod, params_tmp);
374}
375
376/*!
377 * \brief A default TE compute to TIR compute.
378 * \param args The inputs/outputs of the TE compute graph.
379 * \param constants The constants bound to TIR
380 * \param allow_extern_op Whether to allow extern operation in TE.
381 * \return The TIR converted; NullOpt if not supported (dynamic shape)
382 */
383Optional<tir::PrimFunc> DefaultTIRConverterImpl(const Array<te::Tensor>& args,
384 const Array<runtime::NDArray>& constants,
385 bool allow_extern_op) {
386 using namespace ::tvm::te;
387 std::vector<Tensor> stack;
388 std::unordered_set<const TensorNode*> visited;
389 for (const Tensor& v : args) {
390 for (const PrimExpr& e : v->shape) {
391 // Dynamic shape is not supported for now
392 if (!e->IsInstance<IntImmNode>()) {
393 return NullOpt;
394 }
395 }
396 if (!visited.count(v.get())) {
397 visited.insert(v.get());
398 stack.push_back(v);
399 }
400 }
401 while (!stack.empty()) {
402 Tensor tensor = stack.back();
403 stack.pop_back();
404 if (tensor->op->IsInstance<PlaceholderOpNode>()) {
405 // do nothing
406 } else if (tensor->op->IsInstance<ComputeOpNode>() ||
407 (allow_extern_op && tensor->op->IsInstance<ExternOpNode>())) {
408 Array<Tensor> inputs = tensor->op->InputTensors();
409 for (const Tensor& v : inputs) {
410 if (!visited.count(v.get())) {
411 visited.insert(v.get());
412 stack.push_back(v);
413 }
414 }
415 } else {
416 return NullOpt;
417 }
418 }
419 PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, DataType::Int(64));
420 bool dynamic_loop_extent = false;
421 tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void {
422 if (const auto* loop = obj.as<tir::ForNode>()) {
423 if (!loop->extent->IsInstance<IntImmNode>()) {
424 dynamic_loop_extent = true;
425 }
426 }
427 });
428 if (dynamic_loop_extent) {
429 return NullOpt;
430 }
431 return func;
432}
433
434TVM_REGISTER_GLOBAL("relay.backend.tir_converter.default")
435 .set_body_typed([](const Array<te::Tensor>& args,
436 const Array<runtime::NDArray>& constants) -> Optional<tir::PrimFunc> {
437 return DefaultTIRConverterImpl(args, constants, false);
438 });
439
440TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern")
441 .set_body_typed([](const Array<te::Tensor>& args,
442 const Array<runtime::NDArray>& constants) -> Optional<tir::PrimFunc> {
443 return DefaultTIRConverterImpl(args, constants, true);
444 });
445
446} // namespace backend
447} // namespace relay
448} // namespace tvm
449