1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_COMPUTE_UTILS_HPP
18#define GPU_COMPUTE_UTILS_HPP
19
20#include <cassert>
21#include <sstream>
22#include <vector>
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace compute {
28
29using binary_t = std::vector<uint8_t>;
30
31// Stores global/local ranges to use for kernel enqueueing
32class nd_range_t {
33public:
34 nd_range_t() {
35 global_range_[0] = 1;
36 global_range_[1] = 1;
37 global_range_[2] = 1;
38 with_local_range_ = false;
39 }
40
41 nd_range_t(size_t n, const size_t *global_range,
42 const size_t *local_range = nullptr) {
43
44 assert(n <= 3);
45 with_local_range_ = bool(local_range);
46
47 for (size_t i = 0; i < 3; ++i) {
48 global_range_[i] = (i < n) ? global_range[i] : 1;
49 if (with_local_range_) {
50 local_range_[i] = (i < n) ? local_range[i] : 1;
51 }
52 }
53 }
54
55 nd_range_t(const size_t *global_range, const size_t *local_range = nullptr)
56 : nd_range_t(3, global_range, local_range) {}
57
58 template <typename int_type>
59 nd_range_t(std::initializer_list<int_type> global_range,
60 std::initializer_list<int_type> local_range = {}) {
61 with_local_range_ = (local_range.size() > 0);
62 if (with_local_range_) {
63 assert(global_range.size() == local_range.size());
64 }
65 size_t n = global_range.size();
66 for (size_t i = 0; i < 3; i++) {
67 global_range_[i] = (i < n) ? *(global_range.begin() + i) : 1;
68 if (with_local_range_) {
69 local_range_[i] = (i < n) ? *(local_range.begin() + i) : 1;
70 }
71 }
72 }
73
74 template <typename int_type>
75 nd_range_t(const std::vector<int_type> &global_range,
76 const std::vector<int_type> &local_range = {}) {
77 with_local_range_ = (local_range.size() > 0);
78 if (with_local_range_) {
79 assert(global_range.size() == local_range.size());
80 }
81 size_t n = global_range.size();
82 for (size_t i = 0; i < 3; i++) {
83 global_range_[i] = (i < n) ? global_range[i] : 1;
84 if (with_local_range_) {
85 local_range_[i] = (i < n) ? local_range[i] : 1;
86 }
87 }
88 }
89
90 size_t ndims() const { return 3; }
91 const size_t *global_range() const { return global_range_; }
92
93 const size_t *local_range() const {
94 return with_local_range_ ? local_range_ : nullptr;
95 }
96
97 bool is_zero() const {
98 return global_range_[0] == 0 || global_range_[1] == 0
99 || global_range_[2] == 0;
100 }
101
102 std::string str() const {
103 std::stringstream oss;
104 oss << "gws = [" << global_range_[0] << ", " << global_range_[1] << ", "
105 << global_range_[2] << "] lws = ";
106 if (local_range()) {
107 oss << "[" << local_range_[0] << ", " << local_range_[1] << ", "
108 << local_range_[2] << "]";
109 } else {
110 oss << "(nil)";
111 }
112 return oss.str();
113 }
114
115private:
116 size_t global_range_[3];
117 size_t local_range_[3];
118 bool with_local_range_;
119};
120
121} // namespace compute
122} // namespace gpu
123} // namespace impl
124} // namespace dnnl
125
126#endif // GPU_COMPUTE_UTILS_HPP
127