1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/arith/iter_affine_map.h |
22 | * \brief Iterator quasi-affine mapping patterns. |
23 | * |
24 | * This file defines a collection of mapping patterns |
25 | * maps a collection of independent iterators to another |
26 | * collection of independent iterators. |
27 | * |
28 | * There are two main kinds of mapping patterns: |
29 | * |
30 | * - Fuse: fuse a collection of iterators into a single one |
31 | * |
32 | * domain(x0) = [0, 4), domain(x1) = [0, 3), domain(x2) = [0, 2) |
33 | * fuse(x0, x1, x2): y = x2 * 12 + x1 * 4 + x0 |
34 | * domain(y) = [0, 24) |
35 | * |
36 | * - Split: split an iterator into multiple ones |
37 | * |
38 | * domain(x) = [0, 24) |
39 | * split(x, 3, 12): [y0, y1, y2] = [x % 3, (x % 12) / 3, x / 12] |
40 | * domain(y0) = [0, 3), domain(y1) = [0, 4), domain(y2) = [0, 2) |
41 | * |
42 | * We use the name "(quasi)affine" to be consistent with |
43 | * the terminology used in the polyhedral compilation. |
44 | * Notably, fuse is an affine transformation, |
45 | * while split corresponds to additional floordiv/mod operations |
46 | * that can appear in quasi-affine transformations. |
47 | */ |
48 | #ifndef TVM_ARITH_ITER_AFFINE_MAP_H_ |
49 | #define TVM_ARITH_ITER_AFFINE_MAP_H_ |
50 | |
51 | #include <tvm/arith/analyzer.h> |
52 | #include <tvm/ir/diagnostic.h> |
53 | #include <tvm/ir/expr.h> |
54 | #include <tvm/tir/var.h> |
55 | |
56 | namespace tvm { |
57 | namespace arith { |
58 | |
59 | /*! |
60 | * \brief Base class of all iter map expressions. |
61 | * |
62 | * An IterMapExpr is a special expression to store |
63 | * the result of IterMapDetection. |
64 | * It should not appear in a legal TIR PrimFunc. |
65 | */ |
66 | class IterMapExprNode : public PrimExprNode { |
67 | public: |
68 | // overrides |
69 | void VisitAttrs(tvm::AttrVisitor* v) {} |
70 | |
71 | static constexpr const char* _type_key = "arith.IterMapExpr" ; |
72 | static constexpr const uint32_t _type_child_slots = 3; |
73 | TVM_DECLARE_BASE_OBJECT_INFO(IterMapExprNode, PrimExprNode); |
74 | }; |
75 | |
76 | /*! |
77 | * \brief Managed reference to IterMapExprNode. |
78 | * \sa IterMapExprNode |
79 | */ |
80 | class IterMapExpr : public PrimExpr { |
81 | public: |
82 | TVM_DEFINE_OBJECT_REF_METHODS(IterMapExpr, PrimExpr, IterMapExprNode); |
83 | }; |
84 | |
85 | /*! |
86 | * \brief Mark the source as an iterator in [0, extent). |
87 | * |
88 | * IterMark is used to mark source expression as a valid |
89 | * iterator to make future analysis easy. |
90 | */ |
91 | class IterMarkNode : public Object { |
92 | public: |
93 | /*! |
94 | * \brief The source expression, can either be |
95 | * a IterSumExpr or a Var. |
96 | */ |
97 | PrimExpr source; |
98 | /*! |
99 | * \brief The extent of the iteration. |
100 | */ |
101 | PrimExpr extent; |
102 | |
103 | // overrides |
104 | void VisitAttrs(tvm::AttrVisitor* v) { |
105 | v->Visit("source" , &source); |
106 | v->Visit("extent" , &extent); |
107 | } |
108 | |
109 | bool SEqualReduce(const IterMarkNode* other, SEqualReducer equal) const { |
110 | equal->MarkGraphNode(); |
111 | return equal(source, other->source) && equal(extent, other->extent); |
112 | } |
113 | |
114 | void SHashReduce(SHashReducer hash_reduce) const { |
115 | hash_reduce->MarkGraphNode(); |
116 | hash_reduce(source); |
117 | hash_reduce(extent); |
118 | } |
119 | |
120 | static constexpr const bool _type_has_method_sequal_reduce = true; |
121 | static constexpr const bool _type_has_method_shash_reduce = true; |
122 | static constexpr const char* _type_key = "arith.IterMark" ; |
123 | TVM_DECLARE_FINAL_OBJECT_INFO(IterMarkNode, Object); |
124 | }; |
125 | |
126 | /*! |
127 | * \brief Managed reference to IterMarkExprNode. |
128 | * \sa IterMarkExprNode |
129 | */ |
130 | class IterMark : public ObjectRef { |
131 | public: |
132 | /*! |
133 | * \brief constructor. |
134 | * \param source The source expression. |
135 | * \param extent The extent of the iterator. |
136 | */ |
137 | TVM_DLL IterMark(PrimExpr source, PrimExpr extent); |
138 | |
139 | TVM_DEFINE_OBJECT_REF_METHODS(IterMark, ObjectRef, IterMarkNode); |
140 | TVM_DEFINE_OBJECT_REF_COW_METHOD(IterMarkNode); |
141 | }; |
142 | |
143 | /*! |
144 | * \brief Split of an iterator. |
145 | * |
146 | * result = floormod(floordiv(source, lower_factor), extent) * scale |
147 | */ |
148 | class IterSplitExprNode : public IterMapExprNode { |
149 | public: |
150 | /*! \brief The source marked iterator. */ |
151 | IterMark source; |
152 | /*! \brief The lower factor to split the source. */ |
153 | PrimExpr lower_factor; |
154 | /*! \brief The extent of the split. */ |
155 | PrimExpr extent; |
156 | /*! \brief Additional scale. */ |
157 | PrimExpr scale; |
158 | |
159 | // overrides |
160 | void VisitAttrs(tvm::AttrVisitor* v) { |
161 | v->Visit("source" , &source); |
162 | v->Visit("lower_factor" , &lower_factor); |
163 | v->Visit("extent" , &extent); |
164 | v->Visit("scale" , &scale); |
165 | } |
166 | |
167 | bool SEqualReduce(const IterSplitExprNode* other, SEqualReducer equal) const { |
168 | return equal(source, other->source) && equal(lower_factor, other->lower_factor) && |
169 | equal(extent, other->extent) && equal(scale, other->scale); |
170 | } |
171 | |
172 | void SHashReduce(SHashReducer hash_reduce) const { |
173 | hash_reduce(source); |
174 | hash_reduce(lower_factor); |
175 | hash_reduce(extent); |
176 | hash_reduce(scale); |
177 | } |
178 | |
179 | static constexpr const char* _type_key = "arith.IterSplitExpr" ; |
180 | TVM_DECLARE_FINAL_OBJECT_INFO(IterSplitExprNode, IterMapExprNode); |
181 | }; |
182 | |
183 | /*! |
184 | * \brief Managed reference to IterSplitExprNode. |
185 | * \sa IterSplitExprNode |
186 | */ |
187 | class IterSplitExpr : public IterMapExpr { |
188 | public: |
189 | /*! |
190 | * \brief constructor from just source. |
191 | * \param source The source expression. |
192 | */ |
193 | TVM_DLL explicit IterSplitExpr(IterMark source); |
194 | /*! |
195 | * \brief constructor from just source. |
196 | * \param source The source expression. |
197 | * \param scale The additional scaling factor. |
198 | */ |
199 | TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr scale); |
200 | /*! |
201 | * \brief constructor |
202 | * \param source The source expression. |
203 | * \param lower_factor The lower factor to split the source. |
204 | * \param extent The extent of the split. |
205 | * \param scale The additional scaling factor. |
206 | */ |
207 | TVM_DLL explicit IterSplitExpr(IterMark source, PrimExpr lower_factor, PrimExpr extent, |
208 | PrimExpr scale); |
209 | |
210 | TVM_DEFINE_OBJECT_REF_METHODS(IterSplitExpr, IterMapExpr, IterSplitExprNode); |
211 | TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSplitExprNode); |
212 | }; |
213 | |
214 | /*! |
215 | * \brief Fuse multiple iterators by summing them with scaling. |
216 | * |
217 | * result = sum(args) + base |
218 | */ |
219 | class IterSumExprNode : public IterMapExprNode { |
220 | public: |
221 | /*! \brief The args to the sum. */ |
222 | Array<IterSplitExpr> args; |
223 | /*! \brief The base offset. */ |
224 | PrimExpr base; |
225 | |
226 | // overrides |
227 | void VisitAttrs(tvm::AttrVisitor* v) { |
228 | v->Visit("args" , &args); |
229 | v->Visit("base" , &base); |
230 | } |
231 | |
232 | bool SEqualReduce(const IterSumExprNode* other, SEqualReducer equal) const { |
233 | return equal(args, other->args) && equal(base, other->base); |
234 | } |
235 | |
236 | void SHashReduce(SHashReducer hash_reduce) const { |
237 | hash_reduce(args); |
238 | hash_reduce(base); |
239 | } |
240 | |
241 | static constexpr const char* _type_key = "arith.IterSumExpr" ; |
242 | TVM_DECLARE_FINAL_OBJECT_INFO(IterSumExprNode, IterMapExprNode); |
243 | }; |
244 | |
245 | /*! |
246 | * \brief Managed reference to IterSumExprNode. |
247 | * \sa IterSumExprNode |
248 | */ |
249 | class IterSumExpr : public IterMapExpr { |
250 | public: |
251 | /*! |
252 | * \brief constructor. |
253 | * \param args The args to the sum. |
254 | * \param base The base offset. |
255 | */ |
256 | TVM_DLL IterSumExpr(Array<IterSplitExpr> args, PrimExpr base); |
257 | |
258 | TVM_DEFINE_OBJECT_REF_METHODS(IterSumExpr, IterMapExpr, IterSumExprNode); |
259 | TVM_DEFINE_OBJECT_REF_COW_METHOD(IterSumExprNode); |
260 | }; |
261 | |
262 | /*! \brief Mapping level for iterators. */ |
263 | enum IterMapLevel { |
264 | // Require the mapping to be bijective. |
265 | Bijective = 0, |
266 | // Require the mapping to be surjective. |
267 | Surjective = 1, |
268 | // No mapping safety check. |
269 | NoCheck = 3 |
270 | }; |
271 | |
272 | /*! |
273 | * \brief Result of DetectIterMap. |
274 | */ |
275 | class IterMapResultNode : public Object { |
276 | public: |
277 | // The detected pattern if a match exists. |
278 | Array<IterSumExpr> indices; |
279 | |
280 | // Any errors that occurred while converting the input indices. If |
281 | // the array is empty, the conversion was successful. |
282 | Array<String> errors; |
283 | |
284 | /*! \brief Boolean expression indicating if a specific value w |
285 | * |
286 | * `padding_predicate` evaluates to true for a set of indices that |
287 | * are outside the bounds of the provided index iterators, but |
288 | * inside the bounds of the returned index iterators. This |
289 | * expression is in terms of the variables provided in |
290 | * `input_iters`. |
291 | */ |
292 | PrimExpr padding_predicate; |
293 | |
294 | // overrides |
295 | void VisitAttrs(tvm::AttrVisitor* v) { |
296 | v->Visit("errors" , &errors); |
297 | v->Visit("indices" , &indices); |
298 | v->Visit("padding_predicate" , &padding_predicate); |
299 | } |
300 | |
301 | static constexpr const char* _type_key = "arith.IterMapResult" ; |
302 | TVM_DECLARE_FINAL_OBJECT_INFO(IterMapResultNode, Object); |
303 | }; |
304 | |
305 | /*! |
306 | * \brief Managed reference to IterMapResultNode. |
307 | * \sa IterMapResultNode |
308 | */ |
309 | class IterMapResult : public ObjectRef { |
310 | public: |
311 | // constructor |
312 | IterMapResult() { data_ = make_object<IterMapResultNode>(); } |
313 | |
314 | /*! \return mutable pointers to the node. */ |
315 | IterMapResultNode* operator->() const { return static_cast<IterMapResultNode*>(get_mutable()); } |
316 | }; |
317 | |
318 | /*! |
319 | * \brief Detect if indices can be written as |
320 | * [y_0 + c_0, y_1 + c_1, ..., y_n + c_n] |
321 | * |
322 | * Here y = some-quasi-affine-iter-map(input_iters) |
323 | * and c are symbolic constants. |
324 | * |
325 | * We also requires that y_i and y_j to be independent for i != j. |
326 | * |
327 | * For returned value rv, the following is always true: |
328 | * - rv[i]->args.size() <=1: only one iterator per element. |
329 | * |
330 | * \param indices The indices to detect pattern for. |
331 | * \param input_iters Map from variable to iterator's range. |
332 | * \param predicate The predicate constraints on the input iterators |
333 | * \param check_level The iter mapping checking level. |
334 | * \param analyzer Analyzer used to get context information. |
335 | * \param simplify_trivial_iterators If true, iterators with extent of |
336 | * 1 will be replaced with a constant value. |
337 | * |
338 | * \return The detected iteration result. |
339 | * The return object's .indices is empty on failure. |
340 | */ |
341 | IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters, |
342 | const PrimExpr& predicate, IterMapLevel check_level, |
343 | arith::Analyzer* analyzer, bool simplify_trivial_iterators = true); |
344 | |
345 | /*! |
346 | * \brief Use IterVarMap detector to rewrite and simplify the indices |
347 | * |
348 | * \param indices The indices to detect pattern for. |
349 | * \param input_iters Map from variable to iterator's range. |
350 | * \param input_pred The predicate constraints on the input iterators |
351 | * \param check_level The iter mapping checking level. |
352 | * \param simplify_trivial_iterators If true, iterators with unit extents are simplified |
353 | * \return The indices after rewrite |
354 | */ |
355 | Array<PrimExpr> IterMapSimplify(const Array<PrimExpr>& indices, const Map<Var, Range>& input_iters, |
356 | const PrimExpr& input_pred, IterMapLevel check_level, |
357 | bool simplify_trivial_iterators = true); |
358 | |
359 | /*! |
360 | * \brief Apply the inverse of the affine transformation to the outputs. |
361 | * |
362 | * Similar to the back-propagation, starting from the outputs, it visits the DAG of the expressions |
363 | * in reverse topology order and applies the inverse of the affine transformation until it reaches |
364 | * the input. The affine iter map is required to be bijective. |
365 | * |
366 | * For example, iter_map = [l0 // 16, l0 % 16], outputs = [output_0, output_1], |
367 | * the affine transformation specified by `iter_map` will be applied to `outputs` and the result |
368 | * will be {l0: ((output_0*16) + output_1)}. |
369 | * |
370 | * The range of `outputs` should be the same as the output range of the affine transmation. |
371 | * |
372 | * \sa DetectIterMap |
373 | * |
374 | * \param iter_map The bijective affine iter map. |
375 | * \param outputs The outputs of the affine transformation. |
376 | * |
377 | * \return The map from the input to the transformed result. |
378 | */ |
379 | Map<Var, PrimExpr> InverseAffineIterMap(const Array<IterSumExpr>& iter_map, |
380 | const Array<PrimExpr> outputs); |
381 | |
382 | /*! |
383 | * \brief Detect if bindings can be written as |
384 | * [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n] |
385 | * |
386 | * where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters) |
387 | * b = some-quasi-affine-iter-map(sub_iters) |
388 | * c is constant symbols |
389 | * e is the extent of b |
390 | * |
391 | * For example, z*12 + y*3 + x + c = (z*4+y)*3 + x, if sub_iters={x} |
392 | * |
393 | * \param bindings The input bindings |
394 | * \param input_iters Map from variable to iterator's range. |
395 | * \param sub_iters Iterators of subspace. |
396 | * \param predicate The predicate constraints on the input iterators |
397 | * \param check_level The iter mapping checking level. |
398 | * \param analyzer Analyzer used to get context information. |
399 | * \param simplify_trivial_iterators If true, iterators with extent of |
400 | * 1 will be replaced with a constant value. |
401 | * |
402 | * \return The result list has length len(bindings) + 1 |
403 | [0, len(bindings)): The iter map matching result. The inner list is of length 2. |
404 | The first expr is the basis of the quotient space. |
405 | The second expr is the basis of the subspace. |
406 | len(bindings): the predicate of outer space and inner space |
407 | Empty array if no match can be found. |
408 | */ |
409 | Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings, |
410 | const Map<Var, Range>& input_iters, |
411 | const Array<Var>& sub_iters, const PrimExpr& predicate, |
412 | IterMapLevel check_level, arith::Analyzer* analyzer, |
413 | bool simplify_trivial_iterators = true); |
414 | |
415 | /*! |
416 | * \brief Given an expression that may contain IterMapExpr, transform it to normal PrimExpr. |
417 | * \param expr The input expression, which may contain IterMapExpr. |
418 | * \return The corresponding normal PrimExpr. |
419 | */ |
420 | PrimExpr NormalizeIterMapToExpr(const PrimExpr& expr); |
421 | |
422 | } // namespace arith |
423 | } // namespace tvm |
424 | #endif // TVM_ARITH_ITER_AFFINE_MAP_H_ |
425 | |