1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/backend/annotate_used_memory.cc
22 * \brief Analyzes the used memory at the callsite of primitive functions.
23 */
24
25#include <tvm/ir/module.h>
26#include <tvm/relay/attrs/memory.h>
27#include <tvm/relay/transform.h>
28
29#include <unordered_map>
30#include <unordered_set>
31
32#include "../transforms/device_aware_visitors.h"
33#include "../transforms/pass_utils.h"
34#include "./liveness_analysis.h"
35#include "./utils.h"
36
37namespace tvm {
38namespace relay {
39namespace backend {
40
41/*!
42 * \brief Annotates the minimum required memory of each primitive function callsite by analyzing
43 * the liveness of the input/output tensors at each function callsite and calculating the total
44 * amount of memory these tensors require. This is added as a "used_memory" annotation to the
45 * function in question as a list of the number of bytes for each callsite. In addition, the
46 * containing function is annotated with an "io_used_memory" annotation which refers to the total
47 * memory required for the IO tensors.
48 *
49 * Note: This pass does not support dynamic shapes, it is the users responsibility to check this
50 * pass isn't applied where dynamic shapes may be input.
51 *
52 * A simple example:
53 *
54 * Before:
55 * \verbatim
56 * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] {
57 * let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] {
58 * nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
59 * };
60 * let %x_1 = %x_0(%input);
61 * %x_1
62 * }
63 * \endverbatim
64 *
65 * After:
66 * \verbatim
67 * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] {
68 * let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=[32]) -> Tensor[(1, 2,
69 * 2, 4), int8] {
70 * nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0])
71 * };
72 * let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input);
73 * %x_1
74 * }
75 * \endverbatim
76 *
77 * Note that in the simple example above io_used_memory and used_memory are the same since there
78 * is only one primitive function.
79 */
80class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator {
81 public:
82 AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg,
83 const transform::LivenessAnalysis& lva)
84 : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {}
85
86 /*!
87 * \brief Mutates the input function. In addition, an "io_used_memory" annotation is
88 * added to the input function which refers to the total size required for the IO
89 * tensors.
90 */
91 Function operator()(const Function& func) {
92 uint64_t io_used_memory = 0;
93
94 // Inputs
95 for (const Var& param : func->params) {
96 Type type = param->checked_type();
97 ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
98 ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
99 io_used_memory += CalculateRelayExprSizeBytes(type);
100 }
101
102 // Outputs
103 Type type = func->body->checked_type();
104 ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
105 ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
106 io_used_memory += CalculateRelayExprSizeBytes(type);
107
108 Expr new_func_body = VisitExpr(func->body);
109 Function new_func = WithFields(func, func->params, new_func_body);
110 return WithAttr(std::move(new_func), "io_used_memory",
111 tvm::IntImm(tvm::DataType::UInt(64), io_used_memory));
112 }
113
114 /*!
115 * \brief Establish which let bindings have primitive function values.
116 */
117 std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) override {
118 if (const auto* func_node = value.as<FunctionNode>()) {
119 ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive))
120 << "Expect top-level functions to be primitive.";
121 let_bound_prim_func_.insert(var);
122 }
123 return DeviceAwareExprMutator::PreVisitLetBinding_(var, value);
124 }
125
126 /*!
127 * \brief Visit let nodes and perform one of two actions depending on their value:
128 *
129 * 1. CallNode - Calculate "used_memory" annotation value at the callsite of
130 * primitive functions.
131 *
132 * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the
133 * previous analysis at the callsite.
134 *
135 */
136 Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override {
137 Var let_var = post_let_node->var;
138 Expr let_value = IgnoreOnDevice(post_let_node->value);
139
140 if (let_value->IsInstance<CallNode>()) {
141 Call callsite = Downcast<Call>(let_value);
142 if (CheckPrimitiveFunctionCall(callsite)) {
143 Var call_op = Downcast<Var>(callsite->op);
144
145 // Find all the vars that are live at the callsite. This is done by merging the
146 // in and out varset's and then removing the var that references the primitive
147 // function itself since we don't want this included in the calculation.
148 const transform::ControlFlowGraph::NodePtr cfg_node =
149 control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node));
150 transform::VarSet live_tensors = liveness_.live_in.at(cfg_node);
151 const transform::VarSet& live_out = liveness_.live_out.at(cfg_node);
152 live_tensors.insert(live_out.begin(), live_out.end());
153 live_tensors.erase(call_op);
154
155 // Calculate size of live tensors and store to allow annotation when the function
156 // gets visited.
157 uint64_t used_memory = 0;
158 for (const auto& var : live_tensors) {
159 Type type = var->checked_type();
160 ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory.";
161 ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes.";
162 used_memory += CalculateRelayExprSizeBytes(type);
163 }
164 IntImm annotation(DataType::UInt(64), used_memory);
165 used_memory_annotations_[call_op].push_back(annotation);
166 }
167 } else if (let_value->IsInstance<FunctionNode>()) {
168 Function func = Downcast<Function>(let_value);
169 ICHECK(used_memory_annotations_.find(let_var) != used_memory_annotations_.end())
170 << "Could not find used_memory value for primitive function bound at "
171 << let_var->name_hint();
172 Array<IntImm> used_memory = used_memory_annotations_[let_var];
173 used_memory_annotations_.erase(let_var);
174
175 Function new_func = WithAttr(std::move(func), "used_memory",
176 Array<IntImm>(used_memory.rbegin(), used_memory.rend()));
177 return Let(let_var, new_func, post_let_node->body, post_let_node->span);
178 }
179
180 return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node);
181 }
182
183 private:
184 /*!
185 * \brief Check if a call is a primitive function callsite.
186 */
187 bool CheckPrimitiveFunctionCall(const Call& callsite) {
188 if (const auto* var_node = callsite->op.as<VarNode>()) {
189 Var var = GetRef<Var>(var_node);
190 if (let_bound_prim_func_.find(var) != let_bound_prim_func_.end()) {
191 return true;
192 }
193 }
194 return false;
195 }
196
197 /*! \brief Control flow graph representation of the main function. */
198 transform::ControlFlowGraph control_flow_graph_;
199 /*! \brief Liveness analysis of the main function. */
200 transform::LivenessAnalysis liveness_;
201 /*! \brief Var's that reference primitive functions. */
202 std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_prim_func_;
203 /*! \brief Stores the calculated uint64 used_memory values so they can be annotated on the
204 * relevant function. */
205 std::unordered_map<Var, Array<IntImm>, ObjectPtrHash, ObjectPtrEqual> used_memory_annotations_;
206};
207
208} // namespace backend
209
210namespace transform {
211
212Pass AnnotateUsedMemory() {
213 runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod,
214 PassContext ctx) {
215 GlobalVar gv = mod->GetGlobalVar("main");
216 Function main_func = Downcast<Function>(mod->Lookup("main"));
217
218 // Perform liveness analysis to determine what tensors are 'live' at each functions callsite.
219 support::Arena arena;
220 ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, main_func);
221 UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg);
222 LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def);
223
224 auto new_main_func = backend::AnnotateUsedMemoryMutator(mod, cfg, lva)(main_func);
225 if (!new_main_func.same_as(main_func)) {
226 mod->Update(gv, new_main_func);
227 }
228 return mod;
229 };
230 return CreateModulePass(pass_func, 0, "AnnotateUsedMemory", {"ToANormalForm", "InferType"});
231}
232
233TVM_REGISTER_GLOBAL("relay._transform.AnnotateUsedMemory").set_body_typed(AnnotateUsedMemory);
234
235} // namespace transform
236} // namespace relay
237} // namespace tvm
238