1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file auto_scheduler/utils.h
22 * \brief Common utilities.
23 */
24
25#ifndef TVM_AUTO_SCHEDULER_UTILS_H_
26#define TVM_AUTO_SCHEDULER_UTILS_H_
27
28#include <dmlc/common.h>
29#include <tvm/tir/expr.h>
30
31#include <algorithm>
32#include <deque>
33#include <exception>
34#include <future>
35#include <iomanip>
36#include <numeric>
37#include <random>
38#include <string>
39#include <thread>
40#include <tuple>
41#include <utility>
42#include <vector>
43
44namespace std {
45
46/*! \brief Hash function for std::pair */
47template <typename T1, typename T2>
48struct hash<std::pair<T1, T2>> {
49 std::size_t operator()(const std::pair<T1, T2>& k) const {
50 return ::dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
51 }
52};
53
54/*! \brief Hash function for std::tuple */
55template <typename T1, typename T2, typename T3>
56struct hash<std::tuple<T1, T2, T3>> {
57 std::size_t operator()(const std::tuple<T1, T2, T3>& k) const {
58 return ::dmlc::HashCombine(
59 ::dmlc::HashCombine(std::hash<T1>()(std::get<0>(k)), std::hash<T2>()(std::get<1>(k))),
60 std::hash<T3>()(std::get<2>(k)));
61 }
62};
63
64} // namespace std
65
66namespace tvm {
67namespace auto_scheduler {
68
69/********** Utilities for Array, std::vector, std::string **********/
70/*! \brief Get the first appearance index of elements in an Array */
71template <typename T>
72inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, Array<Integer>* indices) {
73 for (const auto& v : to_locate) {
74 auto it = std::find(array.begin(), array.end(), v);
75 if (it != array.end()) {
76 indices->push_back(it - array.begin());
77 } else {
78 LOG(FATAL) << "Cannot find the item";
79 }
80 }
81}
82
83/*! \brief Get the first appearance index of an element in an Array */
84template <typename T>
85inline int GetIndex(const Array<T>& array, const T& to_locate) {
86 for (size_t i = 0; i < array.size(); ++i) {
87 if (array[i] == to_locate) {
88 return i;
89 }
90 }
91 LOG(FATAL) << "Cannot find the item";
92}
93
94/*! \brief Delete the item in a std::vector if it exists. */
95template <typename T>
96inline void FindAndDeleteItem(std::vector<T>* array, const T& to_delete) {
97 auto iter = std::find(array->begin(), array->end(), to_delete);
98 if (iter != array->end()) {
99 array->erase(iter);
100 }
101}
102
103/*! \brief Compute the product of all elements in a vector */
104inline int64_t ElementProduct(const std::vector<int>& array) {
105 int64_t ret = 1;
106 for (auto x : array) {
107 ret *= x;
108 }
109 return ret;
110}
111
112/*! \brief Move elements from multiple vectors to one vector */
113template <typename T>
114std::vector<T>& ConcatenateMove(std::vector<T>* out, std::vector<T>* in) {
115 out->insert(out->end(), std::make_move_iterator(in->begin()), std::make_move_iterator(in->end()));
116 return *out;
117}
118
119/*! \brief Move elements from multiple vectors to one vector */
120template <typename T, typename... Args>
121std::vector<T>& ConcatenateMove(std::vector<T>* out, std::vector<T>* first, Args... args) {
122 ConcatenateMove(out, first);
123 ConcatenateMove(out, args...);
124 return *out;
125}
126
127/*! \brief Get a random permutation of integers [0, n-1] */
128template <typename G>
129void RandomPermutation(int n, std::vector<int>* out, G* gen) {
130 out->assign(n, 0);
131 std::iota(out->begin(), out->end(), 0);
132 std::shuffle(out->begin(), out->end(), *gen);
133}
134
135/*! \brief Replace a sub-string to another sub-string in a string */
136inline void StrReplace(std::string* base, const std::string& from, const std::string& to) {
137 auto pos = base->find(from);
138 while (pos != std::string::npos) {
139 base->replace(pos, from.size(), to);
140 pos = base->find(from, pos + to.size());
141 }
142}
143
144/*! \brief Return whether two int arrays are elementwise-equal */
145inline bool IntArrayEqual(const Array<PrimExpr>& arr1, const Array<PrimExpr>& arr2) {
146 if (arr1.size() != arr2.size()) {
147 return false;
148 }
149
150 for (size_t i = 0; i < arr1.size(); ++i) {
151 auto int1 = arr1[i].as<IntImmNode>();
152 auto int2 = arr2[i].as<IntImmNode>();
153 ICHECK(int1 != nullptr);
154 ICHECK(int2 != nullptr);
155 if (int1->value != int2->value) {
156 return false;
157 }
158 }
159 return true;
160}
161
162/********** Utilities for TVM Containers / ByteArray **********/
163/*! \brief Compute mean of a FloatImm array */
164inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
165 double sum = 0;
166 if (float_array.empty()) {
167 return 0.0;
168 }
169
170 for (const auto& x : float_array) {
171 auto floatimm = x.as<tir::FloatImmNode>();
172 ICHECK(floatimm != nullptr);
173 sum += floatimm->value;
174 }
175 return sum / float_array.size();
176}
177
178/*! \brief Return whether a string starts with another substring */
179inline bool StrStartsWith(const String& a, const String& b) {
180 if (b.size() > a.size()) return false;
181 return std::equal(a.c_str(), a.c_str() + b.size(), b.c_str());
182}
183
184/*! \brief Return whether a string ends with another substring */
185inline bool StrEndsWith(const String& a, const String& b) {
186 if (b.size() > a.size()) return false;
187 return std::equal(a.c_str() + a.size() - b.size(), a.c_str() + a.size(), b.c_str());
188}
189
190/********** Other Utilities **********/
191/*! \brief Get an int value from an Expr */
192inline int64_t GetIntImm(const PrimExpr& expr) {
193 auto pint = expr.as<IntImmNode>();
194 if (pint == nullptr) {
195 return 1;
196 }
197 return pint->value;
198}
199
200/*! \brief Compute the product of the lengths of axes */
201inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) {
202 int64_t ret = 1.0;
203 for (const auto& x : axes) {
204 if (const IntImmNode* imm = x->dom->extent.as<IntImmNode>()) {
205 ret *= imm->value;
206 } else {
207 return -1.0;
208 }
209 }
210 return ret;
211}
212
213/*!
214 * \brief Clean the name of an iterator or an op to make it valid in python code.
215 * \param str The original name.
216 * \param prefix The name prefix to differentiate the same name (e.g., the same iterator names).
217 * \return The cleaned name.
218 */
219inline std::string CleanName(const std::string& str, const std::string& prefix = "") {
220 std::string ret = str;
221 StrReplace(&ret, ".", "_");
222 StrReplace(&ret, "@", "_");
223 StrReplace(&ret, "outer", "o");
224 StrReplace(&ret, "inner", "i");
225 if (prefix != "") {
226 return prefix + "_" + ret;
227 }
228 return ret;
229}
230
231/*! \brief An empty output stream */
232class NullStream : public std::ostream {
233 public:
234 NullStream() : std::ostream(nullptr) {}
235 NullStream(const NullStream&) : std::ostream(nullptr) {}
236 static NullStream& Global();
237};
238
239template <class T>
240NullStream& operator<<(NullStream& os, const T& value) {
241 return os;
242}
243
244/*! \brief Get std cout with verbose control */
245inline std::ostream& StdCout(int verbose, int setting = 1) {
246 return verbose >= setting ? std::cout : NullStream::Global();
247}
248
249/*! \brief Print multiple chars */
250inline std::string Chars(const char& str, int times) {
251 std::stringstream ret;
252 for (int i = 0; i < times; ++i) {
253 ret << str;
254 }
255 return ret.str();
256}
257
258/*! \brief Print the time elapsed */
259inline void PrintTimeElapsed(std::chrono::time_point<std::chrono::high_resolution_clock> t_begin,
260 const std::string& info, int verbose) {
261 double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
262 std::chrono::high_resolution_clock::now() - t_begin)
263 .count();
264 StdCout(verbose) << "Time elapsed for " << info << ": " << std::fixed << std::setprecision(2)
265 << duration << " s" << std::endl;
266}
267
268/*!
269 * \brief Parse shape and axis names from layout string
270 */
271inline void ParseKernelLayout(const String& layout, Array<PrimExpr>* shape,
272 std::vector<std::string>* axes) {
273 int32_t factor = 0;
274 std::string axis = "";
275 for (char c : std::string(layout)) {
276 if (c >= 'A' && c <= 'z') {
277 axis += c;
278 if (factor != 0) {
279 shape->push_back(factor);
280 factor = 0;
281 }
282 } else if (c >= '0' && c <= '9') {
283 factor = factor * 10 + c - '0';
284 if (!axis.empty()) {
285 axes->push_back(axis);
286 axis = "";
287 }
288 } else {
289 LOG(FATAL) << "Invalid layout " << layout;
290 }
291 }
292 if (!axis.empty()) {
293 axes->push_back(axis);
294 }
295}
296
297/*! \brief Get the base name before '_' of an axis */
298inline std::string AxisBaseName(const std::string& str) { return str.substr(0, str.rfind("_")); }
299
300} // namespace auto_scheduler
301} // namespace tvm
302
303#endif // TVM_AUTO_SCHEDULER_UTILS_H_
304