1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_SUPPORT_ARRAY_H_
20#define TVM_SUPPORT_ARRAY_H_
21#include <tvm/ir/expr.h>
22#include <tvm/runtime/container/array.h>
23
24#include <vector>
25
26namespace tvm {
27namespace support {
28
29/*!
30 * \brief Checks if two arrays contain the same objects
31 * \tparam T The type of objects in the array
32 * \param a The first array
33 * \param b The second array
34 * \return A boolean indicating if they are the same
35 */
36template <class T>
37inline bool ArrayWithSameContent(const Array<T>& a, const Array<T>& b) {
38 if (a.size() != b.size()) {
39 return false;
40 }
41 int n = a.size();
42 for (int i = 0; i < n; ++i) {
43 if (!a[i].same_as(b[i])) {
44 return false;
45 }
46 }
47 return true;
48}
49
50/*!
51 * \brief Checks if two arrays contain the same objects
52 * \tparam T The type of objects in the array
53 * \param a The first array
54 * \param b The second array
55 * \return A boolean indicating if they are the same
56 */
57template <class T>
58inline bool ArrayWithSameContent(const std::vector<T*>& a, const std::vector<T*>& b) {
59 if (a.size() != b.size()) {
60 return false;
61 }
62 int n = a.size();
63 for (int i = 0; i < n; ++i) {
64 if (a[i] != b[i]) {
65 return false;
66 }
67 }
68 return true;
69}
70
71/*!
72 * \brief Convert a tvm::runtime::Array to std::vector
73 * \tparam TSrc The type of elements in the source Array
74 * \tparam TDst The type of elements in the result vector
75 * \return The result vector
76 */
77template <class TSrc, class TDst>
78inline std::vector<TDst> AsVector(const Array<TSrc>& vec);
79
80/*!
81 * \brief Convert a std::vector to tvm::runtime::Array
82 * \tparam TSrc The type of elements in the source vector
83 * \tparam TDst The type of elements in the result Array
84 * \return The result vector
85 */
86template <class TSrc, class TDst>
87inline Array<TDst> AsArray(const std::vector<TSrc>& vec);
88
89/*!
90 * \brief Get the shape tuple as array
91 * \param shape The shape tuple
92 * \return An array of the shape tuple
93 */
94inline Array<Integer> AsArray(const ShapeTuple& shape) {
95 Array<Integer> result;
96 result.reserve(shape->size);
97 for (ShapeTuple::index_type i : shape) {
98 result.push_back(Integer(i));
99 }
100 return result;
101}
102
103/*!
104 * \brief Concatenate a list of arrays into a single array
105 * \tparam T The type of elements in the arrays
106 * \tparam Iterator The type of the iterator into the list of arrays
107 * \param begin The begin iterator to the array list
108 * \param end The end iterator to the array list
109 * \return The concatenated array
110 */
111template <class T, class Iterator>
112inline Array<T> ConcatArrayList(Iterator begin, Iterator end) {
113 int size = 0;
114 for (Iterator it = begin; it != end; ++it) {
115 size += (*it).size();
116 }
117 Array<T> result;
118 result.reserve(size);
119 for (Iterator it = begin; it != end; ++it) {
120 const auto& item = *it;
121 result.insert(result.end(), item.begin(), item.end());
122 }
123 return result;
124}
125
126/********** Implementation details of AsVector<TSrc, TDst> **********/
127
128namespace details {
129
130template <class TSrc, class TDst>
131struct AsVectorImpl {};
132
133template <class TSrc>
134struct AsVectorImpl<TSrc, TSrc> {
135 inline std::vector<TSrc> operator()(const Array<TSrc>& vec) const {
136 return std::vector<TSrc>(vec.begin(), vec.end());
137 }
138};
139
140template <class TSrcObjectRef>
141struct AsVectorImpl<TSrcObjectRef, int> {
142 inline std::vector<int> operator()(const Array<TSrcObjectRef>& vec) const {
143 std::vector<int> results;
144 for (const TSrcObjectRef& x : vec) {
145 const auto* n = x.template as<IntImmNode>();
146 ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey();
147 results.push_back(n->value);
148 }
149 return results;
150 }
151};
152
153template <class TSrcObjectRef>
154struct AsVectorImpl<TSrcObjectRef, int64_t> {
155 inline std::vector<int64_t> operator()(const Array<TSrcObjectRef>& vec) const {
156 std::vector<int64_t> results;
157 for (const TSrcObjectRef& x : vec) {
158 const auto* n = x.template as<IntImmNode>();
159 ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey();
160 results.push_back(n->value);
161 }
162 return results;
163 }
164};
165
166template <class TSrcObjectRef>
167struct AsVectorImpl<TSrcObjectRef, double> {
168 inline std::vector<double> operator()(const Array<TSrcObjectRef>& array) const {
169 std::vector<double> results;
170 for (const TSrcObjectRef& x : array) {
171 const auto* n = x.template as<FloatImmNode>();
172 ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey();
173 results.push_back(n->value);
174 }
175 return results;
176 }
177};
178} // namespace details
179
180/********** Implementation details of AsArray<TSrc, TDst> **********/
181
182namespace details {
183
184template <class TSrc, class TDst>
185struct AsArrayImpl {};
186
187template <class TSrc>
188struct AsArrayImpl<TSrc, TSrc> {
189 inline Array<TSrc> operator()(const std::vector<TSrc>& vec) const {
190 return Array<TSrc>(vec.begin(), vec.end());
191 }
192};
193
194template <class TDstObjectRef>
195struct AsArrayImpl<int, TDstObjectRef> {
196 inline Array<TDstObjectRef> operator()(const std::vector<int>& vec) const {
197 Array<TDstObjectRef> result;
198 result.reserve(vec.size());
199 for (int x : vec) {
200 result.push_back(Integer(x));
201 }
202 return result;
203 }
204};
205
206template <class TDstObjectRef>
207struct AsArrayImpl<int64_t, TDstObjectRef> {
208 inline Array<TDstObjectRef> operator()(const std::vector<int64_t>& vec) const {
209 Array<TDstObjectRef> result;
210 result.reserve(vec.size());
211 for (int64_t x : vec) {
212 result.push_back(Integer(x));
213 }
214 return result;
215 }
216};
217
218template <class TDstObjectRef>
219struct AsArrayImpl<double, TDstObjectRef> {
220 inline Array<TDstObjectRef> operator()(const std::vector<double>& vec) const {
221 Array<TDstObjectRef> result;
222 result.reserve(vec.size());
223 for (double x : vec) {
224 result.push_back(FloatImm(tvm::DataType::Float(64), x));
225 }
226 return result;
227 }
228};
229
230} // namespace details
231
232template <class TSrc, class TDst>
233inline std::vector<TDst> AsVector(const Array<TSrc>& vec) {
234 return details::AsVectorImpl<TSrc, TDst>()(vec);
235}
236
237template <class TSrc, class TDst>
238inline Array<TDst> AsArray(const std::vector<TSrc>& vec) {
239 return details::AsArrayImpl<TSrc, TDst>()(vec);
240}
241
242} // namespace support
243} // namespace tvm
244#endif // TVM_SUPPORT_ARRAY_H_
245