1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file relay/backend/graph_plan_memory.cc
22 * \brief Memory index assignment pass for executing
23 * the program in the graph executor.
24 */
25#include <tvm/relay/analysis.h>
26#include <tvm/relay/attrs/annotation.h>
27#include <tvm/relay/attrs/call.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/transform.h>
31#include <tvm/runtime/container/array.h>
32#include <tvm/tir/op.h>
33
34#include "../../runtime/texture.h"
35#include "../../support/arena.h"
36#include "../op/annotation/annotation.h"
37#include "../op/call/call.h"
38#include "../op/memory/memory.h"
39#include "../transforms/device_aware_visitors.h"
40#include "./token_allocator.h"
41#include "./utils.h"
42
43namespace tvm {
44namespace relay {
45
46using TargetsMap = Map<Integer, Target>;
47using Texture2DShape = runtime::Texture2DShape<int64_t>;
48constexpr auto Is2DStorage = runtime::IsTextureStorage;
49
50using backend::StaticMemoryPlan;
51using backend::StorageInfo;
52using IntegerArray = Array<Integer>;
53
54class StorageAllocaBaseVisitor : public transform::DeviceAwareExprVisitor {
55 public:
56 StorageAllocaBaseVisitor() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
57
58 // run the visitor on a global function.
59 void Run(const Function& func) { VisitExpr(func); }
60
61 using transform::DeviceAwareExprVisitor::VisitExpr_;
62
63 void VisitExpr_(const ConstantNode* op) final { this->CreateToken(op, false); }
64
65 void VisitExpr_(const VarNode* op) final {
66 // Do nothing.
67 }
68
69 void DeviceAwareVisitExpr_(const FunctionNode* func_node) final {
70 if (function_nesting() > 1) {
71 // do not recurse into sub functions.
72 return;
73 }
74 if (func_node->HasNonzeroAttr(attr::kPrimitive)) {
75 // No storage needed for primitive functions.
76 return;
77 }
78 for (const auto& param : func_node->params) {
79 CreateToken(param.get(), /*can_realloc=*/false);
80 }
81 // Process the function body, and make sure all result tokens are considered 'alive'.
82 for (StorageToken* tok : GetToken(func_node->body)) {
83 tok->ref_counter += 1;
84 }
85 }
86
87 void VisitExpr_(const GlobalVarNode* op) final {
88 // Do nothing.
89 }
90
91 void VisitExpr_(const OpNode* op) final {
92 // Do nothing.
93 }
94
95 void VisitExpr_(const TupleNode* op) final {
96 std::vector<StorageToken*> fields;
97 for (Expr field : op->fields) {
98 auto tokens = GetToken(field);
99 fields.insert(fields.end(), tokens.begin(), tokens.end());
100 }
101 token_map_[op] = fields;
102 }
103
104 void VisitExpr_(const TupleGetItemNode* op) final {
105 const auto& tok = GetToken(op->tuple);
106 ICHECK_LT(static_cast<size_t>(op->index), tok.size());
107 token_map_[op] = {tok[op->index]};
108 }
109
110 void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; }
111
112 void PreVisitLetBinding_(const Var& var, const Expr& value) final {
113 token_map_[var.get()] = GetToken(value);
114 }
115
116 void PostVisitLet_(const LetNode* let_node) final {
117 token_map_[let_node] = GetToken(let_node->body);
118 }
119
120 protected:
121 /*! \brief internal token map */
122 std::unordered_map<const ExprNode*, std::vector<StorageToken*>> token_map_;
123 /*! \brief empty token map */
124 const std::vector<StorageToken*> no_tokens_;
125
126 /*!
127 * \brief Get the necessary token.
128 * \param expr The expression.
129 * \return The corresponding token.
130 */
131 const std::vector<StorageToken*>& GetToken(const Expr& expr) {
132 this->VisitExpr(expr);
133 // See through on_device calls.
134 Expr real_expr = IgnoreOnDevice(expr);
135
136 // Functions don't require data storage, represented by the empty token
137 if (real_expr->checked_type().as<FuncTypeNode>()) {
138 return no_tokens_;
139 }
140 this->VisitExpr(real_expr);
141 auto it = token_map_.find(real_expr.get());
142 ICHECK(it != token_map_.end()) << "Expression not found in storage map:" << std::endl
143 << PrettyPrint(real_expr);
144 return it->second;
145 }
146
147 /*!
148 * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding
149 * the result of evaluating \p op.
150 */
151 void CreateToken(const ExprNode* expr_node, bool can_realloc) {
152 return CreateTokenOnDevice(expr_node, GetVirtualDevice(GetRef<Expr>(expr_node)), can_realloc);
153 }
154
155 /*!
156 * \brief Allocates (or reuses if \p can_realloc is true) a storage token for holding
157 * the result of evaluating \p op on \p device_type.
158 */
159 virtual void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device,
160 bool can_realloc) = 0;
161};
162
163/*! \brief Associate storage with every expression without any concern for sharing. */
164class StorageAllocaInit : protected StorageAllocaBaseVisitor {
165 public:
166 explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {}
167
168 /*! \return The internal token map */
169 std::unordered_map<const ExprNode*, std::vector<StorageToken*>> GetInitTokenMap(
170 const Function& func) {
171 this->Run(func);
172 return std::move(token_map_);
173 }
174
175 protected:
176 using StorageAllocaBaseVisitor::VisitExpr_;
177
178 void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device,
179 bool can_realloc) override {
180 ICHECK(!token_map_.count(op));
181 std::vector<StorageToken*> tokens;
182 for (const auto& ttype : FlattenTupleType(op->checked_type())) {
183 auto* token = arena_->make<StorageToken>();
184 token->ttype = ttype;
185 token->virtual_device = virtual_device;
186 tokens.push_back(token);
187 }
188 token_map_[op] = tokens;
189 }
190
191 using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_;
192
193 void DeviceAwareVisitExpr_(const CallNode* call_node) final {
194 // create token for the call node.
195 CreateToken(call_node, true);
196
197 // for each input, visit argument token.
198 for (Expr arg : call_node->args) {
199 for (StorageToken* tok : GetToken(arg)) {
200 tok->ref_counter += 1;
201 }
202 }
203 }
204
205 private:
206 // allocator
207 support::Arena* arena_;
208 Map<Expr, Array<String>> node_storage_map_;
209};
210
211/*! \brief Associate storage with every expression, reusing storage where possible. */
212class StorageAllocator : public StorageAllocaBaseVisitor {
213 public:
214 StorageAllocator() = default;
215
216 /*!
217 * \return total number of bytes allocated
218 */
219 size_t TotalAllocBytes() const {
220 size_t total = 0;
221 for (const auto* p : data_) {
222 total += p->max_bytes;
223 }
224 return total;
225 }
226
227 // Run storage allocation for a function.
228 StaticMemoryPlan Plan(const Function& func) {
229 VLOG_CONTEXT << "StorageAllocator";
230 VLOG(1) << "planning:" << std::endl << PrettyPrint(func);
231 prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
232 this->Run(func);
233
234 // The value of smap contains two integer arrays where the first array
235 // contains the planned storage ids and the second holds the device types.
236 Map<Expr, backend::StorageInfo> smap;
237 int num_annotated_nodes = 0;
238 int num_nodes = 0;
239
240 for (const auto& kv : token_map_) {
241 std::vector<int64_t> storage_ids;
242 storage_ids.reserve(kv.second.size());
243 std::vector<VirtualDevice> virtual_devices;
244 virtual_devices.reserve(kv.second.size());
245 std::vector<int64_t> sid_sizes_byte;
246 sid_sizes_byte.reserve(kv.second.size());
247
248 for (StorageToken* tok : kv.second) {
249 VLOG(1) << "token: " << tok->ToString();
250 if (tok->is_valid()) {
251 num_annotated_nodes++;
252 }
253 num_nodes++;
254 storage_ids.push_back(tok->storage_id);
255 virtual_devices.push_back(tok->virtual_device);
256 sid_sizes_byte.push_back(allocator_.GetMemorySize(tok));
257 }
258 auto storage_info = backend::StorageInfo(std::move(storage_ids), std::move(virtual_devices),
259 std::move(sid_sizes_byte));
260 smap.Set(GetRef<Expr>(kv.first), storage_info);
261 }
262 // Either all or none of the nodes should be annotated.
263 VLOG(1) << "num annotated nodes / num_nodes: " << num_annotated_nodes << " / " << num_nodes
264 << std::endl;
265 if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
266 LOG(FATAL) << num_annotated_nodes << " out of " << num_nodes
267 << "expressions are assigned with virtual device types. Either all "
268 "or none of the expressions are expected to be annotated.";
269 }
270 return backend::StaticMemoryPlan(smap);
271 }
272
273 protected:
274 // override create token by getting token as prototype requirements.
275 void CreateTokenOnDevice(const ExprNode* op, const VirtualDevice& virtual_device,
276 bool can_realloc) final {
277 ICHECK(!token_map_.count(op));
278 auto it = prototype_.find(op);
279 ICHECK(it != prototype_.end());
280 std::vector<StorageToken*> tokens;
281
282 for (StorageToken* tok : it->second) {
283 ICHECK(tok->virtual_device == virtual_device);
284 if (can_realloc) {
285 tokens.push_back(allocator_.Request(tok));
286 } else {
287 // Allocate a new token,
288 StorageToken* allocated_tok = allocator_.Alloc(tok);
289 allocated_tok->virtual_device = tok->virtual_device;
290 // ensure it never get de-allocated.
291 allocated_tok->ref_counter += 1;
292 tokens.push_back(allocated_tok);
293 }
294 }
295 token_map_[op] = tokens;
296 }
297
298 // Mark op to reuse the input_token
299 // tie the two memories together
300 void ReuseInputToken(const ExprNode* op, StorageToken* input_token) {
301 ICHECK(!token_map_.count(op));
302 auto it = prototype_.find(op);
303 ICHECK(it != prototype_.end());
304 ICHECK_EQ(it->second.size(), 1U);
305 StorageToken* prototype = it->second[0];
306 // add the reference counter of the output
307 // so the input token can only be deleted after references
308 // to both are expired
309 input_token->ref_counter += prototype->ref_counter;
310 // reuse the input token
311 token_map_[op] = {input_token};
312 }
313
314 using StorageAllocaBaseVisitor::DeviceAwareVisitExpr_;
315
316 // The call map
317 void DeviceAwareVisitExpr_(const CallNode* call_node) final {
318 std::vector<StorageToken*> args;
319 // for each input, visit argument token.
320
321 for (const Expr& arg : call_node->args) {
322 // Note: GetToken skips GlobalVars and handles tuples properly, so we don't need to treat
323 // call_lowered specially.
324 for (StorageToken* tok : GetToken(arg)) {
325 args.push_back(tok);
326 }
327 }
328
329 // Under the flat-memory setting.
330 // we can force aliasing the input and output of reshape
331 // to make it an nop. Note that this is not true
332 // for non-flat memory case. Given the current graph plan memory
333 // only works for flat memory case, we will go with this choice
334 //
335 // TODO(tvm-team) Update checks of flat memory enablement when we support
336 // opaque-nd memory planning to skip this path.
337 // TODO(mbs): "reshape" cleanup.
338 CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node);
339 if (call_lowered_props.lowered_func.defined() && IsReshapeOnly(call_lowered_props)) {
340 ICHECK_EQ(call_lowered_props.arguments.size(), 1U);
341 ReuseInputToken(call_node, args[0]);
342 } else {
343 // create token for the call node.
344 CreateToken(call_node, true);
345 }
346
347 // check if there is orphaned output that can be released immediately.
348 for (StorageToken* tok : token_map_.at(call_node)) {
349 allocator_.CheckForRelease(tok);
350 }
351 for (StorageToken* tok : args) {
352 tok->ref_counter -= 1;
353 allocator_.CheckForRelease(tok);
354 }
355 }
356
357 class TokenAllocator {
358 public:
359 StorageToken* Alloc(StorageToken* proto) {
360 return Is2DStorage(proto) ? token_2d_.Alloc(proto, storage_ids_++)
361 : token_1d_.Alloc(proto, storage_ids_++);
362 }
363 StorageToken* Request(StorageToken* proto) {
364 StorageToken* token =
365 Is2DStorage(proto) ? token_2d_.Request(proto) : token_1d_.Request(proto);
366 return token ? token : this->Alloc(proto);
367 }
368 void CheckForRelease(StorageToken* tok) {
369 return Is2DStorage(tok) ? token_2d_.CheckForRelease(tok) : token_1d_.CheckForRelease(tok);
370 }
371
372 size_t GetMemorySize(StorageToken* tok) {
373 // TODO(amalyshe): figure out who requries sizes and for what
374 // size in case of texture is not enough - we can return any value if it
375 // assumed to be used for memory allocatoion or we can return real size
376 // if it is just for information
377 return Is2DStorage(tok) ? 0 : token_1d_.GetMemorySize(tok);
378 }
379 static bool Is2DStorage(StorageToken* tok) {
380 return relay::Is2DStorage(tok->virtual_device->memory_scope);
381 }
382
383 private:
384 int64_t storage_ids_{0};
385 TokenAllocator1D token_1d_;
386 TokenAllocator2D token_2d_;
387 };
388
389 private:
390 // allocator
391 support::Arena arena_;
392 // scale used for rough match
393 // size_t match_range_{16};
394 // free list of storage entry
395 std::multimap<size_t, StorageToken*> free_;
396 // all the storage resources available
397 std::vector<StorageToken*> data_;
398 /*! \brief internal prototype token map */
399 std::unordered_map<const ExprNode*, std::vector<StorageToken*>> prototype_;
400 /*! \brief token allocator for optimizing 1d and 2d token alloc requests */
401 TokenAllocator allocator_;
402};
403
404StaticMemoryPlan GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); }
405
406TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory);
407
408} // namespace relay
409} // namespace tvm
410