1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arith/iter_affine_map.h
22 * \brief Iterator quasi-affine mapping patterns.
23 *
24 * This file defines a collection of mapping patterns
25 * maps a collection of independent iterators to another
26 * collection of independent iterators.
27 *
28 * There are two main kinds of mapping patterns:
29 *
30 * - Fuse: fuse a collection of iterators into a single one
31 *
32 * domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2)
33 * fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0
34 * domain(y) = [0, 24)
35 *
36 * - Split: split an iterator into multiple ones
37 *
38 * domain(x) = [0, 24)
39 * split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12]
40 * domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2)
41 *
42 * We use the name "(quasi)affine" to be consistent with
43 * the terminology used in the polyhedral compilation.
44 * Notably, fuse is an affine transformation,
45 * while split corresponds to additional floordiv/mod operations
46 * that can appear in quasi-affine transformations.
47 */
48#ifndef TVM_ARITH_ITER_AFFINE_MAP_H_
49#define TVM_ARITH_ITER_AFFINE_MAP_H_
50
51#include <tvm/arith/analyzer.h>
52#include <tvm/ir/diagnostic.h>
53#include <tvm/ir/expr.h>
54#include <tvm/tir/var.h>
55
56namespace tvm {
57namespace arith {
58
59/*!
60 * \brief Base class of all iter map expressions.
61 *
62 * An IterMapExpr is a special expression to store
63 * the result of IterMapDetection.
64 * It should not appear in a legal TIR PrimFunc.
65 */
66class IterMapExprNode : public PrimExprNode {
67 public:
68 // overrides
69 void VisitAttrs(tvm::AttrVisitor* v) {}
70
71 static constexpr const char* _type_key = "arith.IterMapExpr";
72 static constexpr const uint32_t _type_child_slots = 3;
73 TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode);
74};
75
76/*!
77 * \brief Managed reference to IterMapExprNode.
78 * \sa IterMapExprNode
79 */
80class IterMapExpr : public PrimExpr {
81 public:
82 TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode);
83};
84
85/*!
86 * \brief Mark the source as an iterator in [0, extent).
87 *
88 * IterMark is used to mark source expression as a valid
89 * iterator to make future analysis easy.
90 */
91class IterMarkNode : public Object {
92 public:
93 /*!
94 * \brief The source expression, can either be
95 * a IterSumExpr or a Var.
96 */
97 PrimExpr source;
98 /*!
99 * \brief The extent of the iteration.
100 */
101 PrimExpr extent;
102
103 // overrides
104 void VisitAttrs(tvm::AttrVisitor* v) {
105 v->Visit("source", &source);
106 v->Visit("extent", &extent);
107 }
108
109 bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const {
110 equal->MarkGraphNode();
111 return equal(source, other->source) && equal(extent, other->extent);
112 }
113
114 void SHashReduce(SHashReducer hash_reduce) const {
115 hash_reduce->MarkGraphNode();
116 hash_reduce(source);
117 hash_reduce(extent);
118 }
119
120 static constexpr const bool _type_has_method_sequal_reduce = true;
121 static constexpr const bool _type_has_method_shash_reduce = true;
122 static constexpr const char* _type_key = "arith.IterMark";
123 TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object);
124};
125
126/*!
127 * \brief Managed reference to IterMarkExprNode.
128 * \sa IterMarkExprNode
129 */
130class IterMark : public ObjectRef {
131 public:
132 /*!
133 * \brief constructor.
134 * \param source The source expression.
135 * \param extent The extent of the iterator.
136 */
137 TVM_DLL IterMark(PrimExpr source, PrimExpr extent);
138
139 TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode);
140 TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMarkNode);
141};
142
143/*!
144 * \brief Split of an iterator.
145 *
146 * result = floormod(floordiv(source, lower_factor), extent) * scale
147 */
148class IterSplitExprNode : public IterMapExprNode {
149 public:
150 /*! \brief The source marked iterator. */
151 IterMark source;
152 /*! \brief The lower factor to split the source. */
153 PrimExpr lower_factor;
154 /*! \brief The extent of the split. */
155 PrimExpr extent;
156 /*! \brief Additional scale. */
157 PrimExpr scale;
158
159 // overrides
160 void VisitAttrs(tvm::AttrVisitor* v) {
161 v->Visit("source", &source);
162 v->Visit("lower_factor", &lower_factor);
163 v->Visit("extent", &extent);
164 v->Visit("scale", &scale);
165 }
166
167 bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const {
168 return equal(source, other->source) && equal(lower_factor, other->lower_factor) &&
169 equal(extent, other->extent) && equal(scale, other->scale);
170 }
171
172 void SHashReduce(SHashReducer hash_reduce) const {
173 hash_reduce(source);
174 hash_reduce(lower_factor);
175 hash_reduce(extent);
176 hash_reduce(scale);
177 }
178
179 static constexpr const char* _type_key = "arith.IterSplitExpr";
180 TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode);
181};
182
183/*!
184 * \brief Managed reference to IterSplitExprNode.
185 * \sa IterSplitExprNode
186 */
187class IterSplitExpr : public IterMapExpr {
188 public:
189 /*!
190 * \brief constructor from just source.
191 * \param source The source expression.
192 */
193 TVM_DLL explicit IterSplitExpr(IterMark source);
194 /*!
195 * \brief constructor from just source.
196 * \param source The source expression.
197 * \param scale The additional scaling factor.
198 */
199 TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr scale);
200 /*!
201 * \brief constructor
202 * \param source The source expression.
203 * \param lower_factor The lower factor to split the source.
204 * \param extent The extent of the split.
205 * \param scale The additional scaling factor.
206 */
207 TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent,
208 PrimExpr scale);
209
210 TVM_DEFINE_OBJECT_REF_METHODS(IterSplitExpr, IterMapExpr, IterSplitExprNode);
211 TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSplitExprNode);
212};
213
214/*!
215 * \brief Fuse multiple iterators by summing them with scaling.
216 *
217 * result = sum(args) + base
218 */
219class IterSumExprNode : public IterMapExprNode {
220 public:
221 /*! \brief The args to the sum. */
222 Array<IterSplitExpr> args;
223 /*! \brief The base offset. */
224 PrimExpr base;
225
226 // overrides
227 void VisitAttrs(tvm::AttrVisitor* v) {
228 v->Visit("args", &args);
229 v->Visit("base", &base);
230 }
231
232 bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const {
233 return equal(args, other->args) && equal(base, other->base);
234 }
235
236 void SHashReduce(SHashReducer hash_reduce) const {
237 hash_reduce(args);
238 hash_reduce(base);
239 }
240
241 static constexpr const char* _type_key = "arith.IterSumExpr";
242 TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode);
243};
244
245/*!
246 * \brief Managed reference to IterSumExprNode.
247 * \sa IterSumExprNode
248 */
249class IterSumExpr : public IterMapExpr {
250 public:
251 /*!
252 * \brief constructor.
253 * \param args The args to the sum.
254 * \param base The base offset.
255 */
256 TVM_DLL IterSumExpr(Array<IterSplitExpr> args, PrimExpr base);
257
258 TVM_DEFINE_OBJECT_REF_METHODS(IterSumExpr, IterMapExpr, IterSumExprNode);
259 TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode);
260};
261
262/*! \brief Mapping level for iterators. */
263enum IterMapLevel {
264 // Require the mapping to be bijective.
265 Bijective = 0,
266 // Require the mapping to be surjective.
267 Surjective = 1,
268 // No mapping safety check.
269 NoCheck = 3
270};
271
272/*!
273 * \brief Result of DetectIterMap.
274 */
275class IterMapResultNode : public Object {
276 public:
277 // The detected pattern if a match exists.
278 Array<IterSumExpr> indices;
279
280 // Any errors that occurred while converting the input indices. If
281 // the array is empty, the conversion was successful.
282 Array<String> errors;
283
284 /*! \brief Boolean expression indicating if a specific value w
285 *
286 * `padding_predicate` evaluates to true for a set of indices that
287 * are outside the bounds of the provided index iterators, but
288 * inside the bounds of the returned index iterators. This
289 * expression is in terms of the variables provided in
290 * `input_iters`.
291 */
292 PrimExpr padding_predicate;
293
294 // overrides
295 void VisitAttrs(tvm::AttrVisitor* v) {
296 v->Visit("errors", &errors);
297 v->Visit("indices", &indices);
298 v->Visit("padding_predicate", &padding_predicate);
299 }
300
301 static constexpr const char* _type_key = "arith.IterMapResult";
302 TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object);
303};
304
305/*!
306 * \brief Managed reference to IterMapResultNode.
307 * \sa IterMapResultNode
308 */
309class IterMapResult : public ObjectRef {
310 public:
311 // constructor
312 IterMapResult() { data_ = make_object<IterMapResultNode>(); }
313
314 /*! \return mutable pointers to the node. */
315 IterMapResultNode* operator->() const { return static_cast<IterMapResultNode*>(get_mutable()); }
316};
317
318/*!
319 * \brief Detect if indices can be written as
320 * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n]
321 *
322 * Here y = some-quasi-affine-iter-map(input_iters)
323 * and c are symbolic constants.
324 *
325 * We also requires that y_i and y_j to be independent for i != j.
326 *
327 * For returned value rv, the following is always true:
328 * - rv[i]->args.size() <=1: only one iterator per element.
329 *
330 * \param indices The indices to detect pattern for.
331 * \param input_iters Map from variable to iterator's range.
332 * \param predicate The predicate constraints on the input iterators
333 * \param check_level The iter mapping checking level.
334 * \param analyzer Analyzer used to get context information.
335 * \param simplify_trivial_iterators If true, iterators with extent of
336 * 1 will be replaced with a constant value.
337 *
338 * \return The detected iteration result.
339 * The return object's .indices is empty on failure.
340 */
341IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
342 const PrimExpr& predicate, IterMapLevel check_level,
343 arith::Analyzer* analyzer, bool simplify_trivial_iterators = true);
344
345/*!
346 * \brief Use IterVarMap detector to rewrite and simplify the indices
347 *
348 * \param indices The indices to detect pattern for.
349 * \param input_iters Map from variable to iterator's range.
350 * \param input_pred The predicate constraints on the input iterators
351 * \param check_level The iter mapping checking level.
352 * \param simplify_trivial_iterators If true, iterators with unit extents are simplified
353 * \return The indices after rewrite
354 */
355Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters,
356 const PrimExpr& input_pred, IterMapLevel check_level,
357 bool simplify_trivial_iterators = true);
358
359/*!
360 * \brief Apply the inverse of the affine transformation to the outputs.
361 *
362 * Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions
363 * in reverse topology order and applies the inverse of the affine transformation until it reaches
364 * the input. The affine iter map is required to be bijective.
365 *
366 * For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1],
367 * the affine transformation specified by `iter_map` will be applied to `outputs` and the result
368 * will be {l0: ((output_0*16) + output_1)}.
369 *
370 * The range of `outputs` should be the same as the output range of the affine transmation.
371 *
372 * \sa DetectIterMap
373 *
374 * \param iter_map The bijective affine iter map.
375 * \param outputs The outputs of the affine transformation.
376 *
377 * \return The map from the input to the transformed result.
378 */
379Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map,
380 const Array<PrimExpr> outputs);
381
382/*!
383 * \brief Detect if bindings can be written as
384 * [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
385 *
386 * where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
387 * b = some-quasi-affine-iter-map(sub_iters)
388 * c is constant symbols
389 * e is the extent of b
390 *
391 * For example, z*12 + y*3 + x + c = (z*4+y)*3 + x, if sub_iters={x}
392 *
393 * \param bindings The input bindings
394 * \param input_iters Map from variable to iterator's range.
395 * \param sub_iters Iterators of subspace.
396 * \param predicate The predicate constraints on the input iterators
397 * \param check_level The iter mapping checking level.
398 * \param analyzer Analyzer used to get context information.
399 * \param simplify_trivial_iterators If true, iterators with extent of
400 * 1 will be replaced with a constant value.
401 *
402 * \return The result list has length len(bindings) + 1
403 [0, len(bindings)): The iter map matching result. The inner list is of length 2.
404 The first expr is the basis of the quotient space.
405 The second expr is the basis of the subspace.
406 len(bindings): the predicate of outer space and inner space
407 Empty array if no match can be found.
408 */
409Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
410 const Map<Var, Range>& input_iters,
411 const Array<Var>& sub_iters, const PrimExpr& predicate,
412 IterMapLevel check_level, arith::Analyzer* analyzer,
413 bool simplify_trivial_iterators = true);
414
415/*!
416 * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr.
417 * \param expr The input expression, which may contain IterMapExpr.
418 * \return The corresponding normal PrimExpr.
419 */
420PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr);
421
422} // namespace arith
423} // namespace tvm
424#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
425