1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file relay/backend/graph_plan_memory.cc |
22 | * \brief Memory index assignment pass for executing |
23 | * the program in the graph executor. |
24 | */ |
25 | #include <tvm/relay/analysis.h> |
26 | #include <tvm/relay/attrs/annotation.h> |
27 | #include <tvm/relay/attrs/call.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/transform.h> |
31 | #include <tvm/runtime/container/array.h> |
32 | #include <tvm/tir/op.h> |
33 | |
34 | #include "../../runtime/texture.h" |
35 | #include "../../support/arena.h" |
36 | #include "../op/annotation/annotation.h" |
37 | #include "../op/call/call.h" |
38 | #include "../op/memory/memory.h" |
39 | #include "../transforms/device_aware_visitors.h" |
40 | #include "./token_allocator.h" |
41 | #include "./utils.h" |
42 | |
43 | namespace tvm { |
44 | namespace relay { |
45 | |
46 | using TargetsMap = Map<Integer, Target>; |
47 | using Texture2DShape = runtime::Texture2DShape<int64_t>; |
48 | constexpr auto Is2DStorage = runtime::IsTextureStorage; |
49 | |
50 | using backend::StaticMemoryPlan; |
51 | using backend::StorageInfo; |
52 | using IntegerArray = Array<Integer>; |
53 | |
54 | class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor { |
55 | public: |
56 | StorageAllocaBaseVisitor() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {} |
57 | |
58 | // run the visitor on a global function. |
59 | void Run(const Function& func) { VisitExpr(func); } |
60 | |
61 | using transform::DeviceAwareExprVisitor::VisitExpr_; |
62 | |
63 | void VisitExpr_(const ConstantNode* op) final { this->CreateToken(op, false); } |
64 | |
65 | void VisitExpr_(const VarNode* op) final { |
66 | // Do nothing. |
67 | } |
68 | |
69 | void DeviceAwareVisitExpr_(const FunctionNode* func_node) final { |
70 | if (function_nesting() > 1) { |
71 | // do not recurse into sub functions. |
72 | return; |
73 | } |
74 | if (func_node->HasNonzeroAttr(attr::kPrimitive)) { |
75 | // No storage needed for primitive functions. |
76 | return; |
77 | } |
78 | for (const auto& param : func_node->params) { |
79 | CreateToken(param.get(), /*can_realloc=*/false); |
80 | } |
81 | // Process the function body, and make sure all result tokens are considered 'alive'. |
82 | for (StorageToken* tok : GetToken(func_node->body)) { |
83 | tok->ref_counter += 1; |
84 | } |
85 | } |
86 | |
87 | void VisitExpr_(const GlobalVarNode* op) final { |
88 | // Do nothing. |
89 | } |
90 | |
91 | void VisitExpr_(const OpNode* op) final { |
92 | // Do nothing. |
93 | } |
94 | |
95 | void VisitExpr_(const TupleNode* op) final { |
96 | std::vector<StorageToken*> fields; |
97 | for (Expr field : op->fields) { |
98 | auto tokens = GetToken(field); |
99 | fields.insert(fields.end(), tokens.begin(), tokens.end()); |
100 | } |
101 | token_map_[op] = fields; |
102 | } |
103 | |
104 | void VisitExpr_(const TupleGetItemNode* op) final { |
105 | const auto& tok = GetToken(op->tuple); |
106 | ICHECK_LT(static_cast<size_t>(op->index), tok.size()); |
107 | token_map_[op] = {tok[op->index]}; |
108 | } |
109 | |
110 | void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported." ; } |
111 | |
112 | void PreVisitLetBinding_(const Var& var, const Expr& value) final { |
113 | token_map_[var.get()] = GetToken(value); |
114 | } |
115 | |
116 | void PostVisitLet_(const LetNode* let_node) final { |
117 | token_map_[let_node] = GetToken(let_node->body); |
118 | } |
119 | |
120 | protected: |
121 | /*! \brief internal token map */ |
122 | std::unordered_map<const ExprNode*, std::vector<StorageToken*>> token_map_; |
123 | /*! \brief empty token map */ |
124 | const std::vector<StorageToken*> no_tokens_; |
125 | |
126 | /*! |
127 | * \brief Get the necessary token. |
128 | * \param expr The expression. |
129 | * \return The corresponding token. |
130 | */ |
131 | const std::vector<StorageToken*>& GetToken(const Expr& expr) { |
132 | this->VisitExpr(expr); |
133 | // See through on_device calls. |
134 | Expr real_expr = IgnoreOnDevice(expr); |
135 | |
136 | // Functions don't require data storage, represented by the empty token |
137 | if (real_expr->checked_type().as<FuncTypeNode>()) { |
138 | return no_tokens_; |
139 | } |
140 | this->VisitExpr(real_expr); |
141 | auto it = token_map_.find(real_expr.get()); |
142 | ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl |
143 | << PrettyPrint(real_expr); |
144 | return it->second; |
145 | } |
146 | |
147 | /*! |
148 | * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding |
149 | * the result of evaluating \p op. |
150 | */ |
151 | void CreateToken(const ExprNode* expr_node, bool can_realloc) { |
152 | return CreateTokenOnDevice(expr_node, GetVirtualDevice(GetRef<Expr>(expr_node)), can_realloc); |
153 | } |
154 | |
155 | /*! |
156 | * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding |
157 | * the result of evaluating \p op on \p device_type. |
158 | */ |
159 | virtual void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device, |
160 | bool can_realloc) = 0; |
161 | }; |
162 | |
163 | /*! \brief Associate storage with every expression without any concern for sharing. */ |
164 | class StorageAllocaInit : protected StorageAllocaBaseVisitor { |
165 | public: |
166 | explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {} |
167 | |
168 | /*! \return The internal token map */ |
169 | std::unordered_map<const ExprNode*, std::vector<StorageToken*>> GetInitTokenMap( |
170 | const Function& func) { |
171 | this->Run(func); |
172 | return std::move(token_map_); |
173 | } |
174 | |
175 | protected: |
176 | using StorageAllocaBaseVisitor::VisitExpr_; |
177 | |
178 | void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device, |
179 | bool can_realloc) override { |
180 | ICHECK(!token_map_.count(op)); |
181 | std::vector<StorageToken*> tokens; |
182 | for (const auto& ttype : FlattenTupleType(op->checked_type())) { |
183 | auto* token = arena_->make<StorageToken>(); |
184 | token->ttype = ttype; |
185 | token->virtual_device = virtual_device; |
186 | tokens.push_back(token); |
187 | } |
188 | token_map_[op] = tokens; |
189 | } |
190 | |
191 | using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; |
192 | |
193 | void DeviceAwareVisitExpr_(const CallNode* call_node) final { |
194 | // create token for the call node. |
195 | CreateToken(call_node, true); |
196 | |
197 | // for each input, visit argument token. |
198 | for (Expr arg : call_node->args) { |
199 | for (StorageToken* tok : GetToken(arg)) { |
200 | tok->ref_counter += 1; |
201 | } |
202 | } |
203 | } |
204 | |
205 | private: |
206 | // allocator |
207 | support::Arena* arena_; |
208 | Map<Expr, Array<String>> node_storage_map_; |
209 | }; |
210 | |
211 | /*! \brief Associate storage with every expression, reusing storage where possible. */ |
212 | class StorageAllocator : public StorageAllocaBaseVisitor { |
213 | public: |
214 | StorageAllocator() = default; |
215 | |
216 | /*! |
217 | * \return total number of bytes allocated |
218 | */ |
219 | size_t TotalAllocBytes() const { |
220 | size_t total = 0; |
221 | for (const auto* p : data_) { |
222 | total += p->max_bytes; |
223 | } |
224 | return total; |
225 | } |
226 | |
227 | // Run storage allocation for a function. |
228 | StaticMemoryPlan Plan(const Function& func) { |
229 | VLOG_CONTEXT << "StorageAllocator" ; |
230 | VLOG(1) << "planning:" << std::endl << PrettyPrint(func); |
231 | prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); |
232 | this->Run(func); |
233 | |
234 | // The value of smap contains two integer arrays where the first array |
235 | // contains the planned storage ids and the second holds the device types. |
236 | Map<Expr, backend::StorageInfo> smap; |
237 | int num_annotated_nodes = 0; |
238 | int num_nodes = 0; |
239 | |
240 | for (const auto& kv : token_map_) { |
241 | std::vector<int64_t> storage_ids; |
242 | storage_ids.reserve(kv.second.size()); |
243 | std::vector<VirtualDevice> virtual_devices; |
244 | virtual_devices.reserve(kv.second.size()); |
245 | std::vector<int64_t> sid_sizes_byte; |
246 | sid_sizes_byte.reserve(kv.second.size()); |
247 | |
248 | for (StorageToken* tok : kv.second) { |
249 | VLOG(1) << "token: " << tok->ToString(); |
250 | if (tok->is_valid()) { |
251 | num_annotated_nodes++; |
252 | } |
253 | num_nodes++; |
254 | storage_ids.push_back(tok->storage_id); |
255 | virtual_devices.push_back(tok->virtual_device); |
256 | sid_sizes_byte.push_back(allocator_.GetMemorySize(tok)); |
257 | } |
258 | auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(virtual_devices), |
259 | std::move(sid_sizes_byte)); |
260 | smap.Set(GetRef<Expr>(kv.first), storage_info); |
261 | } |
262 | // Either all or none of the nodes should be annotated. |
263 | VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes |
264 | << std::endl; |
265 | if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { |
266 | LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes |
267 | << "expressions are assigned with virtual device types. Either all " |
268 | "or none of the expressions are expected to be annotated." ; |
269 | } |
270 | return backend::StaticMemoryPlan(smap); |
271 | } |
272 | |
273 | protected: |
274 | // override create token by getting token as prototype requirements. |
275 | void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device, |
276 | bool can_realloc) final { |
277 | ICHECK(!token_map_.count(op)); |
278 | auto it = prototype_.find(op); |
279 | ICHECK(it != prototype_.end()); |
280 | std::vector<StorageToken*> tokens; |
281 | |
282 | for (StorageToken* tok : it->second) { |
283 | ICHECK(tok->virtual_device == virtual_device); |
284 | if (can_realloc) { |
285 | tokens.push_back(allocator_.Request(tok)); |
286 | } else { |
287 | // Allocate a new token, |
288 | StorageToken* allocated_tok = allocator_.Alloc(tok); |
289 | allocated_tok->virtual_device = tok->virtual_device; |
290 | // ensure it never get de-allocated. |
291 | allocated_tok->ref_counter += 1; |
292 | tokens.push_back(allocated_tok); |
293 | } |
294 | } |
295 | token_map_[op] = tokens; |
296 | } |
297 | |
298 | // Mark op to reuse the input_token |
299 | // tie the two memories together |
300 | void ReuseInputToken(const ExprNode* op, StorageToken* input_token) { |
301 | ICHECK(!token_map_.count(op)); |
302 | auto it = prototype_.find(op); |
303 | ICHECK(it != prototype_.end()); |
304 | ICHECK_EQ(it->second.size(), 1U); |
305 | StorageToken* prototype = it->second[0]; |
306 | // add the reference counter of the output |
307 | // so the input token can only be deleted after references |
308 | // to both are expired |
309 | input_token->ref_counter += prototype->ref_counter; |
310 | // reuse the input token |
311 | token_map_[op] = {input_token}; |
312 | } |
313 | |
314 | using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_; |
315 | |
316 | // The call map |
317 | void DeviceAwareVisitExpr_(const CallNode* call_node) final { |
318 | std::vector<StorageToken*> args; |
319 | // for each input, visit argument token. |
320 | |
321 | for (const Expr& arg : call_node->args) { |
322 | // Note: GetToken skips GlobalVars and handles tuples properly, so we don't need to treat |
323 | // call_lowered specially. |
324 | for (StorageToken* tok : GetToken(arg)) { |
325 | args.push_back(tok); |
326 | } |
327 | } |
328 | |
329 | // Under the flat-memory setting. |
330 | // we can force aliasing the input and output of reshape |
331 | // to make it an nop. Note that this is not true |
332 | // for non-flat memory case. Given the current graph plan memory |
333 | // only works for flat memory case, we will go with this choice |
334 | // |
335 | // TODO(tvm-team) Update checks of flat memory enablement when we support |
336 | // opaque-nd memory planning to skip this path. |
337 | // TODO(mbs): "reshape" cleanup. |
338 | CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); |
339 | if (call_lowered_props.lowered_func.defined() && IsReshapeOnly(call_lowered_props)) { |
340 | ICHECK_EQ(call_lowered_props.arguments.size(), 1U); |
341 | ReuseInputToken(call_node, args[0]); |
342 | } else { |
343 | // create token for the call node. |
344 | CreateToken(call_node, true); |
345 | } |
346 | |
347 | // check if there is orphaned output that can be released immediately. |
348 | for (StorageToken* tok : token_map_.at(call_node)) { |
349 | allocator_.CheckForRelease(tok); |
350 | } |
351 | for (StorageToken* tok : args) { |
352 | tok->ref_counter -= 1; |
353 | allocator_.CheckForRelease(tok); |
354 | } |
355 | } |
356 | |
357 | class TokenAllocator { |
358 | public: |
359 | StorageToken* Alloc(StorageToken* proto) { |
360 | return Is2DStorage(proto) ? token_2d_.Alloc(proto, storage_ids_++) |
361 | : token_1d_.Alloc(proto, storage_ids_++); |
362 | } |
363 | StorageToken* Request(StorageToken* proto) { |
364 | StorageToken* token = |
365 | Is2DStorage(proto) ? token_2d_.Request(proto) : token_1d_.Request(proto); |
366 | return token ? token : this->Alloc(proto); |
367 | } |
368 | void CheckForRelease(StorageToken* tok) { |
369 | return Is2DStorage(tok) ? token_2d_.CheckForRelease(tok) : token_1d_.CheckForRelease(tok); |
370 | } |
371 | |
372 | size_t GetMemorySize(StorageToken* tok) { |
373 | // TODO(amalyshe): figure out who requries sizes and for what |
374 | // size in case of texture is not enough - we can return any value if it |
375 | // assumed to be used for memory allocatoion or we can return real size |
376 | // if it is just for information |
377 | return Is2DStorage(tok) ? 0 : token_1d_.GetMemorySize(tok); |
378 | } |
379 | static bool Is2DStorage(StorageToken* tok) { |
380 | return relay::Is2DStorage(tok->virtual_device->memory_scope); |
381 | } |
382 | |
383 | private: |
384 | int64_t storage_ids_{0}; |
385 | TokenAllocator1D token_1d_; |
386 | TokenAllocator2D token_2d_; |
387 | }; |
388 | |
389 | private: |
390 | // allocator |
391 | support::Arena arena_; |
392 | // scale used for rough match |
393 | // size_t match_range_{16}; |
394 | // free list of storage entry |
395 | std::multimap<size_t, StorageToken*> free_; |
396 | // all the storage resources available |
397 | std::vector<StorageToken*> data_; |
398 | /*! \brief internal prototype token map */ |
399 | std::unordered_map<const ExprNode*, std::vector<StorageToken*>> prototype_; |
400 | /*! \brief token allocator for optimizing 1d and 2d token alloc requests */ |
401 | TokenAllocator allocator_; |
402 | }; |
403 | |
404 | StaticMemoryPlan GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } |
405 | |
406 | TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory" ).set_body_typed(GraphPlanMemory); |
407 | |
408 | } // namespace relay |
409 | } // namespace tvm |
410 | |