1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file annotate_texture_storage.cc
22 * \brief Collection of target specific relay passes which
23 * storage scope related information.
24 *
25 * - CollectStorageInfo returns a mapping from relay expr
26 * to a map of storage scopes for each call argument.
27 * These scopes are used during memory planning as well
28 * as downstream when doing codegen and in the graph runtime when doing runtime dataspace
29 * allocations.
30 *
31 * - AnnotateMemoryScope calls *target.CollectStorageInfo for all target been represented
32 * in the graph and rewrites graph modifying or inserting of VirtualDevice with required
33 * memory_scope collected from the CollectStorageInfo
34 */
35
36#include <tvm/relay/attrs/nn.h>
37#include <tvm/relay/expr.h>
38#include <tvm/relay/expr_functor.h>
39#include <tvm/relay/transform.h>
40#include <tvm/tir/expr.h>
41
42#include <memory>
43#include <unordered_map>
44#include <unordered_set>
45
46#include "../op/memory/device_copy.h"
47#include "../op/memory/memory.h"
48#include "../transforms/device_aware_visitors.h"
49
50namespace tvm {
51namespace relay {
52namespace {
53
54/**
55 * @brief Analyzes the graph and returns mapping of expressions vs desired memory scope
56 */
57class StorageInfo : private transform::DeviceAwareExprVisitor {
58 public:
59 StorageInfo() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
60
61 static Map<Expr, Map<Expr, Array<String>>> GetStorageMap(const Expr& expr) {
62 StorageInfo storage_info;
63 storage_info.VisitExpr(expr);
64 storage_info.LegalizeProducerStorage();
65 Map<Expr, Map<Expr, Array<String>>> storage_map = storage_info.accept_textures_;
66 for (auto& kv : storage_info.storage_scope_) {
67 std::vector<String> storage_scopes;
68 std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(storage_scopes));
69 Map<Expr, Array<String>> ent;
70 ent.Set(Expr(), Array<String>{storage_scopes});
71 storage_map.Set(GetRef<Expr>(kv.first), ent);
72 }
73
74 // Filling the input arguments by "global" scope to handle PlanDevice algo which propagates
75 // virtual devices from outputs to inputs. At the same time outputs must be unconstrained
76 // to avoid useless device_copy
77 for (const auto& cs : storage_info.consumer_storage_scopes_) {
78 // we have record in consumers that mean that potentially consumer
79 // dealt with textures anyhow, it's safe to mark this expr as global scope
80 // even without verification of the consumer's outputs scope
81 if (storage_info.CanConsumeTextures(cs.second) &&
82 storage_map.find(GetRef<Expr>(cs.first)) == storage_map.end()) {
83 Map<Expr, Array<String>> ent;
84 ent.Set(Expr(), Array<String>{"global"});
85 storage_map.Set(GetRef<Expr>(cs.first), ent);
86 }
87 }
88
89 // initial algo assumes mapping of outputs of the expr that is not enough, need to update
90 // VirtualDevice for function variables to get proper codegen. Adding vars to storage_map
91 for (const auto& a : storage_info.args_to_vars_) {
92 if (storage_map.count(a.first)) {
93 for (const auto& v : a.second) {
94 if (storage_info.buffers_params.find(v) != storage_info.buffers_params.end()) {
95 Map<Expr, Array<String>> ent;
96 ent.Set(Expr(), Array<String>{"global"});
97 storage_map.Set(v, ent);
98 } else {
99 storage_map.Set(v, storage_map[a.first]);
100 if (storage_map[a.first][Expr()][0] == "global" &&
101 storage_info.accept_textures_.count(v)) {
102 Map<Expr, Array<String>> ent;
103 ent.Set(Expr(), storage_info.accept_textures_[v][Expr()]);
104 storage_map.Set(v, ent);
105 for (const auto& calls : storage_info.accept_textures_[v]) {
106 if (calls.first != Expr()) {
107 if (storage_map.count(a.first)) {
108 Map<Expr, Array<String>> ent_call = storage_map[a.first];
109 ent_call.Set(calls.first, calls.second);
110 storage_map.Set(a.first, ent_call);
111 } else {
112 Map<Expr, Array<String>> ent_call;
113 ent_call.Set(calls.first, calls.second);
114 storage_map.Set(a.first, ent_call);
115 }
116 }
117 }
118 }
119 }
120 }
121 }
122 }
123 return storage_map;
124 }
125
126 private:
127 using transform::DeviceAwareExprVisitor::VisitExpr_;
128
129 void Visit(const Expr& expr) {
130 // Pre-order traversal to enable upward propagation
131 // of consumer storage scopes to producers when desirable.
132 if (const auto* fn = expr.as<FunctionNode>()) {
133 this->VisitExpr(fn->body);
134 for (const auto& param : fn->params) {
135 this->VisitExpr(param);
136 }
137 } else {
138 this->VisitExpr(expr);
139 }
140 }
141
142 void VisitExpr_(const VarNode* vn) final { ApplyConsumerScopeToInputs(vn); }
143
144 void VisitExpr_(const ConstantNode* cn) final { ApplyConsumerScopeToInputs(cn); }
145
146 void DeviceAwareVisitExpr_(const CallNode* call) final {
147 // Check the contents of this primitive function
148 if (const auto* fn = call->op.as<FunctionNode>()) {
149 if (fn->HasNonzeroAttr(attr::kPrimitive)) {
150 primitive_supports_texture_ = false;
151 Visit(call->op);
152 if (primitive_supports_texture_) {
153 if (call->checked_type().as<TensorTypeNode>()) {
154 std::string scope = "global.texture";
155 if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) {
156 scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call)));
157 }
158 storage_scope_[call].push_back(scope);
159 } else {
160 const auto* tuple_type = call->type_as<TupleTypeNode>();
161 ICHECK(tuple_type);
162 // TODO(csullivan): Add support for mixed output storage scope.
163 // In current adreno storage planner all outputs of a
164 // primitive function are assumed to be of the same storage
165 // type. This should be easy to extend in the future.
166 for (size_t i = 0; i < tuple_type->fields.size(); i++) {
167 storage_scope_[call].push_back("global.texture");
168 }
169 }
170 const int weights_pos = 1;
171 for (size_t i = 0; i < fn->params.size(); i++) {
172 args_to_vars_[call->args[i]].push_back(fn->params[i]);
173 // adding info about arguments if they can be converted to texture
174 for (const auto& ttype : FlattenTupleType(fn->params[i]->checked_type())) {
175 std::string scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call)));
176 if (expr_attrib.as<Conv2DAttrs>() || expr_attrib.as<Conv2DWinogradAttrs>()) {
177 if ((i == weights_pos) && !ttype->dtype.is_float16() &&
178 CanUseBuffers(call->args[i], ttype->shape, fn->attrs)) {
179 buffers_params.insert(fn->params[i]);
180 buffers_args.insert(call->args[i]);
181 scope = "global";
182 }
183 }
184 if (scope.find("global.texture") != std::string::npos) {
185 if (accept_textures_.count(fn->params[i])) {
186 Map<Expr, Array<String>> ent = accept_textures_[fn->params[i]];
187 ent.Set(GetRef<Expr>(call), Array<String>{scope});
188 ent.Set(Expr(), Array<String>{scope});
189 accept_textures_.Set(fn->params[i], ent);
190 } else {
191 Map<Expr, Array<String>> ent;
192 ent.Set(GetRef<Expr>(call), Array<String>{scope});
193 ent.Set(Expr(), Array<String>{scope});
194 accept_textures_.Set(fn->params[i], ent);
195 }
196 }
197 }
198 }
199 }
200 // Add consumer storage scope information for call arguments
201 for (auto& arg : call->args) {
202 if (storage_scope_.count(call)) {
203 ICHECK(!HasMixedStorageOutputs(call))
204 << "Mixed output storage scopes are not currently supported";
205 consumer_storage_scopes_[arg.operator->()].push_back("global.texture");
206 } else {
207 consumer_storage_scopes_[arg.operator->()].push_back("global");
208 }
209 }
210 }
211 }
212 if (!primitive_supports_texture_) {
213 expr_attrib = call->attrs;
214 primitive_supports_texture_ = SupportsTextureStorage(call);
215 }
216
217 for (auto& arg : call->args) {
218 if (buffers_args.find(arg) == buffers_args.end()) {
219 Visit(arg);
220 }
221 }
222 // We have all callees filled into storage_scope_ if they support textures
223 // We need to verify if this call expects texture and if it does not, remove from
224 // storage_scope_ since initially storage_scope_ is filled only based on knowledge
225 // that function able to work with textures, but not necessary that this texture is
226 // expected by function callee
227 for (auto& arg : call->args) {
228 if (consumer_storage_scopes_.count(arg.operator->()) &&
229 GetConsumerScope(consumer_storage_scopes_[arg.operator->()]) != "global.texture") {
230 storage_scope_.erase(arg.operator->());
231 }
232 }
233 }
234
235 /**
236 * Defines the name of the memory scope which can fit the tensor of required shape
237 *
238 * The scope stands for "global" if tensor does not satisfy current flattening rules for textures
239 * (texture currently has to be 5d tensors with value eq 4 in the last dimension)
240 *
241 * The packing layout inside the texture scope (the part after the dash) is defined
242 * during the shape itself. Hardware can have limitations on the texture spatial dimensions
243 * we must not exceed these sizes. In addition to the fitting of h/w limitation we want to
244 * get balanced packing where final spatial sizes of textures will not be too different
245 * @param shape shape to be analyzed
246 * @param vd VirtualDevice for the tensors determined of memory scope
247 * @return string representing memory scope either "global" or "global.texture-layout"
248 */
249 std::string Scope(Array<PrimExpr> shape, const VirtualDevice& vd) {
250 // currently we support only textures been made from 5d tensors
251 // 5d requirement is not limitation of textures in general, it is limitation how
252 // we are representing memory scopes/layout and flattening of textures in tir
253 if (vd != VirtualDevice::FullyUnconstrained() && shape.size() == 5 &&
254 shape[4].as<IntImmNode>()->value == 4) {
255 std::map<int, std::string> diffs;
256 int limit =
257 vd->target->GetAttr<Integer>("texture_spatial_limit").value_or(Integer(16384))->value;
258 int a0 = shape[0].as<IntImmNode>()->value;
259 int a1 = shape[1].as<IntImmNode>()->value;
260 int a2 = shape[2].as<IntImmNode>()->value;
261 int a3 = shape[3].as<IntImmNode>()->value;
262
263 int d3l = a0 * a1 * a2;
264 int d3r = a3;
265 int diff3 = d3l > d3r ? d3l - d3r : d3r - d3l;
266 if (d3l < limit && d3r < limit) diffs[diff3] = "";
267
268 int d2l = a0 * a1;
269 int d2r = a2 * a3;
270 int diff2 = d2l > d2r ? d2l - d2r : d2r - d2l;
271 if (d2l < limit && d2r < limit) diffs[diff2] = "nhwc";
272
273 int d1l = a0;
274 int d1r = a1 * a2 * a3;
275 int diff1 = d1l > d1r ? d1l - d1r : d1r - d1l;
276 if (d1l < limit && d1r < limit) diffs[diff1] = "weight";
277 if (!diffs.empty()) {
278 std::string scope = "global.texture";
279 if (!diffs.begin()->second.empty()) {
280 scope += ("-" + diffs.begin()->second);
281 }
282 return scope;
283 }
284 }
285 return "global";
286 }
287
288 void ApplyConsumerScopeToInputs(const ExprNode* expr) {
289 std::string scope;
290 auto consumer_scopes_it = consumer_storage_scopes_.find(expr);
291 if (consumer_scopes_it != consumer_storage_scopes_.end()) {
292 std::string consumer_scope = GetConsumerScope(consumer_scopes_it->second);
293 ICHECK(!storage_scope_.count(expr))
294 << "Already propagated consumer scopes to input: " << GetRef<Expr>(expr);
295
296 bool expr_is_rgba_vectorizable = false;
297 if (const auto* ttype = expr->checked_type().as<TensorTypeNode>()) {
298 scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(expr)));
299 if (scope != "global") {
300 auto inner_dim = ttype->shape.back().as<IntImmNode>();
301 if (inner_dim && inner_dim->value == 4) {
302 expr_is_rgba_vectorizable = true;
303 }
304 }
305 }
306
307 // Only propagate texture scope from consumers to input expr if
308 // the input shape of the input expr is rgba vectorizable.
309 if (consumer_scope.find("global.texture") != std::string::npos) {
310 if (expr_is_rgba_vectorizable) {
311 storage_scope_[expr].push_back(scope);
312 }
313 } else {
314 storage_scope_[expr].push_back(consumer_scope);
315 }
316 }
317 }
318
319 void LegalizeProducerStorage() {
320 for (auto& kv : consumer_storage_scopes_) {
321 const ExprNode* producer = kv.first;
322 std::string legal_scope = GetConsumerScope(kv.second);
323 if (storage_scope_.count(producer)) {
324 ICHECK(!HasMixedStorageOutputs(producer))
325 << "Mixed output storage scopes are not currently supported";
326 if (storage_scope_[producer][0].find(legal_scope) == std::string::npos) {
327 for (size_t i = 0; i < storage_scope_[producer].size(); i++) {
328 // Only support uniform storage scope across all outputs for now
329 storage_scope_[producer][i] = legal_scope;
330 }
331 }
332 }
333 }
334 }
335
336 std::string GetConsumerScope(const std::vector<std::string>& consumer_scopes) const {
337 if (!consumer_scopes.size()) {
338 return "global";
339 }
340 std::string texture_tag = "global.texture";
341 for (auto& consumer_scope : consumer_scopes) {
342 if (consumer_scope.find(texture_tag) == std::string::npos) {
343 return "global";
344 }
345 }
346 return texture_tag;
347 }
348
349 bool CanConsumeTextures(const std::vector<std::string>& consumer_scopes) const {
350 std::string texture_tag = "global.texture";
351 for (auto& consumer_scope : consumer_scopes) {
352 if (consumer_scope.find(texture_tag) == 0) {
353 return true;
354 }
355 }
356 return false;
357 }
358
359 bool HasMixedStorageOutputs(const ExprNode* expr) {
360 if (storage_scope_.count(expr)) {
361 std::string ref_scope = storage_scope_[expr][0];
362 for (std::string& scope : storage_scope_[expr]) {
363 if (scope != ref_scope) {
364 return true;
365 }
366 }
367 }
368 return false;
369 }
370
371 bool SupportsTextureStorage(const CallNode* call) const {
372 bool supports_texture_storage = false;
373 // we need to verify only entry functions since one of entry op defines main schedule
374 for (const auto& arg : call->args) {
375 if (!arg.as<VarNode>()) {
376 return false;
377 }
378 }
379 if (auto attrs = call->attrs.as<Conv2DAttrs>()) {
380 if (attrs->data_layout == "NCHW4c" && attrs->kernel_layout == "OIHW4o") {
381 supports_texture_storage = true;
382 } else if (attrs->data_layout == "NHWC4c" &&
383 (attrs->kernel_layout == "HWOI4o" || attrs->kernel_layout == "HWIO4o" ||
384 attrs->kernel_layout == "OIHW4o")) {
385 supports_texture_storage = true;
386 }
387 } else if (auto attrs = call->attrs.as<Conv2DWinogradAttrs>()) {
388 if ((attrs->data_layout == "NCHW4c" || attrs->data_layout == "NHWC4c") &&
389 (attrs->kernel_layout == "OIHW4o" || attrs->kernel_layout == "HWIO4o")) {
390 supports_texture_storage = true;
391 }
392 } else if (auto attrs = call->attrs.as<GlobalPool2DAttrs>()) {
393 if (attrs->layout == "NCHW4c") {
394 supports_texture_storage = true;
395 }
396 } else if (auto attrs = call->attrs.as<MaxPool2DAttrs>()) {
397 if (attrs->layout == "NCHW4c") {
398 supports_texture_storage = true;
399 }
400 } else if (auto attrs = call->attrs.as<AvgPool2DAttrs>()) {
401 if (attrs->layout == "NCHW4c") {
402 supports_texture_storage = true;
403 }
404 } else if (const OpNode* opnode = call->op.as<OpNode>()) {
405 auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
406 auto pattern = fpattern[GetRef<Op>(opnode)];
407 if (pattern <= kCommReduce) {
408 if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) {
409 if (ttype->shape.size() == 5) {
410 supports_texture_storage = true;
411 }
412 }
413 }
414 }
415
416 return supports_texture_storage;
417 }
418
419 bool CanUseBuffers(const Expr param, const Array<PrimExpr> shape,
420 const tvm::DictAttrs param_attrs) const {
421 bool use_buffer = false;
422 if (param.as<ConstantNode>() && shape.size() == 5) {
423 auto kernel_layout = param_attrs.GetAttr<String>("kernel_layout");
424 if (kernel_layout == "HWOI4o" || kernel_layout == "HWIO4o") {
425 int a0 = shape[0].as<IntImmNode>()->value;
426 int a1 = shape[1].as<IntImmNode>()->value;
427 if (a0 != 1 && a1 != 1) {
428 use_buffer = true;
429 }
430 } else if (kernel_layout == "OIHW4o") {
431 int a2 = shape[2].as<IntImmNode>()->value;
432 int a3 = shape[3].as<IntImmNode>()->value;
433 if (a2 != 1 && a3 != 1) {
434 use_buffer = true;
435 }
436 }
437 }
438 return use_buffer;
439 }
440
441 /*! \brief Temporary state for marking whether a visited function
442 * primitive supports texture storage scope */
443 bool primitive_supports_texture_ = false;
444 /*! \brief expr storage scope mapping for each output */
445 std::unordered_map<const ExprNode*, std::vector<std::string>> storage_scope_;
446 /*! \brief output storage scopes used by consumers of expr key */
447 std::unordered_map<const ExprNode*, std::vector<std::string>> consumer_storage_scopes_;
448 /*! \brief mapping of arguments to call to function variables*/
449 std::unordered_map<Expr, std::vector<Var>, ObjectPtrHash, ObjectPtrEqual> args_to_vars_;
450 /*! \brief mapping of arguments that can be converted to texture*/
451 Map<Expr, Map<Expr, Array<String>>> accept_textures_;
452 /*! \brief main attribute for expression*/
453 tvm::Attrs expr_attrib;
454 /*! \brief parameters that filter out from storage_map to use buffers*/
455 std::unordered_set<Expr, ObjectPtrHash> buffers_params;
456 /*! \brief arguments in expression that will use buffers*/
457 std::unordered_set<Expr, ObjectPtrHash> buffers_args;
458};
459
460} // namespace
461
462/**
463 * @brief rewrite of virtual devices, memory_scope part for expressions defined
464 * by the StorageInfo analysis pass
465 *
466 * Currently this workflow supports analysis and rewriting of VirtualDevice for
467 * Constants and function Variables
468 */
469class RewriteVDStorageScopes : public transform::DeviceAwareExprMutator {
470 using VarMap = std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>;
471
472 public:
473 using transform::DeviceAwareExprMutator::VisitExpr_;
474
475 explicit RewriteVDStorageScopes(const Map<Expr, Map<Expr, Array<String>>>& storage_scope)
476 : transform::DeviceAwareExprMutator(Optional<IRModule>()), storage_scope_(storage_scope) {}
477
478 Function Rewrite(const Expr& expr) { return Downcast<Function>(Mutate(expr)); }
479
480 Expr VisitExpr_(const VarNode* vn) final {
481 if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end() &&
482 storage_scope_[GetRef<Expr>(vn)].find(Expr()) != storage_scope_[GetRef<Expr>(vn)].end() &&
483 storage_scope_[GetRef<Expr>(vn)][Expr()][0] != "global") {
484 Var c = Var(vn->vid, vn->type_annotation, vn->span);
485 auto virtual_device = GetVirtualDevice(GetRef<Expr>(vn));
486 c->virtual_device_ =
487 VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
488 virtual_device->target, storage_scope_[GetRef<Expr>(vn)][Expr()][0]);
489 return std::move(c);
490 }
491 return GetRef<Var>(vn);
492 }
493
494 Expr VisitExpr_(const ConstantNode* vn) final {
495 if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end() &&
496 storage_scope_[GetRef<Expr>(vn)].find(Expr()) != storage_scope_[GetRef<Expr>(vn)].end()) {
497 Expr c = Constant(vn->data, vn->span);
498 auto virtual_device = GetVirtualDevice(GetRef<Expr>(vn));
499 c = OnDevice(
500 c,
501 VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
502 virtual_device->target, storage_scope_[GetRef<Expr>(vn)][Expr()][0]),
503 true);
504 return c;
505 }
506 return GetRef<Constant>(vn);
507 }
508
509 Expr DeviceAwareVisitExpr_(const CallNode* call_node) final {
510 // we need to duplicate ExprMutator::VisitExpr_ to correct argument scopes and
511 // put device_copy
512 auto new_op = this->Mutate(call_node->op);
513
514 tvm::Array<Type> ty_args;
515 ty_args.reserve(call_node->type_args.size());
516
517 for (auto ty_arg : call_node->type_args) {
518 auto new_ty_arg = this->VisitType(ty_arg);
519 ty_args.push_back(new_ty_arg);
520 }
521
522 tvm::Array<Expr> call_args;
523 call_args.reserve(call_node->args.size());
524 for (auto arg : call_node->args) {
525 auto new_arg = this->Mutate(arg);
526 // verification if we need to put device_copy
527 if (storage_scope_.count(arg) && storage_scope_[arg].count(GetRef<Expr>(call_node))) {
528 auto virtual_device = GetVirtualDevice(GetRef<Expr>(call_node));
529 VirtualDevice virtual_device_from =
530 VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
531 virtual_device->target, virtual_device->memory_scope);
532 VirtualDevice virtual_device_to =
533 VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
534 virtual_device->target, storage_scope_[arg][GetRef<Expr>(call_node)][0]);
535 new_arg = DeviceCopy(new_arg, virtual_device_from, virtual_device_to);
536 new_arg = OnDevice(
537 new_arg,
538 VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
539 virtual_device->target, storage_scope_[arg][GetRef<Expr>(call_node)][0]),
540 true);
541 }
542 call_args.push_back(new_arg);
543 }
544
545 auto new_call = WithFields(GetRef<Call>(call_node), new_op, call_args, {}, ty_args);
546
547 auto virtual_device = GetVirtualDevice(GetRef<Expr>(call_node));
548 std::string memory_scope = "";
549 if (storage_scope_.find(GetRef<Expr>(call_node)) != storage_scope_.end() &&
550 storage_scope_[GetRef<Expr>(call_node)].find(Expr()) !=
551 storage_scope_[GetRef<Expr>(call_node)].end()) {
552 memory_scope = storage_scope_[GetRef<Expr>(call_node)][Expr()][0];
553 } else if (virtual_device->memory_scope != "") {
554 memory_scope = virtual_device->memory_scope;
555 } else if (!call_node->op.as<FunctionNode>()) {
556 memory_scope = "";
557 }
558 if (!memory_scope.empty()) {
559 new_call =
560 OnDevice(new_call,
561 VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id,
562 virtual_device->target, memory_scope),
563 true);
564 }
565 return std::move(new_call);
566 }
567
568 private:
569 Map<Expr, Map<Expr, Array<String>>> storage_scope_;
570 VarMap new_vars_;
571 Array<String> current_function_scope_;
572};
573
574Map<Expr, Map<Expr, Array<String>>> CollectTextureStorage(const Expr& expr) {
575 return StorageInfo::GetStorageMap(expr);
576}
577
578/**
579 * @brief Collects all target devices participated in graph
580 */
581class CollectVirtualDevices : public transform::DeviceAwareExprVisitor {
582 public:
583 CollectVirtualDevices() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {}
584 /**
585 * @brief Get all unique device elements from target of each VirtualDevice
586 *
587 * @param expr - IR
588 * @return set of devices
589 */
590 std::set<std::string> GetDevices(const Expr& expr) {
591 this->Run(expr);
592 return std::move(devices_);
593 }
594
595 void Visit(const Expr& expr) {
596 // Pre-order traversal to enable upward propagation
597 // of consumer storage scopes to producers when desirable.
598 if (const auto* fn = expr.as<FunctionNode>()) {
599 this->VisitExpr(fn->body);
600 for (const auto& param : fn->params) {
601 this->VisitExpr(param);
602 }
603 } else {
604 this->VisitExpr(expr);
605 }
606 }
607
608 void DeviceAwareVisitExpr_(const CallNode* call) final {
609 auto vd = GetVirtualDevice(GetRef<Expr>(call));
610 if (vd != VirtualDevice::FullyUnconstrained()) {
611 if (Optional<String> t_device = vd->target->GetAttr<String>("device")) {
612 devices_.insert(vd->target->kind->name + "." + t_device.value());
613 }
614 }
615 for (auto& arg : call->args) {
616 Visit(arg);
617 }
618 }
619
620 void Run(const Expr& expr) { VisitExpr(expr); }
621 using transform::DeviceAwareExprVisitor::VisitExpr_;
622 std::set<std::string> devices_;
623};
624
625/*!
626 * \brief Collect the target specific tensor storage info for each expression's output.
627 * \param expr The expression.
628 * \return The device based storage mapping.
629 */
630Map<Expr, Map<Expr, Array<String>>> CollectStorageInfo(const Expr& expr) {
631 std::set<std::string> device_types = CollectVirtualDevices().GetDevices(expr);
632 // TODO(amalyshe): current approach collects all targets withing graph and call the only
633 // function corresponding to all these targets in alphabetic order
634 // this will work reliable only for case of only one device and should be redesigned
635 // to handle common case
636 std::string ftarget_prefix = "relay.backend";
637 for (auto& dev_id : device_types) {
638 ftarget_prefix += (std::string(".") + dev_id);
639 }
640
641 Map<Expr, Map<Expr, Array<String>>> storage_info = {};
642 if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectStorageInfo")) {
643 storage_info = (*f)(expr);
644 }
645 return storage_info;
646}
647
648Expr AnnotateMemoryScopeExpr(const Expr& expr, const IRModule& mod) {
649 auto storage_scope = CollectStorageInfo(expr);
650 if (storage_scope.size()) {
651 return RewriteVDStorageScopes(storage_scope).Rewrite(expr);
652 } else {
653 return expr;
654 }
655}
656
657namespace transform {
658tvm::transform::Pass AnnotateMemoryScope() {
659 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
660 [](Function f, IRModule m, PassContext pc) {
661 return Downcast<Function>(AnnotateMemoryScopeExpr(f, m));
662 };
663 return CreateFunctionPass(pass_func, 2, "AnnotateMemoryScope", {});
664}
665} // namespace transform
666
667TVM_REGISTER_GLOBAL("relay.backend.opencl.adreno._CollectStorageInfo")
668 .set_body_typed(CollectTextureStorage);
669
670} // namespace relay
671} // namespace tvm
672