1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/aot/create_function_metadata.cc |
22 | * \brief Create FunctionInfo metadata from a lowered TIR module. |
23 | */ |
24 | #include "./create_function_metadata.h" |
25 | |
26 | #include <tvm/ir/expr.h> |
27 | #include <tvm/ir/module.h> |
28 | #include <tvm/runtime/container/array.h> |
29 | #include <tvm/runtime/container/map.h> |
30 | #include <tvm/runtime/container/string.h> |
31 | #include <tvm/runtime/data_type.h> |
32 | #include <tvm/runtime/module.h> |
33 | #include <tvm/target/target_kind.h> |
34 | #include <tvm/tir/analysis.h> |
35 | #include <tvm/tir/function.h> |
36 | #include <tvm/tir/op.h> |
37 | #include <tvm/tir/usmp/utils.h> |
38 | |
39 | #include "../utils.h" |
40 | |
41 | namespace tvm { |
42 | namespace relay { |
43 | namespace backend { |
44 | namespace aot { |
45 | |
46 | /*! |
47 | * \brief Calculate FunctionInfo for all the PrimFuncs in a module. |
48 | */ |
49 | Map<String, backend::FunctionInfo> CalculateFunctionInfos(const IRModule& mod, |
50 | Integer workspace_byte_alignment, |
51 | Integer constant_byte_alignment) { |
52 | Map<String, backend::FunctionInfo> function_metadata; |
53 | for (const auto& kv : mod->functions) { |
54 | GlobalVar global_var = kv.first; |
55 | BaseFunc base_func = kv.second; |
56 | if (base_func->IsInstance<tir::PrimFuncNode>()) { |
57 | tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func); |
58 | Optional<Target> tgt_opt = pfunc->GetAttr<Target>(tvm::attr::kTarget); |
59 | ICHECK(tgt_opt) << "Target must be defined for all primfuncs." ; |
60 | Target tgt = tgt_opt.value(); |
61 | // Determine the size of input/output buffers |
62 | auto params = pfunc->params; |
63 | int64_t total_io_bytes = 0; |
64 | for (const auto& param : params) { |
65 | if (pfunc->buffer_map.find(param) != pfunc->buffer_map.end()) { |
66 | auto buffer = pfunc->buffer_map[param]; |
67 | total_io_bytes += GetMemorySizeBytes(buffer->shape, buffer->dtype); |
68 | } |
69 | } |
70 | const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment); |
71 | const auto& cs = CalculateConstantBytes(pfunc, constant_byte_alignment); |
72 | backend::FunctionInfo finfo{ |
73 | {{tgt, ws}}, {{tgt, total_io_bytes}}, {{tgt, cs}}, {{tgt, pfunc}}, {}}; |
74 | function_metadata.Set(global_var->name_hint, finfo); |
75 | } |
76 | } |
77 | return function_metadata; |
78 | } |
79 | |
80 | Map<String, backend::FunctionInfo> CreateFunctionMetadata(const IRModule& mod, |
81 | Integer workspace_byte_alignment, |
82 | Integer constant_byte_alignment) { |
83 | // First calculate the FunctionInfos from the buffers that are explicitly allocated |
84 | auto function_metadata = |
85 | CalculateFunctionInfos(mod, workspace_byte_alignment, constant_byte_alignment); |
86 | // Now adjust the FunctionInfo for the main func to also include PoolInfo allocations |
87 | // made by the USMP. |
88 | Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos = |
89 | mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs); |
90 | backend::FunctionInfo main_func_info = |
91 | function_metadata.Get(runtime::symbol::tvm_module_main).value(); |
92 | if (allocated_pool_infos) { |
93 | for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) { |
94 | for (const auto& tgt : allocated_pool_info->pool_info->targets) { |
95 | VLOG(1) << "USMP requires target " << tgt->ToDebugString() << " to have pool size " |
96 | << allocated_pool_info->allocated_size->value; |
97 | size_t size = allocated_pool_info->allocated_size->value; |
98 | if (allocated_pool_info->pool_info->IsInstance<ConstantPoolInfoNode>()) { |
99 | size += main_func_info->constant_sizes.count(tgt) |
100 | ? main_func_info->constant_sizes[tgt]->value |
101 | : 0; |
102 | main_func_info->constant_sizes.Set(tgt, size); |
103 | } else if (allocated_pool_info->pool_info->IsInstance<WorkspacePoolInfoNode>()) { |
104 | size += main_func_info->workspace_sizes.count(tgt) |
105 | ? main_func_info->workspace_sizes[tgt]->value |
106 | : 0; |
107 | main_func_info->workspace_sizes.Set(tgt, size); |
108 | } else { |
109 | LOG(FATAL) << "Unknown pool type: " << allocated_pool_info->pool_info->GetTypeKey(); |
110 | } |
111 | } |
112 | } |
113 | } |
114 | function_metadata.Set(runtime::symbol::tvm_module_main, main_func_info); |
115 | return function_metadata; |
116 | } |
117 | |
118 | TVM_REGISTER_GLOBAL("relay.backend.aot.CreateFunctionMetadata" ) |
119 | .set_body_typed(CreateFunctionMetadata); |
120 | |
121 | } // namespace aot |
122 | } // namespace backend |
123 | } // namespace relay |
124 | } // namespace tvm |
125 | |