1 | /******************************************************************************* |
2 | * Copyright 2017-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef DNNL_MEMORY_HPP |
18 | #define DNNL_MEMORY_HPP |
19 | |
20 | #include "oneapi/dnnl/dnnl.h" |
21 | |
22 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_DPCPP |
23 | #include "oneapi/dnnl/dnnl_sycl.h" |
24 | #endif |
25 | |
26 | #include "common.hpp" |
27 | #include "utils/dims.hpp" |
28 | #include "utils/wrapper.hpp" |
29 | |
30 | #define dnnl_mem_default_value 0xFF |
31 | |
32 | struct dnn_mem_t { |
33 | struct handle_info_t { |
34 | bool is_host_ptr; |
35 | void *ptr; |
36 | |
37 | bool is_allocate() const { return ptr == DNNL_MEMORY_ALLOCATE; } |
38 | |
39 | static handle_info_t allocate() { |
40 | return {false, DNNL_MEMORY_ALLOCATE}; |
41 | } |
42 | }; |
43 | |
44 | dnn_mem_t() { map(); } |
45 | dnn_mem_t(const_dnnl_memory_desc_t md, dnnl_engine_t engine, |
46 | const handle_info_t &handle_info = handle_info_t::allocate()); |
47 | dnn_mem_t(const_dnnl_memory_desc_t md, dnnl_data_type_t dt, |
48 | const std::string &tag, dnnl_engine_t engine); |
49 | |
50 | dnn_mem_t(int ndims, const dnnl_dims_t dims, dnnl_data_type_t dt, |
51 | const std::string &tag, dnnl_engine_t engine); |
52 | dnn_mem_t(int ndims, const dnnl_dims_t dims, dnnl_data_type_t dt, |
53 | const dnnl_dims_t strides, dnnl_engine_t engine); |
54 | |
55 | dnn_mem_t(const dnn_mem_t &rhs, dnnl_data_type_t dt, const std::string &tag, |
56 | dnnl_engine_t engine); |
57 | |
58 | dnn_mem_t(const dnn_mem_t &rhs) = delete; |
59 | dnn_mem_t &operator=(const dnn_mem_t &rhs) = delete; |
60 | |
61 | dnn_mem_t &operator=(dnn_mem_t &&rhs) { |
62 | if (&rhs == this) return *this; |
63 | cleanup(); |
64 | |
65 | md_ = rhs.md_; |
66 | m_ = rhs.m_; |
67 | m_padded_ = rhs.m_padded_; |
68 | data_ = rhs.data_; |
69 | is_data_owner_ = rhs.is_data_owner_; |
70 | active_ = rhs.active_; |
71 | engine_kind_ = rhs.engine_kind_; |
72 | engine_ = rhs.engine_; |
73 | is_mapped_ = (bool)rhs.is_mapped_; |
74 | mapped_ptr_ = rhs.mapped_ptr_; |
75 | |
76 | rhs.active_ = false; |
77 | return *this; |
78 | } |
79 | dnn_mem_t(dnn_mem_t &&rhs) : dnn_mem_t() { *this = std::move(rhs); } |
80 | |
81 | ~dnn_mem_t() { cleanup(); } |
82 | |
83 | int reorder(const dnn_mem_t &rhs, const_dnnl_primitive_attr_t attr); |
84 | int reorder(const dnn_mem_t &rhs) { return reorder(rhs, nullptr); } |
85 | |
86 | size_t size() const; |
87 | |
88 | int64_t nelems(bool with_padded_dims = false) const { |
89 | const auto &_dims = with_padded_dims ? padded_dims() : dims(); |
90 | if (ndims() == 0) return 0; |
91 | |
92 | int64_t n = 1; |
93 | for (int i = 0; i < ndims(); ++i) |
94 | n *= _dims[i]; |
95 | return n; |
96 | } |
97 | |
98 | // Queries from memory descriptor. |
99 | int ndims() const; |
100 | const dnnl_dims_t &dims() const; |
101 | const dnnl_dims_t &padded_dims() const; |
102 | dnnl_data_type_t dt() const; |
103 | const dnnl_dims_t &padded_offsets() const; |
104 | dnnl_dim_t offset0() const; |
105 | dnnl_format_kind_t format_kind() const; |
106 | const dnnl_dims_t &strides() const; |
107 | int inner_nblks() const; |
108 | const dnnl_dims_t &inner_blks() const; |
109 | const dnnl_dims_t &inner_idxs() const; |
110 | |
111 | size_t sizeof_dt() const; |
112 | |
113 | void set_dt(dnnl_data_type_t dt) const; |
114 | |
115 | template <typename T> |
116 | explicit operator T *() const { |
117 | assert(is_mapped_); |
118 | return static_cast<T *>(mapped_ptr_); |
119 | } |
120 | |
121 | explicit operator bool() const { return active_; } |
122 | |
123 | float get_elem(int64_t idx) const; |
124 | void set_elem(int64_t idx, float value) const; |
125 | |
126 | int64_t get_scale_idx( |
127 | int64_t data_idx, int scale_mask, const int ndims) const { |
128 | const auto &_dims = dims(); |
129 | int64_t stride = 1; |
130 | int64_t offset = 0; |
131 | |
132 | if (scale_mask != 0) { |
133 | for (int i = 0; i < ndims; ++i) { |
134 | int d = ndims - 1 - i; |
135 | auto pos = data_idx % _dims[d]; |
136 | data_idx /= _dims[d]; |
137 | if (scale_mask & (1 << d)) { |
138 | offset += pos * stride; |
139 | stride *= _dims[d]; |
140 | } |
141 | } |
142 | } |
143 | |
144 | return offset; |
145 | } |
146 | |
147 | int64_t get_scale_idx(int64_t data_idx, int scale_mask) const { |
148 | return get_scale_idx(data_idx, scale_mask, ndims()); |
149 | } |
150 | |
151 | dnnl_engine_t engine() const { return engine_; } |
152 | dnnl_engine_kind_t engine_kind() const { return engine_kind_; } |
153 | |
154 | bool is_mapped() const { return is_mapped_; } |
155 | |
156 | bool is_canary_protected() const { return is_canary_protected_; } |
157 | |
158 | void map() const; |
159 | void unmap() const; |
160 | void memset(int value, size_t size) const; |
161 | |
162 | static dnn_mem_t create_from_host_ptr( |
163 | const dnnl_memory_desc_t &md, dnnl_engine_t engine, void *host_ptr); |
164 | |
165 | // Increases memory size to catch potential buffer overreads and |
166 | // overwrites. The padded area is filled with a canary value. |
167 | static size_t pad_memory_size(size_t sz, dnnl_engine_kind_t engine_kind, |
168 | bool *was_padded = nullptr); |
169 | // Increases memory descriptor size to catch potential buffer overreads and |
170 | // overwrites. The padded area is filled with a canary value. |
171 | static dnnl_memory_desc_t pad_memory_desc(const_dnnl_memory_desc_t md, |
172 | dnnl_engine_kind_t engine_kind, bool *was_padded = nullptr); |
173 | // Initializes memory descriptor from sporadic tag or strides. |
174 | static benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> init_md(int ndims, |
175 | const dnnl_dims_t dims, dnnl_data_type_t data_type, |
176 | const std::string &tag, const dims_t &strides_ = {}); |
177 | |
178 | /* fields */ |
179 | dnnl_memory_desc_t md_ {}; |
180 | dnnl_memory_t m_ {}; |
181 | |
182 | // "Base" memory with a canary-padded buffer for buffer overflow |
183 | // protection. |
184 | dnnl_memory_t m_padded_ {}; |
185 | bool is_canary_protected_ = false; |
186 | |
187 | private: |
188 | void *data_ = NULL; |
189 | bool is_data_owner_ = false; |
190 | bool active_ = false; |
191 | |
192 | dnnl_engine_kind_t engine_kind_ = dnnl_any_engine; |
193 | dnnl_engine_t engine_ = NULL; |
194 | |
195 | mutable bool is_mapped_ = false; |
196 | mutable void *mapped_ptr_ = NULL; |
197 | |
198 | int initialize_memory_create_sycl(const handle_info_t &handle_info); |
199 | int initialize_memory_create_opencl(const handle_info_t &handle_info); |
200 | int initialize_memory_create(const handle_info_t &handle_info); |
201 | |
202 | int initialize(dnnl_engine_t engine, |
203 | const handle_info_t &handle_info = handle_info_t::allocate()); |
204 | |
205 | int cleanup(); |
206 | }; |
207 | |
208 | dnnl_memory_desc_t clone_md(const_dnnl_memory_desc_t md); |
209 | |
210 | // Checks that zero padding is preserved. |
211 | int check_zero_padding(const dnn_mem_t &mem, int arg, res_t *res = nullptr, |
212 | int *error_count = nullptr); |
213 | |
214 | // Checks that the buffer is not overrun if it was protected by a canary. |
215 | int check_buffer_overwrite(const dnn_mem_t &mem, int arg, res_t *res = nullptr); |
216 | |
217 | // Returns physical offset by logical one. Logical offset is represented by an |
218 | // array pos. If is_pos_padded is true pos represents the position in already |
219 | // padded area. |
220 | dnnl_dim_t md_off_v(const dnn_mem_t &mem, const dnnl_dims_t pos, |
221 | bool is_pos_padded = false); |
222 | |
223 | #endif |
224 | |