1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file tvm/tir/index_map.h |
22 | * \brief Defines a remapping of buffer indices |
23 | * |
24 | * For use with tvm::tir::Buffer. |
25 | */ |
26 | #ifndef TVM_TIR_INDEX_MAP_H_ |
27 | #define TVM_TIR_INDEX_MAP_H_ |
28 | |
29 | #include <tvm/ir/expr.h> |
30 | #include <tvm/runtime/container/array.h> |
31 | #include <tvm/runtime/object.h> |
32 | #include <tvm/tir/var.h> |
33 | |
34 | #include <utility> |
35 | |
36 | namespace tvm { |
37 | namespace arith { |
38 | class Analyzer; |
39 | } |
40 | } // namespace tvm |
41 | |
42 | namespace tvm { |
43 | namespace tir { |
44 | |
45 | /*! |
46 | * \brief Defines a mapping between two representations of indices |
47 | * into a buffer. |
48 | * |
49 | * This is primarily used for layout transformations of Buffer |
50 | * objects. |
51 | */ |
52 | class IndexMapNode : public Object { |
53 | public: |
54 | /*! \brief Variables representing the indices prior to remapping. |
55 | * |
56 | * If initial_indices is empty, then final_indices should also be |
57 | * empty, and no mapping is applied. |
58 | */ |
59 | Array<Var> initial_indices; |
60 | |
61 | /*! |
62 | * \brief Expressions defining the indices after remapping. |
63 | * |
64 | * These expressions should only be in terms of the initial_indices, |
65 | * and must be expressible as an IterSumExpr. The mapping from |
66 | * initial_indices to final_indices must be injective. |
67 | * |
68 | * If final_indices is empty, then initial_indices should also be |
69 | * empty, and the map is an identity function. |
70 | */ |
71 | Array<PrimExpr> final_indices; |
72 | |
73 | /*! |
74 | * \brief The inverse index map. |
75 | * |
76 | * When this is defined, IndexMap::Inverse will return the |
77 | * pre-defined inverse index map. Otherwise, the inverse index map |
78 | * will be computed on the fly. It is the user's responsibility to |
79 | * ensure the correctness of the pre-defined inverse index map. |
80 | * |
81 | * \note ObjectRef is used here instead of IndexMap to avoid circular reference. |
82 | */ |
83 | Optional<ObjectRef> inverse_index_map; |
84 | |
85 | /*! |
86 | * \brief Default constructor |
87 | * |
88 | * Defines the mapping as an identity function, with initial_indices |
89 | * equal to the final indices. |
90 | */ |
91 | IndexMapNode() {} |
92 | |
93 | /*! |
94 | * \brief Map indices to the output space |
95 | * |
96 | * \param indices The indices in the input space. Should contain |
97 | * one value for each variable in `initial_indices`. |
98 | * |
99 | * \param analyzer An optional analyzer to be used to simplify the |
100 | * resulting expressions. If null, will use a fresh analyzer. |
101 | * |
102 | * \returns The indices in the output space. Contains one value for |
103 | * each expression in `final_indices`. |
104 | */ |
105 | Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices, |
106 | arith::Analyzer* analyzer = nullptr) const; |
107 | |
108 | /*! \brief Map a memory range to the output space |
109 | * |
110 | * If contiguous memory locations in the input space are not |
111 | * necessarily contiguous in the output space (e.g. `lambda i: |
112 | * [8*(i%8) + (i//8)]`), then this will return the smallest range |
113 | * such that all valid indices are contained within the given range. |
114 | * |
115 | * \param ranges The ranges in the input space. Should contain one |
116 | * value for each variable in `initial_indices`. |
117 | * |
118 | * \param analyzer An optional analyzer to be used to simplify the |
119 | * resulting expressions. If null, will use a fresh analyzer. |
120 | * |
121 | * \returns The ranges in the output space. Contains one value for |
122 | * each expression in `final_indices`. |
123 | */ |
124 | Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer = nullptr) const; |
125 | |
126 | /*! \brief Map a buffer shape to the output space |
127 | * |
128 | * \param shape The buffer shape in the input space. Should contain |
129 | * one value for each variable in `initial_indices`. |
130 | * |
131 | * \param analyzer An optional analyzer to be used to simplify the |
132 | * resulting expressions. If null, will use a fresh analyzer. |
133 | * |
134 | * \returns The buffer shape in the output space. Contains one |
135 | * value for each expression in `final_indices`. |
136 | */ |
137 | Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* analyzer = nullptr) const; |
138 | |
139 | /* \brief Map an NDArray according to this index map |
140 | * |
141 | * \param arr_src The NDArray whose layout is transformed by this index map. |
142 | * |
143 | * \returns The transformed NDArray. |
144 | */ |
145 | runtime::NDArray MapNDArray(runtime::NDArray arr_src) const; |
146 | |
147 | /*! |
148 | * \brief Convert to string representation in Python. |
149 | * \param f_name_map Optional function to specify the stringified name of the variables. |
150 | * \return The stringified lambda expression in Python. |
151 | */ |
152 | String ToPythonString( |
153 | const std::function<Optional<String>(const Var& var)>& f_name_map = nullptr) const; |
154 | |
155 | void VisitAttrs(AttrVisitor* v) { |
156 | v->Visit("initial_indices" , &initial_indices); |
157 | v->Visit("final_indices" , &final_indices); |
158 | v->Visit("inverse_index_map" , &inverse_index_map); |
159 | } |
160 | |
161 | bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const { |
162 | return equal.DefEqual(initial_indices, other->initial_indices) && |
163 | equal(final_indices, other->final_indices); |
164 | } |
165 | |
166 | void SHashReduce(SHashReducer hash_reduce) const { |
167 | hash_reduce.DefHash(initial_indices); |
168 | hash_reduce(final_indices); |
169 | } |
170 | |
171 | static constexpr const char* _type_key = "tir.IndexMap" ; |
172 | static constexpr const bool _type_has_method_sequal_reduce = true; |
173 | static constexpr const bool _type_has_method_shash_reduce = true; |
174 | TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object); |
175 | }; |
176 | |
177 | class IndexMap : public ObjectRef { |
178 | public: |
179 | /*! |
180 | * \brief The constructor |
181 | * \param initial_indices Variables representing the indices prior to remapping |
182 | * \param final_indices Expressions defining the indices after remapping. |
183 | * \param inverse_index_map The optional pre-defined inverse index map |
184 | */ |
185 | IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices, |
186 | Optional<IndexMap> inverse_index_map = NullOpt); |
187 | |
188 | /*! |
189 | * \brief Create an index map from a packed function |
190 | * \param ndim The number of dimensions |
191 | * \param func The function to be applied |
192 | * \param inverse_index_map The optional pre-defined inverse index map |
193 | * \return The created index map |
194 | */ |
195 | static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func, |
196 | Optional<IndexMap> inverse_index_map = NullOpt); |
197 | |
198 | /*! \brief Generate the inverse mapping. |
199 | * |
200 | * The range of the input indices is required in order to ensure |
201 | * that the transformation is bijective over the input domain. |
202 | * |
203 | * If the user has supplied an `inverse_index_map`, that map is |
204 | * assumed to be correct and bijective, and is returned. |
205 | */ |
206 | IndexMap Inverse(Array<Range> initial_ranges) const; |
207 | |
208 | /*! \brief Rename the variables in the index map and ensure the names are unique. |
209 | * |
210 | * Construct a new index map with the same transformation, but with name_hint of variables to be |
211 | * guaranteed unique. The optional f_name_map can be provided to rename the variables. |
212 | * |
213 | * \param f_name_map The optional name map to rename the variables. |
214 | * \return The renamed index map. |
215 | */ |
216 | IndexMap RenameVariables( |
217 | const std::function<Optional<String>(const Var& var)>& f_name_map = nullptr) const; |
218 | |
219 | /*! \brief Generate the inverse mapping. |
220 | * |
221 | * Determine the inverse, where the output range may contain |
222 | * addresses that do not correspond to an address in the input |
223 | * range. |
224 | * |
225 | * \return The inverted index map, along with the predicate for |
226 | * which the inverse maps to a valid range. |
227 | */ |
228 | std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> initial_ranges) const; |
229 | |
230 | TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode); |
231 | }; |
232 | |
233 | /*! \brief Substitute variables in an index map. |
234 | * |
235 | * \param index_map The index_map |
236 | * \param f_subst The substitution function |
237 | */ |
238 | IndexMap Substitute(const IndexMap& index_map, |
239 | std::function<Optional<PrimExpr>(const Var& var)> f_subst); |
240 | |
241 | } // namespace tir |
242 | } // namespace tvm |
243 | |
244 | #endif // TVM_TIR_INDEX_MAP_H_ |
245 | |