1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #ifndef TVM_SUPPORT_ARRAY_H_ |
20 | #define TVM_SUPPORT_ARRAY_H_ |
21 | #include <tvm/ir/expr.h> |
22 | #include <tvm/runtime/container/array.h> |
23 | |
24 | #include <vector> |
25 | |
26 | namespace tvm { |
27 | namespace support { |
28 | |
29 | /*! |
30 | * \brief Checks if two arrays contain the same objects |
31 | * \tparam T The type of objects in the array |
32 | * \param a The first array |
33 | * \param b The second array |
34 | * \return A boolean indicating if they are the same |
35 | */ |
36 | template <class T> |
37 | inline bool ArrayWithSameContent(const Array<T>& a, const Array<T>& b) { |
38 | if (a.size() != b.size()) { |
39 | return false; |
40 | } |
41 | int n = a.size(); |
42 | for (int i = 0; i < n; ++i) { |
43 | if (!a[i].same_as(b[i])) { |
44 | return false; |
45 | } |
46 | } |
47 | return true; |
48 | } |
49 | |
50 | /*! |
51 | * \brief Checks if two arrays contain the same objects |
52 | * \tparam T The type of objects in the array |
53 | * \param a The first array |
54 | * \param b The second array |
55 | * \return A boolean indicating if they are the same |
56 | */ |
57 | template <class T> |
58 | inline bool ArrayWithSameContent(const std::vector<T*>& a, const std::vector<T*>& b) { |
59 | if (a.size() != b.size()) { |
60 | return false; |
61 | } |
62 | int n = a.size(); |
63 | for (int i = 0; i < n; ++i) { |
64 | if (a[i] != b[i]) { |
65 | return false; |
66 | } |
67 | } |
68 | return true; |
69 | } |
70 | |
71 | /*! |
72 | * \brief Convert a tvm::runtime::Array to std::vector |
73 | * \tparam TSrc The type of elements in the source Array |
74 | * \tparam TDst The type of elements in the result vector |
75 | * \return The result vector |
76 | */ |
77 | template <class TSrc, class TDst> |
78 | inline std::vector<TDst> AsVector(const Array<TSrc>& vec); |
79 | |
80 | /*! |
81 | * \brief Convert a std::vector to tvm::runtime::Array |
82 | * \tparam TSrc The type of elements in the source vector |
83 | * \tparam TDst The type of elements in the result Array |
84 | * \return The result vector |
85 | */ |
86 | template <class TSrc, class TDst> |
87 | inline Array<TDst> AsArray(const std::vector<TSrc>& vec); |
88 | |
89 | /*! |
90 | * \brief Get the shape tuple as array |
91 | * \param shape The shape tuple |
92 | * \return An array of the shape tuple |
93 | */ |
94 | inline Array<Integer> AsArray(const ShapeTuple& shape) { |
95 | Array<Integer> result; |
96 | result.reserve(shape->size); |
97 | for (ShapeTuple::index_type i : shape) { |
98 | result.push_back(Integer(i)); |
99 | } |
100 | return result; |
101 | } |
102 | |
103 | /*! |
104 | * \brief Concatenate a list of arrays into a single array |
105 | * \tparam T The type of elements in the arrays |
106 | * \tparam Iterator The type of the iterator into the list of arrays |
107 | * \param begin The begin iterator to the array list |
108 | * \param end The end iterator to the array list |
109 | * \return The concatenated array |
110 | */ |
111 | template <class T, class Iterator> |
112 | inline Array<T> ConcatArrayList(Iterator begin, Iterator end) { |
113 | int size = 0; |
114 | for (Iterator it = begin; it != end; ++it) { |
115 | size += (*it).size(); |
116 | } |
117 | Array<T> result; |
118 | result.reserve(size); |
119 | for (Iterator it = begin; it != end; ++it) { |
120 | const auto& item = *it; |
121 | result.insert(result.end(), item.begin(), item.end()); |
122 | } |
123 | return result; |
124 | } |
125 | |
126 | /********** Implementation details of AsVector<TSrc, TDst> **********/ |
127 | |
128 | namespace details { |
129 | |
130 | template <class TSrc, class TDst> |
131 | struct AsVectorImpl {}; |
132 | |
133 | template <class TSrc> |
134 | struct AsVectorImpl<TSrc, TSrc> { |
135 | inline std::vector<TSrc> operator()(const Array<TSrc>& vec) const { |
136 | return std::vector<TSrc>(vec.begin(), vec.end()); |
137 | } |
138 | }; |
139 | |
140 | template <class TSrcObjectRef> |
141 | struct AsVectorImpl<TSrcObjectRef, int> { |
142 | inline std::vector<int> operator()(const Array<TSrcObjectRef>& vec) const { |
143 | std::vector<int> results; |
144 | for (const TSrcObjectRef& x : vec) { |
145 | const auto* n = x.template as<IntImmNode>(); |
146 | ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); |
147 | results.push_back(n->value); |
148 | } |
149 | return results; |
150 | } |
151 | }; |
152 | |
153 | template <class TSrcObjectRef> |
154 | struct AsVectorImpl<TSrcObjectRef, int64_t> { |
155 | inline std::vector<int64_t> operator()(const Array<TSrcObjectRef>& vec) const { |
156 | std::vector<int64_t> results; |
157 | for (const TSrcObjectRef& x : vec) { |
158 | const auto* n = x.template as<IntImmNode>(); |
159 | ICHECK(n) << "TypeError: Expects IntImm, but gets: " << x->GetTypeKey(); |
160 | results.push_back(n->value); |
161 | } |
162 | return results; |
163 | } |
164 | }; |
165 | |
166 | template <class TSrcObjectRef> |
167 | struct AsVectorImpl<TSrcObjectRef, double> { |
168 | inline std::vector<double> operator()(const Array<TSrcObjectRef>& array) const { |
169 | std::vector<double> results; |
170 | for (const TSrcObjectRef& x : array) { |
171 | const auto* n = x.template as<FloatImmNode>(); |
172 | ICHECK(n) << "TypeError: Expects FloatImm, but gets: " << x->GetTypeKey(); |
173 | results.push_back(n->value); |
174 | } |
175 | return results; |
176 | } |
177 | }; |
178 | } // namespace details |
179 | |
180 | /********** Implementation details of AsArray<TSrc, TDst> **********/ |
181 | |
182 | namespace details { |
183 | |
184 | template <class TSrc, class TDst> |
185 | struct AsArrayImpl {}; |
186 | |
187 | template <class TSrc> |
188 | struct AsArrayImpl<TSrc, TSrc> { |
189 | inline Array<TSrc> operator()(const std::vector<TSrc>& vec) const { |
190 | return Array<TSrc>(vec.begin(), vec.end()); |
191 | } |
192 | }; |
193 | |
194 | template <class TDstObjectRef> |
195 | struct AsArrayImpl<int, TDstObjectRef> { |
196 | inline Array<TDstObjectRef> operator()(const std::vector<int>& vec) const { |
197 | Array<TDstObjectRef> result; |
198 | result.reserve(vec.size()); |
199 | for (int x : vec) { |
200 | result.push_back(Integer(x)); |
201 | } |
202 | return result; |
203 | } |
204 | }; |
205 | |
206 | template <class TDstObjectRef> |
207 | struct AsArrayImpl<int64_t, TDstObjectRef> { |
208 | inline Array<TDstObjectRef> operator()(const std::vector<int64_t>& vec) const { |
209 | Array<TDstObjectRef> result; |
210 | result.reserve(vec.size()); |
211 | for (int64_t x : vec) { |
212 | result.push_back(Integer(x)); |
213 | } |
214 | return result; |
215 | } |
216 | }; |
217 | |
218 | template <class TDstObjectRef> |
219 | struct AsArrayImpl<double, TDstObjectRef> { |
220 | inline Array<TDstObjectRef> operator()(const std::vector<double>& vec) const { |
221 | Array<TDstObjectRef> result; |
222 | result.reserve(vec.size()); |
223 | for (double x : vec) { |
224 | result.push_back(FloatImm(tvm::DataType::Float(64), x)); |
225 | } |
226 | return result; |
227 | } |
228 | }; |
229 | |
230 | } // namespace details |
231 | |
232 | template <class TSrc, class TDst> |
233 | inline std::vector<TDst> AsVector(const Array<TSrc>& vec) { |
234 | return details::AsVectorImpl<TSrc, TDst>()(vec); |
235 | } |
236 | |
237 | template <class TSrc, class TDst> |
238 | inline Array<TDst> AsArray(const std::vector<TSrc>& vec) { |
239 | return details::AsArrayImpl<TSrc, TDst>()(vec); |
240 | } |
241 | |
242 | } // namespace support |
243 | } // namespace tvm |
244 | #endif // TVM_SUPPORT_ARRAY_H_ |
245 | |