1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/arith/int_set.h> |
20 | |
21 | #include "../../transforms/common_subexpr_elim_tools.h" |
22 | #include "../../transforms/replace_selected_expr.h" |
23 | #include "../utils.h" |
24 | |
25 | namespace tvm { |
26 | namespace tir { |
27 | |
28 | /******** Helper Functions/Classes ********/ |
29 | |
30 | /*! \brief The auxiliary info used for the insertion point and content of the cache stage. */ |
31 | struct IndexInfo { |
32 | /*! \brief The target block to perform cache_index */ |
33 | StmtSRef target_block; |
34 | /*! \brief Record the common subexpr extract threshold */ |
35 | size_t cse_thresh; |
36 | /*! \brief The cache buffer to store the precomputed index */ |
37 | std::vector<Buffer> cache_buffer; |
38 | /*! \brief The expr to be precomputed */ |
39 | std::vector<PrimExpr> index_exprs; |
40 | /*! \brief The range of the loop vars relating to index computation */ |
41 | Map<Var, Range> range_map; |
42 | /*! \brief The binding table of the block var and the loop var */ |
43 | Map<Var, PrimExpr> var_binding; |
44 | /*! \brief The block var of the target block */ |
45 | std::vector<Array<Var>> origin_block_vars; |
46 | /*! \brief The index to insert the cache stage. */ |
47 | size_t loc_pos; |
48 | /*! \brief The cache stage to be inserted. */ |
49 | Stmt cache_stage; |
50 | /*! \brief The map used for ScheduleStateNode::Replace. */ |
51 | Map<Block, Block> block_reuse; |
52 | }; |
53 | |
54 | /*! |
55 | * \brief Determine the data type base on the integer range. |
56 | * \param range The range of the integer. |
57 | * \returns A data type that covers the input range. |
58 | */ |
59 | DataType DetermineDatatype(const arith::IntSet& range) { |
60 | arith::Analyzer ana; |
61 | if (ana.CanProve(range.min() >= INT32_MIN && range.max() <= INT32_MAX)) { |
62 | return DataType::Int(32); |
63 | } else { |
64 | ICHECK(ana.CanProve(range.min() >= make_const(DataType::Int(64), INT64_MIN) && |
65 | range.max() <= make_const(DataType::Int(64), INT64_MAX))); |
66 | return DataType::Int(64); |
67 | } |
68 | } |
69 | |
70 | /*! \brief Collect the index info to be cached */ |
71 | class IndexInfoCollector : public StmtExprVisitor { |
72 | public: |
73 | /*! |
74 | * \brief Collect the index info for cache_index and write into the IndexInfo |
75 | * \param self The state of the schedule \param block_sref The sref of the target |
76 | * block of the target buffer being applied cache_index \param scope_sref The sref |
77 | * of the scope block of the target block \param info The index info. |
78 | */ |
79 | static void Collect(const ScheduleState& self, const StmtSRef& block_sref, |
80 | const StmtSRef& scope_sref, IndexInfo* info) { |
81 | IndexInfoCollector collector(self, block_sref, scope_sref, info->cse_thresh); |
82 | collector(GetRef<Stmt>(scope_sref->stmt)); |
83 | info->loc_pos = collector.loc_pos_; |
84 | info->index_exprs = collector.exprs_; |
85 | info->range_map = collector.range_map_; |
86 | } |
87 | |
88 | private: |
89 | /*! |
90 | * \brief Constructor |
91 | * \param self The state of the schedule |
92 | * \param block_sref The sref of the target block of the buffer being applied cache_index |
93 | * \param scope_sref The sref of the scope block of the target block |
94 | * \param cse_thresh The repeat threshold that determines a common subexpr |
95 | */ |
96 | IndexInfoCollector(const ScheduleState self, const StmtSRef& block_sref, |
97 | const StmtSRef& scope_sref, int cse_thresh) |
98 | : self_(self), block_sref_(block_sref), scope_sref_(scope_sref), cse_thresh_(cse_thresh) {} |
99 | |
100 | void VisitStmt_(const SeqStmtNode* seq_stmt) final { |
101 | for (size_t i = 0; i < seq_stmt->size(); ++i) { |
102 | if (loc_pos_ != -1) { |
103 | break; |
104 | } |
105 | VisitStmt(seq_stmt->seq[i]); |
106 | // `pos` can be assigned only once when we visited `block_sref` |
107 | if (visited_block_ && loc_pos_ == -1 && update_seq_pos_) { |
108 | // The offset of insert position from the block |
109 | loc_pos_ = i; |
110 | return; |
111 | } |
112 | } |
113 | } |
114 | |
115 | void VisitStmt_(const BlockNode* block) final { |
116 | visiting_target_block = static_cast<bool>(block_sref_->stmt == block); |
117 | StmtVisitor::VisitStmt_(block); |
118 | visiting_target_block = false; |
119 | if (block == scope_sref_->stmt) { |
120 | // The block vistied is the current parent scope |
121 | // Handling cases when no SeqStmt in the scope |
122 | if (visited_block_ && loc_pos_ == -1) { |
123 | loc_pos_ = 0; |
124 | } |
125 | } else if (block_sref_->stmt == block) { |
126 | visited_block_ = true; |
127 | } |
128 | // Update seq pos only at top scope |
129 | if (visited_block_ && self_->stmt2ref.at(block)->parent == scope_sref_.get()) { |
130 | update_seq_pos_ = true; |
131 | } |
132 | } |
133 | |
134 | void VisitStmt_(const ForNode* loop) final { |
135 | range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent)); |
136 | StmtVisitor::VisitStmt_(loop); |
137 | // Update seq pos only at top scope |
138 | if (visited_block_ && self_->stmt2ref.at(loop)->parent == scope_sref_.get()) { |
139 | update_seq_pos_ = true; |
140 | } |
141 | } |
142 | |
143 | void VisitStmt_(const BufferStoreNode* store) final { |
144 | // Only analyze the cache candidate for stores in target block |
145 | if (visiting_target_block) { |
146 | auto IsEligibleComputation = [](const PrimExpr& expr) { |
147 | return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 && |
148 | (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr)); |
149 | }; |
150 | |
151 | // Analyze sub expr candidates |
152 | ComputationTable table_syntactic_comp_done_by_stmt = |
153 | ComputationsDoneBy::GetComputationsDoneBy(GetRef<Stmt>(store), IsEligibleComputation, |
154 | [](const PrimExpr& expr) { return true; }); |
155 | std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt = |
156 | SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, true); |
157 | |
158 | // Analyze the sub expr of a candidate whose repeat time is under cse_thresh_ |
159 | for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) { |
160 | std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i]; |
161 | if (computation_and_nb.second < cse_thresh_) { |
162 | std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions( |
163 | computation_and_nb.first, IsEligibleComputation, |
164 | [](const PrimExpr& expr) { return true; }); |
165 | InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs, |
166 | true, computation_and_nb.second); |
167 | } |
168 | } |
169 | |
170 | // Record the final sub expr with repeat time greater than cse_thresh_ |
171 | // In order to make the result stable, sort it by post order and then by complexity |
172 | PostOrderVisit(store->value, [&semantic_comp_done_by_stmt, this](const ObjectRef& node) { |
173 | if (node->IsInstance<PrimExprNode>()) { |
174 | PrimExpr this_expr = Downcast<PrimExpr>(node); |
175 | for (auto& it : semantic_comp_done_by_stmt) { |
176 | if (it.second >= this->cse_thresh_ && EquivalentTerms(this_expr, it.first, true)) { |
177 | auto find_result = |
178 | std::find_if(this->exprs_.begin(), this->exprs_.end(), |
179 | [&](PrimExpr expr) { return expr.get() == it.first.get(); }); |
180 | if (find_result == this->exprs_.end()) { |
181 | this->exprs_.push_back(it.first); |
182 | } |
183 | } |
184 | } |
185 | } |
186 | }); |
187 | auto cmp = [&](const PrimExpr& lhs, const PrimExpr& rhs) -> bool { |
188 | return CalculateExprComplexity(lhs) > CalculateExprComplexity(rhs); |
189 | }; |
190 | std::stable_sort(exprs_.begin(), exprs_.end(), cmp); |
191 | } |
192 | StmtVisitor::VisitStmt_(store); |
193 | } |
194 | |
195 | /*! \brief The schedule class */ |
196 | const ScheduleState self_; |
197 | /*! \brief The target block that read the target buffer */ |
198 | const StmtSRef& block_sref_; |
199 | /*! \brief The parent scope of the target block */ |
200 | const StmtSRef& scope_sref_; |
201 | /*! \brief Record the common subexpr extract threshold */ |
202 | size_t cse_thresh_; |
203 | /*! \brief The calculation expr to be precomputed */ |
204 | std::vector<PrimExpr> exprs_; |
205 | /*! \brief The flag whether we have visited the target block */ |
206 | bool visited_block_{false}; |
207 | /*! \brief The flag indicating currently visiting target block */ |
208 | bool visiting_target_block{false}; |
209 | /*! \brief The index to insert the cache_index stage */ |
210 | int loc_pos_{-1}; |
211 | /*! \brief The flag indicating the right scope to update seq pos */ |
212 | bool update_seq_pos_{false}; |
213 | /*! \brief Record the ranges of iter vars */ |
214 | Map<Var, Range> range_map_; |
215 | }; |
216 | |
217 | /*! |
218 | * \brief Create a loop nest that writes precomputed index into index buffer. |
219 | * \param info The cache stage information, which will be updated in the function. |
220 | * \param storage_scope The storage scope of the cached buffer (only used in naming here) |
221 | * \returns A block indicating the body of the loop nesting. |
222 | */ |
223 | Array<Block> MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) { |
224 | Array<Block> blocks; |
225 | Array<Stmt> bodies; |
226 | bodies.reserve(info->index_exprs.size()); |
227 | info->cache_buffer.reserve(info->index_exprs.size()); |
228 | |
229 | // For each index calculation, create a block to pre-compute. |
230 | for (size_t expr_index = 0; expr_index < info->index_exprs.size(); expr_index++) { |
231 | const PrimExpr& index_expr = info->index_exprs[expr_index]; |
232 | |
233 | // Collect the block vars in original index computation |
234 | info->origin_block_vars.push_back({}); |
235 | PostOrderVisit(index_expr, [&info, &expr_index](const ObjectRef& node) { |
236 | if (node->IsInstance<VarNode>()) { |
237 | Var iter_var = Downcast<Var>(node); |
238 | const Array<Var>& origin_block_var = info->origin_block_vars[expr_index]; |
239 | auto find_result = std::find_if(origin_block_var.begin(), origin_block_var.end(), |
240 | [&](Var it) { return it.get() == iter_var.get(); }); |
241 | if (find_result == origin_block_var.end()) { |
242 | info->origin_block_vars[expr_index].push_back(iter_var); |
243 | } |
244 | } |
245 | }); |
246 | |
247 | // Collect the loop vars corresponding to collected block vars, |
248 | // which will be used to create new loop vars |
249 | std::vector<Var> iter_vars; |
250 | for (const Var& it : info->origin_block_vars[expr_index]) { |
251 | PostOrderVisit(info->var_binding.at(it), [/*&info,*/ &iter_vars](const ObjectRef& node) { |
252 | if (node->IsInstance<VarNode>()) { |
253 | Var iter_var = Downcast<Var>(node); |
254 | if (std::find_if(iter_vars.begin(), iter_vars.end(), |
255 | [&](Var it) { return it.get() == iter_var.get(); }) == iter_vars.end()) { |
256 | iter_vars.push_back(iter_var); |
257 | } |
258 | } |
259 | }); |
260 | } |
261 | |
262 | DataType data_type = index_expr.dtype(); |
263 | Var index_buffer_var("index_var_" + std::to_string(expr_index), |
264 | PointerType(PrimType(data_type), storage_scope)); |
265 | Array<PrimExpr> buffer_shape; |
266 | for (const Var& it : info->origin_block_vars[expr_index]) { |
267 | buffer_shape.push_back( |
268 | arith::EvalSet(info->var_binding.at(it), arith::AsIntSet(info->range_map)).max() + 1); |
269 | } |
270 | info->cache_buffer.push_back(Buffer(index_buffer_var, data_type, buffer_shape, {1}, {0}, |
271 | index_buffer_var->name_hint, 0, 0, kDefault)); |
272 | |
273 | // Create loop vars and block vars' binding_value |
274 | std::vector<Var> loop_vars; |
275 | Map<Var, PrimExpr> replace_table; |
276 | for (const Var& it : iter_vars) { |
277 | DataType data_type = DetermineDatatype(arith::IntSet::FromRange(info->range_map.at(it))); |
278 | Var loop_var("ax" + std::to_string(replace_table.size()), data_type); |
279 | loop_vars.push_back(loop_var); |
280 | replace_table.Set(it, loop_var); |
281 | } |
282 | // Create iter_values from the original block. |
283 | std::vector<PrimExpr> iter_values; |
284 | for (const Var& it : info->origin_block_vars[expr_index]) { |
285 | iter_values.push_back(Substitute(info->var_binding.at(it), replace_table)); |
286 | } |
287 | // block variables |
288 | Array<IterVar> block_vars; |
289 | // block access region for write buffers |
290 | Region access_region; |
291 | // indices used in block body |
292 | Array<PrimExpr> access_indices; |
293 | Map<Var, PrimExpr> block_var_map; |
294 | // Create block vars, block's accessed region and accessing indices |
295 | for (size_t i = 0; i < info->origin_block_vars[expr_index].size(); i++) { |
296 | const Var& block_var = info->origin_block_vars[expr_index][i]; |
297 | Var var("v" + std::to_string(access_indices.size()), block_var.dtype()); |
298 | Range range = Range::FromMinExtent(make_zero(block_var.dtype()), |
299 | info->range_map.at(iter_vars[i])->extent); |
300 | block_vars.push_back(IterVar(/*dom=*/range, |
301 | /*var=*/var, |
302 | /*IterVarType=*/kDataPar)); |
303 | |
304 | access_indices.push_back(var); |
305 | access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1))); |
306 | block_var_map.Set(block_var, var); |
307 | } |
308 | |
309 | // Create the index computing block |
310 | PrimExpr new_expr = Substitute(index_expr, block_var_map); |
311 | Block block( |
312 | /*iter_vars=*/std::move(block_vars), |
313 | /*reads=*/{}, |
314 | /*writes=*/{BufferRegion(info->cache_buffer[expr_index], access_region)}, |
315 | /*name_hint=*/"index_" + std::to_string(expr_index), |
316 | /*body=*/ |
317 | BufferStore(info->cache_buffer[expr_index], new_expr, access_indices), |
318 | /*init=*/NullOpt, |
319 | /*alloc_buffers=*/{}, |
320 | /*match_buffers=*/{}, |
321 | /*annotations=*/{}); |
322 | blocks.push_back(block); |
323 | // Create the block realize node |
324 | Stmt body = BlockRealize(/*values=*/iter_values, |
325 | /*predicate=*/const_true(), |
326 | /*block=*/block); |
327 | // Create surrounding loops |
328 | for (size_t i = loop_vars.size(); i >= 1; --i) { |
329 | body = For(/*loop_var=*/loop_vars[i - 1], |
330 | /*min=*/0, |
331 | /*extent=*/info->range_map.at(iter_vars[i - 1])->extent, |
332 | /*kind=*/ForKind::kSerial, |
333 | /*body=*/body); |
334 | } |
335 | bodies.push_back(body); |
336 | } |
337 | |
338 | info->cache_stage = SeqStmt(bodies); |
339 | return blocks; |
340 | } |
341 | |
342 | /*! |
343 | * \brief Insert the cache stages into the specific position |
344 | * \param stmt A sequence of statements or a single statement that the new stage is inserted in |
345 | * \param pos The position where the cache stage is inserted |
346 | * \param stage The stage to be inserted |
347 | * \return A SeqStmt, the result after insertion |
348 | */ |
349 | Stmt InsertIndexStage(const Stmt& stmt, int pos, const Stmt& stage) { |
350 | if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) { |
351 | ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt); |
352 | result->seq.insert(result->seq.begin() + pos, stage); |
353 | return SeqStmt(result); |
354 | } |
355 | if (pos == 0) { |
356 | return SeqStmt::Flatten<Array<Stmt>>({stage, stmt}); |
357 | } |
358 | ICHECK_EQ(pos, 1); |
359 | return SeqStmt::Flatten<Array<Stmt>>({stmt, stage}); |
360 | } |
361 | |
362 | /*! \brief Mutator for CacheIndex. */ |
363 | class CacheIndexRewriter : public StmtExprMutator { |
364 | public: |
365 | /*! |
366 | * \brief Rewrite the AST and add stages of writting precomputed index |
367 | * \param scope_sref The parent scope of this mutation |
368 | * \param info The index information |
369 | * \return The new AST rooting at the original parent scope |
370 | */ |
371 | static Stmt Rewrite(const StmtSRef& scope_sref, IndexInfo* info) { |
372 | CacheIndexRewriter rewriter(scope_sref, info); |
373 | return rewriter(GetRef<Stmt>(scope_sref->stmt)); |
374 | } |
375 | |
376 | private: |
377 | explicit CacheIndexRewriter(const StmtSRef& scope_sref, IndexInfo* info) |
378 | : scope_sref_(scope_sref), info_(info) { |
379 | cache_indices_.reserve(info_->origin_block_vars.size()); |
380 | for (const Array<Var>& group_it : info_->origin_block_vars) { |
381 | cache_indices_.push_back({}); |
382 | for (const Var& it : group_it) { |
383 | cache_indices_.back().push_back(it); |
384 | } |
385 | } |
386 | } |
387 | |
388 | Stmt VisitStmt_(const BlockNode* block) final { |
389 | Block old_stmt = GetRef<Block>(block); |
390 | // Mutate the body |
391 | visiting_target_block = static_cast<bool>(block == info_->target_block->stmt); |
392 | Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block)); |
393 | visiting_target_block = false; |
394 | |
395 | // Check if it is the block corresponding to the parent scope |
396 | if (block == scope_sref_->stmt) { |
397 | // If so, put buffer allocation and insert cache stages on the parent scope |
398 | ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>()); |
399 | n->body = InsertIndexStage(n->body, info_->loc_pos, info_->cache_stage); |
400 | for (const Buffer& it : info_->cache_buffer) { |
401 | n->alloc_buffers.push_back(it); |
402 | } |
403 | stmt = Block(n); |
404 | } |
405 | info_->block_reuse.Set(old_stmt, stmt); |
406 | return std::move(stmt); |
407 | } |
408 | |
409 | Stmt VisitStmt_(const BufferStoreNode* store) final { |
410 | Stmt ret_stmt = StmtMutator::VisitStmt_(store); |
411 | // Replace common sub expr for target block, with cached buffer load |
412 | if (visiting_target_block) { |
413 | for (size_t i = 0; i < info_->index_exprs.size(); i++) { |
414 | PrimExpr& computation = info_->index_exprs[i]; |
415 | std::function<bool(const PrimExpr&)> predicate_selector = |
416 | [computation](const PrimExpr& current_expr) { |
417 | return (EquivalentTerms(current_expr, computation, true)); |
418 | }; |
419 | BufferLoad load = BufferLoad(info_->cache_buffer[i], cache_indices_[i]); |
420 | ret_stmt = ReplaceSelectedExpr::ReplaceSelectedExprInStmt( |
421 | ret_stmt, predicate_selector, std::move(load), |
422 | [](const PrimExpr& expr) { return true; }); |
423 | } |
424 | } |
425 | return ret_stmt; |
426 | } |
427 | |
428 | PrimExpr VisitExpr_(const LoadNode* op) final { |
429 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
430 | } |
431 | |
432 | private: |
433 | /*! \brief The parent scope of the insertion */ |
434 | const StmtSRef& scope_sref_; |
435 | /*! \brief The info for inserting cache stage */ |
436 | IndexInfo* info_; |
437 | /*! \brief The indices for the cache buffer */ |
438 | std::vector<Array<PrimExpr>> cache_indices_; |
439 | /*! \brief Indicating whether cache stage is inserted, only do index replacement afterwards*/ |
440 | bool visiting_target_block{false}; |
441 | }; |
442 | |
443 | Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref, |
444 | const String& storage_scope, int cse_thresh) { |
445 | /*! |
446 | * Check: |
447 | * - The index is in the array of block reading region |
448 | * |
449 | * Mutate: |
450 | * - Allocate new cache buffers under the current scope. |
451 | * - Precompute the index and store it in cache buffers. |
452 | */ |
453 | |
454 | // Step 0. Checking index, getting the target buffer and the parent scope |
455 | IndexInfo info; |
456 | info.target_block = block_sref; |
457 | CHECK_GE(cse_thresh, 0) << "cse_thresh should not be negative number" ; |
458 | info.cse_thresh = cse_thresh; |
459 | StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false); |
460 | |
461 | // Step 1. Collect the indexing info of target buffer. |
462 | IndexInfoCollector::Collect(self, block_sref, scope_sref, &info); |
463 | |
464 | // Step 2. Create cache stages and rewrite the stmt. |
465 | BlockRealize realize = GetBlockRealize(self, block_sref); |
466 | info.var_binding = GetBindings(realize); |
467 | Array<Block> cache_stages = MakeIndexCacheStage(&info, storage_scope); |
468 | Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info); |
469 | |
470 | bool old_stage_pipeline = self->block_info[block_sref].scope->stage_pipeline; |
471 | |
472 | // Step 3. Replacing and updating flags. |
473 | self->Replace(scope_sref, new_scope, info.block_reuse); |
474 | Array<StmtSRef> result_block_srefs; |
475 | for (const Block& it : cache_stages) { |
476 | StmtSRef result_block_sref = self->stmt2ref.at(it.get()); |
477 | result_block_srefs.push_back(result_block_sref); |
478 | BlockInfo& block_info = self->block_info[result_block_sref]; |
479 | |
480 | bool affine_binding = false; |
481 | if (result_block_sref->parent == nullptr) { |
482 | affine_binding = true; |
483 | } else { |
484 | arith::Analyzer analyzer; |
485 | StmtSRef parent_sref = GetRef<StmtSRef>(result_block_sref->parent); |
486 | affine_binding = IsAffineBinding(/*realize=*/GetBlockRealize(self, result_block_sref), |
487 | /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref), |
488 | /*analyzer=*/&analyzer); |
489 | } |
490 | |
491 | block_info.affine_binding = affine_binding; |
492 | block_info.region_cover = true; |
493 | block_info.scope->stage_pipeline = old_stage_pipeline; |
494 | } |
495 | |
496 | return result_block_srefs; |
497 | } |
498 | |
499 | /******** InstructionKind Registration ********/ |
500 | |
501 | struct : public UnpackedInstTraits<CacheIndexTraits> { |
502 | static constexpr const char* = "CacheIndex" ; |
503 | static constexpr bool = false; |
504 | |
505 | private: |
506 | static constexpr size_t = 1; |
507 | static constexpr size_t = 2; |
508 | static constexpr size_t = 0; |
509 | |
510 | static Array<BlockRV> (Schedule sch, BlockRV block, String storage_scope, |
511 | Integer cse_thresh) { |
512 | return sch->CacheIndex(block, storage_scope, cse_thresh->value); |
513 | } |
514 | |
515 | static String (Array<String> outputs, String block, String storage_scope, |
516 | Integer cse_thresh) { |
517 | PythonAPICall py("cache_index" ); |
518 | py.Input("block" , block); |
519 | py.Input("storage_scope" , storage_scope); |
520 | py.Input("cse_thresh" , cse_thresh->value); |
521 | py.OutputList(outputs); |
522 | return py.Str(); |
523 | } |
524 | |
525 | template <typename> |
526 | friend struct ::tvm::tir::UnpackedInstTraits; |
527 | }; |
528 | |
529 | TVM_REGISTER_INST_KIND_TRAITS(CacheIndexTraits); |
530 | |
531 | } // namespace tir |
532 | } // namespace tvm |
533 | |