1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <initializer_list> |
18 | |
19 | #include "common/c_types_map.hpp" |
20 | #include "common/dnnl_thread.hpp" |
21 | #include "common/math_utils.hpp" |
22 | #include "common/rnn.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | #include "cpu/gemm/gemm_pack.hpp" |
26 | |
27 | #include "cpu/rnn/ref_rnn.hpp" |
28 | #include "cpu/rnn/rnn_utils.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | using namespace dnnl::impl::utils; |
35 | using namespace rnn_utils; |
36 | using namespace format_tag; |
37 | using namespace rnn_packed_format; |
38 | using namespace data_type; |
39 | |
40 | static bool check_dims_contiguous_except_one(const memory_desc_wrapper &mdw, |
41 | int idx_with_arbitrary_stride, std::initializer_list<int> perm) { |
42 | if (mdw.format_kind() != format_kind::blocked) return false; |
43 | if ((size_t)mdw.ndims() != perm.size()) return false; |
44 | |
45 | const auto &blk = mdw.blocking_desc(); |
46 | |
47 | dim_t expect_stride = 1; |
48 | for (int idx = mdw.ndims() - 1; idx >= 0; --idx) { |
49 | const int permuted_idx = *(perm.begin() + idx); |
50 | const bool ok = (idx == idx_with_arbitrary_stride) |
51 | ? expect_stride <= blk.strides[permuted_idx] |
52 | : expect_stride == blk.strides[permuted_idx]; |
53 | if (!ok) return false; |
54 | expect_stride = mdw.dims()[permuted_idx] * blk.strides[permuted_idx]; |
55 | } |
56 | |
57 | return true; |
58 | } |
59 | |
60 | bool rnn_utils::is_ldigo(const memory_desc_wrapper &mdw) { |
61 | return check_dims_contiguous_except_one(mdw, 2, {0, 1, 2, 3, 4}); |
62 | } |
63 | |
64 | bool rnn_utils::is_ldgoi(const memory_desc_wrapper &mdw) { |
65 | return check_dims_contiguous_except_one(mdw, 3, {0, 1, 3, 4, 2}); |
66 | } |
67 | |
68 | bool rnn_utils::is_ldio(const memory_desc_wrapper &mdw) { |
69 | return check_dims_contiguous_except_one(mdw, 2, {0, 1, 2, 3}); |
70 | } |
71 | |
72 | bool rnn_utils::is_ldoi(const memory_desc_wrapper &mdw) { |
73 | return check_dims_contiguous_except_one(mdw, 2, {0, 1, 3, 2}); |
74 | } |
75 | |
76 | bool rnn_utils::is_ldigo_blocked(const memory_desc_wrapper &mdw) { |
77 | format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldgOi32o, |
78 | format_tag::ldgOI32o2i, format_tag::ldgOI32o4i, |
79 | format_tag::ldgOI64o2i, format_tag::ldgOI64o4i); |
80 | return md_format_tag != format_tag::undef; |
81 | } |
82 | |
83 | bool rnn_utils::is_ldgoi_blocked(const memory_desc_wrapper &mdw) { |
84 | format_tag_t md_format_tag = mdw.matches_one_of_tag( |
85 | format_tag::ldgIo32i, format_tag::ldgIO32i2o); |
86 | return md_format_tag != format_tag::undef; |
87 | } |
88 | |
89 | bool rnn_utils::is_ldio_blocked(const memory_desc_wrapper &mdw) { |
90 | format_tag_t md_format_tag = mdw.matches_one_of_tag( |
91 | format_tag::ldOi32o, format_tag::ldOI32o4i); |
92 | return md_format_tag != format_tag::undef; |
93 | } |
94 | |
95 | int rnn_utils::get_good_ld(int dim, int sizeof_dt) { |
96 | // we want matrices leading dimentions to be 64-byte aligned, |
97 | // and not divisible by 256 to avoid 4K aliasing effects |
98 | const int ld = rnd_up(dim, 64 / sizeof_dt); |
99 | return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld; |
100 | } |
101 | |
102 | void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset, |
103 | size_t &ws_ht_offset, size_t &ws_states_layer_offset, |
104 | size_t &ws_states_iter_offset, size_t &ws_states_iter_c_offset, |
105 | size_t &ws_diff_states_layer_offset, size_t &ws_diff_states_iter_offset, |
106 | size_t &ws_diff_states_iter_c_offset, size_t &ws_grid_comp_offset, |
107 | size_t &ws_bias_offset, size_t &scratch_gates_offset, |
108 | size_t &scratch_ht_offset, size_t &scratch_diff_ht_offset, |
109 | size_t &scratch_cell_offset, size_t &scratchpad_size, |
110 | size_t &workspace_size) { |
111 | |
112 | const size_t page_size = 4096; // 2097152; |
113 | size_t current_offset; |
114 | /* Mandatory workspaces: go to workspace if use_workspace, scratchpad |
115 | * otherwise */ |
116 | current_offset = 0; // assumes the workspace base pointer is page aligned |
117 | |
118 | #define register_space(a) \ |
119 | do { \ |
120 | current_offset = utils::rnd_up(current_offset, page_size); \ |
121 | CONCAT2(a, _offset) = current_offset; \ |
122 | current_offset += rnn.CONCAT2(a, _size); \ |
123 | } while (false) |
124 | |
125 | register_space(ws_gates); |
126 | register_space(ws_ht); |
127 | register_space(ws_states_layer); |
128 | register_space(ws_states_iter); |
129 | register_space(ws_states_iter); |
130 | |
131 | // For all currently supported cells, ws_iter should not be used |
132 | // at all since dst_iter == dst_layer |
133 | assert(rnn.ws_states_layer_size == rnn.ws_states_iter_size); |
134 | ws_states_iter_offset = ws_states_layer_offset; |
135 | |
136 | register_space(ws_states_iter_c); |
137 | register_space(ws_diff_states_layer); |
138 | register_space(ws_diff_states_iter); |
139 | register_space(ws_diff_states_iter_c); |
140 | register_space(ws_grid_comp); |
141 | |
142 | workspace_size = rnn.use_workspace ? current_offset : 0; |
143 | |
144 | /* Optional scratchpads */ |
145 | // Assumes the scratchpad base pointer is page aligned. |
146 | // If use_workspace, the following goes to scratchpad alone, |
147 | // otherwise, all goes to scratchpad and continue incrementing offset |
148 | current_offset = rnn.use_workspace ? 0 : current_offset; |
149 | |
150 | register_space(scratch_gates); |
151 | register_space(scratch_ht); |
152 | register_space(scratch_diff_ht); |
153 | register_space(scratch_cell); |
154 | if (rnn.copy_bias) |
155 | register_space(ws_bias); |
156 | else |
157 | ws_bias_offset = 0; |
158 | |
159 | scratchpad_size = current_offset; |
160 | #undef register_space |
161 | } |
162 | |
163 | void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn, |
164 | size_t &scratchpad_size, size_t &workspace_size) { |
165 | size_t ws_gates_offset, ws_ht_offset, ws_states_layer_offset, |
166 | ws_states_iter_offset, ws_states_iter_c_offset, |
167 | ws_diff_states_layer_offset, ws_diff_states_iter_offset, |
168 | ws_diff_states_iter_c_offset, ws_grid_comp_offset, |
169 | scratch_gates_offset, scratch_ht_offset, scratch_diff_ht_offset, |
170 | scratch_cell_offset, ws_bias_offset; |
171 | set_offsets(rnn, ws_gates_offset, ws_ht_offset, ws_states_layer_offset, |
172 | ws_states_iter_offset, ws_states_iter_c_offset, |
173 | ws_diff_states_layer_offset, ws_diff_states_iter_offset, |
174 | ws_diff_states_iter_c_offset, ws_grid_comp_offset, ws_bias_offset, |
175 | scratch_gates_offset, scratch_ht_offset, scratch_diff_ht_offset, |
176 | scratch_cell_offset, scratchpad_size, workspace_size); |
177 | } |
178 | |
179 | status_t rnn_utils::set_good_strides( |
180 | memory_desc_t &weights_md, format_tag_t tag) { |
181 | auto &strides = weights_md.format_desc.blocking.strides; |
182 | const auto dims = weights_md.dims; |
183 | |
184 | int ld_dim_idx = 0; |
185 | switch (tag) { |
186 | case ldio: |
187 | case ldigo: |
188 | strides[2] = rnn_utils::get_good_ld((int)strides[2], |
189 | (int)types::data_type_size(weights_md.data_type)); |
190 | ld_dim_idx = 2; |
191 | break; |
192 | case ldoi: |
193 | case ldgoi: |
194 | strides[weights_md.ndims - 1] |
195 | = rnn_utils::get_good_ld((int)strides[weights_md.ndims - 1], |
196 | (int)types::data_type_size(weights_md.data_type)); |
197 | if (tag == ldgoi) strides[3] = dims[4] * strides[4]; |
198 | ld_dim_idx = 3; |
199 | break; |
200 | default: return status::unimplemented; |
201 | } |
202 | strides[1] = dims[ld_dim_idx] * strides[ld_dim_idx]; |
203 | strides[0] = dims[1] * strides[1]; |
204 | |
205 | return status::success; |
206 | } |
207 | |
208 | status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn, |
209 | memory_desc_t &weights_md, rnn_utils::weights_type_t weights_type) { |
210 | using namespace rnn_utils; |
211 | bool use_packed_gemm = false; |
212 | switch (weights_type) { |
213 | case weights_type_t::layer: |
214 | use_packed_gemm = rnn.use_layer_packed_gemm; |
215 | break; |
216 | case weights_type_t::iter: |
217 | use_packed_gemm = rnn.use_iter_packed_gemm; |
218 | break; |
219 | case weights_type_t::projection: |
220 | use_packed_gemm = rnn.use_projection_packed_gemm; |
221 | break; |
222 | default: assert(!"unsupported weights type" ); |
223 | } |
224 | |
225 | if (use_packed_gemm) { |
226 | weights_md.format_kind = format_kind::rnn_packed; |
227 | rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc; |
228 | switch (weights_type) { |
229 | case weights_type_t::iter: |
230 | rnn_pdata.format = rnn.is_fwd |
231 | ? rnn_packed_memory_format_t::ldigo_p |
232 | : rnn_packed_memory_format_t::ldgoi_p; |
233 | rnn_pdata.ldb = rnn.ws_states_iter_ld; |
234 | rnn_pdata.n = rnn.mb; |
235 | rnn_pdata.n_parts = rnn.n_parts_weights_iter; |
236 | array_copy(rnn_pdata.parts, rnn.parts_weights_iter, |
237 | DNNL_RNN_MAX_N_PARTS); |
238 | array_copy(rnn_pdata.part_pack_size, |
239 | rnn.part_weights_iter_pack_size, DNNL_RNN_MAX_N_PARTS); |
240 | rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset; |
241 | rnn_pdata.size = rnn.weights_iter_pack_size; |
242 | break; |
243 | case weights_type_t::layer: |
244 | rnn_pdata.format = rnn.is_fwd |
245 | ? rnn_packed_memory_format_t::ldigo_p |
246 | : rnn_packed_memory_format_t::ldgoi_p; |
247 | rnn_pdata.ldb = rnn.ws_states_layer_ld; |
248 | rnn_pdata.n |
249 | = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb; |
250 | rnn_pdata.n_parts = rnn.n_parts_weights_layer; |
251 | array_copy(rnn_pdata.parts, rnn.parts_weights_layer, |
252 | DNNL_RNN_MAX_N_PARTS); |
253 | array_copy(rnn_pdata.part_pack_size, |
254 | rnn.part_weights_layer_pack_size, DNNL_RNN_MAX_N_PARTS); |
255 | rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset; |
256 | rnn_pdata.size = rnn.weights_layer_pack_size; |
257 | break; |
258 | case weights_type_t::projection: |
259 | // TODO: add ldoi_p for bwd? |
260 | rnn_pdata.format = rnn_packed_memory_format_t::ldio_p; |
261 | rnn_pdata.ldb = rnn.proj_ht_ld; |
262 | rnn_pdata.n = rnn.mb; |
263 | rnn_pdata.n_parts = rnn.n_parts_weights_projection; |
264 | array_copy(rnn_pdata.parts, rnn.parts_weights_projection, |
265 | DNNL_RNN_MAX_N_PARTS); |
266 | array_copy(rnn_pdata.part_pack_size, |
267 | rnn.part_weights_projection_pack_size, |
268 | DNNL_RNN_MAX_N_PARTS); |
269 | rnn_pdata.offset_compensation |
270 | = rnn.weights_projection_comp_offset; |
271 | rnn_pdata.size = rnn.weights_projection_pack_size; |
272 | break; |
273 | default: assert(!"unsupported weights type" ); |
274 | } |
275 | if (rnn.is_signed_int8_conf()) { |
276 | weights_md.extra.flags |
277 | = 0 | memory_extra_flags::rnn_s8s8_compensation; |
278 | weights_md.extra.compensation_mask = 0; |
279 | } |
280 | } else { |
281 | using namespace format_tag; |
282 | if (rnn.is_brgemm) { |
283 | format_tag_t tag = format_tag::undef; |
284 | |
285 | if (weights_type == weights_type_t::projection) { |
286 | tag = rnn.is_int8_conf() ? format_tag::ldOI32o4i |
287 | : format_tag::ldOi32o; |
288 | } else if (rnn.is_fwd) { |
289 | tag = rnn.is_int8_conf() |
290 | ? (rnn.n_block == 64 ? format_tag::ldgOI64o4i |
291 | : format_tag::ldgOI32o4i) |
292 | : rnn.is_bf16_conf() |
293 | ? (rnn.n_block == 64 ? format_tag::ldgOI64o2i |
294 | : format_tag::ldgOI32o2i) |
295 | : format_tag::ldgOi32o; |
296 | } else { |
297 | tag = rnn.is_bf16_conf() ? format_tag::ldgIO32i2o |
298 | : format_tag::ldgIo32i; |
299 | } |
300 | |
301 | if (tag != format_tag::undef) { |
302 | CHECK(memory_desc_init_by_tag(weights_md, tag)); |
303 | if (rnn.is_unsigned_int8_conf()) { |
304 | weights_md.extra.flags |
305 | = 0 | memory_extra_flags::rnn_u8s8_compensation; |
306 | weights_md.extra.compensation_mask |
307 | = (weights_type == weights_type_t::projection) |
308 | ? 13 // 1101 |
309 | : 27; // 11011 |
310 | } else if (rnn.is_signed_int8_conf()) { |
311 | weights_md.extra.flags |
312 | = 0 | memory_extra_flags::rnn_s8s8_compensation; |
313 | weights_md.extra.compensation_mask = 0; |
314 | } |
315 | return status::success; |
316 | } else { |
317 | return status::unimplemented; |
318 | } |
319 | } else { |
320 | const format_tag_t tag = weights_type == weights_type_t::projection |
321 | ? rnn.is_fwd ? ldio : ldoi |
322 | : rnn.is_fwd ? ldigo : ldgoi; |
323 | CHECK(memory_desc_init_by_tag(weights_md, tag)); |
324 | // Adjust strides for good leading dimension in GEMM |
325 | CHECK(set_good_strides(weights_md, tag)); |
326 | } |
327 | } |
328 | return status::success; |
329 | } |
330 | |
331 | float rnn_utils::to_float(const void *data, const data_type_t dt) { |
332 | if (dt == data_type::f32) |
333 | return *static_cast<const float *>(data); |
334 | else if (dt == data_type::bf16) |
335 | return float(*static_cast<const bfloat16_t *>(data)); |
336 | return 0.0; |
337 | } |
338 | |
339 | const void *rnn_utils::inc_ptr( |
340 | const void *data, data_type_t data_type, int offset) { |
341 | if (data_type == data_type::f32) |
342 | return static_cast<const float *>(data) + offset; |
343 | else if (data_type == data_type::bf16) |
344 | return static_cast<const bfloat16_t *>(data) + offset; |
345 | else |
346 | return data; |
347 | } |
348 | |
349 | void *rnn_utils::inc_ptr(void *data, data_type_t data_type, int offset) { |
350 | return const_cast<void *>( |
351 | inc_ptr(static_cast<const void *>(data), data_type, offset)); |
352 | } |
353 | |
354 | } // namespace cpu |
355 | } // namespace impl |
356 | } // namespace dnnl |
357 | |