1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file annotate_texture_storage.cc |
22 | * \brief Collection of target specific relay passes which |
23 | * storage scope related information. |
24 | * |
25 | * - CollectStorageInfo returns a mapping from relay expr |
26 | * to a map of storage scopes for each call argument. |
27 | * These scopes are used during memory planning as well |
28 | * as downstream when doing codegen and in the graph runtime when doing runtime dataspace |
29 | * allocations. |
30 | * |
31 | * - AnnotateMemoryScope calls *target.CollectStorageInfo for all target been represented |
32 | * in the graph and rewrites graph modifying or inserting of VirtualDevice with required |
33 | * memory_scope collected from the CollectStorageInfo |
34 | */ |
35 | |
36 | #include <tvm/relay/attrs/nn.h> |
37 | #include <tvm/relay/expr.h> |
38 | #include <tvm/relay/expr_functor.h> |
39 | #include <tvm/relay/transform.h> |
40 | #include <tvm/tir/expr.h> |
41 | |
42 | #include <memory> |
43 | #include <unordered_map> |
44 | #include <unordered_set> |
45 | |
46 | #include "../op/memory/device_copy.h" |
47 | #include "../op/memory/memory.h" |
48 | #include "../transforms/device_aware_visitors.h" |
49 | |
50 | namespace tvm { |
51 | namespace relay { |
52 | namespace { |
53 | |
54 | /** |
55 | * @brief Analyzes the graph and returns mapping of expressions vs desired memory scope |
56 | */ |
57 | class StorageInfo : private transform::DeviceAwareExprVisitor { |
58 | public: |
59 | StorageInfo() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {} |
60 | |
61 | static Map<Expr, Map<Expr, Array<String>>> GetStorageMap(const Expr& expr) { |
62 | StorageInfo storage_info; |
63 | storage_info.VisitExpr(expr); |
64 | storage_info.LegalizeProducerStorage(); |
65 | Map<Expr, Map<Expr, Array<String>>> storage_map = storage_info.accept_textures_; |
66 | for (auto& kv : storage_info.storage_scope_) { |
67 | std::vector<String> storage_scopes; |
68 | std::copy(kv.second.begin(), kv.second.end(), std::back_inserter(storage_scopes)); |
69 | Map<Expr, Array<String>> ent; |
70 | ent.Set(Expr(), Array<String>{storage_scopes}); |
71 | storage_map.Set(GetRef<Expr>(kv.first), ent); |
72 | } |
73 | |
74 | // Filling the input arguments by "global" scope to handle PlanDevice algo which propagates |
75 | // virtual devices from outputs to inputs. At the same time outputs must be unconstrained |
76 | // to avoid useless device_copy |
77 | for (const auto& cs : storage_info.consumer_storage_scopes_) { |
78 | // we have record in consumers that mean that potentially consumer |
79 | // dealt with textures anyhow, it's safe to mark this expr as global scope |
80 | // even without verification of the consumer's outputs scope |
81 | if (storage_info.CanConsumeTextures(cs.second) && |
82 | storage_map.find(GetRef<Expr>(cs.first)) == storage_map.end()) { |
83 | Map<Expr, Array<String>> ent; |
84 | ent.Set(Expr(), Array<String>{"global" }); |
85 | storage_map.Set(GetRef<Expr>(cs.first), ent); |
86 | } |
87 | } |
88 | |
89 | // initial algo assumes mapping of outputs of the expr that is not enough, need to update |
90 | // VirtualDevice for function variables to get proper codegen. Adding vars to storage_map |
91 | for (const auto& a : storage_info.args_to_vars_) { |
92 | if (storage_map.count(a.first)) { |
93 | for (const auto& v : a.second) { |
94 | if (storage_info.buffers_params.find(v) != storage_info.buffers_params.end()) { |
95 | Map<Expr, Array<String>> ent; |
96 | ent.Set(Expr(), Array<String>{"global" }); |
97 | storage_map.Set(v, ent); |
98 | } else { |
99 | storage_map.Set(v, storage_map[a.first]); |
100 | if (storage_map[a.first][Expr()][0] == "global" && |
101 | storage_info.accept_textures_.count(v)) { |
102 | Map<Expr, Array<String>> ent; |
103 | ent.Set(Expr(), storage_info.accept_textures_[v][Expr()]); |
104 | storage_map.Set(v, ent); |
105 | for (const auto& calls : storage_info.accept_textures_[v]) { |
106 | if (calls.first != Expr()) { |
107 | if (storage_map.count(a.first)) { |
108 | Map<Expr, Array<String>> ent_call = storage_map[a.first]; |
109 | ent_call.Set(calls.first, calls.second); |
110 | storage_map.Set(a.first, ent_call); |
111 | } else { |
112 | Map<Expr, Array<String>> ent_call; |
113 | ent_call.Set(calls.first, calls.second); |
114 | storage_map.Set(a.first, ent_call); |
115 | } |
116 | } |
117 | } |
118 | } |
119 | } |
120 | } |
121 | } |
122 | } |
123 | return storage_map; |
124 | } |
125 | |
126 | private: |
127 | using transform::DeviceAwareExprVisitor::VisitExpr_; |
128 | |
129 | void Visit(const Expr& expr) { |
130 | // Pre-order traversal to enable upward propagation |
131 | // of consumer storage scopes to producers when desirable. |
132 | if (const auto* fn = expr.as<FunctionNode>()) { |
133 | this->VisitExpr(fn->body); |
134 | for (const auto& param : fn->params) { |
135 | this->VisitExpr(param); |
136 | } |
137 | } else { |
138 | this->VisitExpr(expr); |
139 | } |
140 | } |
141 | |
142 | void VisitExpr_(const VarNode* vn) final { ApplyConsumerScopeToInputs(vn); } |
143 | |
144 | void VisitExpr_(const ConstantNode* cn) final { ApplyConsumerScopeToInputs(cn); } |
145 | |
146 | void DeviceAwareVisitExpr_(const CallNode* call) final { |
147 | // Check the contents of this primitive function |
148 | if (const auto* fn = call->op.as<FunctionNode>()) { |
149 | if (fn->HasNonzeroAttr(attr::kPrimitive)) { |
150 | primitive_supports_texture_ = false; |
151 | Visit(call->op); |
152 | if (primitive_supports_texture_) { |
153 | if (call->checked_type().as<TensorTypeNode>()) { |
154 | std::string scope = "global.texture" ; |
155 | if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) { |
156 | scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call))); |
157 | } |
158 | storage_scope_[call].push_back(scope); |
159 | } else { |
160 | const auto* tuple_type = call->type_as<TupleTypeNode>(); |
161 | ICHECK(tuple_type); |
162 | // TODO(csullivan): Add support for mixed output storage scope. |
163 | // In current adreno storage planner all outputs of a |
164 | // primitive function are assumed to be of the same storage |
165 | // type. This should be easy to extend in the future. |
166 | for (size_t i = 0; i < tuple_type->fields.size(); i++) { |
167 | storage_scope_[call].push_back("global.texture" ); |
168 | } |
169 | } |
170 | const int weights_pos = 1; |
171 | for (size_t i = 0; i < fn->params.size(); i++) { |
172 | args_to_vars_[call->args[i]].push_back(fn->params[i]); |
173 | // adding info about arguments if they can be converted to texture |
174 | for (const auto& ttype : FlattenTupleType(fn->params[i]->checked_type())) { |
175 | std::string scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(call))); |
176 | if (expr_attrib.as<Conv2DAttrs>() || expr_attrib.as<Conv2DWinogradAttrs>()) { |
177 | if ((i == weights_pos) && !ttype->dtype.is_float16() && |
178 | CanUseBuffers(call->args[i], ttype->shape, fn->attrs)) { |
179 | buffers_params.insert(fn->params[i]); |
180 | buffers_args.insert(call->args[i]); |
181 | scope = "global" ; |
182 | } |
183 | } |
184 | if (scope.find("global.texture" ) != std::string::npos) { |
185 | if (accept_textures_.count(fn->params[i])) { |
186 | Map<Expr, Array<String>> ent = accept_textures_[fn->params[i]]; |
187 | ent.Set(GetRef<Expr>(call), Array<String>{scope}); |
188 | ent.Set(Expr(), Array<String>{scope}); |
189 | accept_textures_.Set(fn->params[i], ent); |
190 | } else { |
191 | Map<Expr, Array<String>> ent; |
192 | ent.Set(GetRef<Expr>(call), Array<String>{scope}); |
193 | ent.Set(Expr(), Array<String>{scope}); |
194 | accept_textures_.Set(fn->params[i], ent); |
195 | } |
196 | } |
197 | } |
198 | } |
199 | } |
200 | // Add consumer storage scope information for call arguments |
201 | for (auto& arg : call->args) { |
202 | if (storage_scope_.count(call)) { |
203 | ICHECK(!HasMixedStorageOutputs(call)) |
204 | << "Mixed output storage scopes are not currently supported" ; |
205 | consumer_storage_scopes_[arg.operator->()].push_back("global.texture" ); |
206 | } else { |
207 | consumer_storage_scopes_[arg.operator->()].push_back("global" ); |
208 | } |
209 | } |
210 | } |
211 | } |
212 | if (!primitive_supports_texture_) { |
213 | expr_attrib = call->attrs; |
214 | primitive_supports_texture_ = SupportsTextureStorage(call); |
215 | } |
216 | |
217 | for (auto& arg : call->args) { |
218 | if (buffers_args.find(arg) == buffers_args.end()) { |
219 | Visit(arg); |
220 | } |
221 | } |
222 | // We have all callees filled into storage_scope_ if they support textures |
223 | // We need to verify if this call expects texture and if it does not, remove from |
224 | // storage_scope_ since initially storage_scope_ is filled only based on knowledge |
225 | // that function able to work with textures, but not necessary that this texture is |
226 | // expected by function callee |
227 | for (auto& arg : call->args) { |
228 | if (consumer_storage_scopes_.count(arg.operator->()) && |
229 | GetConsumerScope(consumer_storage_scopes_[arg.operator->()]) != "global.texture" ) { |
230 | storage_scope_.erase(arg.operator->()); |
231 | } |
232 | } |
233 | } |
234 | |
235 | /** |
236 | * Defines the name of the memory scope which can fit the tensor of required shape |
237 | * |
238 | * The scope stands for "global" if tensor does not satisfy current flattening rules for textures |
239 | * (texture currently has to be 5d tensors with value eq 4 in the last dimension) |
240 | * |
241 | * The packing layout inside the texture scope (the part after the dash) is defined |
242 | * during the shape itself. Hardware can have limitations on the texture spatial dimensions |
243 | * we must not exceed these sizes. In addition to the fitting of h/w limitation we want to |
244 | * get balanced packing where final spatial sizes of textures will not be too different |
245 | * @param shape shape to be analyzed |
246 | * @param vd VirtualDevice for the tensors determined of memory scope |
247 | * @return string representing memory scope either "global" or "global.texture-layout" |
248 | */ |
249 | std::string Scope(Array<PrimExpr> shape, const VirtualDevice& vd) { |
250 | // currently we support only textures been made from 5d tensors |
251 | // 5d requirement is not limitation of textures in general, it is limitation how |
252 | // we are representing memory scopes/layout and flattening of textures in tir |
253 | if (vd != VirtualDevice::FullyUnconstrained() && shape.size() == 5 && |
254 | shape[4].as<IntImmNode>()->value == 4) { |
255 | std::map<int, std::string> diffs; |
256 | int limit = |
257 | vd->target->GetAttr<Integer>("texture_spatial_limit" ).value_or(Integer(16384))->value; |
258 | int a0 = shape[0].as<IntImmNode>()->value; |
259 | int a1 = shape[1].as<IntImmNode>()->value; |
260 | int a2 = shape[2].as<IntImmNode>()->value; |
261 | int a3 = shape[3].as<IntImmNode>()->value; |
262 | |
263 | int d3l = a0 * a1 * a2; |
264 | int d3r = a3; |
265 | int diff3 = d3l > d3r ? d3l - d3r : d3r - d3l; |
266 | if (d3l < limit && d3r < limit) diffs[diff3] = "" ; |
267 | |
268 | int d2l = a0 * a1; |
269 | int d2r = a2 * a3; |
270 | int diff2 = d2l > d2r ? d2l - d2r : d2r - d2l; |
271 | if (d2l < limit && d2r < limit) diffs[diff2] = "nhwc" ; |
272 | |
273 | int d1l = a0; |
274 | int d1r = a1 * a2 * a3; |
275 | int diff1 = d1l > d1r ? d1l - d1r : d1r - d1l; |
276 | if (d1l < limit && d1r < limit) diffs[diff1] = "weight" ; |
277 | if (!diffs.empty()) { |
278 | std::string scope = "global.texture" ; |
279 | if (!diffs.begin()->second.empty()) { |
280 | scope += ("-" + diffs.begin()->second); |
281 | } |
282 | return scope; |
283 | } |
284 | } |
285 | return "global" ; |
286 | } |
287 | |
288 | void ApplyConsumerScopeToInputs(const ExprNode* expr) { |
289 | std::string scope; |
290 | auto consumer_scopes_it = consumer_storage_scopes_.find(expr); |
291 | if (consumer_scopes_it != consumer_storage_scopes_.end()) { |
292 | std::string consumer_scope = GetConsumerScope(consumer_scopes_it->second); |
293 | ICHECK(!storage_scope_.count(expr)) |
294 | << "Already propagated consumer scopes to input: " << GetRef<Expr>(expr); |
295 | |
296 | bool expr_is_rgba_vectorizable = false; |
297 | if (const auto* ttype = expr->checked_type().as<TensorTypeNode>()) { |
298 | scope = Scope(ttype->shape, GetVirtualDevice(GetRef<Expr>(expr))); |
299 | if (scope != "global" ) { |
300 | auto inner_dim = ttype->shape.back().as<IntImmNode>(); |
301 | if (inner_dim && inner_dim->value == 4) { |
302 | expr_is_rgba_vectorizable = true; |
303 | } |
304 | } |
305 | } |
306 | |
307 | // Only propagate texture scope from consumers to input expr if |
308 | // the input shape of the input expr is rgba vectorizable. |
309 | if (consumer_scope.find("global.texture" ) != std::string::npos) { |
310 | if (expr_is_rgba_vectorizable) { |
311 | storage_scope_[expr].push_back(scope); |
312 | } |
313 | } else { |
314 | storage_scope_[expr].push_back(consumer_scope); |
315 | } |
316 | } |
317 | } |
318 | |
319 | void LegalizeProducerStorage() { |
320 | for (auto& kv : consumer_storage_scopes_) { |
321 | const ExprNode* producer = kv.first; |
322 | std::string legal_scope = GetConsumerScope(kv.second); |
323 | if (storage_scope_.count(producer)) { |
324 | ICHECK(!HasMixedStorageOutputs(producer)) |
325 | << "Mixed output storage scopes are not currently supported" ; |
326 | if (storage_scope_[producer][0].find(legal_scope) == std::string::npos) { |
327 | for (size_t i = 0; i < storage_scope_[producer].size(); i++) { |
328 | // Only support uniform storage scope across all outputs for now |
329 | storage_scope_[producer][i] = legal_scope; |
330 | } |
331 | } |
332 | } |
333 | } |
334 | } |
335 | |
336 | std::string GetConsumerScope(const std::vector<std::string>& consumer_scopes) const { |
337 | if (!consumer_scopes.size()) { |
338 | return "global" ; |
339 | } |
340 | std::string texture_tag = "global.texture" ; |
341 | for (auto& consumer_scope : consumer_scopes) { |
342 | if (consumer_scope.find(texture_tag) == std::string::npos) { |
343 | return "global" ; |
344 | } |
345 | } |
346 | return texture_tag; |
347 | } |
348 | |
349 | bool CanConsumeTextures(const std::vector<std::string>& consumer_scopes) const { |
350 | std::string texture_tag = "global.texture" ; |
351 | for (auto& consumer_scope : consumer_scopes) { |
352 | if (consumer_scope.find(texture_tag) == 0) { |
353 | return true; |
354 | } |
355 | } |
356 | return false; |
357 | } |
358 | |
359 | bool HasMixedStorageOutputs(const ExprNode* expr) { |
360 | if (storage_scope_.count(expr)) { |
361 | std::string ref_scope = storage_scope_[expr][0]; |
362 | for (std::string& scope : storage_scope_[expr]) { |
363 | if (scope != ref_scope) { |
364 | return true; |
365 | } |
366 | } |
367 | } |
368 | return false; |
369 | } |
370 | |
371 | bool SupportsTextureStorage(const CallNode* call) const { |
372 | bool supports_texture_storage = false; |
373 | // we need to verify only entry functions since one of entry op defines main schedule |
374 | for (const auto& arg : call->args) { |
375 | if (!arg.as<VarNode>()) { |
376 | return false; |
377 | } |
378 | } |
379 | if (auto attrs = call->attrs.as<Conv2DAttrs>()) { |
380 | if (attrs->data_layout == "NCHW4c" && attrs->kernel_layout == "OIHW4o" ) { |
381 | supports_texture_storage = true; |
382 | } else if (attrs->data_layout == "NHWC4c" && |
383 | (attrs->kernel_layout == "HWOI4o" || attrs->kernel_layout == "HWIO4o" || |
384 | attrs->kernel_layout == "OIHW4o" )) { |
385 | supports_texture_storage = true; |
386 | } |
387 | } else if (auto attrs = call->attrs.as<Conv2DWinogradAttrs>()) { |
388 | if ((attrs->data_layout == "NCHW4c" || attrs->data_layout == "NHWC4c" ) && |
389 | (attrs->kernel_layout == "OIHW4o" || attrs->kernel_layout == "HWIO4o" )) { |
390 | supports_texture_storage = true; |
391 | } |
392 | } else if (auto attrs = call->attrs.as<GlobalPool2DAttrs>()) { |
393 | if (attrs->layout == "NCHW4c" ) { |
394 | supports_texture_storage = true; |
395 | } |
396 | } else if (auto attrs = call->attrs.as<MaxPool2DAttrs>()) { |
397 | if (attrs->layout == "NCHW4c" ) { |
398 | supports_texture_storage = true; |
399 | } |
400 | } else if (auto attrs = call->attrs.as<AvgPool2DAttrs>()) { |
401 | if (attrs->layout == "NCHW4c" ) { |
402 | supports_texture_storage = true; |
403 | } |
404 | } else if (const OpNode* opnode = call->op.as<OpNode>()) { |
405 | auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern" ); |
406 | auto pattern = fpattern[GetRef<Op>(opnode)]; |
407 | if (pattern <= kCommReduce) { |
408 | if (const auto* ttype = call->checked_type().as<TensorTypeNode>()) { |
409 | if (ttype->shape.size() == 5) { |
410 | supports_texture_storage = true; |
411 | } |
412 | } |
413 | } |
414 | } |
415 | |
416 | return supports_texture_storage; |
417 | } |
418 | |
419 | bool CanUseBuffers(const Expr param, const Array<PrimExpr> shape, |
420 | const tvm::DictAttrs param_attrs) const { |
421 | bool use_buffer = false; |
422 | if (param.as<ConstantNode>() && shape.size() == 5) { |
423 | auto kernel_layout = param_attrs.GetAttr<String>("kernel_layout" ); |
424 | if (kernel_layout == "HWOI4o" || kernel_layout == "HWIO4o" ) { |
425 | int a0 = shape[0].as<IntImmNode>()->value; |
426 | int a1 = shape[1].as<IntImmNode>()->value; |
427 | if (a0 != 1 && a1 != 1) { |
428 | use_buffer = true; |
429 | } |
430 | } else if (kernel_layout == "OIHW4o" ) { |
431 | int a2 = shape[2].as<IntImmNode>()->value; |
432 | int a3 = shape[3].as<IntImmNode>()->value; |
433 | if (a2 != 1 && a3 != 1) { |
434 | use_buffer = true; |
435 | } |
436 | } |
437 | } |
438 | return use_buffer; |
439 | } |
440 | |
441 | /*! \brief Temporary state for marking whether a visited function |
442 | * primitive supports texture storage scope */ |
443 | bool primitive_supports_texture_ = false; |
444 | /*! \brief expr storage scope mapping for each output */ |
445 | std::unordered_map<const ExprNode*, std::vector<std::string>> storage_scope_; |
446 | /*! \brief output storage scopes used by consumers of expr key */ |
447 | std::unordered_map<const ExprNode*, std::vector<std::string>> consumer_storage_scopes_; |
448 | /*! \brief mapping of arguments to call to function variables*/ |
449 | std::unordered_map<Expr, std::vector<Var>, ObjectPtrHash, ObjectPtrEqual> args_to_vars_; |
450 | /*! \brief mapping of arguments that can be converted to texture*/ |
451 | Map<Expr, Map<Expr, Array<String>>> accept_textures_; |
452 | /*! \brief main attribute for expression*/ |
453 | tvm::Attrs expr_attrib; |
454 | /*! \brief parameters that filter out from storage_map to use buffers*/ |
455 | std::unordered_set<Expr, ObjectPtrHash> buffers_params; |
456 | /*! \brief arguments in expression that will use buffers*/ |
457 | std::unordered_set<Expr, ObjectPtrHash> buffers_args; |
458 | }; |
459 | |
460 | } // namespace |
461 | |
462 | /** |
463 | * @brief rewrite of virtual devices, memory_scope part for expressions defined |
464 | * by the StorageInfo analysis pass |
465 | * |
466 | * Currently this workflow supports analysis and rewriting of VirtualDevice for |
467 | * Constants and function Variables |
468 | */ |
469 | class RewriteVDStorageScopes : public transform::DeviceAwareExprMutator { |
470 | using VarMap = std::unordered_map<Expr, Var, ObjectPtrHash, ObjectPtrEqual>; |
471 | |
472 | public: |
473 | using transform::DeviceAwareExprMutator::VisitExpr_; |
474 | |
475 | explicit RewriteVDStorageScopes(const Map<Expr, Map<Expr, Array<String>>>& storage_scope) |
476 | : transform::DeviceAwareExprMutator(Optional<IRModule>()), storage_scope_(storage_scope) {} |
477 | |
478 | Function Rewrite(const Expr& expr) { return Downcast<Function>(Mutate(expr)); } |
479 | |
480 | Expr VisitExpr_(const VarNode* vn) final { |
481 | if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end() && |
482 | storage_scope_[GetRef<Expr>(vn)].find(Expr()) != storage_scope_[GetRef<Expr>(vn)].end() && |
483 | storage_scope_[GetRef<Expr>(vn)][Expr()][0] != "global" ) { |
484 | Var c = Var(vn->vid, vn->type_annotation, vn->span); |
485 | auto virtual_device = GetVirtualDevice(GetRef<Expr>(vn)); |
486 | c->virtual_device_ = |
487 | VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, |
488 | virtual_device->target, storage_scope_[GetRef<Expr>(vn)][Expr()][0]); |
489 | return std::move(c); |
490 | } |
491 | return GetRef<Var>(vn); |
492 | } |
493 | |
494 | Expr VisitExpr_(const ConstantNode* vn) final { |
495 | if (storage_scope_.find(GetRef<Expr>(vn)) != storage_scope_.end() && |
496 | storage_scope_[GetRef<Expr>(vn)].find(Expr()) != storage_scope_[GetRef<Expr>(vn)].end()) { |
497 | Expr c = Constant(vn->data, vn->span); |
498 | auto virtual_device = GetVirtualDevice(GetRef<Expr>(vn)); |
499 | c = OnDevice( |
500 | c, |
501 | VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, |
502 | virtual_device->target, storage_scope_[GetRef<Expr>(vn)][Expr()][0]), |
503 | true); |
504 | return c; |
505 | } |
506 | return GetRef<Constant>(vn); |
507 | } |
508 | |
509 | Expr DeviceAwareVisitExpr_(const CallNode* call_node) final { |
510 | // we need to duplicate ExprMutator::VisitExpr_ to correct argument scopes and |
511 | // put device_copy |
512 | auto new_op = this->Mutate(call_node->op); |
513 | |
514 | tvm::Array<Type> ty_args; |
515 | ty_args.reserve(call_node->type_args.size()); |
516 | |
517 | for (auto ty_arg : call_node->type_args) { |
518 | auto new_ty_arg = this->VisitType(ty_arg); |
519 | ty_args.push_back(new_ty_arg); |
520 | } |
521 | |
522 | tvm::Array<Expr> call_args; |
523 | call_args.reserve(call_node->args.size()); |
524 | for (auto arg : call_node->args) { |
525 | auto new_arg = this->Mutate(arg); |
526 | // verification if we need to put device_copy |
527 | if (storage_scope_.count(arg) && storage_scope_[arg].count(GetRef<Expr>(call_node))) { |
528 | auto virtual_device = GetVirtualDevice(GetRef<Expr>(call_node)); |
529 | VirtualDevice virtual_device_from = |
530 | VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, |
531 | virtual_device->target, virtual_device->memory_scope); |
532 | VirtualDevice virtual_device_to = |
533 | VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, |
534 | virtual_device->target, storage_scope_[arg][GetRef<Expr>(call_node)][0]); |
535 | new_arg = DeviceCopy(new_arg, virtual_device_from, virtual_device_to); |
536 | new_arg = OnDevice( |
537 | new_arg, |
538 | VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, |
539 | virtual_device->target, storage_scope_[arg][GetRef<Expr>(call_node)][0]), |
540 | true); |
541 | } |
542 | call_args.push_back(new_arg); |
543 | } |
544 | |
545 | auto new_call = WithFields(GetRef<Call>(call_node), new_op, call_args, {}, ty_args); |
546 | |
547 | auto virtual_device = GetVirtualDevice(GetRef<Expr>(call_node)); |
548 | std::string memory_scope = "" ; |
549 | if (storage_scope_.find(GetRef<Expr>(call_node)) != storage_scope_.end() && |
550 | storage_scope_[GetRef<Expr>(call_node)].find(Expr()) != |
551 | storage_scope_[GetRef<Expr>(call_node)].end()) { |
552 | memory_scope = storage_scope_[GetRef<Expr>(call_node)][Expr()][0]; |
553 | } else if (virtual_device->memory_scope != "" ) { |
554 | memory_scope = virtual_device->memory_scope; |
555 | } else if (!call_node->op.as<FunctionNode>()) { |
556 | memory_scope = "" ; |
557 | } |
558 | if (!memory_scope.empty()) { |
559 | new_call = |
560 | OnDevice(new_call, |
561 | VirtualDevice(virtual_device->device_type(), virtual_device->virtual_device_id, |
562 | virtual_device->target, memory_scope), |
563 | true); |
564 | } |
565 | return std::move(new_call); |
566 | } |
567 | |
568 | private: |
569 | Map<Expr, Map<Expr, Array<String>>> storage_scope_; |
570 | VarMap new_vars_; |
571 | Array<String> current_function_scope_; |
572 | }; |
573 | |
574 | Map<Expr, Map<Expr, Array<String>>> CollectTextureStorage(const Expr& expr) { |
575 | return StorageInfo::GetStorageMap(expr); |
576 | } |
577 | |
578 | /** |
579 | * @brief Collects all target devices participated in graph |
580 | */ |
581 | class CollectVirtualDevices : public transform::DeviceAwareExprVisitor { |
582 | public: |
583 | CollectVirtualDevices() : transform::DeviceAwareExprVisitor(Optional<IRModule>()) {} |
584 | /** |
585 | * @brief Get all unique device elements from target of each VirtualDevice |
586 | * |
587 | * @param expr - IR |
588 | * @return set of devices |
589 | */ |
590 | std::set<std::string> GetDevices(const Expr& expr) { |
591 | this->Run(expr); |
592 | return std::move(devices_); |
593 | } |
594 | |
595 | void Visit(const Expr& expr) { |
596 | // Pre-order traversal to enable upward propagation |
597 | // of consumer storage scopes to producers when desirable. |
598 | if (const auto* fn = expr.as<FunctionNode>()) { |
599 | this->VisitExpr(fn->body); |
600 | for (const auto& param : fn->params) { |
601 | this->VisitExpr(param); |
602 | } |
603 | } else { |
604 | this->VisitExpr(expr); |
605 | } |
606 | } |
607 | |
608 | void DeviceAwareVisitExpr_(const CallNode* call) final { |
609 | auto vd = GetVirtualDevice(GetRef<Expr>(call)); |
610 | if (vd != VirtualDevice::FullyUnconstrained()) { |
611 | if (Optional<String> t_device = vd->target->GetAttr<String>("device" )) { |
612 | devices_.insert(vd->target->kind->name + "." + t_device.value()); |
613 | } |
614 | } |
615 | for (auto& arg : call->args) { |
616 | Visit(arg); |
617 | } |
618 | } |
619 | |
620 | void Run(const Expr& expr) { VisitExpr(expr); } |
621 | using transform::DeviceAwareExprVisitor::VisitExpr_; |
622 | std::set<std::string> devices_; |
623 | }; |
624 | |
625 | /*! |
626 | * \brief Collect the target specific tensor storage info for each expression's output. |
627 | * \param expr The expression. |
628 | * \return The device based storage mapping. |
629 | */ |
630 | Map<Expr, Map<Expr, Array<String>>> CollectStorageInfo(const Expr& expr) { |
631 | std::set<std::string> device_types = CollectVirtualDevices().GetDevices(expr); |
632 | // TODO(amalyshe): current approach collects all targets withing graph and call the only |
633 | // function corresponding to all these targets in alphabetic order |
634 | // this will work reliable only for case of only one device and should be redesigned |
635 | // to handle common case |
636 | std::string ftarget_prefix = "relay.backend" ; |
637 | for (auto& dev_id : device_types) { |
638 | ftarget_prefix += (std::string("." ) + dev_id); |
639 | } |
640 | |
641 | Map<Expr, Map<Expr, Array<String>>> storage_info = {}; |
642 | if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectStorageInfo" )) { |
643 | storage_info = (*f)(expr); |
644 | } |
645 | return storage_info; |
646 | } |
647 | |
648 | Expr AnnotateMemoryScopeExpr(const Expr& expr, const IRModule& mod) { |
649 | auto storage_scope = CollectStorageInfo(expr); |
650 | if (storage_scope.size()) { |
651 | return RewriteVDStorageScopes(storage_scope).Rewrite(expr); |
652 | } else { |
653 | return expr; |
654 | } |
655 | } |
656 | |
657 | namespace transform { |
658 | tvm::transform::Pass AnnotateMemoryScope() { |
659 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
660 | [](Function f, IRModule m, PassContext pc) { |
661 | return Downcast<Function>(AnnotateMemoryScopeExpr(f, m)); |
662 | }; |
663 | return CreateFunctionPass(pass_func, 2, "AnnotateMemoryScope" , {}); |
664 | } |
665 | } // namespace transform |
666 | |
667 | TVM_REGISTER_GLOBAL("relay.backend.opencl.adreno._CollectStorageInfo" ) |
668 | .set_body_typed(CollectTextureStorage); |
669 | |
670 | } // namespace relay |
671 | } // namespace tvm |
672 | |