1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file auto_scheduler/utils.h |
22 | * \brief Common utilities. |
23 | */ |
24 | |
25 | #ifndef TVM_AUTO_SCHEDULER_UTILS_H_ |
26 | #define TVM_AUTO_SCHEDULER_UTILS_H_ |
27 | |
28 | #include <dmlc/common.h> |
29 | #include <tvm/tir/expr.h> |
30 | |
31 | #include <algorithm> |
32 | #include <deque> |
33 | #include <exception> |
34 | #include <future> |
35 | #include <iomanip> |
36 | #include <numeric> |
37 | #include <random> |
38 | #include <string> |
39 | #include <thread> |
40 | #include <tuple> |
41 | #include <utility> |
42 | #include <vector> |
43 | |
44 | namespace std { |
45 | |
46 | /*! \brief Hash function for std::pair */ |
47 | template <typename T1, typename T2> |
48 | struct hash<std::pair<T1, T2>> { |
49 | std::size_t operator()(const std::pair<T1, T2>& k) const { |
50 | return ::dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second)); |
51 | } |
52 | }; |
53 | |
54 | /*! \brief Hash function for std::tuple */ |
55 | template <typename T1, typename T2, typename T3> |
56 | struct hash<std::tuple<T1, T2, T3>> { |
57 | std::size_t operator()(const std::tuple<T1, T2, T3>& k) const { |
58 | return ::dmlc::HashCombine( |
59 | ::dmlc::HashCombine(std::hash<T1>()(std::get<0>(k)), std::hash<T2>()(std::get<1>(k))), |
60 | std::hash<T3>()(std::get<2>(k))); |
61 | } |
62 | }; |
63 | |
64 | } // namespace std |
65 | |
66 | namespace tvm { |
67 | namespace auto_scheduler { |
68 | |
69 | /********** Utilities for Array, std::vector, std::string **********/ |
70 | /*! \brief Get the first appearance index of elements in an Array */ |
71 | template <typename T> |
72 | inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, Array<Integer>* indices) { |
73 | for (const auto& v : to_locate) { |
74 | auto it = std::find(array.begin(), array.end(), v); |
75 | if (it != array.end()) { |
76 | indices->push_back(it - array.begin()); |
77 | } else { |
78 | LOG(FATAL) << "Cannot find the item" ; |
79 | } |
80 | } |
81 | } |
82 | |
83 | /*! \brief Get the first appearance index of an element in an Array */ |
84 | template <typename T> |
85 | inline int GetIndex(const Array<T>& array, const T& to_locate) { |
86 | for (size_t i = 0; i < array.size(); ++i) { |
87 | if (array[i] == to_locate) { |
88 | return i; |
89 | } |
90 | } |
91 | LOG(FATAL) << "Cannot find the item" ; |
92 | } |
93 | |
94 | /*! \brief Delete the item in a std::vector if it exists. */ |
95 | template <typename T> |
96 | inline void FindAndDeleteItem(std::vector<T>* array, const T& to_delete) { |
97 | auto iter = std::find(array->begin(), array->end(), to_delete); |
98 | if (iter != array->end()) { |
99 | array->erase(iter); |
100 | } |
101 | } |
102 | |
103 | /*! \brief Compute the product of all elements in a vector */ |
104 | inline int64_t ElementProduct(const std::vector<int>& array) { |
105 | int64_t ret = 1; |
106 | for (auto x : array) { |
107 | ret *= x; |
108 | } |
109 | return ret; |
110 | } |
111 | |
112 | /*! \brief Move elements from multiple vectors to one vector */ |
113 | template <typename T> |
114 | std::vector<T>& ConcatenateMove(std::vector<T>* out, std::vector<T>* in) { |
115 | out->insert(out->end(), std::make_move_iterator(in->begin()), std::make_move_iterator(in->end())); |
116 | return *out; |
117 | } |
118 | |
119 | /*! \brief Move elements from multiple vectors to one vector */ |
120 | template <typename T, typename... Args> |
121 | std::vector<T>& ConcatenateMove(std::vector<T>* out, std::vector<T>* first, Args... args) { |
122 | ConcatenateMove(out, first); |
123 | ConcatenateMove(out, args...); |
124 | return *out; |
125 | } |
126 | |
127 | /*! \brief Get a random permutation of integers [0, n-1] */ |
128 | template <typename G> |
129 | void RandomPermutation(int n, std::vector<int>* out, G* gen) { |
130 | out->assign(n, 0); |
131 | std::iota(out->begin(), out->end(), 0); |
132 | std::shuffle(out->begin(), out->end(), *gen); |
133 | } |
134 | |
135 | /*! \brief Replace a sub-string to another sub-string in a string */ |
136 | inline void StrReplace(std::string* base, const std::string& from, const std::string& to) { |
137 | auto pos = base->find(from); |
138 | while (pos != std::string::npos) { |
139 | base->replace(pos, from.size(), to); |
140 | pos = base->find(from, pos + to.size()); |
141 | } |
142 | } |
143 | |
144 | /*! \brief Return whether two int arrays are elementwise-equal */ |
145 | inline bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) { |
146 | if (arr1.size() != arr2.size()) { |
147 | return false; |
148 | } |
149 | |
150 | for (size_t i = 0; i < arr1.size(); ++i) { |
151 | auto int1 = arr1[i].as<IntImmNode>(); |
152 | auto int2 = arr2[i].as<IntImmNode>(); |
153 | ICHECK(int1 != nullptr); |
154 | ICHECK(int2 != nullptr); |
155 | if (int1->value != int2->value) { |
156 | return false; |
157 | } |
158 | } |
159 | return true; |
160 | } |
161 | |
162 | /********** Utilities for TVM Containers / ByteArray **********/ |
163 | /*! \brief Compute mean of a FloatImm array */ |
164 | inline double FloatArrayMean(const Array<PrimExpr>& float_array) { |
165 | double sum = 0; |
166 | if (float_array.empty()) { |
167 | return 0.0; |
168 | } |
169 | |
170 | for (const auto& x : float_array) { |
171 | auto floatimm = x.as<tir::FloatImmNode>(); |
172 | ICHECK(floatimm != nullptr); |
173 | sum += floatimm->value; |
174 | } |
175 | return sum / float_array.size(); |
176 | } |
177 | |
178 | /*! \brief Return whether a string starts with another substring */ |
179 | inline bool StrStartsWith(const String& a, const String& b) { |
180 | if (b.size() > a.size()) return false; |
181 | return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str()); |
182 | } |
183 | |
184 | /*! \brief Return whether a string ends with another substring */ |
185 | inline bool StrEndsWith(const String& a, const String& b) { |
186 | if (b.size() > a.size()) return false; |
187 | return std::equal(a.c_str() + a.size() - b.size(), a.c_str() + a.size(), b.c_str()); |
188 | } |
189 | |
190 | /********** Other Utilities **********/ |
191 | /*! \brief Get an int value from an Expr */ |
192 | inline int64_t GetIntImm(const PrimExpr& expr) { |
193 | auto pint = expr.as<IntImmNode>(); |
194 | if (pint == nullptr) { |
195 | return 1; |
196 | } |
197 | return pint->value; |
198 | } |
199 | |
200 | /*! \brief Compute the product of the lengths of axes */ |
201 | inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) { |
202 | int64_t ret = 1.0; |
203 | for (const auto& x : axes) { |
204 | if (const IntImmNode* imm = x->dom->extent.as<IntImmNode>()) { |
205 | ret *= imm->value; |
206 | } else { |
207 | return -1.0; |
208 | } |
209 | } |
210 | return ret; |
211 | } |
212 | |
213 | /*! |
214 | * \brief Clean the name of an iterator or an op to make it valid in python code. |
215 | * \param str The original name. |
216 | * \param prefix The name prefix to differentiate the same name (e.g., the same iterator names). |
217 | * \return The cleaned name. |
218 | */ |
219 | inline std::string CleanName(const std::string& str, const std::string& prefix = "" ) { |
220 | std::string ret = str; |
221 | StrReplace(&ret, "." , "_" ); |
222 | StrReplace(&ret, "@" , "_" ); |
223 | StrReplace(&ret, "outer" , "o" ); |
224 | StrReplace(&ret, "inner" , "i" ); |
225 | if (prefix != "" ) { |
226 | return prefix + "_" + ret; |
227 | } |
228 | return ret; |
229 | } |
230 | |
231 | /*! \brief An empty output stream */ |
232 | class NullStream : public std::ostream { |
233 | public: |
234 | NullStream() : std::ostream(nullptr) {} |
235 | NullStream(const NullStream&) : std::ostream(nullptr) {} |
236 | static NullStream& Global(); |
237 | }; |
238 | |
239 | template <class T> |
240 | NullStream& operator<<(NullStream& os, const T& value) { |
241 | return os; |
242 | } |
243 | |
244 | /*! \brief Get std cout with verbose control */ |
245 | inline std::ostream& StdCout(int verbose, int setting = 1) { |
246 | return verbose >= setting ? std::cout : NullStream::Global(); |
247 | } |
248 | |
249 | /*! \brief Print multiple chars */ |
250 | inline std::string Chars(const char& str, int times) { |
251 | std::stringstream ret; |
252 | for (int i = 0; i < times; ++i) { |
253 | ret << str; |
254 | } |
255 | return ret.str(); |
256 | } |
257 | |
258 | /*! \brief Print the time elapsed */ |
259 | inline void PrintTimeElapsed(std::chrono::time_point<std::chrono::high_resolution_clock> t_begin, |
260 | const std::string& info, int verbose) { |
261 | double duration = std::chrono::duration_cast<std::chrono::duration<double>>( |
262 | std::chrono::high_resolution_clock::now() - t_begin) |
263 | .count(); |
264 | StdCout(verbose) << "Time elapsed for " << info << ": " << std::fixed << std::setprecision(2) |
265 | << duration << " s" << std::endl; |
266 | } |
267 | |
268 | /*! |
269 | * \brief Parse shape and axis names from layout string |
270 | */ |
271 | inline void ParseKernelLayout(const String& layout, Array<PrimExpr>* shape, |
272 | std::vector<std::string>* axes) { |
273 | int32_t factor = 0; |
274 | std::string axis = "" ; |
275 | for (char c : std::string(layout)) { |
276 | if (c >= 'A' && c <= 'z') { |
277 | axis += c; |
278 | if (factor != 0) { |
279 | shape->push_back(factor); |
280 | factor = 0; |
281 | } |
282 | } else if (c >= '0' && c <= '9') { |
283 | factor = factor * 10 + c - '0'; |
284 | if (!axis.empty()) { |
285 | axes->push_back(axis); |
286 | axis = "" ; |
287 | } |
288 | } else { |
289 | LOG(FATAL) << "Invalid layout " << layout; |
290 | } |
291 | } |
292 | if (!axis.empty()) { |
293 | axes->push_back(axis); |
294 | } |
295 | } |
296 | |
297 | /*! \brief Get the base name before '_' of an axis */ |
298 | inline std::string AxisBaseName(const std::string& str) { return str.substr(0, str.rfind("_" )); } |
299 | |
300 | } // namespace auto_scheduler |
301 | } // namespace tvm |
302 | |
303 | #endif // TVM_AUTO_SCHEDULER_UTILS_H_ |
304 | |