1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file ravel_unravel.h
22 * \brief Index ravel and unraval operations
23 */
24#ifndef TVM_TOPI_DETAIL_RAVEL_UNRAVEL_H_
25#define TVM_TOPI_DETAIL_RAVEL_UNRAVEL_H_
26
27#include <tvm/te/operation.h>
28
29#include <vector>
30
31namespace tvm {
32namespace topi {
33namespace detail {
34
35using namespace tvm::te;
36
37/*!
38 * \brief Flatten the indices to 1D
39 *
40 * \param indices The input coordinates
41 * \param shape Shape of the tensor
42 *
43 * \return The index after flattening
44 */
45inline PrimExpr RavelIndex(Array<PrimExpr> indices, Array<PrimExpr> shape) {
46 ICHECK_EQ(indices.size(), shape.size()) << "indices and shape must have equal size";
47 if (indices.size() == 0U) {
48 return 0;
49 }
50 PrimExpr idx;
51 for (size_t i = 0; i < indices.size(); ++i) {
52 if (i == 0) {
53 idx = indices[i];
54 } else {
55 idx = idx * shape[i] + indices[i];
56 }
57 }
58 return idx;
59}
60
61/*!
62 * \brief Convert flattened index to coordinate array
63 *
64 * \param idx The 1D index
65 * \param shape Shape of the tensor
66 *
67 * \return The coordinate corresponding to the 1D index
68 */
69inline Array<PrimExpr> UnravelIndex(PrimExpr idx, Array<PrimExpr> shape) {
70 std::vector<PrimExpr> indices;
71
72 for (int i = static_cast<int>(shape.size()) - 1; i >= 0; --i) {
73 indices.push_back(indexmod(idx, shape[i]));
74 idx = indexdiv(idx, shape[i]);
75 }
76 std::reverse(indices.begin(), indices.end());
77 return indices;
78}
79
80} // namespace detail
81} // namespace topi
82} // namespace tvm
83#endif // TVM_TOPI_DETAIL_RAVEL_UNRAVEL_H_
84