1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file index_map.cc
22 */
23
24#include <tvm/arith/analyzer.h>
25#include <tvm/arith/int_set.h>
26#include <tvm/arith/iter_affine_map.h>
27#include <tvm/ir/name_supply.h>
28#include <tvm/tir/index_map.h>
29#include <tvm/tir/op.h>
30#include <tvm/tir/stmt_functor.h>
31
32#include <sstream>
33
34namespace tvm {
35namespace tir {
36
37IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices,
38 Optional<IndexMap> inverse_index_map) {
39 auto n = make_object<IndexMapNode>();
40 n->initial_indices = std::move(initial_indices);
41 n->final_indices = std::move(final_indices);
42 n->inverse_index_map = std::move(inverse_index_map);
43 data_ = std::move(n);
44}
45
46IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func,
47 Optional<IndexMap> inverse_index_map) {
48 Array<Var> initial_indices;
49 initial_indices.reserve(ndim);
50 for (int i = 0; i < ndim; ++i) {
51 initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));
52 }
53 return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map));
54}
55
56std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const IndexMap& self,
57 const Array<Range>& initial_ranges,
58 arith::IterMapLevel check_level) {
59 if (self->inverse_index_map.defined()) {
60 // return the pre-defined inverse index map if exists. In this
61 // case, the user-defined inverse is assumed to be correct and
62 // bijective.
63 PrimExpr padding_predicate = Bool(false);
64 return {Downcast<IndexMap>(self->inverse_index_map.value()), padding_predicate};
65 }
66
67 // Dummy variables to represent the inverse's inputs.
68 Array<Var> output_vars;
69 for (size_t i = 0; i < self->final_indices.size(); i++) {
70 PrimExpr index = self->final_indices[i];
71 // TODO(Lunderberg): Better names for these variables. A variable
72 // that is passed through unmodified (`index` is an element of
73 // `initial_indices`) should use that input index's name. A pair
74 // of output indices variables split from a single input index
75 // should be named (X.outer,X.inner).
76 std::stringstream ss;
77 ss << "axis" << i;
78 Var var_index(ss.str(), index.dtype());
79 output_vars.push_back(var_index);
80 }
81
82 // Dummy ranges for the extent of each input.
83 Map<Var, Range> input_iters;
84 ICHECK_EQ(self->initial_indices.size(), initial_ranges.size());
85 for (size_t i = 0; i < initial_ranges.size(); i++) {
86 input_iters.Set(self->initial_indices[i], initial_ranges[i]);
87 }
88
89 // Unpack the output indices into linear combinations of the initial
90 // indices.
91 arith::Analyzer analyzer;
92 auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /* predicate = */ 1,
93 /*check_level=*/check_level, &analyzer,
94 /*simplify_trivial_iterators=*/false);
95 CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. "
96 << "Error: " << padded_iter_map->errors[0];
97
98 // Determine expressions for the input variables, in terms of the
99 // output variables.
100 Map<Var, PrimExpr> inverse_exprs_map = InverseAffineIterMap(
101 padded_iter_map->indices, Array<PrimExpr>(output_vars.begin(), output_vars.end()));
102
103 // Unpack the map to an array, maintaining the same parameter order.
104 Array<PrimExpr> inverse_exprs;
105 for (int i = 0, n = self->initial_indices.size(); i < n; ++i) {
106 Var index = self->initial_indices[i];
107 PrimExpr expr;
108 if (is_one(initial_ranges[i]->extent) && !inverse_exprs_map.count(index)) {
109 expr = initial_ranges[i]->min;
110 } else {
111 expr = inverse_exprs_map.at(index);
112 }
113 inverse_exprs.push_back(analyzer.Simplify(expr));
114 }
115
116 PrimExpr padding_predicate = padded_iter_map->padding_predicate;
117 padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate);
118 padding_predicate = Substitute(padding_predicate, inverse_exprs_map);
119
120 {
121 auto output_ranges = self->MapRanges(initial_ranges);
122 ICHECK_EQ(output_ranges.size(), output_vars.size());
123
124 arith::Analyzer analyzer;
125 for (size_t i = 0; i < output_vars.size(); ++i) {
126 analyzer.Bind(output_vars[i], output_ranges[i]);
127 }
128
129 // Additional simplification steps required to unwrap nested floordiv/floormod
130 padding_predicate = analyzer.Simplify(padding_predicate, 10);
131 }
132
133 return {IndexMap(output_vars, inverse_exprs), padding_predicate};
134}
135
136std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const {
137 return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck);
138}
139
140IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const {
141 auto [inverse, padding_predicate] =
142 IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective);
143 arith::Analyzer analyzer;
144 CHECK(analyzer.CanProve(!padding_predicate))
145 << "Bijective inverse should not contain padding, but inverse of " << *this << " over range "
146 << initial_ranges << " resulted in a padding predicate of " << padding_predicate;
147 return inverse;
148}
149
150Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices,
151 arith::Analyzer* analyzer) const {
152 ICHECK_EQ(indices.size(), initial_indices.size());
153
154 Map<Var, PrimExpr> vmap;
155
156 for (size_t i = 0; i < initial_indices.size(); i++) {
157 vmap.Set(initial_indices[i], indices[i]);
158 }
159
160 arith::Analyzer local_analyzer;
161 if (!analyzer) {
162 analyzer = &local_analyzer;
163 }
164
165 Array<PrimExpr> output = final_indices.Map([&](PrimExpr index) {
166 PrimExpr result = SubstituteWithDataTypeLegalization(
167 std::move(index), [&](const Var& var) { return vmap.Get(var); });
168 return analyzer->Simplify(result);
169 });
170 return output;
171}
172
173Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer) const {
174 ICHECK_EQ(ranges.size(), initial_indices.size());
175
176 Map<Var, Range> input_iters;
177 for (size_t i = 0; i < initial_indices.size(); i++) {
178 input_iters.Set(initial_indices[i], ranges[i]);
179 }
180
181 arith::Analyzer local_analyzer;
182 if (!analyzer) {
183 analyzer = &local_analyzer;
184 }
185
186 auto iter_map = DetectIterMap(final_indices, input_iters, /* predicate = */ 1,
187 /*check_level=*/arith::IterMapLevel::NoCheck, analyzer,
188 /*simplify_trivial_iterators=*/false);
189 Array<Range> output;
190 if (iter_map->indices.size()) {
191 // Preferred route, requires the map to be expressible as an
192 // affine sum. Since the terms are orthogonal, the extent of the
193 // sum is the extent of the largest term.
194 for (const auto& index : iter_map->indices) {
195 Optional<PrimExpr> extent = NullOpt;
196 for (const auto& term : index->args) {
197 PrimExpr term_extent = term->extent * term->scale;
198 if (extent.defined()) {
199 extent = tvm::max(extent.value(), term_extent);
200 } else {
201 extent = term_extent;
202 }
203 }
204 output.push_back(Range::FromMinExtent(index->base, extent.value_or(1)));
205 }
206
207 } else {
208 // Fall-back method, more general but can ignore intended padding.
209 // For example, [N] mapped through i=>[i//4,i%4] should have shape
210 // [ceildiv(N,4), 4]. However, for N<4, this method instead
211 // results in a shape [1, N].
212 std::unordered_map<const VarNode*, arith::IntSet> dom_map;
213 for (size_t i = 0; i < initial_indices.size(); i++) {
214 dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]);
215 }
216
217 for (const auto& final_index : final_indices) {
218 auto int_set = arith::EvalSet(final_index, dom_map);
219 output.push_back(Range::FromMinExtent(analyzer->Simplify(int_set.min()),
220 analyzer->Simplify(int_set.max() - int_set.min() + 1)));
221 }
222 }
223 auto output_dtype = [&]() {
224 int max_bits = 0;
225 for (const auto& range : ranges) {
226 max_bits = std::max(max_bits, range->extent.dtype().bits());
227 }
228 return DataType::Int(max_bits);
229 }();
230 output.MutateByApply([&](const Range& range) {
231 if (range->min.dtype() != output_dtype || range->extent.dtype() != output_dtype) {
232 return Range::FromMinExtent(cast(output_dtype, range->min),
233 cast(output_dtype, range->extent));
234 } else {
235 return range;
236 }
237 });
238 return output;
239}
240
241Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape,
242 arith::Analyzer* analyzer) const {
243 ICHECK_EQ(shape.size(), initial_indices.size());
244
245 Array<Range> ranges;
246 for (auto& dim : shape) {
247 ranges.push_back(Range(make_zero(dim.dtype()), dim));
248 }
249 Array<Range> mapped = MapRanges(std::move(ranges), analyzer);
250
251 Array<PrimExpr> output;
252 for (auto& range : mapped) {
253 ICHECK(is_zero(range->min));
254 output.push_back(range->extent);
255 }
256
257 return output;
258}
259
260runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const {
261 auto shape = arr_src.Shape();
262 ICHECK(shape.size() == initial_indices.size())
263 << "The rank of the input array should be " << initial_indices.size() << " but got "
264 << shape.size();
265 size_t size_1d = 1;
266 Array<PrimExpr> orig_shape;
267 for (size_t i = 0; i < shape.size(); ++i) {
268 size_1d *= shape[i];
269 orig_shape.push_back(PrimExpr(static_cast<int>((shape[i]))));
270 }
271 auto dst_shape = MapShape(orig_shape);
272
273 std::vector<int64_t> dst_shape_int;
274 for (size_t i = 0; i < dst_shape.size(); ++i) {
275 dst_shape_int.push_back(dst_shape[i].as<IntImmNode>()->value);
276 }
277
278 auto elem_bytes = (arr_src->dtype.bits / 8) * arr_src->dtype.lanes;
279 std::vector<uint8_t> bytes_src(size_1d * elem_bytes);
280 arr_src.CopyToBytes(bytes_src.data(), bytes_src.size());
281
282 std::vector<uint8_t> bytes_dst(bytes_src.size());
283
284 for (size_t i = 0; i < size_1d; ++i) {
285 // Convert a linear coordinate to an N-d coordinate tuple
286 // z * height * width + y * width + x -> (z, y, x)
287 Array<PrimExpr> src_indices;
288 auto div_factor = size_1d;
289 auto src_linear_index = i;
290 for (auto s : shape) {
291 div_factor /= s;
292 src_indices.push_back(PrimExpr(static_cast<int>((src_linear_index / div_factor))));
293 src_linear_index %= div_factor;
294 }
295 auto dst_indices = MapIndices(src_indices);
296
297 // Convert an N-d coordinate to a linear coordinate
298 // (z, y, x) -> z * height * width + y * width + x
299 size_t dst_linear_index = 0;
300 auto mul_factor = size_1d;
301 for (size_t j = 0; j < dst_indices.size(); ++j) {
302 mul_factor /= dst_shape_int[j];
303 dst_linear_index += dst_indices[j].as<IntImmNode>()->value * mul_factor;
304 }
305 std::copy(bytes_src.begin() + i * elem_bytes, bytes_src.begin() + (i + 1) * elem_bytes,
306 bytes_dst.begin() + dst_linear_index * elem_bytes);
307 }
308
309 auto arr_dst = runtime::NDArray::Empty(dst_shape_int, arr_src->dtype, arr_src->device);
310 arr_dst.CopyFromBytes(bytes_dst.data(), bytes_dst.size());
311 return arr_dst;
312}
313
314IndexMap IndexMap::RenameVariables(
315 const std::function<Optional<String>(const Var& var)>& f_name_map) const {
316 std::unordered_set<std::string> used_names;
317 Map<Var, PrimExpr> var_remap;
318 NameSupply name_supply{""};
319 const IndexMapNode* n = this->get();
320 if (f_name_map != nullptr) {
321 // Collect variables with pre-defined names provided by f_name_map.
322 std::unordered_set<const Object*> visited;
323 std::for_each(n->final_indices.begin(), n->final_indices.end(), [&](const PrimExpr& expr) {
324 PostOrderVisit(expr, [&](const ObjectRef& obj) {
325 if (!obj->IsInstance<VarNode>()) {
326 return;
327 }
328 if (visited.count(obj.get())) {
329 return;
330 }
331 visited.emplace(obj.get());
332 Var var = Downcast<Var>(obj);
333 if (Optional<String> opt_name = f_name_map(var); opt_name.defined()) {
334 String name = opt_name.value();
335 ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false));
336 name_supply->ReserveName(name, /*add_prefix=*/false);
337 var_remap.Set(var, Var(name, var->dtype));
338 }
339 });
340 });
341 }
342
343 for (const Var& initial_index : n->initial_indices) {
344 if (var_remap.count(initial_index)) {
345 // The name of the variable is pre-defined.
346 continue;
347 }
348 String unique_name = name_supply->FreshName(initial_index->name_hint, /*add_prefix=*/false);
349 if (unique_name != initial_index->name_hint) {
350 var_remap.Set(initial_index, Var(unique_name));
351 }
352 }
353
354 auto new_initial_indices = n->initial_indices.Map(
355 [&](const Var& var) { return Downcast<Var>(Substitute(var, var_remap)); });
356 auto new_final_indices =
357 n->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, var_remap); });
358 Optional<IndexMap> new_inverse_index_map = NullOpt;
359 if (n->inverse_index_map.defined()) {
360 new_inverse_index_map = Downcast<IndexMap>(n->inverse_index_map).RenameVariables(f_name_map);
361 }
362 return IndexMap(new_initial_indices, new_final_indices, new_inverse_index_map);
363}
364
365/*!
366 * \brief Auxilarry function to convert an index map to lambda expression in Python.
367 * \param initial_indices The initial indices in the index map.
368 * \param final_indices The final indices in the index map.
369 * \return The lambda expression string.
370 */
371std::string IndexMap2PythonLambdaExpr(const Array<Var>& initial_indices,
372 const Array<PrimExpr>& final_indices) {
373 std::unordered_set<std::string> used_names;
374 Map<Var, PrimExpr> var_remap;
375 std::ostringstream oss;
376 oss << "lambda ";
377 for (size_t i = 0; i < initial_indices.size(); ++i) {
378 if (i != 0) {
379 oss << ", ";
380 }
381 oss << initial_indices[i];
382 }
383 oss << ": (";
384 for (size_t i = 0; i < final_indices.size(); ++i) {
385 if (i != 0) {
386 oss << " ";
387 }
388 oss << final_indices[i];
389 oss << ",";
390 }
391 oss << ")";
392 return oss.str();
393}
394
395String IndexMapNode::ToPythonString(
396 const std::function<Optional<String>(const Var& var)>& f_name_map) const {
397 auto index_map = GetRef<IndexMap>(this).RenameVariables(f_name_map);
398 std::string lambda_expr =
399 IndexMap2PythonLambdaExpr(index_map->initial_indices, index_map->final_indices);
400 if (!index_map->inverse_index_map.defined()) {
401 return String(lambda_expr);
402 }
403 // Also convert the inverse index map.
404 IndexMap inverse = Downcast<IndexMap>(index_map->inverse_index_map.value());
405 std::string inverse_lambda_expr =
406 IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices);
407 std::ostringstream oss;
408 oss << "tvm.tir.IndexMap.from_func(" << lambda_expr
409 << ", inverse_index_map=" << inverse_lambda_expr << ")";
410 return String(oss.str());
411}
412
413IndexMap Substitute(const IndexMap& index_map,
414 std::function<Optional<PrimExpr>(const Var& var)> f_subst) {
415 Array<PrimExpr> new_output =
416 index_map->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, f_subst); });
417 Optional<IndexMap> new_inverse_map = NullOpt;
418 if (index_map->inverse_index_map.defined()) {
419 new_inverse_map = Substitute(Downcast<IndexMap>(index_map->inverse_index_map.value()), f_subst);
420 }
421 return IndexMap{index_map->initial_indices, new_output, new_inverse_map};
422}
423
424TVM_REGISTER_NODE_TYPE(IndexMapNode);
425
426TVM_REGISTER_GLOBAL("tir.IndexMap")
427 .set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices,
428 Optional<IndexMap> inverse_index_map) {
429 return IndexMap(initial_indices, final_indices, inverse_index_map);
430 });
431
432TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices")
433 .set_body_typed([](IndexMap map, Array<PrimExpr> indices) { return map->MapIndices(indices); });
434
435TVM_REGISTER_GLOBAL("tir.IndexMapMapShape").set_body_typed([](IndexMap map, Array<PrimExpr> shape) {
436 return map->MapShape(shape);
437});
438TVM_REGISTER_GLOBAL("tir.IndexMapInverse").set_body_method(&IndexMap::Inverse);
439
440TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray")
441 .set_body_typed([](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); });
442
443TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse")
444 .set_body_typed([](IndexMap forward, Array<Range> initial_ranges) {
445 auto result = forward.NonSurjectiveInverse(initial_ranges);
446 return Array<ObjectRef>{result.first, result.second};
447 });
448
449} // namespace tir
450} // namespace tvm
451