1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tir/analysis/apply_device_constraints.cc |
22 | * \brief Applies device-related constraints to \p PrimFunc parameters. |
23 | * |
24 | * This is used by the \p PlanDevices pass to flow device-constraints *into* \p PrimFuncs. |
25 | * |
26 | * Currently only applies memory scope constraints into \p Buffer data pointer |
27 | * storage scopes. Aliased ('matched') buffers take on any scope introduced on |
28 | * the buffer they alias. However currently does not attempt to flow constraints into |
29 | * allocated buffers. |
30 | */ |
31 | |
32 | #include "./device_constraint_utils.h" |
33 | |
34 | #include <tvm/relay/attrs/memory.h> |
35 | #include <tvm/target/virtual_device.h> |
36 | #include <tvm/tir/function.h> |
37 | #include <tvm/tir/stmt_functor.h> |
38 | |
39 | namespace tvm { |
40 | namespace tir { |
41 | namespace { |
42 | |
43 | /*! |
44 | * \brief Returns the \p PointerTypeNode for \p buffer, or nullptr if \p buffer does not describe a |
45 | * pointer. |
46 | */ |
47 | const PointerTypeNode* PointerInBuffer(const tir::Buffer& buffer) { |
48 | return buffer->data->type_annotation.defined() |
49 | ? buffer->data->type_annotation.as<PointerTypeNode>() |
50 | : nullptr; |
51 | } |
52 | |
53 | /*! |
54 | * \brief Returns the parameter variable and corresponding buffer at or after \p |
55 | * *current_primfunc_param_index in \p prim_func. Will skip over any non-pointer parameters. This |
56 | * can be used to find the parameter matching a tensor type in a flattened Relay function parameter |
57 | * or result. |
58 | */ |
59 | std::pair<tir::Var, tir::Buffer> FindPointerParam(const tir::PrimFunc& prim_func, |
60 | size_t* current_primfunc_param_index) { |
61 | while (true) { |
62 | ICHECK_LT(*current_primfunc_param_index, prim_func->params.size()); |
63 | const tir::Var& param = prim_func->params[*current_primfunc_param_index]; |
64 | auto itr = prim_func->buffer_map.find(param); |
65 | if (itr == prim_func->buffer_map.end()) { |
66 | VLOG(2) << "no buffer map entry for '" << param->name_hint << "'" ; |
67 | ++*current_primfunc_param_index; |
68 | continue; |
69 | } |
70 | const auto* pointer_type_node = PointerInBuffer((*itr).second); |
71 | if (pointer_type_node == nullptr) { |
72 | VLOG(2) << "not a pointer type for '" << param->name_hint << "'" ; |
73 | ++*current_primfunc_param_index; |
74 | continue; |
75 | } |
76 | VLOG(2) << "using PrimFunc param '" << param->name_hint << "'" ; |
77 | return *itr; |
78 | } |
79 | } |
80 | |
81 | /*! |
82 | * \brief Check fails if any parameter at or after \p *current_primfunc_param_index in \p prim_func |
83 | * is for a pointer type. This can be used to check all \p prim_func parameters have been accounted |
84 | * for when using \p FindPointerParam above. |
85 | */ |
86 | void CheckNoRemainingPointerParams(const tir::PrimFunc& prim_func, |
87 | size_t* current_primfunc_param_index) { |
88 | while (*current_primfunc_param_index < prim_func->params.size()) { |
89 | const tir::Var& param = prim_func->params[*current_primfunc_param_index]; |
90 | auto itr = prim_func->buffer_map.find(param); |
91 | if (itr == prim_func->buffer_map.end()) { |
92 | VLOG(1) << "no buffer map entry for '" << param->name_hint << "'" ; |
93 | ++*current_primfunc_param_index; |
94 | continue; |
95 | } |
96 | const auto* pointer_type_node = PointerInBuffer((*itr).second); |
97 | ICHECK(pointer_type_node == nullptr); |
98 | ++*current_primfunc_param_index; |
99 | } |
100 | } |
101 | |
102 | /*! |
103 | * \brief Returns the (consistent) constraint to use for a Relay parameter of \p type, |
104 | * using \p prim_func parameters at or after \p *current_primfunc_param_index. Currently |
105 | * only memory scope is extracted. Fails if constraints are not consistent, ie \p type is a tuple |
106 | * type and the \p prim_func is attempting to map different fields of that tuple to different memory |
107 | * scopes. Returns the fully unconstrained \p VirtualDevice if no memory scopes constraints arise |
108 | * from the \p prim_func, ie all storage scope strings in pointer types are empty. |
109 | */ |
110 | VirtualDevice ConsistentParamConstraint(const tir::PrimFunc& prim_func, const Type& type, |
111 | size_t* current_primfunc_param_index) { |
112 | std::string memory_scope; // default empty => no constraint |
113 | for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) { |
114 | std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func, current_primfunc_param_index); |
115 | const tir::Buffer& buffer = kv.second; |
116 | const auto* pointer_type_node = buffer->data->type_annotation.as<PointerTypeNode>(); |
117 | const MemoryScope& buffer_memory_scope = pointer_type_node->storage_scope; |
118 | if (memory_scope.empty()) { |
119 | memory_scope = buffer_memory_scope; |
120 | } else if (buffer_memory_scope.empty()) { |
121 | // No constraint. |
122 | } else { |
123 | // Tuples must be homogenous on their VirtualDevice and thus memory scope. |
124 | ICHECK_EQ(buffer_memory_scope, memory_scope); |
125 | } |
126 | ++*current_primfunc_param_index; |
127 | } |
128 | return VirtualDevice::ForMemoryScope(memory_scope); |
129 | } |
130 | |
131 | /*! |
132 | * \brief Insert into param_constraints an entry for each parameter of \p prim_func starting from |
133 | * \p *current_primfunc_param_index for the flattened form of a Rleay parameters of \p type. Each |
134 | * entry maps to \p virtual_device. |
135 | */ |
136 | void InsertParamConstraints( |
137 | const tir::PrimFunc& prim_func, const Type& type, const VirtualDevice& virtual_device, |
138 | size_t* current_primfunc_param_index, |
139 | std::unordered_map<const tir::VarNode*, VirtualDevice>* param_constraints) { |
140 | for (size_t i = 0; i < relay::FlattenTupleType(type).size(); ++i) { |
141 | std::pair<tir::Var, tir::Buffer> kv = FindPointerParam(prim_func, current_primfunc_param_index); |
142 | param_constraints->emplace(kv.first.get(), virtual_device); |
143 | ++*current_primfunc_param_index; |
144 | } |
145 | } |
146 | |
147 | /*! |
148 | * \brief Apply the memory scope constraints to the \p Buffers and data \p Vars of a \p PrimFunc. |
149 | * |
150 | * All definitional occurrences of buffer Vars are rewritten to capture memory scopes in their |
151 | * PointerTypes: |
152 | * - Buffer::data (if the buffer itself is a definitional occurrence) |
153 | * - AllocateNode::buffer_var |
154 | * - FUTURE: LetStmtNode::var if aliasing a buffer data var. |
155 | * |
156 | * All referential occurrences of buffer Vars are replaced with their new definitions: |
157 | * - LoadNode::buffer_var |
158 | * - StoreNode::buffer_var |
159 | * |
160 | * Similarly all definitional occurrences of Buffers are rewritten to account for any new memory |
161 | * scopes: |
162 | * - PrimFuncNode::buffer_map keys. |
163 | * - BlockNode::match_buffers.buffer |
164 | * - FUTURE: BlockNode::alloc_buffers? |
165 | * |
166 | * And all referential occurrences of Buffers are replaced with their new definitions: |
167 | * - BufferLoadNode::buffer |
168 | * - BufferStoreNode::buffer |
169 | * - BufferRealizeNode::buffer |
170 | * - PrefetchNode::buffer |
171 | * - BufferRegionNode:buffer |
172 | * - BlockNode.match_buffers.source.buffer |
173 | * - BlockNode::{reads, writes}.buffer |
174 | * |
175 | * CAUTION: We assume strict sharing of Buffer objects and do not attempt to rewrite the bodies |
176 | * of referential buffers. |
177 | * |
178 | * CAUTION: EXPERIMENTAL: We don't yet account for all buffers and pointer types. |
179 | */ |
180 | class ApplyDeviceConstraintsMutator : public StmtExprMutator { |
181 | public: |
182 | ApplyDeviceConstraintsMutator() = default; |
183 | |
184 | /*! |
185 | * \brief Returns \p prim_func written to capture the memory scope constraints in \p |
186 | * param_constraints for each pointer \p prim_func parameter. Returns \p prim_func unchanged if no |
187 | * memory scopes needed to change. |
188 | */ |
189 | PrimFunc Rewrite(const PrimFunc& prim_func, const FuncType& relay_func_type, |
190 | const Array<VirtualDevice>& arg_and_result_virtual_devices) { |
191 | size_t current_primfunc_param_index = 0; |
192 | std::unordered_map<const tir::VarNode*, VirtualDevice> param_constraints; |
193 | |
194 | // For each Relay function parameter... |
195 | for (size_t i = 0; i < relay_func_type->arg_types.size(); ++i) { |
196 | const Type& param_type = relay_func_type->arg_types[i]; |
197 | const VirtualDevice& param_virtual_device = arg_and_result_virtual_devices[i]; |
198 | InsertParamConstraints(prim_func, param_type, param_virtual_device, |
199 | ¤t_primfunc_param_index, ¶m_constraints); |
200 | } |
201 | |
202 | // For the Relay function result... |
203 | const Type& ret_type = relay_func_type->ret_type; |
204 | const VirtualDevice& ret_virtual_device = arg_and_result_virtual_devices.back(); |
205 | InsertParamConstraints(prim_func, ret_type, ret_virtual_device, ¤t_primfunc_param_index, |
206 | ¶m_constraints); |
207 | |
208 | // Make sure we accounted for all prim_func parameters. |
209 | CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); |
210 | |
211 | // Start with a copy of the current prim_func buffer map. |
212 | Map<Var, Buffer> new_buffer_map(prim_func->buffer_map.begin(), prim_func->buffer_map.end()); |
213 | bool any_change = false; |
214 | |
215 | // For each constrained parameter... |
216 | for (const auto& kv : param_constraints) { |
217 | const tir::Var param = GetRef<tir::Var>(kv.first); |
218 | const VirtualDevice& virtual_device = kv.second; |
219 | const tir::Buffer& buffer = prim_func->buffer_map[param]; |
220 | // Rewrite the buffer to account for constraint. |
221 | const Buffer new_buffer = RewriteBuffer(buffer, virtual_device); |
222 | if (!new_buffer.same_as(buffer)) { |
223 | any_change = true; |
224 | } |
225 | new_buffer_map.Set(param, new_buffer); |
226 | } |
227 | // Make sure we have accounted for all prim_func parameters. |
228 | CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); |
229 | |
230 | // Apply data variable and buffer substitutions to the prim_func body. These will have been |
231 | // accumulated from processing the parameters above. |
232 | Stmt new_body = VisitStmt(prim_func->body); |
233 | if (!new_body.same_as(prim_func->body)) { |
234 | any_change = true; |
235 | } |
236 | |
237 | // We are done with the substitutions. |
238 | var_subst_.clear(); |
239 | buffer_subst_.clear(); |
240 | |
241 | if (any_change) { |
242 | return PrimFunc(prim_func->params, std::move(new_body), prim_func->ret_type, |
243 | std::move(new_buffer_map), prim_func->attrs, prim_func->span); |
244 | } else { |
245 | return prim_func; |
246 | } |
247 | } |
248 | |
249 | private: |
250 | PrimExpr VisitExpr_(const VarNode* var_node) final { return Subst(var_node); } |
251 | |
252 | PrimExpr VisitExpr_(const LoadNode* load_node) final { |
253 | Load new_load = Downcast<Load>(StmtExprMutator::VisitExpr_(load_node)); |
254 | Var new_buffer_var = Subst(new_load->buffer_var.get()); |
255 | if (!new_buffer_var.same_as(new_load->buffer_var)) { |
256 | return Load(load_node->dtype, new_buffer_var, load_node->index, load_node->predicate); |
257 | } |
258 | return std::move(new_load); |
259 | } |
260 | |
261 | PrimExpr VisitExpr_(const BufferLoadNode* buffer_load_node) final { |
262 | BufferLoad new_buffer_load = |
263 | Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(buffer_load_node)); |
264 | Buffer new_buffer = Subst(new_buffer_load->buffer.get()); |
265 | if (!new_buffer.same_as(new_buffer_load->buffer)) { |
266 | return BufferLoad(new_buffer, new_buffer_load->indices, new_buffer_load->span); |
267 | } |
268 | return std::move(new_buffer_load); |
269 | } |
270 | |
271 | Stmt VisitStmt_(const LetStmtNode* let_stmt_node) final { |
272 | // TODO(mbs): If the let-bound var is aliasing an existing buffer data var we need to |
273 | // rewrite it. |
274 | return StmtExprMutator::VisitStmt_(let_stmt_node); |
275 | } |
276 | |
277 | Stmt VisitStmt_(const AttrStmtNode* attr_stmt_node) final { |
278 | AttrStmt new_attr_stmt = Downcast<AttrStmt>(StmtExprMutator::VisitStmt_(attr_stmt_node)); |
279 | // remap node if a var |
280 | if (const auto* var_node = new_attr_stmt->node.as<VarNode>()) { |
281 | Var new_var = Subst(var_node); |
282 | if (!new_var.same_as(new_attr_stmt->node)) { |
283 | return AttrStmt(new_var, new_attr_stmt->attr_key, new_attr_stmt->value, |
284 | new_attr_stmt->body); |
285 | } |
286 | } |
287 | return std::move(new_attr_stmt); |
288 | } |
289 | |
290 | // ForNode default ok since loop_var never of PointerType |
291 | |
292 | // WhileNode default ok |
293 | |
294 | Stmt VisitStmt_(const AllocateNode* allocate_node) final { |
295 | // TODO(mbs): What memory scope should we assign to the new pointer? |
296 | return StmtExprMutator::VisitStmt_(allocate_node); |
297 | } |
298 | |
299 | Stmt VisitStmt_(const StoreNode* store_node) final { |
300 | Store new_store = Downcast<Store>(StmtExprMutator::VisitStmt_(store_node)); |
301 | Var new_buffer_var = Subst(new_store->buffer_var.get()); |
302 | if (!new_buffer_var.same_as(new_store->buffer_var)) { |
303 | Store(new_buffer_var, new_store->value, new_store->index, new_store->predicate); |
304 | } |
305 | return std::move(new_store); |
306 | } |
307 | |
308 | Stmt VisitStmt_(const BufferStoreNode* buffer_store_node) final { |
309 | BufferStore new_buffer_store = |
310 | Downcast<BufferStore>(StmtExprMutator::VisitStmt_(buffer_store_node)); |
311 | Buffer new_buffer = Subst(new_buffer_store->buffer.get()); |
312 | if (!new_buffer.same_as(new_buffer_store->buffer)) { |
313 | return BufferStore(new_buffer, new_buffer_store->value, new_buffer_store->indices, |
314 | new_buffer_store->span); |
315 | } |
316 | return std::move(new_buffer_store); |
317 | } |
318 | |
319 | Stmt VisitStmt_(const BufferRealizeNode* buffer_realize_node) final { |
320 | BufferRealize new_buffer_realize = |
321 | Downcast<BufferRealize>(StmtExprMutator::VisitStmt_(buffer_realize_node)); |
322 | Buffer new_buffer = Subst(new_buffer_realize->buffer.get()); |
323 | if (!new_buffer.same_as(new_buffer_realize->buffer)) { |
324 | return BufferRealize(new_buffer, new_buffer_realize->bounds, new_buffer_realize->condition, |
325 | new_buffer_realize->body, new_buffer_realize->span); |
326 | } |
327 | return std::move(new_buffer_realize); |
328 | } |
329 | |
330 | // IfThenElseNode default ok |
331 | // AssertStmtNode default ok |
332 | // ProducerStoreNode default ok (though does not visit producer) |
333 | // ProducerRealizeNode default ok (though does not visit producer) |
334 | |
335 | Stmt VisitStmt_(const PrefetchNode* prefetch_node) final { |
336 | Prefetch new_prefetch = Downcast<Prefetch>(StmtExprMutator::VisitStmt_(prefetch_node)); |
337 | Buffer new_buffer = Subst(new_prefetch->buffer.get()); |
338 | if (!new_buffer.same_as(new_prefetch->buffer)) { |
339 | return Prefetch(new_buffer, prefetch_node->bounds, prefetch_node->span); |
340 | } |
341 | return std::move(new_prefetch); |
342 | } |
343 | |
344 | // SeqStmtNode default ok |
345 | // EvaluateNode default ok |
346 | |
347 | BufferRegion VisitItem(const BufferRegionNode* buffer_region_node) { |
348 | Buffer new_buffer = Subst(buffer_region_node->buffer.get()); |
349 | if (!new_buffer.same_as(buffer_region_node->buffer)) { |
350 | return BufferRegion(new_buffer, buffer_region_node->region); |
351 | } |
352 | return GetRef<BufferRegion>(buffer_region_node); |
353 | } |
354 | |
355 | MatchBufferRegion VisitItem(const MatchBufferRegionNode* match_buffer_region_node) { |
356 | // The source field has a referential occurrence of the buffer. Apply the buffer substitution |
357 | // to that. |
358 | BufferRegion new_source = VisitItem(match_buffer_region_node->source.get()); |
359 | // The buffer field however is a definitional occurrence, aliased on top of the source. |
360 | // Transfer any memory scope from the source to the destination. |
361 | Optional<VirtualDevice> opt_virtual_device = GetBufferConstraint(new_source->buffer); |
362 | tir::Buffer new_buffer; |
363 | if (opt_virtual_device.defined()) { |
364 | new_buffer = RewriteBuffer(match_buffer_region_node->buffer, opt_virtual_device.value()); |
365 | } else { |
366 | new_buffer = match_buffer_region_node->buffer; |
367 | } |
368 | if (!new_buffer.same_as(match_buffer_region_node->buffer) || |
369 | !new_source.same_as(match_buffer_region_node->source)) { |
370 | return MatchBufferRegion(new_buffer, new_source); |
371 | } |
372 | return GetRef<MatchBufferRegion>(match_buffer_region_node); |
373 | } |
374 | |
375 | template <typename T> |
376 | Array<T> VisitItems(const Array<T>& items) { |
377 | return items.Map([this](T item) -> T { return VisitItem(item.get()); }); |
378 | } |
379 | |
380 | Stmt VisitStmt_(const BlockNode* block_node) final { |
381 | Block new_block = Downcast<Block>(StmtExprMutator::VisitStmt_(block_node)); |
382 | Array<BufferRegion> new_reads = VisitItems(new_block->reads); |
383 | Array<BufferRegion> new_writes = VisitItems(new_block->writes); |
384 | // TODO(mbs): What memory scope should we assign to the new buffers? |
385 | Array<MatchBufferRegion> new_match_buffers = VisitItems(new_block->match_buffers); |
386 | if (!new_reads.same_as(new_block->reads) || new_writes.same_as(new_block->writes) || |
387 | new_match_buffers.same_as(new_block->match_buffers)) { |
388 | return Block(new_block->iter_vars, std::move(new_reads), std::move(new_writes), |
389 | new_block->name_hint, new_block->body, new_block->init, new_block->alloc_buffers, |
390 | std::move(new_match_buffers), new_block->annotations, new_block->span); |
391 | } |
392 | return std::move(new_block); |
393 | } |
394 | |
395 | // BlockRealizeNode default ok |
396 | |
397 | /*! Applies \p var_subst_ substitution to \p var_node. */ |
398 | Var Subst(const VarNode* var_node) const { |
399 | auto itr = var_subst_.find(var_node); |
400 | return itr == var_subst_.end() ? GetRef<Var>(var_node) : itr->second; |
401 | } |
402 | |
403 | /*! Applies \p buffer_subst_ substitution to \p buffer. */ |
404 | Buffer Subst(const BufferNode* buffer_node) const { |
405 | auto itr = buffer_subst_.find(buffer_node); |
406 | return itr == buffer_subst_.end() ? GetRef<Buffer>(buffer_node) : itr->second; |
407 | } |
408 | |
409 | /*! |
410 | * \brief Rewrites \p buffer so as to follow the constraints in \p virtual_device |
411 | * (currently just memory scope). |
412 | * |
413 | * Updates both the var_subst_ and buffer_subst_ to capture the rewrite, but |
414 | * also returns the new buffer. |
415 | */ |
416 | Buffer RewriteBuffer(const Buffer& buffer, const VirtualDevice& virtual_device) { |
417 | ICHECK(buffer->data->type_annotation.defined()); |
418 | const auto* pointer_type_node = buffer->data->type_annotation.as<PointerTypeNode>(); |
419 | ICHECK(pointer_type_node); |
420 | if (pointer_type_node->storage_scope == virtual_device->memory_scope) { |
421 | // No change. |
422 | return buffer; |
423 | } |
424 | PointerType new_pointer_type(pointer_type_node->element_type, virtual_device->memory_scope); |
425 | Var new_data(buffer->data->name_hint, new_pointer_type, buffer->data->span); |
426 | var_subst_.emplace(buffer->data.get(), new_data); |
427 | Buffer new_buffer = buffer; |
428 | new_buffer.CopyOnWrite()->data = new_data; |
429 | buffer_subst_.emplace(buffer.get(), new_buffer); |
430 | return new_buffer; |
431 | } |
432 | |
433 | /*! |
434 | * \brief Returns the VirtualDevice capturing any memory scope in \p buffer. Returns nullptr if |
435 | * buffer's data var does not have a type annotation of \p PointerType. Returns the fully |
436 | * unconstrained \p VirtualDevice if no memory scope is given. |
437 | */ |
438 | static Optional<VirtualDevice> GetBufferConstraint(const tir::Buffer& buffer) { |
439 | const auto* pointer_type_node = PointerInBuffer(buffer); |
440 | return pointer_type_node == nullptr |
441 | ? Optional<VirtualDevice>() |
442 | : VirtualDevice::ForMemoryScope(pointer_type_node->storage_scope); |
443 | } |
444 | |
445 | /*! |
446 | * \brief Maps each \p Buffer::data \p Var to its constrained equivalent. |
447 | */ |
448 | std::unordered_map<const VarNode*, Var> var_subst_; |
449 | |
450 | /*! |
451 | * \brief Maps each \p Buffer to its constrained equivalent. |
452 | */ |
453 | std::unordered_map<const BufferNode*, Buffer> buffer_subst_; |
454 | }; |
455 | |
456 | } // namespace |
457 | |
458 | Array<VirtualDevice> GetPrimFuncArgAndResultConstraints(const tir::PrimFunc& prim_func, |
459 | const FuncType& relay_func_type) { |
460 | // Build the implied domain (in terms of the function's Relay type) implied by any memory scope |
461 | // constrains in the function's buffers, for both arguments and results. |
462 | Array<VirtualDevice> virtual_devices; |
463 | virtual_devices.reserve(relay_func_type->arg_types.size() + 1); |
464 | |
465 | // For each Relay function parameter... |
466 | size_t current_primfunc_param_index = 0; |
467 | for (const auto& param_type : relay_func_type->arg_types) { |
468 | VirtualDevice param_virtual_device = |
469 | ConsistentParamConstraint(prim_func, param_type, ¤t_primfunc_param_index); |
470 | virtual_devices.push_back(param_virtual_device); |
471 | } |
472 | |
473 | // For the Relay function result... |
474 | const Type& ret_type = relay_func_type->ret_type; |
475 | VirtualDevice ret_virtual_device = |
476 | ConsistentParamConstraint(prim_func, ret_type, ¤t_primfunc_param_index); |
477 | virtual_devices.push_back(ret_virtual_device); |
478 | |
479 | // Make sure all parameters of the prim_func have been accounted for. |
480 | CheckNoRemainingPointerParams(prim_func, ¤t_primfunc_param_index); |
481 | |
482 | return virtual_devices; |
483 | } |
484 | |
485 | TVM_REGISTER_GLOBAL("tir.analysis.GetPrimFuncArgAndResultMemoryConstraints" ) |
486 | .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type) { |
487 | Array<String> memory_scopes; |
488 | memory_scopes.reserve(relay_func_type->type_params.size() + 1); |
489 | for (const auto& virtual_device : |
490 | GetPrimFuncArgAndResultConstraints(prim_func, relay_func_type)) { |
491 | memory_scopes.push_back(virtual_device->memory_scope); |
492 | } |
493 | return memory_scopes; |
494 | }); |
495 | |
496 | PrimFunc ApplyPrimFuncArgAndResultConstraints( |
497 | const PrimFunc& prim_func, const FuncType& relay_func_type, |
498 | const Array<VirtualDevice>& arg_and_result_virtual_devices) { |
499 | return ApplyDeviceConstraintsMutator().Rewrite(prim_func, relay_func_type, |
500 | arg_and_result_virtual_devices); |
501 | } |
502 | |
503 | TVM_REGISTER_GLOBAL("tir.analysis.ApplyPrimFuncArgAndResultMemoryConstraints" ) |
504 | .set_body_typed([](const PrimFunc& prim_func, const FuncType& relay_func_type, |
505 | const Array<String>& arg_and_result_memory_scopes) { |
506 | Array<VirtualDevice> virtual_devices; |
507 | virtual_devices.reserve(arg_and_result_memory_scopes.size()); |
508 | for (const auto& memory_scope : arg_and_result_memory_scopes) { |
509 | virtual_devices.push_back(VirtualDevice::ForMemoryScope(memory_scope)); |
510 | } |
511 | return ApplyPrimFuncArgAndResultConstraints(prim_func, relay_func_type, virtual_devices); |
512 | }); |
513 | |
514 | } // namespace tir |
515 | } // namespace tvm |
516 | |