1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/arith/int_set.h>
20
21#include "../../transforms/common_subexpr_elim_tools.h"
22#include "../../transforms/replace_selected_expr.h"
23#include "../utils.h"
24
25namespace tvm {
26namespace tir {
27
28/******** Helper Functions/Classes ********/
29
30/*! \brief The auxiliary info used for the insertion point and content of the cache stage. */
31struct IndexInfo {
32 /*! \brief The target block to perform cache_index */
33 StmtSRef target_block;
34 /*! \brief Record the common subexpr extract threshold */
35 size_t cse_thresh;
36 /*! \brief The cache buffer to store the precomputed index */
37 std::vector<Buffer> cache_buffer;
38 /*! \brief The expr to be precomputed */
39 std::vector<PrimExpr> index_exprs;
40 /*! \brief The range of the loop vars relating to index computation */
41 Map<Var, Range> range_map;
42 /*! \brief The binding table of the block var and the loop var */
43 Map<Var, PrimExpr> var_binding;
44 /*! \brief The block var of the target block */
45 std::vector<Array<Var>> origin_block_vars;
46 /*! \brief The index to insert the cache stage. */
47 size_t loc_pos;
48 /*! \brief The cache stage to be inserted. */
49 Stmt cache_stage;
50 /*! \brief The map used for ScheduleStateNode::Replace. */
51 Map<Block, Block> block_reuse;
52};
53
54/*!
55 * \brief Determine the data type base on the integer range.
56 * \param range The range of the integer.
57 * \returns A data type that covers the input range.
58 */
59DataType DetermineDatatype(const arith::IntSet& range) {
60 arith::Analyzer ana;
61 if (ana.CanProve(range.min() >= INT32_MIN && range.max() <= INT32_MAX)) {
62 return DataType::Int(32);
63 } else {
64 ICHECK(ana.CanProve(range.min() >= make_const(DataType::Int(64), INT64_MIN) &&
65 range.max() <= make_const(DataType::Int(64), INT64_MAX)));
66 return DataType::Int(64);
67 }
68}
69
70/*! \brief Collect the index info to be cached */
71class IndexInfoCollector : public StmtExprVisitor {
72 public:
73 /*!
74 * \brief Collect the index info for cache_index and write into the IndexInfo
75 * \param self The state of the schedule \param block_sref The sref of the target
76 * block of the target buffer being applied cache_index \param scope_sref The sref
77 * of the scope block of the target block \param info The index info.
78 */
79 static void Collect(const ScheduleState& self, const StmtSRef& block_sref,
80 const StmtSRef& scope_sref, IndexInfo* info) {
81 IndexInfoCollector collector(self, block_sref, scope_sref, info->cse_thresh);
82 collector(GetRef<Stmt>(scope_sref->stmt));
83 info->loc_pos = collector.loc_pos_;
84 info->index_exprs = collector.exprs_;
85 info->range_map = collector.range_map_;
86 }
87
88 private:
89 /*!
90 * \brief Constructor
91 * \param self The state of the schedule
92 * \param block_sref The sref of the target block of the buffer being applied cache_index
93 * \param scope_sref The sref of the scope block of the target block
94 * \param cse_thresh The repeat threshold that determines a common subexpr
95 */
96 IndexInfoCollector(const ScheduleState self, const StmtSRef& block_sref,
97 const StmtSRef& scope_sref, int cse_thresh)
98 : self_(self), block_sref_(block_sref), scope_sref_(scope_sref), cse_thresh_(cse_thresh) {}
99
100 void VisitStmt_(const SeqStmtNode* seq_stmt) final {
101 for (size_t i = 0; i < seq_stmt->size(); ++i) {
102 if (loc_pos_ != -1) {
103 break;
104 }
105 VisitStmt(seq_stmt->seq[i]);
106 // `pos` can be assigned only once when we visited `block_sref`
107 if (visited_block_ && loc_pos_ == -1 && update_seq_pos_) {
108 // The offset of insert position from the block
109 loc_pos_ = i;
110 return;
111 }
112 }
113 }
114
115 void VisitStmt_(const BlockNode* block) final {
116 visiting_target_block = static_cast<bool>(block_sref_->stmt == block);
117 StmtVisitor::VisitStmt_(block);
118 visiting_target_block = false;
119 if (block == scope_sref_->stmt) {
120 // The block vistied is the current parent scope
121 // Handling cases when no SeqStmt in the scope
122 if (visited_block_ && loc_pos_ == -1) {
123 loc_pos_ = 0;
124 }
125 } else if (block_sref_->stmt == block) {
126 visited_block_ = true;
127 }
128 // Update seq pos only at top scope
129 if (visited_block_ && self_->stmt2ref.at(block)->parent == scope_sref_.get()) {
130 update_seq_pos_ = true;
131 }
132 }
133
134 void VisitStmt_(const ForNode* loop) final {
135 range_map_.Set(loop->loop_var, Range::FromMinExtent(loop->min, loop->extent));
136 StmtVisitor::VisitStmt_(loop);
137 // Update seq pos only at top scope
138 if (visited_block_ && self_->stmt2ref.at(loop)->parent == scope_sref_.get()) {
139 update_seq_pos_ = true;
140 }
141 }
142
143 void VisitStmt_(const BufferStoreNode* store) final {
144 // Only analyze the cache candidate for stores in target block
145 if (visiting_target_block) {
146 auto IsEligibleComputation = [](const PrimExpr& expr) {
147 return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 &&
148 (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
149 };
150
151 // Analyze sub expr candidates
152 ComputationTable table_syntactic_comp_done_by_stmt =
153 ComputationsDoneBy::GetComputationsDoneBy(GetRef<Stmt>(store), IsEligibleComputation,
154 [](const PrimExpr& expr) { return true; });
155 std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_stmt =
156 SyntacticToSemanticComputations(table_syntactic_comp_done_by_stmt, true);
157
158 // Analyze the sub expr of a candidate whose repeat time is under cse_thresh_
159 for (size_t i = 0; i < semantic_comp_done_by_stmt.size(); i++) {
160 std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_stmt[i];
161 if (computation_and_nb.second < cse_thresh_) {
162 std::vector<PrimExpr> direct_subexprs = DirectSubexpr::GetDirectSubexpressions(
163 computation_and_nb.first, IsEligibleComputation,
164 [](const PrimExpr& expr) { return true; });
165 InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_stmt, direct_subexprs,
166 true, computation_and_nb.second);
167 }
168 }
169
170 // Record the final sub expr with repeat time greater than cse_thresh_
171 // In order to make the result stable, sort it by post order and then by complexity
172 PostOrderVisit(store->value, [&semantic_comp_done_by_stmt, this](const ObjectRef& node) {
173 if (node->IsInstance<PrimExprNode>()) {
174 PrimExpr this_expr = Downcast<PrimExpr>(node);
175 for (auto& it : semantic_comp_done_by_stmt) {
176 if (it.second >= this->cse_thresh_ && EquivalentTerms(this_expr, it.first, true)) {
177 auto find_result =
178 std::find_if(this->exprs_.begin(), this->exprs_.end(),
179 [&](PrimExpr expr) { return expr.get() == it.first.get(); });
180 if (find_result == this->exprs_.end()) {
181 this->exprs_.push_back(it.first);
182 }
183 }
184 }
185 }
186 });
187 auto cmp = [&](const PrimExpr& lhs, const PrimExpr& rhs) -> bool {
188 return CalculateExprComplexity(lhs) > CalculateExprComplexity(rhs);
189 };
190 std::stable_sort(exprs_.begin(), exprs_.end(), cmp);
191 }
192 StmtVisitor::VisitStmt_(store);
193 }
194
195 /*! \brief The schedule class */
196 const ScheduleState self_;
197 /*! \brief The target block that read the target buffer */
198 const StmtSRef& block_sref_;
199 /*! \brief The parent scope of the target block */
200 const StmtSRef& scope_sref_;
201 /*! \brief Record the common subexpr extract threshold */
202 size_t cse_thresh_;
203 /*! \brief The calculation expr to be precomputed */
204 std::vector<PrimExpr> exprs_;
205 /*! \brief The flag whether we have visited the target block */
206 bool visited_block_{false};
207 /*! \brief The flag indicating currently visiting target block */
208 bool visiting_target_block{false};
209 /*! \brief The index to insert the cache_index stage */
210 int loc_pos_{-1};
211 /*! \brief The flag indicating the right scope to update seq pos */
212 bool update_seq_pos_{false};
213 /*! \brief Record the ranges of iter vars */
214 Map<Var, Range> range_map_;
215};
216
217/*!
218 * \brief Create a loop nest that writes precomputed index into index buffer.
219 * \param info The cache stage information, which will be updated in the function.
220 * \param storage_scope The storage scope of the cached buffer (only used in naming here)
221 * \returns A block indicating the body of the loop nesting.
222 */
223Array<Block> MakeIndexCacheStage(IndexInfo* info, const String& storage_scope) {
224 Array<Block> blocks;
225 Array<Stmt> bodies;
226 bodies.reserve(info->index_exprs.size());
227 info->cache_buffer.reserve(info->index_exprs.size());
228
229 // For each index calculation, create a block to pre-compute.
230 for (size_t expr_index = 0; expr_index < info->index_exprs.size(); expr_index++) {
231 const PrimExpr& index_expr = info->index_exprs[expr_index];
232
233 // Collect the block vars in original index computation
234 info->origin_block_vars.push_back({});
235 PostOrderVisit(index_expr, [&info, &expr_index](const ObjectRef& node) {
236 if (node->IsInstance<VarNode>()) {
237 Var iter_var = Downcast<Var>(node);
238 const Array<Var>& origin_block_var = info->origin_block_vars[expr_index];
239 auto find_result = std::find_if(origin_block_var.begin(), origin_block_var.end(),
240 [&](Var it) { return it.get() == iter_var.get(); });
241 if (find_result == origin_block_var.end()) {
242 info->origin_block_vars[expr_index].push_back(iter_var);
243 }
244 }
245 });
246
247 // Collect the loop vars corresponding to collected block vars,
248 // which will be used to create new loop vars
249 std::vector<Var> iter_vars;
250 for (const Var& it : info->origin_block_vars[expr_index]) {
251 PostOrderVisit(info->var_binding.at(it), [/*&info,*/ &iter_vars](const ObjectRef& node) {
252 if (node->IsInstance<VarNode>()) {
253 Var iter_var = Downcast<Var>(node);
254 if (std::find_if(iter_vars.begin(), iter_vars.end(),
255 [&](Var it) { return it.get() == iter_var.get(); }) == iter_vars.end()) {
256 iter_vars.push_back(iter_var);
257 }
258 }
259 });
260 }
261
262 DataType data_type = index_expr.dtype();
263 Var index_buffer_var("index_var_" + std::to_string(expr_index),
264 PointerType(PrimType(data_type), storage_scope));
265 Array<PrimExpr> buffer_shape;
266 for (const Var& it : info->origin_block_vars[expr_index]) {
267 buffer_shape.push_back(
268 arith::EvalSet(info->var_binding.at(it), arith::AsIntSet(info->range_map)).max() + 1);
269 }
270 info->cache_buffer.push_back(Buffer(index_buffer_var, data_type, buffer_shape, {1}, {0},
271 index_buffer_var->name_hint, 0, 0, kDefault));
272
273 // Create loop vars and block vars' binding_value
274 std::vector<Var> loop_vars;
275 Map<Var, PrimExpr> replace_table;
276 for (const Var& it : iter_vars) {
277 DataType data_type = DetermineDatatype(arith::IntSet::FromRange(info->range_map.at(it)));
278 Var loop_var("ax" + std::to_string(replace_table.size()), data_type);
279 loop_vars.push_back(loop_var);
280 replace_table.Set(it, loop_var);
281 }
282 // Create iter_values from the original block.
283 std::vector<PrimExpr> iter_values;
284 for (const Var& it : info->origin_block_vars[expr_index]) {
285 iter_values.push_back(Substitute(info->var_binding.at(it), replace_table));
286 }
287 // block variables
288 Array<IterVar> block_vars;
289 // block access region for write buffers
290 Region access_region;
291 // indices used in block body
292 Array<PrimExpr> access_indices;
293 Map<Var, PrimExpr> block_var_map;
294 // Create block vars, block's accessed region and accessing indices
295 for (size_t i = 0; i < info->origin_block_vars[expr_index].size(); i++) {
296 const Var& block_var = info->origin_block_vars[expr_index][i];
297 Var var("v" + std::to_string(access_indices.size()), block_var.dtype());
298 Range range = Range::FromMinExtent(make_zero(block_var.dtype()),
299 info->range_map.at(iter_vars[i])->extent);
300 block_vars.push_back(IterVar(/*dom=*/range,
301 /*var=*/var,
302 /*IterVarType=*/kDataPar));
303
304 access_indices.push_back(var);
305 access_region.push_back(Range::FromMinExtent(var, make_const(var.dtype(), 1)));
306 block_var_map.Set(block_var, var);
307 }
308
309 // Create the index computing block
310 PrimExpr new_expr = Substitute(index_expr, block_var_map);
311 Block block(
312 /*iter_vars=*/std::move(block_vars),
313 /*reads=*/{},
314 /*writes=*/{BufferRegion(info->cache_buffer[expr_index], access_region)},
315 /*name_hint=*/"index_" + std::to_string(expr_index),
316 /*body=*/
317 BufferStore(info->cache_buffer[expr_index], new_expr, access_indices),
318 /*init=*/NullOpt,
319 /*alloc_buffers=*/{},
320 /*match_buffers=*/{},
321 /*annotations=*/{});
322 blocks.push_back(block);
323 // Create the block realize node
324 Stmt body = BlockRealize(/*values=*/iter_values,
325 /*predicate=*/const_true(),
326 /*block=*/block);
327 // Create surrounding loops
328 for (size_t i = loop_vars.size(); i >= 1; --i) {
329 body = For(/*loop_var=*/loop_vars[i - 1],
330 /*min=*/0,
331 /*extent=*/info->range_map.at(iter_vars[i - 1])->extent,
332 /*kind=*/ForKind::kSerial,
333 /*body=*/body);
334 }
335 bodies.push_back(body);
336 }
337
338 info->cache_stage = SeqStmt(bodies);
339 return blocks;
340}
341
342/*!
343 * \brief Insert the cache stages into the specific position
344 * \param stmt A sequence of statements or a single statement that the new stage is inserted in
345 * \param pos The position where the cache stage is inserted
346 * \param stage The stage to be inserted
347 * \return A SeqStmt, the result after insertion
348 */
349Stmt InsertIndexStage(const Stmt& stmt, int pos, const Stmt& stage) {
350 if (const auto* seq_stmt = stmt.as<SeqStmtNode>()) {
351 ObjectPtr<SeqStmtNode> result = make_object<SeqStmtNode>(*seq_stmt);
352 result->seq.insert(result->seq.begin() + pos, stage);
353 return SeqStmt(result);
354 }
355 if (pos == 0) {
356 return SeqStmt::Flatten<Array<Stmt>>({stage, stmt});
357 }
358 ICHECK_EQ(pos, 1);
359 return SeqStmt::Flatten<Array<Stmt>>({stmt, stage});
360}
361
362/*! \brief Mutator for CacheIndex. */
363class CacheIndexRewriter : public StmtExprMutator {
364 public:
365 /*!
366 * \brief Rewrite the AST and add stages of writting precomputed index
367 * \param scope_sref The parent scope of this mutation
368 * \param info The index information
369 * \return The new AST rooting at the original parent scope
370 */
371 static Stmt Rewrite(const StmtSRef& scope_sref, IndexInfo* info) {
372 CacheIndexRewriter rewriter(scope_sref, info);
373 return rewriter(GetRef<Stmt>(scope_sref->stmt));
374 }
375
376 private:
377 explicit CacheIndexRewriter(const StmtSRef& scope_sref, IndexInfo* info)
378 : scope_sref_(scope_sref), info_(info) {
379 cache_indices_.reserve(info_->origin_block_vars.size());
380 for (const Array<Var>& group_it : info_->origin_block_vars) {
381 cache_indices_.push_back({});
382 for (const Var& it : group_it) {
383 cache_indices_.back().push_back(it);
384 }
385 }
386 }
387
388 Stmt VisitStmt_(const BlockNode* block) final {
389 Block old_stmt = GetRef<Block>(block);
390 // Mutate the body
391 visiting_target_block = static_cast<bool>(block == info_->target_block->stmt);
392 Block stmt = Downcast<Block>(StmtMutator::VisitStmt_(block));
393 visiting_target_block = false;
394
395 // Check if it is the block corresponding to the parent scope
396 if (block == scope_sref_->stmt) {
397 // If so, put buffer allocation and insert cache stages on the parent scope
398 ObjectPtr<BlockNode> n = make_object<BlockNode>(*stmt.as<BlockNode>());
399 n->body = InsertIndexStage(n->body, info_->loc_pos, info_->cache_stage);
400 for (const Buffer& it : info_->cache_buffer) {
401 n->alloc_buffers.push_back(it);
402 }
403 stmt = Block(n);
404 }
405 info_->block_reuse.Set(old_stmt, stmt);
406 return std::move(stmt);
407 }
408
409 Stmt VisitStmt_(const BufferStoreNode* store) final {
410 Stmt ret_stmt = StmtMutator::VisitStmt_(store);
411 // Replace common sub expr for target block, with cached buffer load
412 if (visiting_target_block) {
413 for (size_t i = 0; i < info_->index_exprs.size(); i++) {
414 PrimExpr& computation = info_->index_exprs[i];
415 std::function<bool(const PrimExpr&)> predicate_selector =
416 [computation](const PrimExpr& current_expr) {
417 return (EquivalentTerms(current_expr, computation, true));
418 };
419 BufferLoad load = BufferLoad(info_->cache_buffer[i], cache_indices_[i]);
420 ret_stmt = ReplaceSelectedExpr::ReplaceSelectedExprInStmt(
421 ret_stmt, predicate_selector, std::move(load),
422 [](const PrimExpr& expr) { return true; });
423 }
424 }
425 return ret_stmt;
426 }
427
428 PrimExpr VisitExpr_(const LoadNode* op) final {
429 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
430 }
431
432 private:
433 /*! \brief The parent scope of the insertion */
434 const StmtSRef& scope_sref_;
435 /*! \brief The info for inserting cache stage */
436 IndexInfo* info_;
437 /*! \brief The indices for the cache buffer */
438 std::vector<Array<PrimExpr>> cache_indices_;
439 /*! \brief Indicating whether cache stage is inserted, only do index replacement afterwards*/
440 bool visiting_target_block{false};
441};
442
443Array<StmtSRef> CacheIndex(ScheduleState self, const StmtSRef& block_sref,
444 const String& storage_scope, int cse_thresh) {
445 /*!
446 * Check:
447 * - The index is in the array of block reading region
448 *
449 * Mutate:
450 * - Allocate new cache buffers under the current scope.
451 * - Precompute the index and store it in cache buffers.
452 */
453
454 // Step 0. Checking index, getting the target buffer and the parent scope
455 IndexInfo info;
456 info.target_block = block_sref;
457 CHECK_GE(cse_thresh, 0) << "cse_thresh should not be negative number";
458 info.cse_thresh = cse_thresh;
459 StmtSRef scope_sref = GetScopeRoot(self, block_sref, /*require_stage_pipeline=*/false);
460
461 // Step 1. Collect the indexing info of target buffer.
462 IndexInfoCollector::Collect(self, block_sref, scope_sref, &info);
463
464 // Step 2. Create cache stages and rewrite the stmt.
465 BlockRealize realize = GetBlockRealize(self, block_sref);
466 info.var_binding = GetBindings(realize);
467 Array<Block> cache_stages = MakeIndexCacheStage(&info, storage_scope);
468 Stmt new_scope = CacheIndexRewriter::Rewrite(/*scope_sref=*/scope_sref, /*info=*/&info);
469
470 bool old_stage_pipeline = self->block_info[block_sref].scope->stage_pipeline;
471
472 // Step 3. Replacing and updating flags.
473 self->Replace(scope_sref, new_scope, info.block_reuse);
474 Array<StmtSRef> result_block_srefs;
475 for (const Block& it : cache_stages) {
476 StmtSRef result_block_sref = self->stmt2ref.at(it.get());
477 result_block_srefs.push_back(result_block_sref);
478 BlockInfo& block_info = self->block_info[result_block_sref];
479
480 bool affine_binding = false;
481 if (result_block_sref->parent == nullptr) {
482 affine_binding = true;
483 } else {
484 arith::Analyzer analyzer;
485 StmtSRef parent_sref = GetRef<StmtSRef>(result_block_sref->parent);
486 affine_binding = IsAffineBinding(/*realize=*/GetBlockRealize(self, result_block_sref),
487 /*loop_var_ranges=*/LoopDomainOfSRefTreePath(parent_sref),
488 /*analyzer=*/&analyzer);
489 }
490
491 block_info.affine_binding = affine_binding;
492 block_info.region_cover = true;
493 block_info.scope->stage_pipeline = old_stage_pipeline;
494 }
495
496 return result_block_srefs;
497}
498
499/******** InstructionKind Registration ********/
500
501struct CacheIndexTraits : public UnpackedInstTraits<CacheIndexTraits> {
502 static constexpr const char* kName = "CacheIndex";
503 static constexpr bool kIsPure = false;
504
505 private:
506 static constexpr size_t kNumInputs = 1;
507 static constexpr size_t kNumAttrs = 2;
508 static constexpr size_t kNumDecisions = 0;
509
510 static Array<BlockRV> UnpackedApplyToSchedule(Schedule sch, BlockRV block, String storage_scope,
511 Integer cse_thresh) {
512 return sch->CacheIndex(block, storage_scope, cse_thresh->value);
513 }
514
515 static String UnpackedAsPython(Array<String> outputs, String block, String storage_scope,
516 Integer cse_thresh) {
517 PythonAPICall py("cache_index");
518 py.Input("block", block);
519 py.Input("storage_scope", storage_scope);
520 py.Input("cse_thresh", cse_thresh->value);
521 py.OutputList(outputs);
522 return py.Str();
523 }
524
525 template <typename>
526 friend struct ::tvm::tir::UnpackedInstTraits;
527};
528
529TVM_REGISTER_INST_KIND_TRAITS(CacheIndexTraits);
530
531} // namespace tir
532} // namespace tvm
533