1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/aot/create_function_metadata.cc
22 * \brief Create FunctionInfo metadata from a lowered TIR module.
23 */
24#include "./create_function_metadata.h"
25
26#include <tvm/ir/expr.h>
27#include <tvm/ir/module.h>
28#include <tvm/runtime/container/array.h>
29#include <tvm/runtime/container/map.h>
30#include <tvm/runtime/container/string.h>
31#include <tvm/runtime/data_type.h>
32#include <tvm/runtime/module.h>
33#include <tvm/target/target_kind.h>
34#include <tvm/tir/analysis.h>
35#include <tvm/tir/function.h>
36#include <tvm/tir/op.h>
37#include <tvm/tir/usmp/utils.h>
38
39#include "../utils.h"
40
41namespace tvm {
42namespace relay {
43namespace backend {
44namespace aot {
45
46/*!
47 * \brief Calculate FunctionInfo for all the PrimFuncs in a module.
48 */
49Map<String, backend::FunctionInfo> CalculateFunctionInfos(const IRModule& mod,
50 Integer workspace_byte_alignment,
51 Integer constant_byte_alignment) {
52 Map<String, backend::FunctionInfo> function_metadata;
53 for (const auto& kv : mod->functions) {
54 GlobalVar global_var = kv.first;
55 BaseFunc base_func = kv.second;
56 if (base_func->IsInstance<tir::PrimFuncNode>()) {
57 tir::PrimFunc pfunc = Downcast<tir::PrimFunc>(base_func);
58 Optional<Target> tgt_opt = pfunc->GetAttr<Target>(tvm::attr::kTarget);
59 ICHECK(tgt_opt) << "Target must be defined for all primfuncs.";
60 Target tgt = tgt_opt.value();
61 // Determine the size of input/output buffers
62 auto params = pfunc->params;
63 int64_t total_io_bytes = 0;
64 for (const auto& param : params) {
65 if (pfunc->buffer_map.find(param) != pfunc->buffer_map.end()) {
66 auto buffer = pfunc->buffer_map[param];
67 total_io_bytes += GetMemorySizeBytes(buffer->shape, buffer->dtype);
68 }
69 }
70 const auto& ws = CalculateWorkspaceBytes(pfunc, workspace_byte_alignment);
71 const auto& cs = CalculateConstantBytes(pfunc, constant_byte_alignment);
72 backend::FunctionInfo finfo{
73 {{tgt, ws}}, {{tgt, total_io_bytes}}, {{tgt, cs}}, {{tgt, pfunc}}, {}};
74 function_metadata.Set(global_var->name_hint, finfo);
75 }
76 }
77 return function_metadata;
78}
79
80Map<String, backend::FunctionInfo> CreateFunctionMetadata(const IRModule& mod,
81 Integer workspace_byte_alignment,
82 Integer constant_byte_alignment) {
83 // First calculate the FunctionInfos from the buffers that are explicitly allocated
84 auto function_metadata =
85 CalculateFunctionInfos(mod, workspace_byte_alignment, constant_byte_alignment);
86 // Now adjust the FunctionInfo for the main func to also include PoolInfo allocations
87 // made by the USMP.
88 Optional<Array<tir::usmp::AllocatedPoolInfo>> allocated_pool_infos =
89 mod->GetAttr<Array<tir::usmp::AllocatedPoolInfo>>(tvm::attr::kPoolArgs);
90 backend::FunctionInfo main_func_info =
91 function_metadata.Get(runtime::symbol::tvm_module_main).value();
92 if (allocated_pool_infos) {
93 for (const tir::usmp::AllocatedPoolInfo& allocated_pool_info : allocated_pool_infos.value()) {
94 for (const auto& tgt : allocated_pool_info->pool_info->targets) {
95 VLOG(1) << "USMP requires target " << tgt->ToDebugString() << " to have pool size "
96 << allocated_pool_info->allocated_size->value;
97 size_t size = allocated_pool_info->allocated_size->value;
98 if (allocated_pool_info->pool_info->IsInstance<ConstantPoolInfoNode>()) {
99 size += main_func_info->constant_sizes.count(tgt)
100 ? main_func_info->constant_sizes[tgt]->value
101 : 0;
102 main_func_info->constant_sizes.Set(tgt, size);
103 } else if (allocated_pool_info->pool_info->IsInstance<WorkspacePoolInfoNode>()) {
104 size += main_func_info->workspace_sizes.count(tgt)
105 ? main_func_info->workspace_sizes[tgt]->value
106 : 0;
107 main_func_info->workspace_sizes.Set(tgt, size);
108 } else {
109 LOG(FATAL) << "Unknown pool type: " << allocated_pool_info->pool_info->GetTypeKey();
110 }
111 }
112 }
113 }
114 function_metadata.Set(runtime::symbol::tvm_module_main, main_func_info);
115 return function_metadata;
116}
117
118TVM_REGISTER_GLOBAL("relay.backend.aot.CreateFunctionMetadata")
119 .set_body_typed(CreateFunctionMetadata);
120
121} // namespace aot
122} // namespace backend
123} // namespace relay
124} // namespace tvm
125