1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tir/analysis/apply_device_constraints.cc
22 * \brief Applies device-related constraints to \p PrimFunc parameters.
23 *
24 * This is used by the \p PlanDevices pass to flow device-constraints *into* \p PrimFuncs.
25 *
26 * Currently only applies memory scope constraints into \p Buffer data pointer
27 * storage scopes. Aliased ('matched') buffers take on any scope introduced on
28 * the buffer they alias. However currently does not attempt to flow constraints into
29 * allocated buffers.
30 */
31
32#include "./device_constraint_utils.h"
33
34#include <tvm/relay/attrs/memory.h>
35#include <tvm/target/virtual_device.h>
36#include <tvm/tir/function.h>
37#include <tvm/tir/stmt_functor.h>
38
39namespace tvm {
40namespace tir {
41namespace {
42
43/*!
44 * \brief Returns the \p PointerTypeNode for \p buffer, or nullptr if \p buffer does not describe a
45 * pointer.
46 */
47const PointerTypeNode* PointerInBuffer(const tir::Buffer& buffer) {
48 return buffer->data->type_annotation.defined()
49 ? buffer->data->type_annotation.as<PointerTypeNode>()
50 : nullptr;
51}
52
53/*!
54 * \brief Returns the parameter variable and corresponding buffer at or after \p
55 * *current_primfunc_param_index in \p prim_func. Will skip over any non-pointer parameters. This
56 * can be used to find the parameter matching a tensor type in a flattened Relay function parameter
57 * or result.
58 */
59std::pair<tir::Var, tir::Buffer> FindPointerParam(const tir::PrimFunc& prim_func,
60 size_t* current_primfunc_param_index) {
61 while (true) {
62 ICHECK_LT(*current_primfunc_param_index, prim_func->params.size());
63 const tir::Var& param = prim_func->params[*current_primfunc_param_index];
64 auto itr = prim_func->buffer_map.find(param);
65 if (itr == prim_func->buffer_map.end()) {
66 VLOG(2) << "no buffer map entry for '" << param->name_hint << "'";
67 ++*current_primfunc_param_index;
68 continue;
69 }
70 const auto* pointer_type_node = PointerInBuffer((*itr).second);
71 if (pointer_type_node == nullptr) {
72 VLOG(2) << "not a pointer type for '" << param->name_hint << "'";
73 ++*current_primfunc_param_index;
74 continue;
75 }
76 VLOG(2) << "using PrimFunc param '" << param->name_hint << "'";
77 return *itr;
78 }
79}
80
81/*!
82 * \brief Check fails if any parameter at or after \p *current_primfunc_param_index in \p prim_func
83 * is for a pointer type. This can be used to check all \p prim_func parameters have been accounted
84 * for when using \p FindPointerParam above.
85 */
86void CheckNoRemainingPointerParams(const tir::PrimFunc& prim_func,
87 size_t* current_primfunc_param_index) {
88 while (*current_primfunc_param_index < prim_func->params.size()) {
89 const tir::Var& param = prim_func->params[*current_primfunc_param_index];
90 auto itr = prim_func->buffer_map.find(param);
91 if (itr == prim_func->buffer_map.end()) {
92 VLOG(1) << "no buffer map entry for '" << param->name_hint << "'";
93 ++*current_primfunc_param_index;
94 continue;
95 }
96 const auto* pointer_type_node = PointerInBuffer((*itr).second);
97 ICHECK(pointer_type_node == nullptr);
98 ++*current_primfunc_param_index;
99 }
100}
101
102/*!
103 * \brief Returns the (consistent) constraint to use for a Relay parameter of \p type,
104 * using \p prim_func parameters at or after \p *current_primfunc_param_index. Currently
105 * only memory scope is extracted. Fails if constraints are not consistent, ie \p type is a tuple
106 * type and the \p prim_func is attempting to map different fields of that tuple to different memory
107 * scopes. Returns the fully unconstrained \p VirtualDevice if no memory scopes constraints arise
108 * from the \p prim_func, ie all storage scope strings in pointer types are empty.
109 */
110VirtualDevice ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type,
111 size_t* current_primfunc_param_index) {
112 std::string memory_scope; // default empty => no constraint
113 for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) {
114 std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func, current_primfunc_param_index);
115 const tir::Buffer& buffer = kv.second;
116 const auto* pointer_type_node = buffer->data->type_annotation.as<PointerTypeNode>();
117 const MemoryScope& buffer_memory_scope = pointer_type_node->storage_scope;
118 if (memory_scope.empty()) {
119 memory_scope = buffer_memory_scope;
120 } else if (buffer_memory_scope.empty()) {
121 // No constraint.
122 } else {
123 // Tuples must be homogenous on their VirtualDevice and thus memory scope.
124 ICHECK_EQ(buffer_memory_scope, memory_scope);
125 }
126 ++*current_primfunc_param_index;
127 }
128 return VirtualDevice::ForMemoryScope(memory_scope);
129}
130
131/*!
132 * \brief Insert into param_constraints an entry for each parameter of \p prim_func starting from
133 * \p *current_primfunc_param_index for the flattened form of a Rleay parameters of \p type. Each
134 * entry maps to \p virtual_device.
135 */
136void InsertParamConstraints(
137 const tir::PrimFunc& prim_func, const Type& type, const VirtualDevice& virtual_device,
138 size_t* current_primfunc_param_index,
139 std::unordered_map<const tir::VarNode*, VirtualDevice>* param_constraints) {
140 for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) {
141 std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func, current_primfunc_param_index);
142 param_constraints->emplace(kv.first.get(), virtual_device);
143 ++*current_primfunc_param_index;
144 }
145}
146
147/*!
148 * \brief Apply the memory scope constraints to the \p Buffers and data \p Vars of a \p PrimFunc.
149 *
150 * All definitional occurrences of buffer Vars are rewritten to capture memory scopes in their
151 * PointerTypes:
152 * - Buffer::data (if the buffer itself is a definitional occurrence)
153 * - AllocateNode::buffer_var
154 * - FUTURE: LetStmtNode::var if aliasing a buffer data var.
155 *
156 * All referential occurrences of buffer Vars are replaced with their new definitions:
157 * - LoadNode::buffer_var
158 * - StoreNode::buffer_var
159 *
160 * Similarly all definitional occurrences of Buffers are rewritten to account for any new memory
161 * scopes:
162 * - PrimFuncNode::buffer_map keys.
163 * - BlockNode::match_buffers.buffer
164 * - FUTURE: BlockNode::alloc_buffers?
165 *
166 * And all referential occurrences of Buffers are replaced with their new definitions:
167 * - BufferLoadNode::buffer
168 * - BufferStoreNode::buffer
169 * - BufferRealizeNode::buffer
170 * - PrefetchNode::buffer
171 * - BufferRegionNode:buffer
172 * - BlockNode.match_buffers.source.buffer
173 * - BlockNode::{reads, writes}.buffer
174 *
175 * CAUTION: We assume strict sharing of Buffer objects and do not attempt to rewrite the bodies
176 * of referential buffers.
177 *
178 * CAUTION: EXPERIMENTAL: We don't yet account for all buffers and pointer types.
179 */
180class ApplyDeviceConstraintsMutator : public StmtExprMutator {
181 public:
182 ApplyDeviceConstraintsMutator() = default;
183
184 /*!
185 * \brief Returns \p prim_func written to capture the memory scope constraints in \p
186 * param_constraints for each pointer \p prim_func parameter. Returns \p prim_func unchanged if no
187 * memory scopes needed to change.
188 */
189 PrimFunc Rewrite(const PrimFunc& prim_func, const FuncType& relay_func_type,
190 const Array<VirtualDevice>& arg_and_result_virtual_devices) {
191 size_t current_primfunc_param_index = 0;
192 std::unordered_map<const tir::VarNode*, VirtualDevice> param_constraints;
193
194 // For each Relay function parameter...
195 for (size_t i = 0; i < relay_func_type->arg_types.size(); ++i) {
196 const Type& param_type = relay_func_type->arg_types[i];
197 const VirtualDevice& param_virtual_device = arg_and_result_virtual_devices[i];
198 InsertParamConstraints(prim_func, param_type, param_virtual_device,
199 &current_primfunc_param_index, &param_constraints);
200 }
201
202 // For the Relay function result...
203 const Type& ret_type = relay_func_type->ret_type;
204 const VirtualDevice& ret_virtual_device = arg_and_result_virtual_devices.back();
205 InsertParamConstraints(prim_func, ret_type, ret_virtual_device, &current_primfunc_param_index,
206 &param_constraints);
207
208 // Make sure we accounted for all prim_func parameters.
209 CheckNoRemainingPointerParams(prim_func, &current_primfunc_param_index);
210
211 // Start with a copy of the current prim_func buffer map.
212 Map<Var, Buffer> new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end());
213 bool any_change = false;
214
215 // For each constrained parameter...
216 for (const auto& kv : param_constraints) {
217 const tir::Var param = GetRef<tir::Var>(kv.first);
218 const VirtualDevice& virtual_device = kv.second;
219 const tir::Buffer& buffer = prim_func->buffer_map[param];
220 // Rewrite the buffer to account for constraint.
221 const Buffer new_buffer = RewriteBuffer(buffer, virtual_device);
222 if (!new_buffer.same_as(buffer)) {
223 any_change = true;
224 }
225 new_buffer_map.Set(param, new_buffer);
226 }
227 // Make sure we have accounted for all prim_func parameters.
228 CheckNoRemainingPointerParams(prim_func, &current_primfunc_param_index);
229
230 // Apply data variable and buffer substitutions to the prim_func body. These will have been
231 // accumulated from processing the parameters above.
232 Stmt new_body = VisitStmt(prim_func->body);
233 if (!new_body.same_as(prim_func->body)) {
234 any_change = true;
235 }
236
237 // We are done with the substitutions.
238 var_subst_.clear();
239 buffer_subst_.clear();
240
241 if (any_change) {
242 return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type,
243 std::move(new_buffer_map), prim_func->attrs, prim_func->span);
244 } else {
245 return prim_func;
246 }
247 }
248
249 private:
250 PrimExpr VisitExpr_(const VarNode* var_node) final { return Subst(var_node); }
251
252 PrimExpr VisitExpr_(const LoadNode* load_node) final {
253 Load new_load = Downcast<Load>(StmtExprMutator::VisitExpr_(load_node));
254 Var new_buffer_var = Subst(new_load->buffer_var.get());
255 if (!new_buffer_var.same_as(new_load->buffer_var)) {
256 return Load(load_node->dtype, new_buffer_var, load_node->index, load_node->predicate);
257 }
258 return std::move(new_load);
259 }
260
261 PrimExpr VisitExpr_(const BufferLoadNode* buffer_load_node) final {
262 BufferLoad new_buffer_load =
263 Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(buffer_load_node));
264 Buffer new_buffer = Subst(new_buffer_load->buffer.get());
265 if (!new_buffer.same_as(new_buffer_load->buffer)) {
266 return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span);
267 }
268 return std::move(new_buffer_load);
269 }
270
271 Stmt VisitStmt_(const LetStmtNode* let_stmt_node) final {
272 // TODO(mbs): If the let-bound var is aliasing an existing buffer data var we need to
273 // rewrite it.
274 return StmtExprMutator::VisitStmt_(let_stmt_node);
275 }
276
277 Stmt VisitStmt_(const AttrStmtNode* attr_stmt_node) final {
278 AttrStmt new_attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(attr_stmt_node));
279 // remap node if a var
280 if (const auto* var_node = new_attr_stmt->node.as<VarNode>()) {
281 Var new_var = Subst(var_node);
282 if (!new_var.same_as(new_attr_stmt->node)) {
283 return AttrStmt(new_var, new_attr_stmt->attr_key, new_attr_stmt->value,
284 new_attr_stmt->body);
285 }
286 }
287 return std::move(new_attr_stmt);
288 }
289
290 // ForNode default ok since loop_var never of PointerType
291
292 // WhileNode default ok
293
294 Stmt VisitStmt_(const AllocateNode* allocate_node) final {
295 // TODO(mbs): What memory scope should we assign to the new pointer?
296 return StmtExprMutator::VisitStmt_(allocate_node);
297 }
298
299 Stmt VisitStmt_(const StoreNode* store_node) final {
300 Store new_store = Downcast<Store>(StmtExprMutator::VisitStmt_(store_node));
301 Var new_buffer_var = Subst(new_store->buffer_var.get());
302 if (!new_buffer_var.same_as(new_store->buffer_var)) {
303 Store(new_buffer_var, new_store->value, new_store->index, new_store->predicate);
304 }
305 return std::move(new_store);
306 }
307
308 Stmt VisitStmt_(const BufferStoreNode* buffer_store_node) final {
309 BufferStore new_buffer_store =
310 Downcast<BufferStore>(StmtExprMutator::VisitStmt_(buffer_store_node));
311 Buffer new_buffer = Subst(new_buffer_store->buffer.get());
312 if (!new_buffer.same_as(new_buffer_store->buffer)) {
313 return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices,
314 new_buffer_store->span);
315 }
316 return std::move(new_buffer_store);
317 }
318
319 Stmt VisitStmt_(const BufferRealizeNode* buffer_realize_node) final {
320 BufferRealize new_buffer_realize =
321 Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(buffer_realize_node));
322 Buffer new_buffer = Subst(new_buffer_realize->buffer.get());
323 if (!new_buffer.same_as(new_buffer_realize->buffer)) {
324 return BufferRealize(new_buffer, new_buffer_realize->bounds, new_buffer_realize->condition,
325 new_buffer_realize->body, new_buffer_realize->span);
326 }
327 return std::move(new_buffer_realize);
328 }
329
330 // IfThenElseNode default ok
331 // AssertStmtNode default ok
332 // ProducerStoreNode default ok (though does not visit producer)
333 // ProducerRealizeNode default ok (though does not visit producer)
334
335 Stmt VisitStmt_(const PrefetchNode* prefetch_node) final {
336 Prefetch new_prefetch = Downcast<Prefetch>(StmtExprMutator::VisitStmt_(prefetch_node));
337 Buffer new_buffer = Subst(new_prefetch->buffer.get());
338 if (!new_buffer.same_as(new_prefetch->buffer)) {
339 return Prefetch(new_buffer, prefetch_node->bounds, prefetch_node->span);
340 }
341 return std::move(new_prefetch);
342 }
343
344 // SeqStmtNode default ok
345 // EvaluateNode default ok
346
347 BufferRegion VisitItem(const BufferRegionNode* buffer_region_node) {
348 Buffer new_buffer = Subst(buffer_region_node->buffer.get());
349 if (!new_buffer.same_as(buffer_region_node->buffer)) {
350 return BufferRegion(new_buffer, buffer_region_node->region);
351 }
352 return GetRef<BufferRegion>(buffer_region_node);
353 }
354
355 MatchBufferRegion VisitItem(const MatchBufferRegionNode* match_buffer_region_node) {
356 // The source field has a referential occurrence of the buffer. Apply the buffer substitution
357 // to that.
358 BufferRegion new_source = VisitItem(match_buffer_region_node->source.get());
359 // The buffer field however is a definitional occurrence, aliased on top of the source.
360 // Transfer any memory scope from the source to the destination.
361 Optional<VirtualDevice> opt_virtual_device = GetBufferConstraint(new_source->buffer);
362 tir::Buffer new_buffer;
363 if (opt_virtual_device.defined()) {
364 new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_virtual_device.value());
365 } else {
366 new_buffer = match_buffer_region_node->buffer;
367 }
368 if (!new_buffer.same_as(match_buffer_region_node->buffer) ||
369 !new_source.same_as(match_buffer_region_node->source)) {
370 return MatchBufferRegion(new_buffer, new_source);
371 }
372 return GetRef<MatchBufferRegion>(match_buffer_region_node);
373 }
374
375 template <typename T>
376 Array<T> VisitItems(const Array<T>& items) {
377 return items.Map([this](T item) -> T { return VisitItem(item.get()); });
378 }
379
380 Stmt VisitStmt_(const BlockNode* block_node) final {
381 Block new_block = Downcast<Block>(StmtExprMutator::VisitStmt_(block_node));
382 Array<BufferRegion> new_reads = VisitItems(new_block->reads);
383 Array<BufferRegion> new_writes = VisitItems(new_block->writes);
384 // TODO(mbs): What memory scope should we assign to the new buffers?
385 Array<MatchBufferRegion> new_match_buffers = VisitItems(new_block->match_buffers);
386 if (!new_reads.same_as(new_block->reads) || new_writes.same_as(new_block->writes) ||
387 new_match_buffers.same_as(new_block->match_buffers)) {
388 return Block(new_block->iter_vars, std::move(new_reads), std::move(new_writes),
389 new_block->name_hint, new_block->body, new_block->init, new_block->alloc_buffers,
390 std::move(new_match_buffers), new_block->annotations, new_block->span);
391 }
392 return std::move(new_block);
393 }
394
395 // BlockRealizeNode default ok
396
397 /*! Applies \p var_subst_ substitution to \p var_node. */
398 Var Subst(const VarNode* var_node) const {
399 auto itr = var_subst_.find(var_node);
400 return itr == var_subst_.end() ? GetRef<Var>(var_node) : itr->second;
401 }
402
403 /*! Applies \p buffer_subst_ substitution to \p buffer. */
404 Buffer Subst(const BufferNode* buffer_node) const {
405 auto itr = buffer_subst_.find(buffer_node);
406 return itr == buffer_subst_.end() ? GetRef<Buffer>(buffer_node) : itr->second;
407 }
408
409 /*!
410 * \brief Rewrites \p buffer so as to follow the constraints in \p virtual_device
411 * (currently just memory scope).
412 *
413 * Updates both the var_subst_ and buffer_subst_ to capture the rewrite, but
414 * also returns the new buffer.
415 */
416 Buffer RewriteBuffer(const Buffer& buffer, const VirtualDevice& virtual_device) {
417 ICHECK(buffer->data->type_annotation.defined());
418 const auto* pointer_type_node = buffer->data->type_annotation.as<PointerTypeNode>();
419 ICHECK(pointer_type_node);
420 if (pointer_type_node->storage_scope == virtual_device->memory_scope) {
421 // No change.
422 return buffer;
423 }
424 PointerType new_pointer_type(pointer_type_node->element_type, virtual_device->memory_scope);
425 Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span);
426 var_subst_.emplace(buffer->data.get(), new_data);
427 Buffer new_buffer = buffer;
428 new_buffer.CopyOnWrite()->data = new_data;
429 buffer_subst_.emplace(buffer.get(), new_buffer);
430 return new_buffer;
431 }
432
433 /*!
434 * \brief Returns the VirtualDevice capturing any memory scope in \p buffer. Returns nullptr if
435 * buffer's data var does not have a type annotation of \p PointerType. Returns the fully
436 * unconstrained \p VirtualDevice if no memory scope is given.
437 */
438 static Optional<VirtualDevice> GetBufferConstraint(const tir::Buffer& buffer) {
439 const auto* pointer_type_node = PointerInBuffer(buffer);
440 return pointer_type_node == nullptr
441 ? Optional<VirtualDevice>()
442 : VirtualDevice::ForMemoryScope(pointer_type_node->storage_scope);
443 }
444
445 /*!
446 * \brief Maps each \p Buffer::data \p Var to its constrained equivalent.
447 */
448 std::unordered_map<const VarNode*, Var> var_subst_;
449
450 /*!
451 * \brief Maps each \p Buffer to its constrained equivalent.
452 */
453 std::unordered_map<const BufferNode*, Buffer> buffer_subst_;
454};
455
456} // namespace
457
458Array<VirtualDevice> GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func,
459 const FuncType& relay_func_type) {
460 // Build the implied domain (in terms of the function's Relay type) implied by any memory scope
461 // constrains in the function's buffers, for both arguments and results.
462 Array<VirtualDevice> virtual_devices;
463 virtual_devices.reserve(relay_func_type->arg_types.size() + 1);
464
465 // For each Relay function parameter...
466 size_t current_primfunc_param_index = 0;
467 for (const auto& param_type : relay_func_type->arg_types) {
468 VirtualDevice param_virtual_device =
469 ConsistentParamConstraint(prim_func, param_type, &current_primfunc_param_index);
470 virtual_devices.push_back(param_virtual_device);
471 }
472
473 // For the Relay function result...
474 const Type& ret_type = relay_func_type->ret_type;
475 VirtualDevice ret_virtual_device =
476 ConsistentParamConstraint(prim_func, ret_type, &current_primfunc_param_index);
477 virtual_devices.push_back(ret_virtual_device);
478
479 // Make sure all parameters of the prim_func have been accounted for.
480 CheckNoRemainingPointerParams(prim_func, &current_primfunc_param_index);
481
482 return virtual_devices;
483}
484
485TVM_REGISTER_GLOBAL("tir.analysis.GetPrimFuncArgAndResultMemoryConstraints")
486 .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type) {
487 Array<String> memory_scopes;
488 memory_scopes.reserve(relay_func_type->type_params.size() + 1);
489 for (const auto& virtual_device :
490 GetPrimFuncArgAndResultConstraints(prim_func, relay_func_type)) {
491 memory_scopes.push_back(virtual_device->memory_scope);
492 }
493 return memory_scopes;
494 });
495
496PrimFunc ApplyPrimFuncArgAndResultConstraints(
497 const PrimFunc& prim_func, const FuncType& relay_func_type,
498 const Array<VirtualDevice>& arg_and_result_virtual_devices) {
499 return ApplyDeviceConstraintsMutator().Rewrite(prim_func, relay_func_type,
500 arg_and_result_virtual_devices);
501}
502
503TVM_REGISTER_GLOBAL("tir.analysis.ApplyPrimFuncArgAndResultMemoryConstraints")
504 .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type,
505 const Array<String>& arg_and_result_memory_scopes) {
506 Array<VirtualDevice> virtual_devices;
507 virtual_devices.reserve(arg_and_result_memory_scopes.size());
508 for (const auto& memory_scope : arg_and_result_memory_scopes) {
509 virtual_devices.push_back(VirtualDevice::ForMemoryScope(memory_scope));
510 }
511 return ApplyPrimFuncArgAndResultConstraints(prim_func, relay_func_type, virtual_devices);
512 });
513
514} // namespace tir
515} // namespace tvm
516