1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file index_map.cc |
22 | */ |
23 | |
24 | #include <tvm/arith/analyzer.h> |
25 | #include <tvm/arith/int_set.h> |
26 | #include <tvm/arith/iter_affine_map.h> |
27 | #include <tvm/ir/name_supply.h> |
28 | #include <tvm/tir/index_map.h> |
29 | #include <tvm/tir/op.h> |
30 | #include <tvm/tir/stmt_functor.h> |
31 | |
32 | #include <sstream> |
33 | |
34 | namespace tvm { |
35 | namespace tir { |
36 | |
37 | IndexMap::IndexMap(Array<Var> initial_indices, Array<PrimExpr> final_indices, |
38 | Optional<IndexMap> inverse_index_map) { |
39 | auto n = make_object<IndexMapNode>(); |
40 | n->initial_indices = std::move(initial_indices); |
41 | n->final_indices = std::move(final_indices); |
42 | n->inverse_index_map = std::move(inverse_index_map); |
43 | data_ = std::move(n); |
44 | } |
45 | |
46 | IndexMap IndexMap::FromFunc(int ndim, runtime::TypedPackedFunc<Array<PrimExpr>(Array<Var>)> func, |
47 | Optional<IndexMap> inverse_index_map) { |
48 | Array<Var> initial_indices; |
49 | initial_indices.reserve(ndim); |
50 | for (int i = 0; i < ndim; ++i) { |
51 | initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32))); |
52 | } |
53 | return IndexMap(initial_indices, func(initial_indices), std::move(inverse_index_map)); |
54 | } |
55 | |
56 | std::pair<IndexMap, PrimExpr> IndexMapInverseImpl(const IndexMap& self, |
57 | const Array<Range>& initial_ranges, |
58 | arith::IterMapLevel check_level) { |
59 | if (self->inverse_index_map.defined()) { |
60 | // return the pre-defined inverse index map if exists. In this |
61 | // case, the user-defined inverse is assumed to be correct and |
62 | // bijective. |
63 | PrimExpr padding_predicate = Bool(false); |
64 | return {Downcast<IndexMap>(self->inverse_index_map.value()), padding_predicate}; |
65 | } |
66 | |
67 | // Dummy variables to represent the inverse's inputs. |
68 | Array<Var> output_vars; |
69 | for (size_t i = 0; i < self->final_indices.size(); i++) { |
70 | PrimExpr index = self->final_indices[i]; |
71 | // TODO(Lunderberg): Better names for these variables. A variable |
72 | // that is passed through unmodified (`index` is an element of |
73 | // `initial_indices`) should use that input index's name. A pair |
74 | // of output indices variables split from a single input index |
75 | // should be named (X.outer,X.inner). |
76 | std::stringstream ss; |
77 | ss << "axis" << i; |
78 | Var var_index(ss.str(), index.dtype()); |
79 | output_vars.push_back(var_index); |
80 | } |
81 | |
82 | // Dummy ranges for the extent of each input. |
83 | Map<Var, Range> input_iters; |
84 | ICHECK_EQ(self->initial_indices.size(), initial_ranges.size()); |
85 | for (size_t i = 0; i < initial_ranges.size(); i++) { |
86 | input_iters.Set(self->initial_indices[i], initial_ranges[i]); |
87 | } |
88 | |
89 | // Unpack the output indices into linear combinations of the initial |
90 | // indices. |
91 | arith::Analyzer analyzer; |
92 | auto padded_iter_map = DetectIterMap(self->final_indices, input_iters, /* predicate = */ 1, |
93 | /*check_level=*/check_level, &analyzer, |
94 | /*simplify_trivial_iterators=*/false); |
95 | CHECK(padded_iter_map->errors.empty()) << "Could not parse mapping as sum of iterators. " |
96 | << "Error: " << padded_iter_map->errors[0]; |
97 | |
98 | // Determine expressions for the input variables, in terms of the |
99 | // output variables. |
100 | Map<Var, PrimExpr> inverse_exprs_map = InverseAffineIterMap( |
101 | padded_iter_map->indices, Array<PrimExpr>(output_vars.begin(), output_vars.end())); |
102 | |
103 | // Unpack the map to an array, maintaining the same parameter order. |
104 | Array<PrimExpr> inverse_exprs; |
105 | for (int i = 0, n = self->initial_indices.size(); i < n; ++i) { |
106 | Var index = self->initial_indices[i]; |
107 | PrimExpr expr; |
108 | if (is_one(initial_ranges[i]->extent) && !inverse_exprs_map.count(index)) { |
109 | expr = initial_ranges[i]->min; |
110 | } else { |
111 | expr = inverse_exprs_map.at(index); |
112 | } |
113 | inverse_exprs.push_back(analyzer.Simplify(expr)); |
114 | } |
115 | |
116 | PrimExpr padding_predicate = padded_iter_map->padding_predicate; |
117 | padding_predicate = arith::NormalizeIterMapToExpr(padding_predicate); |
118 | padding_predicate = Substitute(padding_predicate, inverse_exprs_map); |
119 | |
120 | { |
121 | auto output_ranges = self->MapRanges(initial_ranges); |
122 | ICHECK_EQ(output_ranges.size(), output_vars.size()); |
123 | |
124 | arith::Analyzer analyzer; |
125 | for (size_t i = 0; i < output_vars.size(); ++i) { |
126 | analyzer.Bind(output_vars[i], output_ranges[i]); |
127 | } |
128 | |
129 | // Additional simplification steps required to unwrap nested floordiv/floormod |
130 | padding_predicate = analyzer.Simplify(padding_predicate, 10); |
131 | } |
132 | |
133 | return {IndexMap(output_vars, inverse_exprs), padding_predicate}; |
134 | } |
135 | |
136 | std::pair<IndexMap, PrimExpr> IndexMap::NonSurjectiveInverse(Array<Range> initial_ranges) const { |
137 | return IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::NoCheck); |
138 | } |
139 | |
140 | IndexMap IndexMap::Inverse(Array<Range> initial_ranges) const { |
141 | auto [inverse, padding_predicate] = |
142 | IndexMapInverseImpl(*this, initial_ranges, arith::IterMapLevel::Bijective); |
143 | arith::Analyzer analyzer; |
144 | CHECK(analyzer.CanProve(!padding_predicate)) |
145 | << "Bijective inverse should not contain padding, but inverse of " << *this << " over range " |
146 | << initial_ranges << " resulted in a padding predicate of " << padding_predicate; |
147 | return inverse; |
148 | } |
149 | |
150 | Array<PrimExpr> IndexMapNode::MapIndices(const Array<PrimExpr>& indices, |
151 | arith::Analyzer* analyzer) const { |
152 | ICHECK_EQ(indices.size(), initial_indices.size()); |
153 | |
154 | Map<Var, PrimExpr> vmap; |
155 | |
156 | for (size_t i = 0; i < initial_indices.size(); i++) { |
157 | vmap.Set(initial_indices[i], indices[i]); |
158 | } |
159 | |
160 | arith::Analyzer local_analyzer; |
161 | if (!analyzer) { |
162 | analyzer = &local_analyzer; |
163 | } |
164 | |
165 | Array<PrimExpr> output = final_indices.Map([&](PrimExpr index) { |
166 | PrimExpr result = SubstituteWithDataTypeLegalization( |
167 | std::move(index), [&](const Var& var) { return vmap.Get(var); }); |
168 | return analyzer->Simplify(result); |
169 | }); |
170 | return output; |
171 | } |
172 | |
173 | Array<Range> IndexMapNode::MapRanges(const Array<Range>& ranges, arith::Analyzer* analyzer) const { |
174 | ICHECK_EQ(ranges.size(), initial_indices.size()); |
175 | |
176 | Map<Var, Range> input_iters; |
177 | for (size_t i = 0; i < initial_indices.size(); i++) { |
178 | input_iters.Set(initial_indices[i], ranges[i]); |
179 | } |
180 | |
181 | arith::Analyzer local_analyzer; |
182 | if (!analyzer) { |
183 | analyzer = &local_analyzer; |
184 | } |
185 | |
186 | auto iter_map = DetectIterMap(final_indices, input_iters, /* predicate = */ 1, |
187 | /*check_level=*/arith::IterMapLevel::NoCheck, analyzer, |
188 | /*simplify_trivial_iterators=*/false); |
189 | Array<Range> output; |
190 | if (iter_map->indices.size()) { |
191 | // Preferred route, requires the map to be expressible as an |
192 | // affine sum. Since the terms are orthogonal, the extent of the |
193 | // sum is the extent of the largest term. |
194 | for (const auto& index : iter_map->indices) { |
195 | Optional<PrimExpr> extent = NullOpt; |
196 | for (const auto& term : index->args) { |
197 | PrimExpr term_extent = term->extent * term->scale; |
198 | if (extent.defined()) { |
199 | extent = tvm::max(extent.value(), term_extent); |
200 | } else { |
201 | extent = term_extent; |
202 | } |
203 | } |
204 | output.push_back(Range::FromMinExtent(index->base, extent.value_or(1))); |
205 | } |
206 | |
207 | } else { |
208 | // Fall-back method, more general but can ignore intended padding. |
209 | // For example, [N] mapped through i=>[i//4,i%4] should have shape |
210 | // [ceildiv(N,4), 4]. However, for N<4, this method instead |
211 | // results in a shape [1, N]. |
212 | std::unordered_map<const VarNode*, arith::IntSet> dom_map; |
213 | for (size_t i = 0; i < initial_indices.size(); i++) { |
214 | dom_map[initial_indices[i].get()] = arith::IntSet::FromRange(ranges[i]); |
215 | } |
216 | |
217 | for (const auto& final_index : final_indices) { |
218 | auto int_set = arith::EvalSet(final_index, dom_map); |
219 | output.push_back(Range::FromMinExtent(analyzer->Simplify(int_set.min()), |
220 | analyzer->Simplify(int_set.max() - int_set.min() + 1))); |
221 | } |
222 | } |
223 | auto output_dtype = [&]() { |
224 | int max_bits = 0; |
225 | for (const auto& range : ranges) { |
226 | max_bits = std::max(max_bits, range->extent.dtype().bits()); |
227 | } |
228 | return DataType::Int(max_bits); |
229 | }(); |
230 | output.MutateByApply([&](const Range& range) { |
231 | if (range->min.dtype() != output_dtype || range->extent.dtype() != output_dtype) { |
232 | return Range::FromMinExtent(cast(output_dtype, range->min), |
233 | cast(output_dtype, range->extent)); |
234 | } else { |
235 | return range; |
236 | } |
237 | }); |
238 | return output; |
239 | } |
240 | |
241 | Array<PrimExpr> IndexMapNode::MapShape(const Array<PrimExpr>& shape, |
242 | arith::Analyzer* analyzer) const { |
243 | ICHECK_EQ(shape.size(), initial_indices.size()); |
244 | |
245 | Array<Range> ranges; |
246 | for (auto& dim : shape) { |
247 | ranges.push_back(Range(make_zero(dim.dtype()), dim)); |
248 | } |
249 | Array<Range> mapped = MapRanges(std::move(ranges), analyzer); |
250 | |
251 | Array<PrimExpr> output; |
252 | for (auto& range : mapped) { |
253 | ICHECK(is_zero(range->min)); |
254 | output.push_back(range->extent); |
255 | } |
256 | |
257 | return output; |
258 | } |
259 | |
260 | runtime::NDArray IndexMapNode::MapNDArray(runtime::NDArray arr_src) const { |
261 | auto shape = arr_src.Shape(); |
262 | ICHECK(shape.size() == initial_indices.size()) |
263 | << "The rank of the input array should be " << initial_indices.size() << " but got " |
264 | << shape.size(); |
265 | size_t size_1d = 1; |
266 | Array<PrimExpr> orig_shape; |
267 | for (size_t i = 0; i < shape.size(); ++i) { |
268 | size_1d *= shape[i]; |
269 | orig_shape.push_back(PrimExpr(static_cast<int>((shape[i])))); |
270 | } |
271 | auto dst_shape = MapShape(orig_shape); |
272 | |
273 | std::vector<int64_t> dst_shape_int; |
274 | for (size_t i = 0; i < dst_shape.size(); ++i) { |
275 | dst_shape_int.push_back(dst_shape[i].as<IntImmNode>()->value); |
276 | } |
277 | |
278 | auto elem_bytes = (arr_src->dtype.bits / 8) * arr_src->dtype.lanes; |
279 | std::vector<uint8_t> bytes_src(size_1d * elem_bytes); |
280 | arr_src.CopyToBytes(bytes_src.data(), bytes_src.size()); |
281 | |
282 | std::vector<uint8_t> bytes_dst(bytes_src.size()); |
283 | |
284 | for (size_t i = 0; i < size_1d; ++i) { |
285 | // Convert a linear coordinate to an N-d coordinate tuple |
286 | // z * height * width + y * width + x -> (z, y, x) |
287 | Array<PrimExpr> src_indices; |
288 | auto div_factor = size_1d; |
289 | auto src_linear_index = i; |
290 | for (auto s : shape) { |
291 | div_factor /= s; |
292 | src_indices.push_back(PrimExpr(static_cast<int>((src_linear_index / div_factor)))); |
293 | src_linear_index %= div_factor; |
294 | } |
295 | auto dst_indices = MapIndices(src_indices); |
296 | |
297 | // Convert an N-d coordinate to a linear coordinate |
298 | // (z, y, x) -> z * height * width + y * width + x |
299 | size_t dst_linear_index = 0; |
300 | auto mul_factor = size_1d; |
301 | for (size_t j = 0; j < dst_indices.size(); ++j) { |
302 | mul_factor /= dst_shape_int[j]; |
303 | dst_linear_index += dst_indices[j].as<IntImmNode>()->value * mul_factor; |
304 | } |
305 | std::copy(bytes_src.begin() + i * elem_bytes, bytes_src.begin() + (i + 1) * elem_bytes, |
306 | bytes_dst.begin() + dst_linear_index * elem_bytes); |
307 | } |
308 | |
309 | auto arr_dst = runtime::NDArray::Empty(dst_shape_int, arr_src->dtype, arr_src->device); |
310 | arr_dst.CopyFromBytes(bytes_dst.data(), bytes_dst.size()); |
311 | return arr_dst; |
312 | } |
313 | |
314 | IndexMap IndexMap::RenameVariables( |
315 | const std::function<Optional<String>(const Var& var)>& f_name_map) const { |
316 | std::unordered_set<std::string> used_names; |
317 | Map<Var, PrimExpr> var_remap; |
318 | NameSupply name_supply{"" }; |
319 | const IndexMapNode* n = this->get(); |
320 | if (f_name_map != nullptr) { |
321 | // Collect variables with pre-defined names provided by f_name_map. |
322 | std::unordered_set<const Object*> visited; |
323 | std::for_each(n->final_indices.begin(), n->final_indices.end(), [&](const PrimExpr& expr) { |
324 | PostOrderVisit(expr, [&](const ObjectRef& obj) { |
325 | if (!obj->IsInstance<VarNode>()) { |
326 | return; |
327 | } |
328 | if (visited.count(obj.get())) { |
329 | return; |
330 | } |
331 | visited.emplace(obj.get()); |
332 | Var var = Downcast<Var>(obj); |
333 | if (Optional<String> opt_name = f_name_map(var); opt_name.defined()) { |
334 | String name = opt_name.value(); |
335 | ICHECK(!name_supply->ContainsName(name, /*add_prefix=*/false)); |
336 | name_supply->ReserveName(name, /*add_prefix=*/false); |
337 | var_remap.Set(var, Var(name, var->dtype)); |
338 | } |
339 | }); |
340 | }); |
341 | } |
342 | |
343 | for (const Var& initial_index : n->initial_indices) { |
344 | if (var_remap.count(initial_index)) { |
345 | // The name of the variable is pre-defined. |
346 | continue; |
347 | } |
348 | String unique_name = name_supply->FreshName(initial_index->name_hint, /*add_prefix=*/false); |
349 | if (unique_name != initial_index->name_hint) { |
350 | var_remap.Set(initial_index, Var(unique_name)); |
351 | } |
352 | } |
353 | |
354 | auto new_initial_indices = n->initial_indices.Map( |
355 | [&](const Var& var) { return Downcast<Var>(Substitute(var, var_remap)); }); |
356 | auto new_final_indices = |
357 | n->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, var_remap); }); |
358 | Optional<IndexMap> new_inverse_index_map = NullOpt; |
359 | if (n->inverse_index_map.defined()) { |
360 | new_inverse_index_map = Downcast<IndexMap>(n->inverse_index_map).RenameVariables(f_name_map); |
361 | } |
362 | return IndexMap(new_initial_indices, new_final_indices, new_inverse_index_map); |
363 | } |
364 | |
365 | /*! |
366 | * \brief Auxilarry function to convert an index map to lambda expression in Python. |
367 | * \param initial_indices The initial indices in the index map. |
368 | * \param final_indices The final indices in the index map. |
369 | * \return The lambda expression string. |
370 | */ |
371 | std::string IndexMap2PythonLambdaExpr(const Array<Var>& initial_indices, |
372 | const Array<PrimExpr>& final_indices) { |
373 | std::unordered_set<std::string> used_names; |
374 | Map<Var, PrimExpr> var_remap; |
375 | std::ostringstream oss; |
376 | oss << "lambda " ; |
377 | for (size_t i = 0; i < initial_indices.size(); ++i) { |
378 | if (i != 0) { |
379 | oss << ", " ; |
380 | } |
381 | oss << initial_indices[i]; |
382 | } |
383 | oss << ": (" ; |
384 | for (size_t i = 0; i < final_indices.size(); ++i) { |
385 | if (i != 0) { |
386 | oss << " " ; |
387 | } |
388 | oss << final_indices[i]; |
389 | oss << "," ; |
390 | } |
391 | oss << ")" ; |
392 | return oss.str(); |
393 | } |
394 | |
395 | String IndexMapNode::ToPythonString( |
396 | const std::function<Optional<String>(const Var& var)>& f_name_map) const { |
397 | auto index_map = GetRef<IndexMap>(this).RenameVariables(f_name_map); |
398 | std::string lambda_expr = |
399 | IndexMap2PythonLambdaExpr(index_map->initial_indices, index_map->final_indices); |
400 | if (!index_map->inverse_index_map.defined()) { |
401 | return String(lambda_expr); |
402 | } |
403 | // Also convert the inverse index map. |
404 | IndexMap inverse = Downcast<IndexMap>(index_map->inverse_index_map.value()); |
405 | std::string inverse_lambda_expr = |
406 | IndexMap2PythonLambdaExpr(inverse->initial_indices, inverse->final_indices); |
407 | std::ostringstream oss; |
408 | oss << "tvm.tir.IndexMap.from_func(" << lambda_expr |
409 | << ", inverse_index_map=" << inverse_lambda_expr << ")" ; |
410 | return String(oss.str()); |
411 | } |
412 | |
413 | IndexMap Substitute(const IndexMap& index_map, |
414 | std::function<Optional<PrimExpr>(const Var& var)> f_subst) { |
415 | Array<PrimExpr> new_output = |
416 | index_map->final_indices.Map([&](const PrimExpr& expr) { return Substitute(expr, f_subst); }); |
417 | Optional<IndexMap> new_inverse_map = NullOpt; |
418 | if (index_map->inverse_index_map.defined()) { |
419 | new_inverse_map = Substitute(Downcast<IndexMap>(index_map->inverse_index_map.value()), f_subst); |
420 | } |
421 | return IndexMap{index_map->initial_indices, new_output, new_inverse_map}; |
422 | } |
423 | |
424 | TVM_REGISTER_NODE_TYPE(IndexMapNode); |
425 | |
426 | TVM_REGISTER_GLOBAL("tir.IndexMap" ) |
427 | .set_body_typed([](Array<Var> initial_indices, Array<PrimExpr> final_indices, |
428 | Optional<IndexMap> inverse_index_map) { |
429 | return IndexMap(initial_indices, final_indices, inverse_index_map); |
430 | }); |
431 | |
432 | TVM_REGISTER_GLOBAL("tir.IndexMapMapIndices" ) |
433 | .set_body_typed([](IndexMap map, Array<PrimExpr> indices) { return map->MapIndices(indices); }); |
434 | |
435 | TVM_REGISTER_GLOBAL("tir.IndexMapMapShape" ).set_body_typed([](IndexMap map, Array<PrimExpr> shape) { |
436 | return map->MapShape(shape); |
437 | }); |
438 | TVM_REGISTER_GLOBAL("tir.IndexMapInverse" ).set_body_method(&IndexMap::Inverse); |
439 | |
440 | TVM_REGISTER_GLOBAL("tir.IndexMapMapNDArray" ) |
441 | .set_body_typed([](IndexMap map, runtime::NDArray arr) { return map->MapNDArray(arr); }); |
442 | |
443 | TVM_REGISTER_GLOBAL("tir.IndexMapNonSurjectiveInverse" ) |
444 | .set_body_typed([](IndexMap forward, Array<Range> initial_ranges) { |
445 | auto result = forward.NonSurjectiveInverse(initial_ranges); |
446 | return Array<ObjectRef>{result.first, result.second}; |
447 | }); |
448 | |
449 | } // namespace tir |
450 | } // namespace tvm |
451 | |