1 | |
2 | /* |
3 | * Licensed to the Apache Software Foundation (ASF) under one |
4 | * or more contributor license agreements. See the NOTICE file |
5 | * distributed with this work for additional information |
6 | * regarding copyright ownership. The ASF licenses this file |
7 | * to you under the Apache License, Version 2.0 (the |
8 | * "License"); you may not use this file except in compliance |
9 | * with the License. You may obtain a copy of the License at |
10 | * |
11 | * http://www.apache.org/licenses/LICENSE-2.0 |
12 | * |
13 | * Unless required by applicable law or agreed to in writing, |
14 | * software distributed under the License is distributed on an |
15 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
16 | * KIND, either express or implied. See the License for the |
17 | * specific language governing permissions and limitations |
18 | * under the License. |
19 | */ |
20 | |
21 | /*! |
22 | * \file relay/backend/util.cc |
23 | * \brief Relay backend utilities. |
24 | */ |
25 | |
26 | #include "utils.h" |
27 | |
28 | #include <tvm/relay/parser.h> |
29 | #include <tvm/relay/qnn/transform.h> |
30 | #include <tvm/runtime/ndarray.h> |
31 | #include <tvm/tir/stmt_functor.h> |
32 | |
33 | #include "../../te/operation/create_primfunc.h" |
34 | |
35 | namespace tvm { |
36 | namespace relay { |
37 | namespace backend { |
38 | |
39 | TVM_REGISTER_NODE_TYPE(StorageInfoNode); |
40 | |
41 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
42 | .set_dispatch<StorageInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { |
43 | const auto* node = ref.as<StorageInfoNode>(); |
44 | p->stream << "StorageInfoNode(" |
45 | << "storage_ids=[" ; |
46 | for (auto id : node->storage_ids) { |
47 | p->stream << id << "," ; |
48 | } |
49 | p->stream << "], virtual_devices=[" ; |
50 | for (const auto& virtual_device : node->virtual_devices) { |
51 | p->stream << virtual_device << "," ; |
52 | } |
53 | p->stream << "], storage_size_in_bytes=[" ; |
54 | for (auto bytes : node->storage_sizes_in_bytes) { |
55 | p->stream << bytes << "," ; |
56 | } |
57 | p->stream << "])" ; |
58 | }); |
59 | |
60 | StorageInfo::StorageInfo(std::vector<int64_t> storage_ids, |
61 | std::vector<VirtualDevice> virtual_devices, |
62 | std::vector<int64_t> storage_sizes_in_bytes) { |
63 | ICHECK_EQ(storage_ids.size(), virtual_devices.size()); |
64 | ICHECK_EQ(storage_ids.size(), storage_sizes_in_bytes.size()); |
65 | auto node = make_object<StorageInfoNode>(); |
66 | node->storage_ids = std::move(storage_ids); |
67 | node->virtual_devices = std::move(virtual_devices); |
68 | node->storage_sizes_in_bytes = std::move(storage_sizes_in_bytes); |
69 | data_ = std::move(node); |
70 | } |
71 | |
72 | // This is the legacy interface for devices as DLDeviceTypes (represented by integers) |
73 | TVM_REGISTER_GLOBAL("relay.ir.StorageInfo" ) |
74 | .set_body_typed([](const Array<Integer>& sids, const Array<Integer>& device_types, |
75 | const Array<Integer>& sizes_in_bytes) { |
76 | std::vector<int64_t> sids_v; |
77 | sids_v.reserve(sids.size()); |
78 | for (auto s : sids) { |
79 | sids_v.push_back(s.IntValue()); |
80 | } |
81 | std::vector<VirtualDevice> virtual_devices_v; |
82 | virtual_devices_v.reserve(device_types.size()); |
83 | for (const auto& device_type : device_types) { |
84 | virtual_devices_v.emplace_back(VirtualDevice::ForDeviceType(device_type)); |
85 | } |
86 | std::vector<int64_t> size_in_bytes_v; |
87 | size_in_bytes_v.reserve(sizes_in_bytes.size()); |
88 | for (auto s : sizes_in_bytes) { |
89 | size_in_bytes_v.push_back(s.IntValue()); |
90 | } |
91 | return StorageInfo(std::move(sids_v), std::move(virtual_devices_v), |
92 | std::move(size_in_bytes_v)); |
93 | }); |
94 | |
95 | TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageIds" ).set_body_typed([](StorageInfo si) { |
96 | Array<tvm::Integer> ids; |
97 | for (auto id : si->storage_ids) { |
98 | ids.push_back(id); |
99 | } |
100 | return ids; |
101 | }); |
102 | |
103 | // This is the legacy interface for devices as DLDeviceTypes (represented by integers) |
104 | TVM_REGISTER_GLOBAL("relay.ir.StorageInfoDeviceTypes" ).set_body_typed([](StorageInfo si) { |
105 | Array<tvm::Integer> device_types; |
106 | for (const auto& virtual_device : si->virtual_devices) { |
107 | device_types.push_back(virtual_device->device_type()); |
108 | } |
109 | return device_types; |
110 | }); |
111 | |
112 | TVM_REGISTER_GLOBAL("relay.ir.StorageInfoStorageSizes" ).set_body_typed([](StorageInfo si) { |
113 | Array<tvm::Integer> storage_sizes_in_bytes; |
114 | for (auto id : si->storage_sizes_in_bytes) { |
115 | storage_sizes_in_bytes.push_back(id); |
116 | } |
117 | return storage_sizes_in_bytes; |
118 | }); |
119 | |
120 | TVM_REGISTER_GLOBAL("relay.ir.StorageInfoVirtualDevices" ).set_body_typed([](StorageInfo si) { |
121 | Array<VirtualDevice> virtual_devices; |
122 | for (auto id : si->virtual_devices) { |
123 | virtual_devices.push_back(id); |
124 | } |
125 | return virtual_devices; |
126 | }); |
127 | |
128 | TVM_REGISTER_NODE_TYPE(StaticMemoryPlanNode); |
129 | |
130 | StaticMemoryPlan::StaticMemoryPlan(Map<Expr, StorageInfo> expr_to_storage_info) { |
131 | auto n = make_object<StaticMemoryPlanNode>(); |
132 | n->expr_to_storage_info = std::move(expr_to_storage_info); |
133 | data_ = std::move(n); |
134 | } |
135 | |
136 | TVM_REGISTER_GLOBAL("relay.ir.StaticMemoryPlan" ) |
137 | .set_body_typed([](const Map<Expr, StorageInfo>& expr_to_storage_info) { |
138 | return StaticMemoryPlan(expr_to_storage_info); |
139 | }); |
140 | |
141 | size_t DivRoundUp(size_t size, size_t word_size) { return (size + word_size - 1) / word_size; } |
142 | |
143 | size_t GetMemorySizeBytes(const Array<PrimExpr>& shape, const DataType& dtype) { |
144 | size_t size = 1; |
145 | for (IndexExpr dim : shape) { |
146 | const int64_t* pval = tir::as_const_int(dim); |
147 | ICHECK(pval != nullptr) << "Cannot allocate memory symbolic tensor shape " << shape; |
148 | ICHECK_GE(*pval, 0) << "Cannot allocate memory for tensor with negative shape" << *pval; |
149 | size *= static_cast<size_t>(pval[0]); |
150 | } |
151 | size *= DivRoundUp(dtype.bits() * dtype.lanes(), 8); |
152 | return size; |
153 | } |
154 | |
155 | int64_t CalculateRelayExprSizeBytes(const Type& expr_type) { |
156 | if (expr_type->IsInstance<TupleTypeNode>()) { |
157 | auto tuple_type = Downcast<TupleType>(expr_type); |
158 | int64_t size = 0; |
159 | for (const auto& field : tuple_type->fields) { |
160 | size += CalculateRelayExprSizeBytes(field); |
161 | } |
162 | return size; |
163 | } |
164 | auto tensor_type = expr_type.as<TensorTypeNode>(); |
165 | ICHECK(tensor_type); |
166 | auto shape = tensor_type->shape; |
167 | return GetMemorySizeBytes(tensor_type->shape, tensor_type->dtype); |
168 | } |
169 | |
170 | TVM_REGISTER_NODE_TYPE(FunctionInfoNode); |
171 | |
172 | FunctionInfo::FunctionInfo(Map<Target, Integer> workspace_sizes, Map<Target, Integer> io_sizes, |
173 | Map<Target, Integer> constant_sizes, |
174 | Map<Target, tir::PrimFunc> tir_primfuncs, |
175 | Map<Target, Function> relay_primfuncs) { |
176 | ObjectPtr<FunctionInfoNode> n = make_object<FunctionInfoNode>(); |
177 | n->workspace_sizes = std::move(workspace_sizes); |
178 | n->io_sizes = std::move(io_sizes); |
179 | n->constant_sizes = std::move(constant_sizes); |
180 | n->tir_primfuncs = std::move(tir_primfuncs); |
181 | n->relay_primfuncs = std::move(relay_primfuncs); |
182 | data_ = std::move(n); |
183 | } |
184 | |
185 | TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) |
186 | .set_dispatch<FunctionInfoNode>([](const ObjectRef& ref, ReprPrinter* p) { |
187 | auto* node = static_cast<const FunctionInfoNode*>(ref.get()); |
188 | p->stream << "FunctionInfoNode(\n" |
189 | << "workspace_sizes=" << node->workspace_sizes << ",\n io_sizes=" << node->io_sizes |
190 | << ",\n constant_sizes=" << node->constant_sizes |
191 | << ",\n tir_primfuncs=" << node->tir_primfuncs |
192 | << ",\n relay_primfuncs=" << node->relay_primfuncs << ")" ; |
193 | }); |
194 | |
195 | ExecutorCodegenMetadata::ExecutorCodegenMetadata( |
196 | Array<tir::Var> inputs, Array<TensorType> input_tensor_types, Array<String> outputs, |
197 | Array<TensorType> output_tensor_types, Array<tir::Var> pools, Array<String> devices, |
198 | String executor, String mod_name, String interface_api, bool unpacked_api, |
199 | Integer workspace_alignment, Integer constant_alignment, |
200 | Map<tir::Var, tir::usmp::AllocatedPoolInfo> pool_inputs, |
201 | Map<String, tir::usmp::PoolAllocation> io_pool_allocations) { |
202 | auto n = make_object<ExecutorCodegenMetadataNode>(); |
203 | n->inputs = inputs; |
204 | n->input_tensor_types = input_tensor_types; |
205 | n->outputs = outputs; |
206 | n->output_tensor_types = output_tensor_types; |
207 | n->pools = pools; |
208 | n->devices = devices; |
209 | n->executor = executor; |
210 | n->interface_api = interface_api; |
211 | n->unpacked_api = unpacked_api; |
212 | n->mod_name = mod_name; |
213 | n->workspace_alignment = workspace_alignment; |
214 | n->constant_alignment = constant_alignment; |
215 | n->pool_inputs = pool_inputs; |
216 | n->io_pool_allocations = io_pool_allocations; |
217 | data_ = std::move(n); |
218 | } |
219 | |
220 | TVM_REGISTER_NODE_TYPE(ExecutorCodegenMetadataNode); |
221 | |
222 | Array<Pass> GetPassPrefix(bool is_homogeneous, bool is_vm) { |
223 | Array<Pass> pass_seqs; |
224 | // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton |
225 | // by most passes there's little utility in including this now. Plus we'd need to only do |
226 | // this if there's no existing spans to work from. |
227 | // pass_seqs.push_back(parser::AnnotateSpans()); |
228 | Array<runtime::String> entry_functions{"main" }; |
229 | pass_seqs.push_back(transform::RemoveUnusedFunctions(entry_functions)); |
230 | pass_seqs.push_back(transform::ToBasicBlockNormalForm()); |
231 | // Run all dialect legalization passes. |
232 | pass_seqs.push_back(relay::qnn::transform::Legalize()); |
233 | |
234 | // Legalize pass is restricted to homogeneous execution for now. |
235 | if (is_homogeneous) { |
236 | pass_seqs.push_back(transform::Legalize()); |
237 | } |
238 | |
239 | pass_seqs.push_back(transform::SimplifyInference()); |
240 | |
241 | if (is_vm) { |
242 | // eta expand to support constructors in argument position |
243 | pass_seqs.push_back(transform::EtaExpand( |
244 | /* expand_constructor */ true, /* expand_global_var */ false)); |
245 | } |
246 | |
247 | PackedFunc fskip = PackedFunc([](TVMArgs args, TVMRetValue* rv) { |
248 | Expr expr = args[0]; |
249 | if (auto* call_node = expr.as<CallNode>()) { |
250 | auto op_node = call_node->op.as<OpNode>(); |
251 | if (op_node->name == "cast" ) { |
252 | auto attrs = call_node->attrs.as<CastAttrs>(); |
253 | if (attrs->dtype == DataType::Int(32)) { |
254 | *rv = true; |
255 | } |
256 | } |
257 | } |
258 | *rv = false; |
259 | }); |
260 | pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip)); |
261 | pass_seqs.push_back(transform::CombineParallelConv2D(3)); |
262 | pass_seqs.push_back(transform::CombineParallelDense(3)); |
263 | pass_seqs.push_back(transform::CombineParallelBatchMatmul(3)); |
264 | pass_seqs.push_back(transform::FoldConstant()); |
265 | pass_seqs.push_back(transform::FoldScaleAxis()); |
266 | pass_seqs.push_back(transform::SimplifyExpr()); |
267 | pass_seqs.push_back(transform::CanonicalizeCast()); |
268 | pass_seqs.push_back(transform::CanonicalizeOps()); |
269 | pass_seqs.push_back(transform::FlattenAtrousConv()); |
270 | |
271 | // Alter layout transformation is currently only applied to homogeneous execution. |
272 | if (is_homogeneous) { |
273 | if (!is_vm) { |
274 | pass_seqs.push_back(transform::InferType()); |
275 | } |
276 | pass_seqs.push_back(transform::AlterOpLayout()); |
277 | } |
278 | |
279 | // Fast math optimizations. |
280 | pass_seqs.push_back(transform::FastMath()); |
281 | pass_seqs.push_back(transform::FoldConstant()); |
282 | |
283 | return pass_seqs; |
284 | } |
285 | |
286 | std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> |
287 | TargetModuleMapToTargetStrModuleMap(Map<Target, IRModule> input_map) { |
288 | std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> std_map; |
289 | for (auto kv : input_map) { |
290 | std_map[kv.first] = kv.second; |
291 | } |
292 | return std_map; |
293 | } |
294 | |
295 | Map<Target, IRModule> TargetStrModuleMapToTargetModuleMap( |
296 | std::unordered_map<Target, IRModule, TargetStrHash, TargetStrEqual> input_map) { |
297 | Map<Target, IRModule> tvm_map; |
298 | for (auto kv : input_map) { |
299 | tvm_map.Set(kv.first, kv.second); |
300 | } |
301 | return tvm_map; |
302 | } |
303 | |
304 | void UpdateAutoSchedulerOpWeights(const IRModule& module) { |
305 | const auto* te_compiler_update_weights = |
306 | runtime::Registry::Get("auto_scheduler.relay_integration.te_compiler_update_weights" ); |
307 | |
308 | ICHECK(te_compiler_update_weights != nullptr) |
309 | << "auto_scheduler.relay_integration.te_compiler_update_weights" ; |
310 | |
311 | Map<String, Integer> weight_map = |
312 | module->GetAttr<Map<String, Integer>>("op_weights" , Map<String, Integer>()).value(); |
313 | |
314 | (*te_compiler_update_weights)(weight_map); |
315 | } |
316 | |
317 | std::vector<int64_t> ShapeToJSON(tvm::Array<IndexExpr> shape) { |
318 | std::vector<int64_t> ret; |
319 | for (IndexExpr dim : shape) { |
320 | const int64_t* pval = tir::as_const_int(dim); |
321 | ret.push_back(*pval); |
322 | } |
323 | return ret; |
324 | } |
325 | |
326 | relay::Function BindParamsByName(relay::Function func, |
327 | const std::unordered_map<std::string, runtime::NDArray>& params) { |
328 | std::unordered_map<std::string, relay::Var> name_dict; |
329 | std::unordered_set<relay::Var, ObjectPtrHash, ObjectPtrEqual> repeat_var; |
330 | for (auto arg : func->params) { |
331 | const auto& name = arg->name_hint(); |
332 | if (name_dict.count(name)) { |
333 | repeat_var.insert(name_dict[name]); |
334 | } else { |
335 | name_dict[name] = arg; |
336 | } |
337 | } |
338 | |
339 | std::unordered_map<relay::Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict; |
340 | for (auto& kv : params) { |
341 | if (name_dict.count(kv.first) == 0) { |
342 | continue; |
343 | } |
344 | auto arg = name_dict.at(kv.first); |
345 | if (repeat_var.count(arg)) { |
346 | LOG(FATAL) << "Multiple args in the function have name " << kv.first; |
347 | } |
348 | bind_dict[arg] = Constant(kv.second); |
349 | } |
350 | Expr bound_expr = relay::Bind(func, bind_dict); |
351 | Function ret = Downcast<Function>(bound_expr); |
352 | ICHECK(ret.defined()) << "The returning type is expected to be a Relay Function." |
353 | << "\n" ; |
354 | return ret; |
355 | } |
356 | |
357 | void BindParamsInModule(IRModule mod, |
358 | const std::unordered_map<std::string, runtime::NDArray>& params) { |
359 | if (!params.empty()) { |
360 | BaseFunc base_func = mod->Lookup("main" ); |
361 | ICHECK(base_func->IsInstance<FunctionNode>()); |
362 | auto f = relay::backend::BindParamsByName(Downcast<Function>(base_func), params); |
363 | auto gvar = mod->GetGlobalVar("main" ); |
364 | mod->Add(gvar, f); |
365 | } |
366 | } |
367 | |
368 | void BindParamsInModule(IRModule mod, Map<String, runtime::NDArray> params) { |
369 | std::unordered_map<std::string, runtime::NDArray> params_tmp; |
370 | for (const auto& kv : params) { |
371 | params_tmp[kv.first] = kv.second; |
372 | } |
373 | BindParamsInModule(mod, params_tmp); |
374 | } |
375 | |
376 | /*! |
377 | * \brief A default TE compute to TIR compute. |
378 | * \param args The inputs/outputs of the TE compute graph. |
379 | * \param constants The constants bound to TIR |
380 | * \param allow_extern_op Whether to allow extern operation in TE. |
381 | * \return The TIR converted; NullOpt if not supported (dynamic shape) |
382 | */ |
383 | Optional<tir::PrimFunc> DefaultTIRConverterImpl(const Array<te::Tensor>& args, |
384 | const Array<runtime::NDArray>& constants, |
385 | bool allow_extern_op) { |
386 | using namespace ::tvm::te; |
387 | std::vector<Tensor> stack; |
388 | std::unordered_set<const TensorNode*> visited; |
389 | for (const Tensor& v : args) { |
390 | for (const PrimExpr& e : v->shape) { |
391 | // Dynamic shape is not supported for now |
392 | if (!e->IsInstance<IntImmNode>()) { |
393 | return NullOpt; |
394 | } |
395 | } |
396 | if (!visited.count(v.get())) { |
397 | visited.insert(v.get()); |
398 | stack.push_back(v); |
399 | } |
400 | } |
401 | while (!stack.empty()) { |
402 | Tensor tensor = stack.back(); |
403 | stack.pop_back(); |
404 | if (tensor->op->IsInstance<PlaceholderOpNode>()) { |
405 | // do nothing |
406 | } else if (tensor->op->IsInstance<ComputeOpNode>() || |
407 | (allow_extern_op && tensor->op->IsInstance<ExternOpNode>())) { |
408 | Array<Tensor> inputs = tensor->op->InputTensors(); |
409 | for (const Tensor& v : inputs) { |
410 | if (!visited.count(v.get())) { |
411 | visited.insert(v.get()); |
412 | stack.push_back(v); |
413 | } |
414 | } |
415 | } else { |
416 | return NullOpt; |
417 | } |
418 | } |
419 | PrimFunc func = te::CreatePrimFuncWithConstants(args, constants, DataType::Int(64)); |
420 | bool dynamic_loop_extent = false; |
421 | tir::PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void { |
422 | if (const auto* loop = obj.as<tir::ForNode>()) { |
423 | if (!loop->extent->IsInstance<IntImmNode>()) { |
424 | dynamic_loop_extent = true; |
425 | } |
426 | } |
427 | }); |
428 | if (dynamic_loop_extent) { |
429 | return NullOpt; |
430 | } |
431 | return func; |
432 | } |
433 | |
434 | TVM_REGISTER_GLOBAL("relay.backend.tir_converter.default" ) |
435 | .set_body_typed([](const Array<te::Tensor>& args, |
436 | const Array<runtime::NDArray>& constants) -> Optional<tir::PrimFunc> { |
437 | return DefaultTIRConverterImpl(args, constants, false); |
438 | }); |
439 | |
440 | TVM_REGISTER_GLOBAL("relay.backend.tir_converter.allow_extern" ) |
441 | .set_body_typed([](const Array<te::Tensor>& args, |
442 | const Array<runtime::NDArray>& constants) -> Optional<tir::PrimFunc> { |
443 | return DefaultTIRConverterImpl(args, constants, true); |
444 | }); |
445 | |
446 | } // namespace backend |
447 | } // namespace relay |
448 | } // namespace tvm |
449 | |