1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <initializer_list>
18
19#include "common/c_types_map.hpp"
20#include "common/dnnl_thread.hpp"
21#include "common/math_utils.hpp"
22#include "common/rnn.hpp"
23#include "common/type_helpers.hpp"
24
25#include "cpu/gemm/gemm_pack.hpp"
26
27#include "cpu/rnn/ref_rnn.hpp"
28#include "cpu/rnn/rnn_utils.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34using namespace dnnl::impl::utils;
35using namespace rnn_utils;
36using namespace format_tag;
37using namespace rnn_packed_format;
38using namespace data_type;
39
40static bool check_dims_contiguous_except_one(const memory_desc_wrapper &mdw,
41 int idx_with_arbitrary_stride, std::initializer_list<int> perm) {
42 if (mdw.format_kind() != format_kind::blocked) return false;
43 if ((size_t)mdw.ndims() != perm.size()) return false;
44
45 const auto &blk = mdw.blocking_desc();
46
47 dim_t expect_stride = 1;
48 for (int idx = mdw.ndims() - 1; idx >= 0; --idx) {
49 const int permuted_idx = *(perm.begin() + idx);
50 const bool ok = (idx == idx_with_arbitrary_stride)
51 ? expect_stride <= blk.strides[permuted_idx]
52 : expect_stride == blk.strides[permuted_idx];
53 if (!ok) return false;
54 expect_stride = mdw.dims()[permuted_idx] * blk.strides[permuted_idx];
55 }
56
57 return true;
58}
59
60bool rnn_utils::is_ldigo(const memory_desc_wrapper &mdw) {
61 return check_dims_contiguous_except_one(mdw, 2, {0, 1, 2, 3, 4});
62}
63
64bool rnn_utils::is_ldgoi(const memory_desc_wrapper &mdw) {
65 return check_dims_contiguous_except_one(mdw, 3, {0, 1, 3, 4, 2});
66}
67
68bool rnn_utils::is_ldio(const memory_desc_wrapper &mdw) {
69 return check_dims_contiguous_except_one(mdw, 2, {0, 1, 2, 3});
70}
71
72bool rnn_utils::is_ldoi(const memory_desc_wrapper &mdw) {
73 return check_dims_contiguous_except_one(mdw, 2, {0, 1, 3, 2});
74}
75
76bool rnn_utils::is_ldigo_blocked(const memory_desc_wrapper &mdw) {
77 format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldgOi32o,
78 format_tag::ldgOI32o2i, format_tag::ldgOI32o4i,
79 format_tag::ldgOI64o2i, format_tag::ldgOI64o4i);
80 return md_format_tag != format_tag::undef;
81}
82
83bool rnn_utils::is_ldgoi_blocked(const memory_desc_wrapper &mdw) {
84 format_tag_t md_format_tag = mdw.matches_one_of_tag(
85 format_tag::ldgIo32i, format_tag::ldgIO32i2o);
86 return md_format_tag != format_tag::undef;
87}
88
89bool rnn_utils::is_ldio_blocked(const memory_desc_wrapper &mdw) {
90 format_tag_t md_format_tag = mdw.matches_one_of_tag(
91 format_tag::ldOi32o, format_tag::ldOI32o4i);
92 return md_format_tag != format_tag::undef;
93}
94
95int rnn_utils::get_good_ld(int dim, int sizeof_dt) {
96 // we want matrices leading dimentions to be 64-byte aligned,
97 // and not divisible by 256 to avoid 4K aliasing effects
98 const int ld = rnd_up(dim, 64 / sizeof_dt);
99 return (ld % 256 == 0) ? ld + 64 / sizeof_dt : ld;
100}
101
102void rnn_utils::set_offsets(const rnn_conf_t &rnn, size_t &ws_gates_offset,
103 size_t &ws_ht_offset, size_t &ws_states_layer_offset,
104 size_t &ws_states_iter_offset, size_t &ws_states_iter_c_offset,
105 size_t &ws_diff_states_layer_offset, size_t &ws_diff_states_iter_offset,
106 size_t &ws_diff_states_iter_c_offset, size_t &ws_grid_comp_offset,
107 size_t &ws_bias_offset, size_t &scratch_gates_offset,
108 size_t &scratch_ht_offset, size_t &scratch_diff_ht_offset,
109 size_t &scratch_cell_offset, size_t &scratchpad_size,
110 size_t &workspace_size) {
111
112 const size_t page_size = 4096; // 2097152;
113 size_t current_offset;
114 /* Mandatory workspaces: go to workspace if use_workspace, scratchpad
115 * otherwise */
116 current_offset = 0; // assumes the workspace base pointer is page aligned
117
118#define register_space(a) \
119 do { \
120 current_offset = utils::rnd_up(current_offset, page_size); \
121 CONCAT2(a, _offset) = current_offset; \
122 current_offset += rnn.CONCAT2(a, _size); \
123 } while (false)
124
125 register_space(ws_gates);
126 register_space(ws_ht);
127 register_space(ws_states_layer);
128 register_space(ws_states_iter);
129 register_space(ws_states_iter);
130
131 // For all currently supported cells, ws_iter should not be used
132 // at all since dst_iter == dst_layer
133 assert(rnn.ws_states_layer_size == rnn.ws_states_iter_size);
134 ws_states_iter_offset = ws_states_layer_offset;
135
136 register_space(ws_states_iter_c);
137 register_space(ws_diff_states_layer);
138 register_space(ws_diff_states_iter);
139 register_space(ws_diff_states_iter_c);
140 register_space(ws_grid_comp);
141
142 workspace_size = rnn.use_workspace ? current_offset : 0;
143
144 /* Optional scratchpads */
145 // Assumes the scratchpad base pointer is page aligned.
146 // If use_workspace, the following goes to scratchpad alone,
147 // otherwise, all goes to scratchpad and continue incrementing offset
148 current_offset = rnn.use_workspace ? 0 : current_offset;
149
150 register_space(scratch_gates);
151 register_space(scratch_ht);
152 register_space(scratch_diff_ht);
153 register_space(scratch_cell);
154 if (rnn.copy_bias)
155 register_space(ws_bias);
156 else
157 ws_bias_offset = 0;
158
159 scratchpad_size = current_offset;
160#undef register_space
161}
162
163void rnn_utils::get_scratchpad_and_workspace_sizes(const rnn_conf_t &rnn,
164 size_t &scratchpad_size, size_t &workspace_size) {
165 size_t ws_gates_offset, ws_ht_offset, ws_states_layer_offset,
166 ws_states_iter_offset, ws_states_iter_c_offset,
167 ws_diff_states_layer_offset, ws_diff_states_iter_offset,
168 ws_diff_states_iter_c_offset, ws_grid_comp_offset,
169 scratch_gates_offset, scratch_ht_offset, scratch_diff_ht_offset,
170 scratch_cell_offset, ws_bias_offset;
171 set_offsets(rnn, ws_gates_offset, ws_ht_offset, ws_states_layer_offset,
172 ws_states_iter_offset, ws_states_iter_c_offset,
173 ws_diff_states_layer_offset, ws_diff_states_iter_offset,
174 ws_diff_states_iter_c_offset, ws_grid_comp_offset, ws_bias_offset,
175 scratch_gates_offset, scratch_ht_offset, scratch_diff_ht_offset,
176 scratch_cell_offset, scratchpad_size, workspace_size);
177}
178
179status_t rnn_utils::set_good_strides(
180 memory_desc_t &weights_md, format_tag_t tag) {
181 auto &strides = weights_md.format_desc.blocking.strides;
182 const auto dims = weights_md.dims;
183
184 int ld_dim_idx = 0;
185 switch (tag) {
186 case ldio:
187 case ldigo:
188 strides[2] = rnn_utils::get_good_ld((int)strides[2],
189 (int)types::data_type_size(weights_md.data_type));
190 ld_dim_idx = 2;
191 break;
192 case ldoi:
193 case ldgoi:
194 strides[weights_md.ndims - 1]
195 = rnn_utils::get_good_ld((int)strides[weights_md.ndims - 1],
196 (int)types::data_type_size(weights_md.data_type));
197 if (tag == ldgoi) strides[3] = dims[4] * strides[4];
198 ld_dim_idx = 3;
199 break;
200 default: return status::unimplemented;
201 }
202 strides[1] = dims[ld_dim_idx] * strides[ld_dim_idx];
203 strides[0] = dims[1] * strides[1];
204
205 return status::success;
206}
207
208status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
209 memory_desc_t &weights_md, rnn_utils::weights_type_t weights_type) {
210 using namespace rnn_utils;
211 bool use_packed_gemm = false;
212 switch (weights_type) {
213 case weights_type_t::layer:
214 use_packed_gemm = rnn.use_layer_packed_gemm;
215 break;
216 case weights_type_t::iter:
217 use_packed_gemm = rnn.use_iter_packed_gemm;
218 break;
219 case weights_type_t::projection:
220 use_packed_gemm = rnn.use_projection_packed_gemm;
221 break;
222 default: assert(!"unsupported weights type");
223 }
224
225 if (use_packed_gemm) {
226 weights_md.format_kind = format_kind::rnn_packed;
227 rnn_packed_desc_t &rnn_pdata = weights_md.format_desc.rnn_packed_desc;
228 switch (weights_type) {
229 case weights_type_t::iter:
230 rnn_pdata.format = rnn.is_fwd
231 ? rnn_packed_memory_format_t::ldigo_p
232 : rnn_packed_memory_format_t::ldgoi_p;
233 rnn_pdata.ldb = rnn.ws_states_iter_ld;
234 rnn_pdata.n = rnn.mb;
235 rnn_pdata.n_parts = rnn.n_parts_weights_iter;
236 array_copy(rnn_pdata.parts, rnn.parts_weights_iter,
237 DNNL_RNN_MAX_N_PARTS);
238 array_copy(rnn_pdata.part_pack_size,
239 rnn.part_weights_iter_pack_size, DNNL_RNN_MAX_N_PARTS);
240 rnn_pdata.offset_compensation = rnn.weights_iter_comp_offset;
241 rnn_pdata.size = rnn.weights_iter_pack_size;
242 break;
243 case weights_type_t::layer:
244 rnn_pdata.format = rnn.is_fwd
245 ? rnn_packed_memory_format_t::ldigo_p
246 : rnn_packed_memory_format_t::ldgoi_p;
247 rnn_pdata.ldb = rnn.ws_states_layer_ld;
248 rnn_pdata.n
249 = rnn.merge_gemm_layer ? rnn.n_iter * rnn.mb : rnn.mb;
250 rnn_pdata.n_parts = rnn.n_parts_weights_layer;
251 array_copy(rnn_pdata.parts, rnn.parts_weights_layer,
252 DNNL_RNN_MAX_N_PARTS);
253 array_copy(rnn_pdata.part_pack_size,
254 rnn.part_weights_layer_pack_size, DNNL_RNN_MAX_N_PARTS);
255 rnn_pdata.offset_compensation = rnn.weights_layer_comp_offset;
256 rnn_pdata.size = rnn.weights_layer_pack_size;
257 break;
258 case weights_type_t::projection:
259 // TODO: add ldoi_p for bwd?
260 rnn_pdata.format = rnn_packed_memory_format_t::ldio_p;
261 rnn_pdata.ldb = rnn.proj_ht_ld;
262 rnn_pdata.n = rnn.mb;
263 rnn_pdata.n_parts = rnn.n_parts_weights_projection;
264 array_copy(rnn_pdata.parts, rnn.parts_weights_projection,
265 DNNL_RNN_MAX_N_PARTS);
266 array_copy(rnn_pdata.part_pack_size,
267 rnn.part_weights_projection_pack_size,
268 DNNL_RNN_MAX_N_PARTS);
269 rnn_pdata.offset_compensation
270 = rnn.weights_projection_comp_offset;
271 rnn_pdata.size = rnn.weights_projection_pack_size;
272 break;
273 default: assert(!"unsupported weights type");
274 }
275 if (rnn.is_signed_int8_conf()) {
276 weights_md.extra.flags
277 = 0 | memory_extra_flags::rnn_s8s8_compensation;
278 weights_md.extra.compensation_mask = 0;
279 }
280 } else {
281 using namespace format_tag;
282 if (rnn.is_brgemm) {
283 format_tag_t tag = format_tag::undef;
284
285 if (weights_type == weights_type_t::projection) {
286 tag = rnn.is_int8_conf() ? format_tag::ldOI32o4i
287 : format_tag::ldOi32o;
288 } else if (rnn.is_fwd) {
289 tag = rnn.is_int8_conf()
290 ? (rnn.n_block == 64 ? format_tag::ldgOI64o4i
291 : format_tag::ldgOI32o4i)
292 : rnn.is_bf16_conf()
293 ? (rnn.n_block == 64 ? format_tag::ldgOI64o2i
294 : format_tag::ldgOI32o2i)
295 : format_tag::ldgOi32o;
296 } else {
297 tag = rnn.is_bf16_conf() ? format_tag::ldgIO32i2o
298 : format_tag::ldgIo32i;
299 }
300
301 if (tag != format_tag::undef) {
302 CHECK(memory_desc_init_by_tag(weights_md, tag));
303 if (rnn.is_unsigned_int8_conf()) {
304 weights_md.extra.flags
305 = 0 | memory_extra_flags::rnn_u8s8_compensation;
306 weights_md.extra.compensation_mask
307 = (weights_type == weights_type_t::projection)
308 ? 13 // 1101
309 : 27; // 11011
310 } else if (rnn.is_signed_int8_conf()) {
311 weights_md.extra.flags
312 = 0 | memory_extra_flags::rnn_s8s8_compensation;
313 weights_md.extra.compensation_mask = 0;
314 }
315 return status::success;
316 } else {
317 return status::unimplemented;
318 }
319 } else {
320 const format_tag_t tag = weights_type == weights_type_t::projection
321 ? rnn.is_fwd ? ldio : ldoi
322 : rnn.is_fwd ? ldigo : ldgoi;
323 CHECK(memory_desc_init_by_tag(weights_md, tag));
324 // Adjust strides for good leading dimension in GEMM
325 CHECK(set_good_strides(weights_md, tag));
326 }
327 }
328 return status::success;
329}
330
331float rnn_utils::to_float(const void *data, const data_type_t dt) {
332 if (dt == data_type::f32)
333 return *static_cast<const float *>(data);
334 else if (dt == data_type::bf16)
335 return float(*static_cast<const bfloat16_t *>(data));
336 return 0.0;
337}
338
339const void *rnn_utils::inc_ptr(
340 const void *data, data_type_t data_type, int offset) {
341 if (data_type == data_type::f32)
342 return static_cast<const float *>(data) + offset;
343 else if (data_type == data_type::bf16)
344 return static_cast<const bfloat16_t *>(data) + offset;
345 else
346 return data;
347}
348
349void *rnn_utils::inc_ptr(void *data, data_type_t data_type, int offset) {
350 return const_cast<void *>(
351 inc_ptr(static_cast<const void *>(data), data_type, offset));
352}
353
354} // namespace cpu
355} // namespace impl
356} // namespace dnnl
357