1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_UTILS_UTILS_HPP
18#define GPU_JIT_UTILS_UTILS_HPP
19
20#include <cctype>
21#include <functional>
22#include <iomanip>
23#include <iostream>
24#include <sstream>
25#include <string>
26#include <type_traits>
27#include <unordered_map>
28#include <unordered_set>
29
30#include "common/utils.hpp"
31#include "gpu/compute/device_info.hpp"
32
33// Uncomment this when jit::ir debugging is required:
34//#define GEN_IR_DEBUG
35#ifdef GEN_IR_DEBUG
36#define GEN_CONV_DEBUG
37#endif
38
39// Uncomment this when jit::ir profiling is required:
40//#define GEN_IR_PROFILE
41#ifdef GEN_IR_PROFILE
42#define GEN_CONV_PROFILE
43#endif
44
45// Uncomment this when aborting on ir_assert is desired:
46// #define IR_ABORT_ON_ERROR
47
48#ifdef GEN_CONV_PROFILE
49#include "common/profiler.hpp"
50#endif
51
52namespace dnnl {
53namespace impl {
54namespace gpu {
55namespace jit {
56namespace ir_utils {
57
58const int LOG_OFF = 0;
59const int LOG_WARNING = 100;
60const int LOG_INFO = 150;
61const int LOG_PERF = 170;
62const int LOG_TRACE = 200;
63
64#ifdef GEN_CONV_DEBUG
65const int LOG_LEVEL = LOG_WARNING;
66#else
67const int LOG_LEVEL = LOG_OFF;
68#endif
69
70template <typename T>
71size_t get_hash(const T &t);
72
73template <typename T>
74size_t get_hash(const std::vector<T> &v);
75
76template <typename T>
77void get_hash_impl(size_t &h, const T &t) {
78 h = hash_combine(h, get_hash(t));
79}
80
81template <typename ArgHeadT, typename... ArgsT>
82void get_hash_impl(size_t &h, const ArgHeadT &head, const ArgsT &... args) {
83 size_t h_head = get_hash(head);
84 h = hash_combine(h, h_head);
85 get_hash_impl(h, args...);
86}
87
88template <typename E>
89struct enum_hash_t {
90 size_t operator()(const E &e) const noexcept {
91 return std::hash<size_t>()((size_t)e);
92 }
93};
94
95template <typename T, typename = void>
96struct get_std_hash_helper_t {
97 static size_t call(const T &t) { return std::hash<T>()(t); }
98};
99
100template <typename T>
101struct get_std_hash_helper_t<T,
102 typename std::enable_if<std::is_enum<T>::value>::type> {
103 static size_t call(const T &t) { return enum_hash_t<T>()(t); }
104};
105
106template <typename T, typename = void>
107struct get_hash_helper_t {
108 static size_t call(const T &t) { return get_std_hash_helper_t<T>::call(t); }
109};
110
111template <typename T>
112struct get_hash_helper_t<T, decltype(std::declval<T>().get_hash(), void())> {
113 static size_t call(const T &t) { return t.get_hash(); }
114};
115
116template <typename T>
117size_t get_hash(const T &t) {
118 return get_hash_helper_t<T>::call(t);
119}
120
121template <typename T>
122size_t get_hash(const std::vector<T> &v) {
123 size_t h = 0;
124 for (auto &e : v)
125 h = hash_combine(h, get_hash(e));
126 return h;
127}
128
129template <typename... ArgsT>
130size_t get_hash(const ArgsT &... args) {
131 size_t h = 0;
132 get_hash_impl(h, args...);
133 return h;
134}
135
136template <typename T, typename U, typename = void>
137struct is_equal_helper_t {
138 static bool call(const T &t, const U &u) { return t == u; }
139};
140
141template <typename T, typename U>
142struct is_equal_helper_t<T, U,
143 decltype(std::declval<T>().is_equal(std::declval<U>()), void())> {
144 static bool call(const T &t, const U &u) { return t.is_equal(u); }
145};
146
147// Checks equality of objects:
148// 1. Uses t.is_equal(u) if is_equal() is available
149// 2. Uses (t == u) otherwise
150template <typename T, typename U>
151bool is_equal(const T &t, const U &u) {
152 return is_equal_helper_t<T, U>::call(t, u);
153}
154
155// Checks equality of vector elements.
156template <typename T, typename U>
157bool is_equal(const std::vector<T> &a, const std::vector<U> &b) {
158 if (a.size() != b.size()) return false;
159 for (size_t i = 0; i < a.size(); i++)
160 if (!ir_utils::is_equal(a[i], b[i])) return false;
161 return true;
162}
163
164// Checks equality of vector elements between each other.
165template <typename T>
166bool are_all_equal(const std::vector<T> &a) {
167 if (a.empty()) return true;
168 for (size_t i = 1; i < a.size(); i++)
169 if (!ir_utils::is_equal(a[i], a[0])) return false;
170 return true;
171}
172
173// Checks identity of vector elements.
174template <typename T, typename U>
175bool is_same(const std::vector<T> &a, const std::vector<U> &b) {
176 if (a.size() != b.size()) return false;
177 for (size_t i = 0; i < a.size(); i++)
178 if (!a[i].is_same(b[i])) return false;
179 return true;
180}
181
182class error_stream_t {
183public:
184 error_stream_t(const char *file, int line, const char *assert_msg) {
185 data_ = new data_t(file, line, assert_msg);
186 }
187
188 // This is to be able use a steam object in short-circuit evaluation with
189 // booleans, see below.
190 operator bool() const { return true; }
191
192 template <typename T>
193 error_stream_t &operator<<(const T &t) {
194 data_->out << t;
195 return *this;
196 }
197
198 ~error_stream_t() noexcept(false) {
199 if (data_ == nullptr) return;
200
201#ifdef IR_ABORT_ON_ERROR
202 std::cout << data_->out.str() << "\n";
203 std::abort();
204#else
205 auto err = std::runtime_error(data_->out.str());
206 delete data_;
207 data_ = nullptr;
208
209 // This is techincally unsafe. Since error_stream_t is only used in
210 // debug builds and since it is only used by ir_assert() which signals
211 // an ill-defined program state, nested throws is not a concern.
212 throw err; // NOLINT
213#endif
214 }
215
216private:
217 struct data_t {
218 data_t(const char *file, int line, const char *assert_msg)
219 : file(file), line(line) {
220 out << "Assertion " << assert_msg << " failed at " << file << ":"
221 << line << std::endl;
222 }
223
224 const char *file;
225 int line;
226 std::ostringstream out;
227 };
228
229 data_t *data_;
230};
231
232// Checks assertion and, in case of error, evaluates output operators to print
233// related messages. Usage:
234// ir_assert(condition) << "Error message" << ...;
235
236#if !defined(NDEBUG) || defined(GEN_CONV_DEBUG)
237#define ir_assert(cond) \
238 !(cond) \
239 && dnnl::impl::gpu::jit::ir_utils::error_stream_t( \
240 __FILE__, __LINE__, #cond)
241#else
242#define ir_assert(cond) \
243 (false) && !(cond) \
244 && dnnl::impl::gpu::jit::ir_utils::error_stream_t( \
245 __FILE__, __LINE__, #cond)
246#endif
247
248#define ir_error_not_expected() ir_assert(false) << "Not expected. "
249#define ir_error_not_implemented() ir_assert(false) << "Not implemented. "
250
251template <int level>
252class logger_t {
253public:
254 logger_t(std::ostream &out = std::cout) : out_(out) {}
255
256 operator bool() const { return true; }
257
258 static bool is_enabled() {
259#if defined(GEN_CONV_DEBUG) || defined(GEN_CONV_PROFILE)
260 static const int log_level(getenv_int("log_level", LOG_LEVEL));
261 return log_level >= level;
262#else
263 return LOG_LEVEL >= level;
264#endif
265 }
266
267 template <typename T>
268 logger_t &operator<<(const T &obj) {
269 maybe_print_header();
270 out_ << obj;
271 return *this;
272 }
273
274 logger_t &operator<<(std::ostream &(*os)(std::ostream &)) {
275 maybe_print_header();
276 out_ << os;
277 return *this;
278 }
279
280private:
281 void maybe_print_header() {
282 if (!is_first_print_) return;
283
284 switch (level) {
285 case LOG_WARNING: out_ << "[WARNING] "; break;
286 default: break;
287 }
288 is_first_print_ = false;
289 }
290
291 std::ostream &out_;
292 bool is_first_print_ = true;
293};
294
295#define ir_perf() \
296 ir_utils::logger_t<ir_utils::LOG_PERF>::is_enabled() \
297 && ir_utils::logger_t<ir_utils::LOG_PERF>()
298
299// Trace can result in overhead making measurement meaningless
300#define ir_perf_no_trace() \
301 ir_utils::logger_t<ir_utils::LOG_PERF>::is_enabled() \
302 && !ir_utils::logger_t<ir_utils::LOG_TRACE>::is_enabled() \
303 && ir_utils::logger_t<ir_utils::LOG_PERF>()
304
305#define ir_info() \
306 ir_utils::logger_t<ir_utils::LOG_INFO>::is_enabled() \
307 && ir_utils::logger_t<ir_utils::LOG_INFO>()
308
309#define ir_warning() \
310 ir_utils::logger_t<ir_utils::LOG_WARNING>::is_enabled() \
311 && ir_utils::logger_t<ir_utils::LOG_WARNING>()
312
313#define ir_trace() \
314 ir_utils::logger_t<ir_utils::LOG_TRACE>::is_enabled() \
315 && ir_utils::logger_t<ir_utils::LOG_TRACE>()
316
317// Pretty printers for STL objects.
318template <typename KeyT, typename HashT, typename EqualT>
319inline std::ostream &operator<<(
320 std::ostream &out, const std::unordered_set<KeyT, HashT, EqualT> &s) {
321 out << "{";
322 for (auto it = s.begin(); it != s.end(); it++) {
323 out << (it != s.begin() ? ", " : "") << *it;
324 }
325 out << "}";
326 return out;
327}
328
329template <typename KeyT, typename ValueT, typename HashT, typename EqualT>
330inline std::ostream &operator<<(std::ostream &out,
331 const std::unordered_map<KeyT, ValueT, HashT, EqualT> &m) {
332 out << "{";
333 for (auto it = m.begin(); it != m.end(); it++) {
334 out << (it != m.begin() ? ", " : "") << it->first << ": " << it->second;
335 }
336 out << "}";
337 return out;
338}
339
340template <typename ContainerT>
341struct seq_print_helper_t {
342 seq_print_helper_t(const ContainerT &v, const std::string &sep, int width)
343 : v(v), sep(sep), width(width) {}
344
345 const ContainerT &v;
346 const std::string sep;
347 int width;
348};
349
350template <typename T>
351seq_print_helper_t<T> make_seq_print_helper(
352 const T &v, const std::string &sep = ", ", int width = 0) {
353 return seq_print_helper_t<T>(v, sep, width);
354}
355
356template <typename T>
357inline std::ostream &operator<<(
358 std::ostream &out, const seq_print_helper_t<T> &seq) {
359 for (auto it = seq.v.begin(); it != seq.v.end(); it++) {
360 out << (it != seq.v.begin() ? seq.sep : "") << std::setw(seq.width)
361 << *it;
362 }
363 return out;
364}
365
366template <typename T>
367inline std::ostream &operator<<(std::ostream &out, const std::vector<T> &v) {
368 out << "[";
369 out << make_seq_print_helper(v);
370 out << "]";
371 return out;
372}
373
374class table_t {
375public:
376 table_t(const std::string &title, const std::vector<std::string> &header)
377 : title_(title), header_(header) {}
378
379 template <typename T>
380 table_t &operator<<(const T &value) {
381 std::ostringstream oss;
382 oss << value;
383 auto str_value = oss.str();
384 size_t pos = 0;
385 for (size_t i = 0; i < str_value.length(); i++) {
386 if (str_value[i] != '\n') continue;
387 cur_row_.push_back(str_value.substr(pos, i - pos));
388 new_row();
389 pos = i + 1;
390 }
391 if (str_value.empty() || pos != str_value.length()) {
392 cur_row_.push_back(str_value.substr(pos, str_value.length() - pos));
393 }
394 return *this;
395 }
396
397 table_t &operator<<(std::ostream &(*f)(std::ostream &)) {
398 auto _endl
399 = (std::basic_ostream<char> & (*)(std::basic_ostream<char> &))
400 std::endl;
401 if (f == _endl) new_row();
402 return *this;
403 }
404
405 std::string str() const {
406 std::ostringstream oss;
407 size_t n = header_.size();
408 std::vector<size_t> widths(n);
409 for (size_t i = 0; i < n; i++)
410 widths[i] = header_[i].length();
411 for (auto &r : rows_) {
412 for (size_t i = 0; i < n; i++) {
413 widths[i] = std::max(widths[i], r[i].length());
414 }
415 }
416 auto print = [&](std::ostream &out, size_t idx, const std::string &s) {
417 int w = (int)widths[idx] + 2;
418 out << std::setw(w);
419 out << (idx > 0 ? std::right : std::left);
420 out << s;
421 };
422 oss << title_ << std::endl;
423 for (size_t i = 0; i < n; i++) {
424 print(oss, i, header_[i]);
425 }
426 oss << std::endl;
427 for (auto &r : rows_) {
428 for (size_t i = 0; i < n; i++) {
429 print(oss, i, r[i]);
430 }
431 if (&r != &rows_.back()) oss << std::endl;
432 }
433 return oss.str();
434 }
435
436private:
437 void new_row() {
438 ir_assert(cur_row_.size() == header_.size());
439 rows_.emplace_back();
440 rows_.back().swap(cur_row_);
441 }
442
443 std::string title_;
444 std::vector<std::string> header_;
445 std::vector<std::vector<std::string>> rows_;
446
447 std::vector<std::string> cur_row_;
448};
449
450inline std::ostream &operator<<(std::ostream &out, const table_t &table) {
451 out << table.str();
452 return out;
453}
454
455inline bool getenv_bool(const char *s, bool def) {
456 return getenv_int(s, def ? 1 : 0) == 1;
457}
458
459inline std::string getenv_str(const char *s, const std::string &def) {
460 char buf[1024];
461 int ret = getenv(s, buf, sizeof(buf));
462 if (ret > 0) return buf;
463 return def;
464}
465
466// Input is a comma separate list containing gpu_arch and optionally eu_count.
467inline compute::gpu_arch_t getenv_gpu(const char *s, compute::gpu_arch_t arch,
468 int *eu_count = nullptr, int *max_wg_size = nullptr) {
469 char buf[1024];
470 int ret = getenv(s, buf, sizeof(buf));
471 if (ret > 0) {
472 char *arch_str = buf, *eu_str = nullptr;
473 for (int i = 0; i < ret; i++) {
474 if (buf[i] == ',') {
475 buf[i] = 0;
476 if (i < ret - 1) { eu_str = &buf[i + 1]; }
477 break;
478 }
479 }
480 arch = compute::str2gpu_arch(arch_str);
481 if (eu_count && eu_str) { *eu_count = atoi(eu_str); }
482 if (max_wg_size) {
483 // Assume maximum wg size is basically the number of threads
484 // available in a subslice with simd_size 16
485 const int max_eus_per_wg
486 = compute::device_info_t::max_eus_per_wg(arch);
487 const int simd_size = 16;
488 const int thr_per_eu = utils::rnd_down_pow2(
489 compute::device_info_t::threads_per_eu(arch));
490 *max_wg_size = simd_size * max_eus_per_wg * thr_per_eu;
491 }
492 }
493 return arch;
494}
495
496inline std::string to_string(bool b) {
497 return b ? "True" : "False";
498}
499
500inline bool to_bool(const std::string &s) {
501 if (s == "0" || s == "false") return false;
502 return true;
503}
504
505inline std::vector<std::string> split(const std::string &s,
506 const std::string &delimiter = std::string(1, ' ')) {
507 size_t beg = 0;
508 size_t end = 0;
509 std::vector<std::string> ret;
510 while (end != std::string::npos) {
511 beg = (end == 0) ? 0 : end + delimiter.size();
512 end = s.find(delimiter, beg);
513 size_t len
514 = (end == std::string::npos) ? std::string::npos : (end - beg);
515 ret.push_back(s.substr(beg, len));
516 }
517 return ret;
518}
519
520inline std::string to_lower(const std::string &s) {
521 auto ret = s;
522 std::transform(ret.begin(), ret.end(), ret.begin(),
523 [](char c) { return std::tolower(c); });
524 return ret;
525}
526
527template <typename T>
528inline T max_divisor(T n, std::initializer_list<T> divisors) {
529 T ret = -1;
530 for (auto d : divisors) {
531 if (n % d == 0) ret = std::max(ret, d);
532 }
533 ir_assert(ret != -1);
534 return ret;
535}
536
537// Equivalent of BLSI instruction (extract lowest set isolated bit).
538template <typename T>
539inline T max_pow2_divisor(T n) {
540 return n & ~(n - 1);
541}
542
543template <typename T, typename U>
544inline T safe_divide(T a, U b) {
545 ir_assert(b != 0 && a % b == 0) << "Can't divide: " << a << " / " << b;
546 return a / b;
547}
548
549template <typename ContainerT, typename T>
550inline int find_index(const ContainerT &c, const T &value) {
551 for (int i = 0; i < int(c.size()); i++) {
552 if (c[i] == value) return i;
553 }
554 return -1;
555}
556
557template <typename T, typename F>
558void for_each_impl(size_t pos, std::vector<T> &idx,
559 const std::vector<T> &bounds, const F &f) {
560 if (pos == bounds.size()) {
561 f(idx);
562 return;
563 }
564
565 for (T i = 0; i < bounds[pos]; i++) {
566 idx[pos] = i;
567 for_each_impl(pos + 1, idx, bounds, f);
568 }
569}
570
571template <typename T, typename F>
572void for_each(const std::vector<T> &bounds, const F &f) {
573 std::vector<T> idx(bounds.size());
574 for_each_impl(0, idx, bounds, f);
575}
576
577template <typename MapContainerT, typename KeyT,
578 typename ValueT = typename MapContainerT::mapped_type>
579ValueT get_or_default(const MapContainerT &map, const KeyT &key,
580 const ValueT &default_value) {
581 auto it = map.find(key);
582 if (it == map.end()) return default_value;
583 return it->second;
584}
585
586struct debug_profiler_t {
587#ifdef GEN_CONV_PROFILE
588 debug_profiler_t(std::string profile_name) : profile(profile_name) {};
589 void start() { profile.start(); };
590 void stamp(const char *name) { profile.stamp(name); };
591 void stop(const char *name) { profile.stop(name); };
592 void stop() { profile.stop(); };
593 void reset() { profile.reset(); };
594 std::string str() const { return profile.str(); };
595
596private:
597 profiler_t profile;
598#else
599 debug_profiler_t(std::string) {};
600 void start() {};
601 void stamp(const char *name) {};
602 void stop(const char *name) {};
603 void stop() {};
604 void reset() {};
605 std::string str() const { return ""; };
606#endif
607};
608
609inline std::ostream &operator<<(
610 std::ostream &out, const debug_profiler_t &profile) {
611 out << profile.str();
612 return out;
613}
614
615} // namespace ir_utils
616} // namespace jit
617} // namespace gpu
618} // namespace impl
619} // namespace dnnl
620
621#endif
622