1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/tir/index_map.h
22 * \brief Defines a remapping of buffer indices
23 *
24 * For use with tvm::tir::Buffer.
25 */
26#ifndef TVM_TIR_INDEX_MAP_H_
27#define TVM_TIR_INDEX_MAP_H_
28
29#include <tvm/ir/expr.h>
30#include <tvm/runtime/container/array.h>
31#include <tvm/runtime/object.h>
32#include <tvm/tir/var.h>
33
34#include <utility>
35
36namespace tvm {
37namespace arith {
38class Analyzer;
39}
40} // namespace tvm
41
42namespace tvm {
43namespace tir {
44
45/*!
46 * \brief Defines a mapping between two representations of indices
47 * into a buffer.
48 *
49 * This is primarily used for layout transformations of Buffer
50 * objects.
51 */
52class IndexMapNode : public Object {
53 public:
54 /*! \brief Variables representing the indices prior to remapping.
55 *
56 * If initial_indices is empty, then final_indices should also be
57 * empty, and no mapping is applied.
58 */
59 Array<Var> initial_indices;
60
61 /*!
62 * \brief Expressions defining the indices after remapping.
63 *
64 * These expressions should only be in terms of the initial_indices,
65 * and must be expressible as an IterSumExpr. The mapping from
66 * initial_indices to final_indices must be injective.
67 *
68 * If final_indices is empty, then initial_indices should also be
69 * empty, and the map is an identity function.
70 */
71 Array<PrimExpr> final_indices;
72
73 /*!
74 * \brief The inverse index map.
75 *
76 * When this is defined, IndexMap::Inverse will return the
77 * pre-defined inverse index map. Otherwise, the inverse index map
78 * will be computed on the fly. It is the user's responsibility to
79 * ensure the correctness of the pre-defined inverse index map.
80 *
81 * \note ObjectRef is used here instead of IndexMap to avoid circular reference.
82 */
83 Optional<ObjectRef> inverse_index_map;
84
85 /*!
86 * \brief Default constructor
87 *
88 * Defines the mapping as an identity function, with initial_indices
89 * equal to the final indices.
90 */
91 IndexMapNode() {}
92
93 /*!
94 * \brief Map indices to the output space
95 *
96 * \param indices The indices in the input space. Should contain
97 * one value for each variable in `initial_indices`.
98 *
99 * \param analyzer An optional analyzer to be used to simplify the
100 * resulting expressions. If null, will use a fresh analyzer.
101 *
102 * \returns The indices in the output space. Contains one value for
103 * each expression in `final_indices`.
104 */
105 Array<PrimExpr> MapIndices(const Array<PrimExpr>& indices,
106 arith::Analyzer* analyzer = nullptr) const;
107
108 /*! \brief Map a memory range to the output space
109 *
110 * If contiguous memory locations in the input space are not
111 * necessarily contiguous in the output space (e.g. `lambda i:
112 * [8*(i%8) + (i//8)]`), then this will return the smallest range
113 * such that all valid indices are contained within the given range.
114 *
115 * \param ranges The ranges in the input space. Should contain one
116 * value for each variable in `initial_indices`.
117 *
118 * \param analyzer An optional analyzer to be used to simplify the
119 * resulting expressions. If null, will use a fresh analyzer.
120 *
121 * \returns The ranges in the output space. Contains one value for
122 * each expression in `final_indices`.
123 */
124 Array<Range> MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer = nullptr) const;
125
126 /*! \brief Map a buffer shape to the output space
127 *
128 * \param shape The buffer shape in the input space. Should contain
129 * one value for each variable in `initial_indices`.
130 *
131 * \param analyzer An optional analyzer to be used to simplify the
132 * resulting expressions. If null, will use a fresh analyzer.
133 *
134 * \returns The buffer shape in the output space. Contains one
135 * value for each expression in `final_indices`.
136 */
137 Array<PrimExpr> MapShape(const Array<PrimExpr>& shape, arith::Analyzer* analyzer = nullptr) const;
138
139 /* \brief Map an NDArray according to this index map
140 *
141 * \param arr_src The NDArray whose layout is transformed by this index map.
142 *
143 * \returns The transformed NDArray.
144 */
145 runtime::NDArray MapNDArray(runtime::NDArray arr_src) const;
146
147 /*!
148 * \brief Convert to string representation in Python.
149 * \param f_name_map Optional function to specify the stringified name of the variables.
150 * \return The stringified lambda expression in Python.
151 */
152 String ToPythonString(
153 const std::function<Optional<String>(const Var& var)>& f_name_map = nullptr) const;
154
155 void VisitAttrs(AttrVisitor* v) {
156 v->Visit("initial_indices", &initial_indices);
157 v->Visit("final_indices", &final_indices);
158 v->Visit("inverse_index_map", &inverse_index_map);
159 }
160
161 bool SEqualReduce(const IndexMapNode* other, SEqualReducer equal) const {
162 return equal.DefEqual(initial_indices, other->initial_indices) &&
163 equal(final_indices, other->final_indices);
164 }
165
166 void SHashReduce(SHashReducer hash_reduce) const {
167 hash_reduce.DefHash(initial_indices);
168 hash_reduce(final_indices);
169 }
170
171 static constexpr const char* _type_key = "tir.IndexMap";
172 static constexpr const bool _type_has_method_sequal_reduce = true;
173 static constexpr const bool _type_has_method_shash_reduce = true;
174 TVM_DECLARE_FINAL_OBJECT_INFO(IndexMapNode, Object);
175};
176
177class IndexMap : public ObjectRef {
178 public:
179 /*!
180 * \brief The constructor
181 * \param initial_indices Variables representing the indices prior to remapping
182 * \param final_indices Expressions defining the indices after remapping.
183 * \param inverse_index_map The optional pre-defined inverse index map
184 */
185 IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
186 Optional<IndexMap> inverse_index_map = NullOpt);
187
188 /*!
189 * \brief Create an index map from a packed function
190 * \param ndim The number of dimensions
191 * \param func The function to be applied
192 * \param inverse_index_map The optional pre-defined inverse index map
193 * \return The created index map
194 */
195 static IndexMap FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
196 Optional<IndexMap> inverse_index_map = NullOpt);
197
198 /*! \brief Generate the inverse mapping.
199 *
200 * The range of the input indices is required in order to ensure
201 * that the transformation is bijective over the input domain.
202 *
203 * If the user has supplied an `inverse_index_map`, that map is
204 * assumed to be correct and bijective, and is returned.
205 */
206 IndexMap Inverse(Array<Range> initial_ranges) const;
207
208 /*! \brief Rename the variables in the index map and ensure the names are unique.
209 *
210 * Construct a new index map with the same transformation, but with name_hint of variables to be
211 * guaranteed unique. The optional f_name_map can be provided to rename the variables.
212 *
213 * \param f_name_map The optional name map to rename the variables.
214 * \return The renamed index map.
215 */
216 IndexMap RenameVariables(
217 const std::function<Optional<String>(const Var& var)>& f_name_map = nullptr) const;
218
219 /*! \brief Generate the inverse mapping.
220 *
221 * Determine the inverse, where the output range may contain
222 * addresses that do not correspond to an address in the input
223 * range.
224 *
225 * \return The inverted index map, along with the predicate for
226 * which the inverse maps to a valid range.
227 */
228 std::pair<IndexMap, PrimExpr> NonSurjectiveInverse(Array<Range> initial_ranges) const;
229
230 TVM_DEFINE_OBJECT_REF_METHODS(IndexMap, ObjectRef, IndexMapNode);
231};
232
233/*! \brief Substitute variables in an index map.
234 *
235 * \param index_map The index_map
236 * \param f_subst The substitution function
237 */
238IndexMap Substitute(const IndexMap& index_map,
239 std::function<Optional<PrimExpr>(const Var& var)> f_subst);
240
241} // namespace tir
242} // namespace tvm
243
244#endif // TVM_TIR_INDEX_MAP_H_
245