1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <float.h>
18#include <math.h>
19#include <random>
20#include <stdio.h>
21#include <stdlib.h>
22#include <type_traits>
23
24#include "oneapi/dnnl/dnnl.h"
25
26#include "tests/test_isa_common.hpp"
27#include "utils/parallel.hpp"
28
29#include "dnnl_common.hpp"
30#include "dnnl_memory.hpp"
31
32#include "rnn/rnn.hpp"
33#include "rnn/rnn_aux.hpp"
34
35// Using hidden attr API for testing RNN
36dnnl_status_t dnnl_primitive_attr_set_rnn_tparams(dnnl_primitive_attr_t attr,
37 bool mode, dnnl_dim_t ngates, const float *scales, float cscale);
38
39namespace {
40
41// In order to have consistent filling across compilers and operating systems,
42// we implement the equivalent of std::normal_distribution using the so-called
43// Marsaglia polar method.
44template <typename T>
45class normal_distribution_t {
46public:
47 normal_distribution_t(T mean, T stddev)
48 : gen(-1.f, 1.f)
49 , is_odd_(false)
50 , odd_(1.f)
51 , mean_(mean)
52 , stddev_(stddev) {
53 static_assert(std::is_floating_point<T>::value,
54 "T must be a floating point type.");
55 }
56 template <typename URNG>
57 T operator()(URNG &g) {
58 T r, r2, x, y;
59 if (is_odd_) {
60 is_odd_ = false;
61 return odd_;
62 }
63 is_odd_ = true;
64 do {
65 x = gen(g); // x E [-1, 1)
66 y = gen(g); // y E [-1, 1)
67 r2 = x * x + y * y;
68 } while (0.f == r2 || 1.f < r2); // r2 E (0, 1]
69 r = stddev_ * std::sqrt(-2.f * std::log(r2) / r2);
70 x = mean_ + x * r;
71 y = mean_ + y * r;
72 odd_ = x;
73 return y;
74 }
75
76private:
77 std::uniform_real_distribution<T> gen;
78 bool is_odd_;
79 T odd_;
80 const T mean_;
81 const T stddev_;
82};
83
84} // namespace
85
86namespace rnn {
87
88dnnl_primitive_attr_t create_dnnl_rnn_attr(const prb_t &prb) {
89 dnnl_primitive_attr_t dnnl_attr = nullptr;
90 DNN_SAFE_V(dnnl_primitive_attr_create(&dnnl_attr));
91
92 if (prb.skip_nonlinear)
93 DNN_SAFE_V(dnnl_primitive_attr_set_rnn_tparams(dnnl_attr, true,
94 prb.n_gates(), prb.linear_scales, prb.linear_cscale));
95
96 DNN_SAFE_V(dnnl_primitive_attr_set_rnn_weights_qparams(
97 dnnl_attr, prb.wei_nscales, prb.wei_scales_mask, prb.wei_scales));
98
99 if (prb.is_lstm_projection() && prb.is_int8())
100 DNN_SAFE_V(dnnl_primitive_attr_set_rnn_weights_projection_qparams(
101 dnnl_attr, prb.wei_proj_nscales, prb.wei_proj_scales_mask,
102 prb.wei_proj_scales));
103
104 if (prb.data_scale != 1.0 || prb.data_shift != 0.0)
105 DNN_SAFE_V(dnnl_primitive_attr_set_rnn_data_qparams(
106 dnnl_attr, prb.data_scale, prb.data_shift));
107
108 DNN_SAFE_V(dnnl_primitive_attr_set_scratchpad_mode(
109 dnnl_attr, prb.attr.scratchpad_mode));
110
111 DNN_SAFE_V(dnnl_primitive_attr_set_fpmath_mode(
112 dnnl_attr, prb.attr.fpmath_mode));
113
114 return dnnl_attr;
115}
116
117int check_s8s8_reorder(const prb_t &prb, rnn_data_kind_t kind,
118 const dnn_mem_t &mem_dt, const dnn_mem_t &mem_fp) {
119 // TODO: enable for all cpu_kind when supported
120 if (is_gpu()) return OK;
121
122#if DNNL_CPU_RUNTIME == DNNL_RUNTIME_DPCPP
123 // DPC++ does not provide a simple way to access the underlying
124 // buffer alignment.
125 return OK;
126#endif
127
128 // In the main test, we fill buffers with f32 and reorder to s8
129 // with quantization.
130
131 // The endgoal is to check that the reorder
132 // f32_plain_nonquantized --reorder--> s8_packed_quantized
133 // gives the same output as the sequence
134 // f32_plain --quant--> s8_plain_quantized --reorder--> s8_packed_quantized
135
136 // Here,
137 // 1. we quantize the f32 plain memory to s8 plain memory,
138 // 2. we reorder the s8 plain to s8 packed (queried from rnn primitive desc)
139 // 3. we check that the two memory are bitwise identical.
140
141 // Note: the two s8 packed memories need to have the same
142 // alignment as packed buffer is aligned internally and the offset
143 // is kept in the metadata.
144 // Works fine with dnn_mem_t as it is align to 2MB large page boundary
145 dnn_mem_t mem_s8_src(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine());
146 dnn_mem_t mem_s8_dst(mem_dt.md_, get_test_engine());
147
148 /* 1. compute f32_plain --quant--> s8_plain_quantized */
149 /* Do fixed partitioning to have same filling for any number of threads */
150 auto nelems = mem_fp.nelems();
151 const int64_t n_chunks = 16;
152 const int64_t chunk_size = div_up(nelems, n_chunks);
153 const auto quantize = [&](const float *scales, int nscales, float shift,
154 int idx_chunk) {
155 int64_t idx_start = idx_chunk * chunk_size;
156 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
157 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
158 const float current_scale = scales[idx % nscales];
159 float val_f32 = mem_fp.get_elem(idx);
160 int8_t val_s8 = saturate_and_round<dnnl_s8>(
161 val_f32 * current_scale + shift);
162 mem_s8_src.set_elem(idx, val_s8);
163 }
164 };
165 switch (kind) {
166 case WEIGHTS_LAYER:
167 case WEIGHTS_ITER:
168 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
169 quantize(prb.wei_scales, prb.wei_nscales, 0, idx);
170 });
171 break;
172 case WEIGHTS_PROJECTION:
173 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
174 quantize(prb.wei_proj_scales, prb.wei_proj_nscales, 0, idx);
175 });
176 break;
177 case SRC_LAYER:
178 case SRC_ITER:
179 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
180 quantize(&(prb.data_scale), 1, prb.data_shift, idx);
181 });
182 break;
183 default: assert(!"unsupported kind");
184 }
185
186 /* 2. compute s8_plain_quantized --reorder--> s8_packed_quantized */
187 mem_s8_dst.reorder(mem_s8_src);
188
189 /* 3. we check that the two memory are bitwise identical. */
190 auto sz = mem_dt.size();
191 uint8_t *s8_dst_handle = (uint8_t *)mem_s8_dst;
192 uint8_t *mem_dt_handle = (uint8_t *)mem_dt;
193
194 // check that both have the same size
195 assert(mem_dt.size() == mem_s8_dst.size());
196 // check that both have the same alignment modulo align_data in gemm_pack_storage.hpp
197 assert((uint64_t)s8_dst_handle % 0x1000
198 == (uint64_t)mem_dt_handle % 0x1000);
199 for (size_t i = 0; i < sz; ++i) {
200 if (s8_dst_handle[i] != mem_dt_handle[i]) { return FAIL; }
201 }
202
203 return OK;
204}
205
206int fill_memory(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
207 dnn_mem_t &mem_fp, dnnl_data_type_t dt, float mean, float stddev,
208 float min, float max, const_dnnl_primitive_attr_t attr = nullptr,
209 bool flip_sign = false) {
210 const auto nelems = mem_dt.nelems();
211 if (nelems == 0) return OK;
212 assert(mem_dt.nelems() == mem_fp.nelems());
213
214 // For non-int8 RNN the data is filled according to cfg directly.
215 // However, for int8 RNN we have slightly obscure logic, at least for now:
216 // 1. cfg describes the quantized data;
217 // 2. We fill first f32 de-quantized data, by inverse-applying the scale
218 // and shift to the data generated by cfg distribution;
219 // 3. We reorder the data for the oneDNN RNN primitive
220 // 4. Q10n of the data for reference benchdnn RNN:
221 // 4.a. If the tensor is weights -- q10n it here;
222 // 4.b. If the tensor is data -- reference benchdnn RNN will quantize it.
223
224 // pass rnn attributes to f32 -> int8 reorders only
225 const_dnnl_primitive_attr_t reorder_attr = nullptr;
226 if (prb.is_int8() && (dt != dnnl_f32)) reorder_attr = attr;
227 float default_scales[1] = {1.0f};
228 float default_shift = 0.0f;
229
230 /* Do fixed partitioning to have same filling for any number of threads */
231 const int64_t n_chunks = 16;
232 const int64_t chunk_size = div_up(nelems, n_chunks);
233
234 // 2. We fill first f32 de-quantized data, by inverse-applying the scale
235 // and shift to the data generated by cfg distribution;
236 auto fill_chunk = [&](const float *scales, int nscales, float shift,
237 int idx_chunk) {
238 int64_t idx_start = idx_chunk * chunk_size;
239 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
240 std::minstd_rand msr;
241 msr.seed(idx_start + kind);
242 normal_distribution_t<float> gen(mean, stddev);
243 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
244 float val = round_to_nearest_representable(dt, gen(msr));
245 val = MAX2(MIN2(val, max), min);
246 val = (val - shift)
247 / scales[idx % nscales]; // change only int8-case
248
249 // Vanilla RNN with RELU testing related only: flip the sign of
250 // inputs for `mb` == 0 to test RELU part
251 if (flip_sign) {
252 assert(kind == SRC_LAYER || kind == SRC_ITER);
253 auto ld = kind == SRC_LAYER ? prb.slc : prb.sic;
254 if (idx % (prb.mb * ld) < ld) val *= -1;
255 }
256 mem_fp.set_elem(idx, val);
257 }
258 };
259 switch (kind) {
260 case WEIGHTS_PROJECTION:
261 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
262 fill_chunk(
263 prb.wei_proj_scales, prb.wei_proj_nscales, 0.0f, idx);
264 });
265 break;
266 case WEIGHTS_LAYER:
267 case WEIGHTS_ITER:
268 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
269 fill_chunk(prb.wei_scales, prb.wei_nscales, 0.0f, idx);
270 });
271 break;
272 case SRC_LAYER:
273 case SRC_ITER:
274 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
275 fill_chunk(&(prb.data_scale), 1, prb.data_shift, idx);
276 });
277 break;
278 default: // we do no scale/shift
279 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
280 fill_chunk(default_scales, 1, default_shift, idx);
281 });
282 }
283
284 // 3. We reorder the data for the DNNL RNN primitive
285 mem_dt.reorder(mem_fp, reorder_attr);
286 if ((reorder_attr != nullptr) && (dt == dnnl_s8))
287 if (check_s8s8_reorder(prb, kind, mem_dt, mem_fp) != OK) return FAIL;
288
289 // Bullet 4.a holds: quantize weights for int8 benchdnn reference RNN
290 if (prb.is_int8()) {
291 auto quantize_chunk
292 = [&](const float *scales, int nscales, int idx_chunk) {
293 int64_t idx_start = idx_chunk * chunk_size;
294 int64_t idx_end = MIN2(idx_start + chunk_size, nelems);
295 for (int64_t idx = idx_start; idx < idx_end; ++idx) {
296 float current_scale = scales[idx % nscales];
297 float val = ((float *)mem_fp)[idx];
298 val = round(current_scale * val);
299 mem_fp.set_elem(idx, MAX2(MIN2(val, max), min));
300 }
301 };
302 switch (kind) {
303 case WEIGHTS_LAYER:
304 case WEIGHTS_ITER:
305 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
306 quantize_chunk(prb.wei_scales, prb.wei_nscales, idx);
307 });
308 break;
309 case WEIGHTS_PROJECTION:
310 benchdnn_parallel_nd(n_chunks, [&](int64_t idx) {
311 quantize_chunk(
312 prb.wei_proj_scales, prb.wei_proj_nscales, idx);
313 });
314 break;
315 default: // Nothing to do
316 break;
317 }
318 }
319
320 return OK;
321}
322
323int fill_memory(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
324 dnn_mem_t &mem_fp, const_dnnl_primitive_attr_t attr = nullptr,
325 bool flip_sign = false) {
326 const dt_conf_t::entry_t &c = prb.cfg[kind];
327 return fill_memory(prb, kind, mem_dt, mem_fp, c.dt, c.f_mean, c.f_stddev,
328 c.f_min, c.f_max, attr, flip_sign);
329}
330
331int fill_activation(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
332 dnn_mem_t &mem_fp, const_dnnl_primitive_attr_t attr = nullptr) {
333 // In general, we mostly want to use positive values to avoid
334 // cancellation from happening during computation. The only case
335 // where we actually want negative values to appear is for 1 layer
336 // 1 iteration tests using vanilla_rnn and non-zero alpha. In that
337 // case, we want to check that alpha is applied accordingly. Here
338 // skip_nonlinear is checked as we want to test relu with non-zero
339 // alpha, and not the linear function that would replace it under
340 // skip_nonlinear=true.
341 bool flip_sign = prb.skip_nonlinear == false && prb.alg == VANILLA_RNN
342 && prb.activation == RELU
343 && (kind == SRC_LAYER || kind == SRC_ITER);
344 return fill_memory(prb, kind, mem_dt, mem_fp, attr, flip_sign);
345}
346
347int fill_src_iter_c(const prb_t &prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp,
348 const_dnnl_primitive_attr_t attr = nullptr) {
349 const bool special_case = prb.prop == dnnl_backward && prb.skip_nonlinear;
350 if (!special_case)
351 return fill_memory(prb, SRC_ITER_C, mem_dt, mem_fp, attr);
352
353 // The scaling factors in tparams when testing backward are common for
354 // for forward and backward passes, and computed as 1 over maximum of
355 // the accumulation chain:
356 // - ~n_gates on FWD
357 // - ~dhc * n_gates on BWD_D
358 // - ~mb * n_gates on BWD_W
359 //
360 // This makes tparam relatively small for the forward pass (compare to
361 // the forward pass when we test forward only). This in turn, makes
362 // src_iter_c converge relatively fast to the value ~ i_gate * c_gate,
363 // which is (typically) way smaller than the original distribution for
364 // src_iter_c.
365 //
366 // TODO: use different tparams for forward and
367 // backward passes when testing BWD_DW.
368 //
369 // The problem appears on backward pass. Consider diff_f_gate that
370 // contributes to backward weights when batch or number of iterations
371 // is big:
372 // diff_f_gate[iter] = src_iter_c[iter] * diff_dst[iter].
373 // diff_weights += ~diff_f_gate[iter].
374 //
375 // Assume, that diff_dst[iter] is always about the same for every iter.
376 // Since src_iter_c[0] >> src_iter_c[iter] for iter > 0, this makes the
377 // diff_weight be highly dependent on the order of accumulating the
378 // diff_f_gate[iter].
379 //
380 // Originally we had something like:
381 // diff_weights = v + v * 10^-5 + ... + v * 10^-5 (n_iter * MB summands).
382 // Depending on the order of summation the difference might exceed the
383 // typical bound approximation: coefficient * log(number_of_summands).
384 //
385 // Anyways, the algorithm below tries to put the first src_iter_c[iter = 0]
386 // in the same ballpark as all the subsequent src_iter_c[iter > 0].
387 //
388 // The estimation is based on the following rough assumptions:
389 // src_iter_c[iter+1] = f_gate * src_iter_c[iter] + i_gate * c_gate
390 // ~= f_gate * small_value + i_gate * c_gate
391 // ~= i_gate * c_gate.
392 // i_gate ~= tparams[i_gate] * (
393 // 1 / ngates * mean_src_layer +
394 // 1 / ngates * mean_src_iter +
395 // mean_bias);
396 //
397 // Same for c_gate.
398 // The (1 / ngates) factor is taken from fill_weights().
399
400 float expect_gemm_output = (1.f / prb.n_gates()) * prb.cfg[SRC_LAYER].f_mean
401 + (1.f / prb.n_gates()) * prb.cfg[SRC_ITER].f_mean
402 + prb.cfg[BIAS].f_mean;
403 float expect_i_gate = (float)prb.linear_scales[LSTM_I] * expect_gemm_output;
404 float expect_c_gate = (float)prb.linear_scales[LSTM_C] * expect_gemm_output;
405 float expect_src_iter_c_mean = expect_i_gate * expect_c_gate;
406
407 float adjust_factor = 1;
408
409 const bool need_adjust = expect_src_iter_c_mean < prb.cfg[SRC_ITER_C].f_mean
410 && prb.cfg[SRC_ITER_C].f_mean != 0;
411 if (need_adjust)
412 adjust_factor = expect_src_iter_c_mean / prb.cfg[SRC_ITER_C].f_mean;
413
414 const dt_conf_t::entry_t &c = prb.cfg[SRC_ITER_C];
415 return fill_memory(prb, SRC_ITER_C, mem_dt, mem_fp, c.dt,
416 c.f_mean * adjust_factor, c.f_stddev * adjust_factor,
417 c.f_min * adjust_factor, c.f_max * adjust_factor, attr);
418}
419
420int fill_weights(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
421 dnn_mem_t &mem_fp, const_dnnl_primitive_attr_t attr = nullptr) {
422 const auto nelems = mem_dt.nelems();
423 if (nelems == 0) return OK;
424 const dt_conf_t::entry_t &c = prb.cfg[kind];
425
426 assert(kind == WEIGHTS_PROJECTION ? mem_fp.ndims() == 4
427 : mem_fp.ndims() == 5);
428
429 const auto &dims = mem_fp.dims();
430 const int64_t L = dims[0];
431 const int64_t D = dims[1];
432 const int64_t I = dims[2];
433 const int64_t G = (kind == WEIGHTS_PROJECTION) ? 1 : dims[3];
434 const int64_t O = (kind == WEIGHTS_PROJECTION) ? dims[3] : dims[4];
435
436 float gate_factor
437 = (kind == WEIGHTS_PROJECTION) ? 1.f : 1.f / prb.n_gates();
438
439 const auto tag = tag::abx;
440 dnn_mem_t mem_pure_fp(mem_dt.md_, dnnl_f32, tag, get_cpu_engine());
441
442 for (int64_t i = 0; i < mem_fp.nelems(); i++) {
443 mem_fp.set_elem(i, 0);
444 mem_pure_fp.set_elem(i, 0);
445 }
446
447 auto scales = (kind == WEIGHTS_PROJECTION) ? prb.wei_proj_scales
448 : prb.wei_scales;
449 auto n_scales = (kind == WEIGHTS_PROJECTION) ? prb.wei_proj_nscales
450 : prb.wei_nscales;
451
452 // Fill weights sparsely to avoid accumulation errors. Using two memories:
453 // one is quantized for reference, another is for a reorder.
454 for_(int64_t l = 0; l < L; l++)
455 for_(int64_t d = 0; d < D; d++)
456 for_(int64_t g = 0; g < G; g++)
457 for (int64_t o = 0; o < O; o++) {
458 int64_t i_off = ((19 * o + 7 * g + 11 * d + 13 * l) % I);
459 int64_t off = (((l * D + d) * I + i_off) * G + g) * O + o;
460 float val = gate_factor;
461 mem_pure_fp.set_elem(off, val);
462 if (prb.is_int8()) val *= scales[off % n_scales];
463 mem_fp.set_elem(off, round_to_nearest_representable(c.dt, val));
464 }
465
466 // Pass rnn attributes to f32 -> s8 reorders only
467 const_dnnl_primitive_attr_t reorder_attr = nullptr;
468 if (prb.is_int8()) reorder_attr = attr;
469 mem_dt.reorder(mem_pure_fp, reorder_attr);
470
471 // Test that s8 -> s8 reorder works correctly
472 if ((reorder_attr != nullptr) && (c.dt == dnnl_s8))
473 return check_s8s8_reorder(prb, kind, mem_dt, mem_pure_fp);
474 return OK;
475}
476
477int fill_bias(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt,
478 dnn_mem_t &mem_fp) {
479 // To reduce likelihood of cancellation happening in bwd by bias,
480 // (especially for GRU), we want diff_bias to be sparse
481 const auto &dims = mem_fp.dims();
482 auto L = dims[0];
483 auto D = dims[1];
484 auto G = dims[2];
485 auto O = dims[3];
486
487 std::minstd_rand msr;
488 normal_distribution_t<float> gen(
489 prb.cfg[kind].f_mean, prb.cfg[kind].f_stddev);
490 msr.seed(kind);
491
492 for_(int64_t l = 0; l < L; l++)
493 for_(int64_t d = 0; d < D; d++)
494 for_(int64_t g = 0; g < G; g++)
495 for (int64_t o = 0; o < O; o++) {
496 auto idx = l * D * G * O + d * G * O + g * O + o;
497 auto val = round_to_nearest_representable(
498 prb.cfg[kind].dt, gen(msr) * flip_coin(idx, 0.05f));
499 mem_fp.set_elem(idx, val);
500 }
501 mem_dt.reorder(mem_fp);
502 return OK;
503}
504
505void compute_ref(
506 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
507 const prb_t &prb_ = *prb;
508 if (prb_.prop != dnnl_backward)
509 compute_ref_fwd(prb_, args);
510 else
511 compute_ref_bwd(prb_, args);
512}
513
514dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) {
515 const prb_t &prb = *init_pd_args.prb;
516 const dir_t dir = init_pd_args.dir;
517
518 dnnl_prop_kind_t fwd_prop = dnnl_prop_kind_undef;
519 switch (prb.prop) {
520 case dnnl_forward_training: fwd_prop = dnnl_forward_training; break;
521 case dnnl_forward_inference: fwd_prop = dnnl_forward_inference; break;
522 // If we are testing backward, we have to run forward training first
523 // in order to generate a valid workspace.
524 case dnnl_backward: fwd_prop = dnnl_forward_training; break;
525 default: DNN_SAFE_STATUS(dnnl_invalid_arguments);
526 }
527
528 const bool is_gru_lbr = prb.alg == LBR_GRU || prb.alg == LBR_AUGRU;
529 // Enable testing with non trivial strides
530 const int the_stride = prb.trivial_strides ? 0 : 1;
531
532 // bidirectional = 2, s for lstm = 2, for all other = 1
533 dnnl_dims_t weights_layer_dims
534 = {prb.n_layer, prb.n_dir(), prb.slc, prb.n_gates(), prb.dhc};
535 dnnl_dims_t weights_iter_dims
536 = {prb.n_layer, prb.n_dir(), prb.sic, prb.n_gates(), prb.dhc};
537 dnnl_dims_t attention_dims = {prb.n_iter, prb.mb, 1};
538 dnnl_dims_t weights_peephole_dims = {prb.n_layer, prb.n_dir(), 3, prb.dhc};
539 dnnl_dims_t weights_projection_dims
540 = {prb.n_layer, prb.n_dir(), prb.dhc, prb.dic};
541 dnnl_dims_t bias_dims
542 = {prb.n_layer, prb.n_dir(), prb.n_gates() + is_gru_lbr, prb.dhc};
543 // dnnl_tnc
544 dnnl_dims_t dst_layer_dims = {prb.n_iter, prb.mb, prb.dlc(PRIMITIVE)};
545
546 dnnl_dims_t src_layer_dims = {prb.n_iter, prb.mb, prb.slc};
547 auto src_layer_d = dnn_mem_t::init_md(
548 3, src_layer_dims, prb.cfg[SRC_LAYER].dt, tag::abx /* dnnl_tnc */);
549 dims_t src_layer_strides(query_md_ndims(src_layer_d));
550 std::memcpy(src_layer_strides.data(), query_md_strides(src_layer_d),
551 src_layer_strides.size() * sizeof(dnnl_dim_t));
552 src_layer_strides[0] += the_stride;
553 src_layer_d = dnn_mem_t::init_md(query_md_ndims(src_layer_d),
554 query_md_dims(src_layer_d), query_md_data_type(src_layer_d), "",
555 src_layer_strides);
556
557 dnnl_dims_t src_iter_dims = {prb.n_layer, prb.n_dir(), prb.mb, prb.sic};
558 auto src_iter_d = dnn_mem_t::init_md(
559 4, src_iter_dims, prb.cfg[SRC_ITER].dt, tag::abx /* dnnl_ldnc */);
560 // Adjust strides for src_iter_d.
561 dims_t src_iter_strides(query_md_ndims(src_iter_d));
562 std::memcpy(src_iter_strides.data(), query_md_strides(src_iter_d),
563 src_iter_strides.size() * sizeof(dnnl_dim_t));
564 src_iter_strides[2] = prb.sic + the_stride;
565 for (int d = 1; d >= 0; --d)
566 src_iter_strides[d]
567 = src_iter_strides[d + 1] * query_md_dims(src_iter_d)[d + 1];
568 src_iter_d = dnn_mem_t::init_md(query_md_ndims(src_iter_d),
569 query_md_dims(src_iter_d), query_md_data_type(src_iter_d), "",
570 src_iter_strides);
571
572 dnnl_dims_t src_iter_c_dims = {prb.n_layer, prb.n_dir(), prb.mb, prb.dhc};
573 auto src_iter_c_d = dnn_mem_t::init_md(4, src_iter_c_dims,
574 prb.cfg[SRC_ITER_C].dt, tag::abx /* dnnl_ldnc */);
575 // Adjust strides for src_iter_c_d.
576 dims_t src_iter_c_strides(query_md_ndims(src_iter_c_d));
577 std::memcpy(src_iter_c_strides.data(), query_md_strides(src_iter_c_d),
578 src_iter_c_strides.size() * sizeof(dnnl_dim_t));
579 src_iter_c_strides[2] = prb.dhc + the_stride;
580 for (int d = 1; d >= 0; --d)
581 src_iter_c_strides[d] = src_iter_c_strides[d + 1]
582 * query_md_dims(src_iter_c_d)[d + 1];
583 src_iter_c_d = dnn_mem_t::init_md(query_md_ndims(src_iter_c_d),
584 query_md_dims(src_iter_c_d), query_md_data_type(src_iter_c_d), "",
585 src_iter_c_strides);
586
587 auto weights_layer_d = dnn_mem_t::init_md(
588 5, weights_layer_dims, prb.cfg[WEIGHTS_LAYER].dt, tag::any);
589 auto weights_iter_d = dnn_mem_t::init_md(
590 5, weights_iter_dims, prb.cfg[WEIGHTS_ITER].dt, tag::any);
591
592 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> attention_d {};
593 if (prb.is_augru())
594 attention_d = dnn_mem_t::init_md(3, attention_dims,
595 prb.cfg[AUGRU_ATTENTION].dt, tag::abx /* dnnl_tnc */);
596
597 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> weights_peephole_d {};
598 if (prb.is_lstm_peephole())
599 weights_peephole_d = dnn_mem_t::init_md(4, weights_peephole_dims,
600 prb.cfg[WEIGHTS_PEEPHOLE].dt, tag::abx /* dnnl_ldgo */);
601
602 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> weights_projection_d {};
603 if (prb.is_lstm_projection())
604 weights_projection_d = dnn_mem_t::init_md(4, weights_projection_dims,
605 prb.cfg[WEIGHTS_PROJECTION].dt, tag::any);
606
607 auto bias_d = dnn_mem_t::init_md(4, bias_dims, prb.cfg[BIAS].dt, tag::any);
608
609 auto dst_layer_d = dnn_mem_t::init_md(
610 3, dst_layer_dims, prb.cfg[DST_LAYER].dt, tag::abx /* dnnl_tnc */);
611 dims_t dst_layer_strides(query_md_ndims(dst_layer_d));
612 std::memcpy(dst_layer_strides.data(), query_md_strides(dst_layer_d),
613 dst_layer_strides.size() * sizeof(dnnl_dim_t));
614 dst_layer_strides[0] += the_stride;
615 dst_layer_d = dnn_mem_t::init_md(query_md_ndims(dst_layer_d),
616 query_md_dims(dst_layer_d), query_md_data_type(dst_layer_d), "",
617 dst_layer_strides);
618
619 dnnl_dims_t dst_iter_dims = {prb.n_layer, prb.n_dir(), prb.mb, prb.dic};
620 auto dst_iter_d = dnn_mem_t::init_md(
621 4, dst_iter_dims, prb.cfg[DST_ITER].dt, tag::abx /* dnnl_ldnc */);
622 // Adjust strides for dst_iter_d.
623 dims_t dst_iter_strides(query_md_ndims(dst_iter_d));
624 std::memcpy(dst_iter_strides.data(), query_md_strides(dst_iter_d),
625 dst_iter_strides.size() * sizeof(dnnl_dim_t));
626 dst_iter_strides[2] = prb.dic + the_stride;
627 for (int d = 1; d >= 0; --d)
628 dst_iter_strides[d]
629 = dst_iter_strides[d + 1] * query_md_dims(dst_iter_d)[d + 1];
630 dst_iter_d = dnn_mem_t::init_md(query_md_ndims(dst_iter_d),
631 query_md_dims(dst_iter_d), query_md_data_type(dst_iter_d), "",
632 dst_iter_strides);
633
634 dnnl_dims_t dst_iter_c_dims = {prb.n_layer, prb.n_dir(), prb.mb, prb.dhc};
635 auto dst_iter_c_d = dnn_mem_t::init_md(4, dst_iter_c_dims,
636 prb.cfg[DST_ITER_C].dt, tag::abx /* dnnl_ldnc */);
637 // Adjust strides for dst_iter_c_d.
638 dims_t dst_iter_c_strides(query_md_ndims(dst_iter_c_d));
639 std::memcpy(dst_iter_c_strides.data(), query_md_strides(dst_iter_c_d),
640 dst_iter_c_strides.size() * sizeof(dnnl_dim_t));
641 dst_iter_c_strides[2] = prb.dhc + the_stride;
642 for (int d = 1; d >= 0; --d)
643 dst_iter_c_strides[d] = dst_iter_c_strides[d + 1]
644 * query_md_dims(dst_iter_c_d)[d + 1];
645 dst_iter_c_d = dnn_mem_t::init_md(query_md_ndims(dst_iter_c_d),
646 query_md_dims(dst_iter_c_d), query_md_data_type(dst_iter_c_d), "",
647 dst_iter_c_strides);
648
649 auto dnnl_attr = make_benchdnn_dnnl_wrapper(create_dnnl_rnn_attr(prb));
650
651 // Initializing the forward pass
652 // When inference, we use forward_inference
653 // When training, we use forward_training
654 if (dir & FLAG_FWD) {
655 DNN_SAFE_STATUS(init_rnn_fwd_pd(&init_pd_args.pd, init_pd_args.engine,
656 prb, fwd_prop, src_layer_d, src_iter_d, src_iter_c_d,
657 attention_d, weights_layer_d, weights_iter_d,
658 weights_peephole_d, weights_projection_d, bias_d, dst_layer_d,
659 dst_iter_d, dst_iter_c_d, dnnl_attr));
660 } else {
661 // TODO: add stride support for diff_* tensors
662 auto diff_src_layer_d = dnn_mem_t::init_md(
663 3, src_layer_dims, prb.cfg[DIFF_SRC_LAYER].dt, tag::any);
664 auto diff_src_iter_d = dnn_mem_t::init_md(
665 4, src_iter_dims, prb.cfg[DIFF_SRC_ITER].dt, tag::any);
666 auto diff_src_iter_c_d = dnn_mem_t::init_md(
667 4, src_iter_c_dims, prb.cfg[DIFF_SRC_ITER_C].dt, tag::any);
668 auto diff_weights_layer_d = dnn_mem_t::init_md(5, weights_layer_dims,
669 prb.cfg[DIFF_WEIGHTS_LAYER].dt, tag::any);
670 auto diff_weights_iter_d = dnn_mem_t::init_md(
671 5, weights_iter_dims, prb.cfg[DIFF_WEIGHTS_ITER].dt, tag::any);
672
673 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> diff_attention_d {};
674 if (prb.is_augru())
675 diff_attention_d = dnn_mem_t::init_md(3, attention_dims,
676 prb.cfg[DIFF_AUGRU_ATTENTION].dt, tag::abx /* dnnl_tnc */);
677
678 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> diff_weights_peephole_d {};
679 if (prb.is_lstm_peephole())
680 diff_weights_peephole_d = dnn_mem_t::init_md(4,
681 weights_peephole_dims, prb.cfg[DIFF_WEIGHTS_PEEPHOLE].dt,
682 tag::abx /* dnnl_ldgo */);
683
684 benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t>
685 diff_weights_projection_d {};
686 if (prb.is_lstm_projection())
687 diff_weights_projection_d
688 = dnn_mem_t::init_md(4, weights_projection_dims,
689 prb.cfg[DIFF_WEIGHTS_PROJECTION].dt, tag::any);
690
691 auto diff_bias_d = dnn_mem_t::init_md(
692 4, bias_dims, prb.cfg[DIFF_BIAS].dt, tag::any);
693 auto diff_dst_layer_d = dnn_mem_t::init_md(
694 3, dst_layer_dims, prb.cfg[DIFF_DST_LAYER].dt, tag::any);
695 auto diff_dst_iter_d = dnn_mem_t::init_md(
696 4, dst_iter_dims, prb.cfg[DIFF_DST_ITER].dt, tag::any);
697 auto diff_dst_iter_c_d = dnn_mem_t::init_md(
698 4, dst_iter_c_dims, prb.cfg[DIFF_DST_ITER_C].dt, tag::any);
699
700 DNN_SAFE_STATUS(init_rnn_bwd_pd(&init_pd_args.pd, init_pd_args.engine,
701 prb, prb.prop, src_layer_d, src_iter_d, src_iter_c_d,
702 attention_d, weights_layer_d, weights_iter_d,
703 weights_peephole_d, weights_projection_d, bias_d, dst_layer_d,
704 dst_iter_d, dst_iter_c_d, diff_src_layer_d, diff_src_iter_d,
705 diff_src_iter_c_d, diff_attention_d, diff_weights_layer_d,
706 diff_weights_iter_d, diff_weights_peephole_d,
707 diff_weights_projection_d, diff_bias_d, diff_dst_layer_d,
708 diff_dst_iter_d, diff_dst_iter_c_d, init_pd_args.hint,
709 dnnl_attr));
710 }
711
712 return dnnl_success;
713}
714
715void skip_unimplemented_prb(const prb_t *prb_, res_t *res) {
716 const prb_t &prb = *prb_;
717 dir_t dir = str2dir(prop2str(prb.prop));
718 skip_unimplemented_data_type({prb.cfg[SRC_LAYER].dt}, dir, res);
719 skip_unimplemented_sum_po(prb.attr, res);
720
721#if !defined(DNNL_X64) || DNNL_X64 == 0 \
722 || DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL
723 // int8 is not supported altogether since RNN relies on packed IGEMM
724 // FIXME: this will disable int8 RNN testing if the library is built with
725 // Intel MKL that does have packed IGEMM
726 if (prb.is_int8()) {
727 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
728 return;
729 }
730#endif
731
732#if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE
733 static auto isa = dnnl_get_effective_cpu_isa();
734 const bool is_f16_not_ok = prb.cfg[SRC_LAYER].dt == dnnl_f16
735 && dnnl::is_superset(isa, dnnl_cpu_isa_avx512_core_fp16);
736 if (is_f16_not_ok) {
737 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
738 return;
739 }
740#endif
741
742#ifdef DNNL_AARCH64_USE_ACL
743 const bool is_acl_f16_not_ok = prb.cfg[SRC_LAYER].dt == dnnl_f16
744 && dnnl::impl::cpu::platform::has_data_type_support(dnnl_f16);
745 if (is_acl_f16_not_ok) {
746 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
747 return;
748 }
749#endif
750
751 // int8 weights reorder does not support non trivial strides;
752 // only LSTM and GRU cell kinds support int8 so far;
753 if (prb.is_int8()) {
754 if (!prb.trivial_strides) {
755 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
756 return;
757 }
758 if (prb.alg != VANILLA_LSTM && prb.alg != VANILLA_GRU) {
759 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
760 return;
761 }
762 }
763
764 // LSTM w/ projection is not supported for bf16
765 if (prb.is_lstm_projection() && prb.cfg[SRC_LAYER].dt == dnnl_bf16) {
766 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
767 return;
768 }
769
770 // GPU limitations for RNN
771 if (is_gpu()) {
772 bool is_AUGRU = prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU;
773 if (is_AUGRU) {
774 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
775 return;
776 }
777 if (prb.is_lstm_projection() || prb.is_lstm_peephole()) {
778 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
779 return;
780 }
781 if (prb.is_int8() && prb.alg != VANILLA_LSTM) {
782 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
783 return;
784 }
785 if (prb.is_s8() && prb.alg == VANILLA_LSTM) {
786 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
787 return;
788 }
789 // Implemented only for CPU
790 if (prb.cfg[BIAS].dt == dnnl_bf16 || prb.cfg[SRC_ITER_C].dt == dnnl_bf16
791 || prb.cfg[DST_ITER_C].dt == dnnl_bf16) {
792 res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED;
793 return;
794 }
795 }
796}
797
798void skip_invalid_prb(const prb_t *prb_, res_t *res) {
799 const prb_t &prb = *prb_;
800
801 // Consistency validation.
802 bool consistent_proj
803 = IMPLICATION(!prb.with_projection, prb.dhc == prb.dic);
804 bool consistent_L = IMPLICATION(prb.n_layer > 1, prb.slc == prb.dic);
805 bool consistent_T = IMPLICATION(prb.n_iter > 1, prb.sic == prb.dic);
806 bool is_GRU = prb.alg == VANILLA_GRU || prb.alg == LBR_GRU;
807 bool consistent_GRU = IMPLICATION(is_GRU, prb.sic == prb.dic);
808 bool is_AUGRU = prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU;
809 bool consistent_AUGRU = IMPLICATION(is_AUGRU,
810 prb.sic == prb.dic && prb.n_layer == 1
811 && prb.direction == dnnl_unidirectional_left2right);
812 if (!consistent_proj || !consistent_L || !consistent_T || !consistent_GRU
813 || !consistent_AUGRU) {
814 res->state = SKIPPED, res->reason = INVALID_CASE;
815 return;
816 }
817
818 // Only LSTM supports peephole and projection layer.
819 bool is_lstm_peephole
820 = IMPLICATION(prb.with_peephole, prb.alg == VANILLA_LSTM);
821 bool is_lstm_projection
822 = IMPLICATION(prb.with_projection, prb.alg == VANILLA_LSTM);
823 if (!is_lstm_peephole || !is_lstm_projection) {
824 res->state = SKIPPED, res->reason = INVALID_CASE;
825 return;
826 }
827}
828
829void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind,
830 const args_t &ref_args) {
831 const auto rnn_kind = data_kind2rnn_data_kind(kind);
832 const auto &cfg = prb->cfg[rnn_kind];
833 // factor 2 is because of the sum of 2 GEMMs
834 int64_t fwd_acc_dim = 2 * prb->n_gates() + 1;
835 if (prb->alg == VANILLA_GRU || prb->alg == VANILLA_AUGRU)
836 fwd_acc_dim *= prb->sic;
837 int64_t bwdd_acc_dim = prb->n_gates() * prb->dhc;
838 int64_t bwdw_acc_dim = prb->mb;
839 int64_t acc_dim = fwd_acc_dim;
840 if (prb->prop == dnnl_backward) acc_dim *= MAX2(bwdd_acc_dim, bwdw_acc_dim);
841 // Here the factor 4 just gives some wiggle room for fp32 testing
842
843 float trh = 4
844 * (1 + (prb->prop == dnnl_backward)) // double wiggle room for bwd
845 * ((prb->direction == dnnl_bidirectional_sum)
846 + 1) // double trh if bidir_sum
847 * ceilf(log2f(acc_dim * prb->n_iter)) * cfg.eps;
848 // expect exact value for int8
849 if (cfg.dt == dnnl_u8 || cfg.dt == dnnl_s8) trh = 0.f;
850 cmp.set_threshold(trh);
851
852 // Note: we do an eltwise comparison only when:
853 // - we use skip_nonlinear;
854 // - we do not use skip_nonlinear and we test only one cell execution;
855 // - for int8 computations the tensor is not DST_ITER_C;
856 // If the above conditions are not met, we check only L1, L2 and L8.
857
858 // Rough rationale for the `DST_ITER_C` exception in int8 case:
859 // - The formula for one-step c-state is:
860 // c_t = f_t * c_{t−1} + i_t * c~_t.
861 // Here all computations happen in f32 (f_t, i_t, and c~_t are dequantized
862 // right before the computations + the corresponding bias added).
863 // - In int8 case we don't have much control over these components and
864 // cannot surmount potential cancellations, if any.
865 // In practice, I observed that the relative element-wise error of values
866 // in `DST_ITER_C` was bigger (up-to 8e-5) whenever the values
867 // themselves were smaller (which indirectly means the problem is exactly
868 // in the cancellation). Unfortunately, this even happened with only one
869 // layer and one time stamp.
870 // - So, for now the solution is to use l1- l2- and l_inf-norms to validate
871 // `DST_ITER_C`. When we switch testing on using precise
872 // integer arithmetic based on modulo operation in rnn_tparams (instead of
873 // current unreliable re-scaling), this testing weakness should go away.
874 // - Just an obvious side note: `DST_LAYER` and `DST_ITER`
875 // are immediate dequantization of the corresponding u8 tensors. Hence,
876 // as long as we get precise u8 intermediate results (and so far we do),
877 // the f32 result should be pretty accurate -- the dequantization is just
878 // two simple ops: f32 = scale * u8 + shift.
879 bool check_p2p = (prb->skip_nonlinear
880 || ((prb->n_layer == 1) && (prb->n_iter == 1)));
881 if (prb->is_int8() && rnn_kind == DST_ITER_C) check_p2p = false;
882 cmp.set_norm_validation_mode(!check_p2p);
883
884 const auto rnn_add_check =
885 [&, prb](const compare::compare_t::driver_check_func_args_t &args) {
886 // Limitation from current filling.
887 // TODO: find a better filling to get rid of this...
888 if ((prb->alg == VANILLA_GRU || prb->alg == LBR_AUGRU
889 || prb->alg == VANILLA_RNN || prb->alg == LBR_GRU)
890 && prb->prop == dnnl_backward) {
891 return args.diff < args.trh;
892 }
893 return false;
894 };
895 cmp.set_driver_check_function(rnn_add_check);
896}
897
898int doit(const prb_t &prb, res_t *res) {
899 if (bench_mode == LIST) return res->state = LISTED, OK;
900
901 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim;
902 bool is_service_prim = prb.dir & FLAG_BWD;
903 SAFE(init_prim(prb.ctx_init, prim, init_pd, &prb, res, FLAG_FWD, nullptr,
904 is_service_prim),
905 WARN);
906 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
907
908 auto const_fpd = query_pd(prim);
909
910 const auto &src_layer_md = query_md(const_fpd, DNNL_ARG_SRC_LAYER);
911 const auto &src_layer_attention_md
912 = query_md(const_fpd, DNNL_ARG_AUGRU_ATTENTION);
913 const auto &src_iter_md = query_md(const_fpd, DNNL_ARG_SRC_ITER);
914 const auto &src_iter_c_md = query_md(const_fpd, DNNL_ARG_SRC_ITER_C);
915 const auto &weights_layer_md = query_md(const_fpd, DNNL_ARG_WEIGHTS_LAYER);
916 const auto &weights_iter_md = query_md(const_fpd, DNNL_ARG_WEIGHTS_ITER);
917 const auto &weights_peephole_md
918 = query_md(const_fpd, DNNL_ARG_WEIGHTS_PEEPHOLE);
919 const auto &weights_projection_md
920 = query_md(const_fpd, DNNL_ARG_WEIGHTS_PROJECTION);
921 const auto &bias_md = query_md(const_fpd, DNNL_ARG_BIAS);
922 const auto &dst_layer_md = query_md(const_fpd, DNNL_ARG_DST_LAYER);
923 const auto &dst_iter_md = query_md(const_fpd, DNNL_ARG_DST_ITER);
924 const auto &dst_iter_c_md = query_md(const_fpd, DNNL_ARG_DST_ITER_C);
925 const auto &workspace_md = query_md(const_fpd, DNNL_ARG_WORKSPACE);
926 const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD);
927
928 const auto &test_engine = get_test_engine();
929 const auto &ref_engine = get_cpu_engine();
930
931 dnn_mem_t src_layer_dt(src_layer_md, test_engine);
932 dnn_mem_t src_layer_attention_dt(src_layer_attention_md, test_engine);
933 dnn_mem_t src_iter_dt(src_iter_md, test_engine);
934 dnn_mem_t src_iter_c_dt(src_iter_c_md, test_engine);
935 dnn_mem_t weights_layer_dt(weights_layer_md, test_engine);
936 dnn_mem_t weights_iter_dt(weights_iter_md, test_engine);
937 dnn_mem_t weights_peephole_dt(weights_peephole_md, test_engine);
938 dnn_mem_t weights_projection_dt(weights_projection_md, test_engine);
939 dnn_mem_t bias_dt(bias_md, test_engine);
940 dnn_mem_t dst_layer_dt(dst_layer_md, test_engine);
941 dnn_mem_t dst_iter_dt(dst_iter_md, test_engine);
942 dnn_mem_t dst_iter_c_dt(dst_iter_c_md, test_engine);
943 dnn_mem_t workspace_dt(workspace_md, test_engine);
944 dnn_mem_t scratchpad_dt(scratchpad_md, test_engine);
945
946 dnn_mem_t src_layer_fp(
947 src_layer_md, dnnl_f32, tag::abx /*tnc*/, ref_engine);
948 dnn_mem_t src_layer_attention_fp(
949 src_layer_attention_md, dnnl_f32, tag::abx /*tnc*/, ref_engine);
950 dnn_mem_t src_iter_fp(src_iter_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine);
951 dnn_mem_t src_iter_c_fp(
952 src_iter_c_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine);
953 dnn_mem_t weights_layer_fp(
954 weights_layer_md, dnnl_f32, tag::abx /*ldigo*/, ref_engine);
955 dnn_mem_t weights_iter_fp(
956 weights_iter_md, dnnl_f32, tag::abx /*ldigo*/, ref_engine);
957 dnn_mem_t weights_peephole_fp(
958 weights_peephole_md, dnnl_f32, tag::abx /*ldgo*/, ref_engine);
959 dnn_mem_t weights_projection_fp(
960 weights_projection_md, dnnl_f32, tag::abx /*ldio*/, ref_engine);
961 dnn_mem_t bias_fp(bias_md, dnnl_f32, tag::abx /*ldgo*/, ref_engine);
962 dnn_mem_t dst_layer_fp(
963 dst_layer_md, dnnl_f32, tag::abx /*tnc*/, ref_engine);
964 dnn_mem_t dst_iter_fp(dst_iter_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine);
965 dnn_mem_t dst_iter_c_fp(
966 dst_iter_c_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine);
967
968 dnn_mem_t bwd_weights_layer_dt;
969 dnn_mem_t bwd_weights_iter_dt;
970 dnn_mem_t bwd_weights_projection_dt;
971 dnn_mem_t diff_src_layer_dt;
972 dnn_mem_t diff_src_layer_attention_dt;
973 dnn_mem_t diff_src_iter_dt;
974 dnn_mem_t diff_src_iter_c_dt;
975 dnn_mem_t diff_weights_layer_dt;
976 dnn_mem_t diff_weights_iter_dt;
977 dnn_mem_t diff_weights_peephole_dt;
978 dnn_mem_t diff_weights_projection_dt;
979 dnn_mem_t diff_bias_dt;
980 dnn_mem_t diff_dst_layer_dt;
981 dnn_mem_t diff_dst_iter_dt;
982 dnn_mem_t diff_dst_iter_c_dt;
983
984 // for int8 RNN we need pass attributes for data q10n
985 auto rnn_attr = query_attr(const_fpd);
986 SAFE(fill_activation(prb, SRC_LAYER, src_layer_dt, src_layer_fp, rnn_attr),
987 WARN);
988 if (prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU)
989 SAFE(fill_activation(prb, AUGRU_ATTENTION, src_layer_attention_dt,
990 src_layer_attention_fp, rnn_attr),
991 WARN);
992 SAFE(fill_activation(prb, SRC_ITER, src_iter_dt, src_iter_fp, rnn_attr),
993 WARN);
994 if (prb.alg == VANILLA_LSTM)
995 SAFE(fill_src_iter_c(prb, src_iter_c_dt, src_iter_c_fp, rnn_attr),
996 WARN);
997 SAFE(fill_weights(prb, WEIGHTS_LAYER, weights_layer_dt, weights_layer_fp,
998 rnn_attr),
999 WARN);
1000 SAFE(fill_weights(
1001 prb, WEIGHTS_ITER, weights_iter_dt, weights_iter_fp, rnn_attr),
1002 WARN);
1003 SAFE(fill_memory(prb, WEIGHTS_PEEPHOLE, weights_peephole_dt,
1004 weights_peephole_fp),
1005 WARN);
1006 SAFE(fill_weights(prb, WEIGHTS_PROJECTION, weights_projection_dt,
1007 weights_projection_fp, rnn_attr),
1008 WARN);
1009 SAFE(fill_memory(prb, BIAS, bias_dt, bias_fp), WARN);
1010 SAFE(fill_activation(prb, DST_LAYER, dst_layer_dt, dst_layer_fp), WARN);
1011 SAFE(fill_activation(prb, DST_ITER, dst_iter_dt, dst_iter_fp), WARN);
1012 if (prb.alg == VANILLA_LSTM)
1013 SAFE(fill_memory(prb, DST_ITER_C, dst_iter_c_dt, dst_iter_c_fp), WARN);
1014
1015 args_t args, ref_args;
1016
1017 // Running the forward pass
1018 args.set(DNNL_ARG_SRC_LAYER, src_layer_dt);
1019 args.set(DNNL_ARG_AUGRU_ATTENTION, src_layer_attention_dt);
1020 args.set(DNNL_ARG_SRC_ITER, src_iter_dt);
1021 args.set(DNNL_ARG_SRC_ITER_C, src_iter_c_dt);
1022 args.set(DNNL_ARG_WEIGHTS_LAYER, weights_layer_dt);
1023 args.set(DNNL_ARG_WEIGHTS_ITER, weights_iter_dt);
1024 args.set(DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_dt);
1025 args.set(DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_dt);
1026 args.set(DNNL_ARG_BIAS, bias_dt);
1027 args.set(DNNL_ARG_DST_LAYER, dst_layer_dt);
1028 args.set(DNNL_ARG_DST_ITER, dst_iter_dt);
1029 args.set(DNNL_ARG_DST_ITER_C, dst_iter_c_dt);
1030 args.set(DNNL_ARG_WORKSPACE, workspace_dt);
1031 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
1032
1033 SAFE(execute_and_wait(prim, args, res), WARN);
1034
1035 if (prb.prop != dnnl_backward) {
1036 if (is_bench_mode(CORR)) {
1037 ref_args.set(DNNL_ARG_SRC_LAYER, src_layer_fp);
1038 ref_args.set(DNNL_ARG_AUGRU_ATTENTION, src_layer_attention_fp);
1039 ref_args.set(DNNL_ARG_SRC_ITER, src_iter_fp);
1040 ref_args.set(DNNL_ARG_SRC_ITER_C, src_iter_c_fp);
1041 ref_args.set(DNNL_ARG_WEIGHTS_LAYER, weights_layer_fp);
1042 ref_args.set(DNNL_ARG_WEIGHTS_ITER, weights_iter_fp);
1043 ref_args.set(DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_fp);
1044 ref_args.set(DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_fp);
1045 ref_args.set(DNNL_ARG_BIAS, bias_fp);
1046 ref_args.set(DNNL_ARG_DST_LAYER, dst_layer_fp);
1047 ref_args.set(DNNL_ARG_DST_ITER, dst_iter_fp);
1048 ref_args.set(DNNL_ARG_DST_ITER_C, dst_iter_c_fp);
1049
1050 std::vector<data_kind_t> kinds {
1051 data_kind_t::DST, data_kind_t::DST_ITER};
1052 if (prb.alg == VANILLA_LSTM) {
1053 kinds.push_back(data_kind_t::DST_ITER_C);
1054 }
1055
1056 check_correctness(&prb, kinds, args, ref_args, setup_cmp, res);
1057 }
1058 } else {
1059 benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim;
1060 SAFE(init_prim(tmp_prim, init_pd, &prb, res, FLAG_BWD), WARN);
1061 if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK;
1062 prim.reset(tmp_prim.release());
1063
1064 auto const_bpd = query_pd(prim);
1065
1066 const auto &bwd_weights_layer_md
1067 = query_md(const_bpd, DNNL_ARG_WEIGHTS_LAYER);
1068 const auto &bwd_weights_iter_md
1069 = query_md(const_bpd, DNNL_ARG_WEIGHTS_ITER);
1070 const auto &bwd_weights_projection_md
1071 = query_md(const_bpd, DNNL_ARG_WEIGHTS_PROJECTION);
1072 const auto &diff_src_layer_md
1073 = query_md(const_bpd, DNNL_ARG_DIFF_SRC_LAYER);
1074 const auto &diff_src_layer_attention_md
1075 = query_md(const_bpd, DNNL_ARG_DIFF_AUGRU_ATTENTION);
1076 const auto &diff_src_iter_md
1077 = query_md(const_bpd, DNNL_ARG_DIFF_SRC_ITER);
1078 const auto &diff_src_iter_c_md
1079 = query_md(const_bpd, DNNL_ARG_DIFF_SRC_ITER_C);
1080 const auto &diff_weights_layer_md
1081 = query_md(const_bpd, DNNL_ARG_DIFF_WEIGHTS_LAYER);
1082 const auto &diff_weights_iter_md
1083 = query_md(const_bpd, DNNL_ARG_DIFF_WEIGHTS_ITER);
1084 const auto &diff_weights_peephole_md
1085 = query_md(const_bpd, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE);
1086 const auto &diff_weights_projection_md
1087 = query_md(const_bpd, DNNL_ARG_DIFF_WEIGHTS_PROJECTION);
1088 const auto &diff_bias_md = query_md(const_bpd, DNNL_ARG_DIFF_BIAS);
1089 const auto &diff_dst_layer_md
1090 = query_md(const_bpd, DNNL_ARG_DIFF_DST_LAYER);
1091 const auto &diff_dst_iter_md
1092 = query_md(const_bpd, DNNL_ARG_DIFF_DST_ITER);
1093 const auto &diff_dst_iter_c_md
1094 = query_md(const_bpd, DNNL_ARG_DIFF_DST_ITER_C);
1095 const auto &bwd_scratchpad_md
1096 = query_md(const_bpd, DNNL_ARG_SCRATCHPAD);
1097
1098 bwd_weights_layer_dt = dnn_mem_t(bwd_weights_layer_md, test_engine);
1099 bwd_weights_iter_dt = dnn_mem_t(bwd_weights_iter_md, test_engine);
1100 bwd_weights_projection_dt
1101 = dnn_mem_t(bwd_weights_projection_md, test_engine);
1102 diff_src_layer_dt = dnn_mem_t(diff_src_layer_md, test_engine);
1103 diff_src_layer_attention_dt
1104 = dnn_mem_t(diff_src_layer_attention_md, test_engine);
1105 diff_src_iter_dt = dnn_mem_t(diff_src_iter_md, test_engine);
1106 diff_src_iter_c_dt = dnn_mem_t(diff_src_iter_c_md, test_engine);
1107 diff_weights_layer_dt = dnn_mem_t(diff_weights_layer_md, test_engine);
1108 diff_weights_iter_dt = dnn_mem_t(diff_weights_iter_md, test_engine);
1109 diff_weights_peephole_dt
1110 = dnn_mem_t(diff_weights_peephole_md, test_engine);
1111 diff_weights_projection_dt
1112 = dnn_mem_t(diff_weights_projection_md, test_engine);
1113 diff_bias_dt = dnn_mem_t(diff_bias_md, test_engine);
1114 diff_dst_layer_dt = dnn_mem_t(diff_dst_layer_md, test_engine);
1115 diff_dst_iter_dt = dnn_mem_t(diff_dst_iter_md, test_engine);
1116 diff_dst_iter_c_dt = dnn_mem_t(diff_dst_iter_c_md, test_engine);
1117 scratchpad_dt = dnn_mem_t(bwd_scratchpad_md, test_engine);
1118
1119 dnn_mem_t diff_src_layer_fp(
1120 diff_src_layer_md, dnnl_f32, tag::abx /*tnc*/, ref_engine);
1121 dnn_mem_t diff_src_layer_attention_fp(diff_src_layer_attention_md,
1122 dnnl_f32, tag::abx /*tnc*/, ref_engine);
1123 dnn_mem_t diff_src_iter_fp(
1124 diff_src_iter_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine);
1125 dnn_mem_t diff_src_iter_c_fp(
1126 diff_src_iter_c_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine);
1127 dnn_mem_t diff_weights_layer_fp(diff_weights_layer_md, dnnl_f32,
1128 tag::abx /*ldigo*/, ref_engine);
1129 dnn_mem_t diff_weights_iter_fp(
1130 diff_weights_iter_md, dnnl_f32, tag::abx /*ldigo*/, ref_engine);
1131 dnn_mem_t diff_weights_peephole_fp(diff_weights_peephole_md, dnnl_f32,
1132 tag::abx /*ldgo*/, ref_engine);
1133 dnn_mem_t diff_weights_projection_fp(diff_weights_projection_md,
1134 dnnl_f32, tag::abx /*ldio*/, ref_engine);
1135 dnn_mem_t diff_bias_fp(
1136 diff_bias_md, dnnl_f32, tag::abx /*ldgo*/, ref_engine);
1137 dnn_mem_t diff_dst_layer_fp(
1138 diff_dst_layer_md, dnnl_f32, tag::abx /*tnc*/, ref_engine);
1139 dnn_mem_t diff_dst_iter_fp(
1140 diff_dst_iter_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine);
1141 dnn_mem_t diff_dst_iter_c_fp(
1142 diff_dst_iter_c_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine);
1143
1144 SAFE(bwd_weights_iter_dt.reorder(weights_iter_dt), WARN);
1145 SAFE(bwd_weights_layer_dt.reorder(weights_layer_dt), WARN);
1146 if (prb.is_lstm_projection())
1147 SAFE(bwd_weights_projection_dt.reorder(weights_projection_dt),
1148 WARN);
1149 SAFE(fill_activation(
1150 prb, DIFF_SRC_LAYER, diff_src_layer_dt, diff_src_layer_fp),
1151 WARN);
1152 if (prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU)
1153 SAFE(fill_activation(prb, DIFF_AUGRU_ATTENTION,
1154 diff_src_layer_attention_dt,
1155 diff_src_layer_attention_fp),
1156 WARN);
1157 SAFE(fill_activation(
1158 prb, DIFF_SRC_ITER, diff_src_iter_dt, diff_src_iter_fp),
1159 WARN);
1160 if (prb.alg == VANILLA_LSTM)
1161 SAFE(fill_memory(prb, DIFF_SRC_ITER_C, diff_src_iter_c_dt,
1162 diff_src_iter_c_fp),
1163 WARN);
1164 SAFE(fill_weights(prb, DIFF_WEIGHTS_LAYER, diff_weights_layer_dt,
1165 diff_weights_layer_fp),
1166 WARN);
1167 SAFE(fill_weights(prb, DIFF_WEIGHTS_ITER, diff_weights_iter_dt,
1168 diff_weights_iter_fp),
1169 WARN);
1170 SAFE(fill_memory(prb, DIFF_WEIGHTS_PEEPHOLE, diff_weights_peephole_dt,
1171 diff_weights_peephole_fp),
1172 WARN);
1173 SAFE(fill_memory(prb, DIFF_WEIGHTS_PROJECTION,
1174 diff_weights_projection_dt, diff_weights_projection_fp),
1175 WARN);
1176 SAFE(fill_bias(prb, DIFF_BIAS, diff_bias_dt, diff_bias_fp), WARN);
1177 SAFE(fill_activation(
1178 prb, DIFF_DST_LAYER, diff_dst_layer_dt, diff_dst_layer_fp),
1179 WARN);
1180 SAFE(fill_activation(
1181 prb, DIFF_DST_ITER, diff_dst_iter_dt, diff_dst_iter_fp),
1182 WARN);
1183 if (prb.alg == VANILLA_LSTM)
1184 SAFE(fill_memory(prb, DIFF_DST_ITER_C, diff_dst_iter_c_dt,
1185 diff_dst_iter_c_fp),
1186 WARN);
1187
1188 args.clear();
1189 args.set(DNNL_ARG_SRC_LAYER, src_layer_dt);
1190 args.set(DNNL_ARG_AUGRU_ATTENTION, src_layer_attention_dt);
1191 args.set(DNNL_ARG_SRC_ITER, src_iter_dt);
1192 args.set(DNNL_ARG_SRC_ITER_C, src_iter_c_dt);
1193 args.set(DNNL_ARG_WEIGHTS_LAYER, bwd_weights_layer_dt);
1194 args.set(DNNL_ARG_WEIGHTS_ITER, bwd_weights_iter_dt);
1195 args.set(DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_dt);
1196 args.set(DNNL_ARG_WEIGHTS_PROJECTION, bwd_weights_projection_dt);
1197 args.set(DNNL_ARG_BIAS, bias_dt);
1198 args.set(DNNL_ARG_DST_LAYER, dst_layer_dt);
1199 args.set(DNNL_ARG_DST_ITER, dst_iter_dt);
1200 args.set(DNNL_ARG_DST_ITER_C, dst_iter_c_dt);
1201 args.set(DNNL_ARG_DIFF_DST_LAYER, diff_dst_layer_dt);
1202 args.set(DNNL_ARG_DIFF_DST_ITER, diff_dst_iter_dt);
1203 args.set(DNNL_ARG_DIFF_DST_ITER_C, diff_dst_iter_c_dt);
1204 args.set(DNNL_ARG_WORKSPACE, workspace_dt);
1205 args.set(DNNL_ARG_DIFF_SRC_LAYER, diff_src_layer_dt);
1206 args.set(DNNL_ARG_DIFF_AUGRU_ATTENTION, diff_src_layer_attention_dt);
1207 args.set(DNNL_ARG_DIFF_SRC_ITER, diff_src_iter_dt);
1208 args.set(DNNL_ARG_DIFF_SRC_ITER_C, diff_src_iter_c_dt);
1209 args.set(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_dt);
1210 args.set(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_dt);
1211 args.set(DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE, diff_weights_peephole_dt);
1212 args.set(DNNL_ARG_DIFF_WEIGHTS_PROJECTION, diff_weights_projection_dt);
1213 args.set(DNNL_ARG_DIFF_BIAS, diff_bias_dt);
1214 args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt);
1215
1216 SAFE(execute_and_wait(prim, args, res), WARN);
1217
1218 if (is_bench_mode(CORR)) {
1219 ref_args.set(DNNL_ARG_SRC_LAYER, src_layer_fp);
1220 ref_args.set(DNNL_ARG_AUGRU_ATTENTION, src_layer_attention_fp);
1221 ref_args.set(DNNL_ARG_SRC_ITER, src_iter_fp);
1222 ref_args.set(DNNL_ARG_SRC_ITER_C, src_iter_c_fp);
1223 ref_args.set(DNNL_ARG_WEIGHTS_LAYER, weights_layer_fp);
1224 ref_args.set(DNNL_ARG_WEIGHTS_ITER, weights_iter_fp);
1225 ref_args.set(DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_fp);
1226 ref_args.set(DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_fp);
1227 ref_args.set(DNNL_ARG_BIAS, bias_fp);
1228 ref_args.set(DNNL_ARG_DST_LAYER, dst_layer_fp);
1229 ref_args.set(DNNL_ARG_DST_ITER, dst_iter_fp);
1230 ref_args.set(DNNL_ARG_DST_ITER_C, dst_iter_c_fp);
1231 ref_args.set(DNNL_ARG_DIFF_DST_LAYER, diff_dst_layer_fp);
1232 ref_args.set(DNNL_ARG_DIFF_DST_ITER, diff_dst_iter_fp);
1233 ref_args.set(DNNL_ARG_DIFF_DST_ITER_C, diff_dst_iter_c_fp);
1234 ref_args.set(DNNL_ARG_DIFF_SRC_LAYER, diff_src_layer_fp);
1235 ref_args.set(
1236 DNNL_ARG_DIFF_AUGRU_ATTENTION, diff_src_layer_attention_fp);
1237 ref_args.set(DNNL_ARG_DIFF_SRC_ITER, diff_src_iter_fp);
1238 ref_args.set(DNNL_ARG_DIFF_SRC_ITER_C, diff_src_iter_c_fp);
1239 ref_args.set(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_fp);
1240 ref_args.set(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_fp);
1241 ref_args.set(
1242 DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE, diff_weights_peephole_fp);
1243 ref_args.set(DNNL_ARG_DIFF_WEIGHTS_PROJECTION,
1244 diff_weights_projection_fp);
1245 ref_args.set(DNNL_ARG_DIFF_BIAS, diff_bias_fp);
1246
1247 std::vector<data_kind_t> kinds {data_kind_t::DST,
1248 data_kind_t::DST_ITER, data_kind_t::SRC,
1249 data_kind_t::SRC_ITER, data_kind_t::WEI,
1250 data_kind_t::WEI_ITER, data_kind_t::BIA};
1251 if (prb.alg == VANILLA_LSTM) {
1252 kinds.push_back(data_kind_t::DST_ITER_C);
1253 kinds.push_back(data_kind_t::SRC_ITER_C);
1254 }
1255 if (prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU)
1256 kinds.push_back(data_kind_t::AUGRU_ATTENTION);
1257 if (prb.is_lstm_peephole())
1258 kinds.push_back(data_kind_t::WEI_PEEPHOLE);
1259 if (prb.is_lstm_projection())
1260 kinds.push_back(data_kind_t::WEI_PROJECTION);
1261
1262 check_correctness(&prb, kinds, args, ref_args, setup_cmp, res);
1263 }
1264 }
1265
1266 return measure_perf(prb.ctx_exe, res, prim, args);
1267}
1268
1269} // namespace rnn
1270