1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/backend/annotate_used_memory.cc |
22 | * \brief Analyzes the used memory at the callsite of primitive functions. |
23 | */ |
24 | |
25 | #include <tvm/ir/module.h> |
26 | #include <tvm/relay/attrs/memory.h> |
27 | #include <tvm/relay/transform.h> |
28 | |
29 | #include <unordered_map> |
30 | #include <unordered_set> |
31 | |
32 | #include "../transforms/device_aware_visitors.h" |
33 | #include "../transforms/pass_utils.h" |
34 | #include "./liveness_analysis.h" |
35 | #include "./utils.h" |
36 | |
37 | namespace tvm { |
38 | namespace relay { |
39 | namespace backend { |
40 | |
41 | /*! |
42 | * \brief Annotates the minimum required memory of each primitive function callsite by analyzing |
43 | * the liveness of the input/output tensors at each function callsite and calculating the total |
44 | * amount of memory these tensors require. This is added as a "used_memory" annotation to the |
45 | * function in question as a list of the number of bytes for each callsite. In addition, the |
46 | * containing function is annotated with an "io_used_memory" annotation which refers to the total |
47 | * memory required for the IO tensors. |
48 | * |
49 | * Note: This pass does not support dynamic shapes, it is the users responsibility to check this |
50 | * pass isn't applied where dynamic shapes may be input. |
51 | * |
52 | * A simple example: |
53 | * |
54 | * Before: |
55 | * \verbatim |
56 | * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] { |
57 | * let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] { |
58 | * nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) |
59 | * }; |
60 | * let %x_1 = %x_0(%input); |
61 | * %x_1 |
62 | * } |
63 | * \endverbatim |
64 | * |
65 | * After: |
66 | * \verbatim |
67 | * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] { |
68 | * let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=[32]) -> Tensor[(1, 2, |
69 | * 2, 4), int8] { |
70 | * nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) |
71 | * }; |
72 | * let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input); |
73 | * %x_1 |
74 | * } |
75 | * \endverbatim |
76 | * |
77 | * Note that in the simple example above io_used_memory and used_memory are the same since there |
78 | * is only one primitive function. |
79 | */ |
80 | class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { |
81 | public: |
82 | AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, |
83 | const transform::LivenessAnalysis& lva) |
84 | : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} |
85 | |
86 | /*! |
87 | * \brief Mutates the input function. In addition, an "io_used_memory" annotation is |
88 | * added to the input function which refers to the total size required for the IO |
89 | * tensors. |
90 | */ |
91 | Function operator()(const Function& func) { |
92 | uint64_t io_used_memory = 0; |
93 | |
94 | // Inputs |
95 | for (const Var& param : func->params) { |
96 | Type type = param->checked_type(); |
97 | ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory." ; |
98 | ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes." ; |
99 | io_used_memory += CalculateRelayExprSizeBytes(type); |
100 | } |
101 | |
102 | // Outputs |
103 | Type type = func->body->checked_type(); |
104 | ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory." ; |
105 | ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes." ; |
106 | io_used_memory += CalculateRelayExprSizeBytes(type); |
107 | |
108 | Expr new_func_body = VisitExpr(func->body); |
109 | Function new_func = WithFields(func, func->params, new_func_body); |
110 | return WithAttr(std::move(new_func), "io_used_memory" , |
111 | tvm::IntImm(tvm::DataType::UInt(64), io_used_memory)); |
112 | } |
113 | |
114 | /*! |
115 | * \brief Establish which let bindings have primitive function values. |
116 | */ |
117 | std::pair<Var, Expr> PreVisitLetBinding_(const Var& var, const Expr& value) override { |
118 | if (const auto* func_node = value.as<FunctionNode>()) { |
119 | ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive)) |
120 | << "Expect top-level functions to be primitive." ; |
121 | let_bound_prim_func_.insert(var); |
122 | } |
123 | return DeviceAwareExprMutator::PreVisitLetBinding_(var, value); |
124 | } |
125 | |
126 | /*! |
127 | * \brief Visit let nodes and perform one of two actions depending on their value: |
128 | * |
129 | * 1. CallNode - Calculate "used_memory" annotation value at the callsite of |
130 | * primitive functions. |
131 | * |
132 | * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the |
133 | * previous analysis at the callsite. |
134 | * |
135 | */ |
136 | Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override { |
137 | Var let_var = post_let_node->var; |
138 | Expr let_value = IgnoreOnDevice(post_let_node->value); |
139 | |
140 | if (let_value->IsInstance<CallNode>()) { |
141 | Call callsite = Downcast<Call>(let_value); |
142 | if (CheckPrimitiveFunctionCall(callsite)) { |
143 | Var call_op = Downcast<Var>(callsite->op); |
144 | |
145 | // Find all the vars that are live at the callsite. This is done by merging the |
146 | // in and out varset's and then removing the var that references the primitive |
147 | // function itself since we don't want this included in the calculation. |
148 | const transform::ControlFlowGraph::NodePtr cfg_node = |
149 | control_flow_graph_.let_map.at(GetRef<Let>(pre_let_node)); |
150 | transform::VarSet live_tensors = liveness_.live_in.at(cfg_node); |
151 | const transform::VarSet& live_out = liveness_.live_out.at(cfg_node); |
152 | live_tensors.insert(live_out.begin(), live_out.end()); |
153 | live_tensors.erase(call_op); |
154 | |
155 | // Calculate size of live tensors and store to allow annotation when the function |
156 | // gets visited. |
157 | uint64_t used_memory = 0; |
158 | for (const auto& var : live_tensors) { |
159 | Type type = var->checked_type(); |
160 | ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory." ; |
161 | ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes." ; |
162 | used_memory += CalculateRelayExprSizeBytes(type); |
163 | } |
164 | IntImm annotation(DataType::UInt(64), used_memory); |
165 | used_memory_annotations_[call_op].push_back(annotation); |
166 | } |
167 | } else if (let_value->IsInstance<FunctionNode>()) { |
168 | Function func = Downcast<Function>(let_value); |
169 | ICHECK(used_memory_annotations_.find(let_var) != used_memory_annotations_.end()) |
170 | << "Could not find used_memory value for primitive function bound at " |
171 | << let_var->name_hint(); |
172 | Array<IntImm> used_memory = used_memory_annotations_[let_var]; |
173 | used_memory_annotations_.erase(let_var); |
174 | |
175 | Function new_func = WithAttr(std::move(func), "used_memory" , |
176 | Array<IntImm>(used_memory.rbegin(), used_memory.rend())); |
177 | return Let(let_var, new_func, post_let_node->body, post_let_node->span); |
178 | } |
179 | |
180 | return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); |
181 | } |
182 | |
183 | private: |
184 | /*! |
185 | * \brief Check if a call is a primitive function callsite. |
186 | */ |
187 | bool CheckPrimitiveFunctionCall(const Call& callsite) { |
188 | if (const auto* var_node = callsite->op.as<VarNode>()) { |
189 | Var var = GetRef<Var>(var_node); |
190 | if (let_bound_prim_func_.find(var) != let_bound_prim_func_.end()) { |
191 | return true; |
192 | } |
193 | } |
194 | return false; |
195 | } |
196 | |
197 | /*! \brief Control flow graph representation of the main function. */ |
198 | transform::ControlFlowGraph control_flow_graph_; |
199 | /*! \brief Liveness analysis of the main function. */ |
200 | transform::LivenessAnalysis liveness_; |
201 | /*! \brief Var's that reference primitive functions. */ |
202 | std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> let_bound_prim_func_; |
203 | /*! \brief Stores the calculated uint64 used_memory values so they can be annotated on the |
204 | * relevant function. */ |
205 | std::unordered_map<Var, Array<IntImm>, ObjectPtrHash, ObjectPtrEqual> used_memory_annotations_; |
206 | }; |
207 | |
208 | } // namespace backend |
209 | |
210 | namespace transform { |
211 | |
212 | Pass AnnotateUsedMemory() { |
213 | runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func = [=](IRModule mod, |
214 | PassContext ctx) { |
215 | GlobalVar gv = mod->GetGlobalVar("main" ); |
216 | Function main_func = Downcast<Function>(mod->Lookup("main" )); |
217 | |
218 | // Perform liveness analysis to determine what tensors are 'live' at each functions callsite. |
219 | support::Arena arena; |
220 | ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, main_func); |
221 | UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); |
222 | LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); |
223 | |
224 | auto new_main_func = backend::AnnotateUsedMemoryMutator(mod, cfg, lva)(main_func); |
225 | if (!new_main_func.same_as(main_func)) { |
226 | mod->Update(gv, new_main_func); |
227 | } |
228 | return mod; |
229 | }; |
230 | return CreateModulePass(pass_func, 0, "AnnotateUsedMemory" , {"ToANormalForm" , "InferType" }); |
231 | } |
232 | |
233 | TVM_REGISTER_GLOBAL("relay._transform.AnnotateUsedMemory" ).set_body_typed(AnnotateUsedMemory); |
234 | |
235 | } // namespace transform |
236 | } // namespace relay |
237 | } // namespace tvm |
238 | |