1 | /******************************************************************************* |
2 | * Copyright 2018-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <float.h> |
18 | #include <math.h> |
19 | #include <random> |
20 | #include <stdio.h> |
21 | #include <stdlib.h> |
22 | #include <type_traits> |
23 | |
24 | #include "oneapi/dnnl/dnnl.h" |
25 | |
26 | #include "tests/test_isa_common.hpp" |
27 | #include "utils/parallel.hpp" |
28 | |
29 | #include "dnnl_common.hpp" |
30 | #include "dnnl_memory.hpp" |
31 | |
32 | #include "rnn/rnn.hpp" |
33 | #include "rnn/rnn_aux.hpp" |
34 | |
35 | // Using hidden attr API for testing RNN |
36 | dnnl_status_t dnnl_primitive_attr_set_rnn_tparams(dnnl_primitive_attr_t attr, |
37 | bool mode, dnnl_dim_t ngates, const float *scales, float cscale); |
38 | |
39 | namespace { |
40 | |
41 | // In order to have consistent filling across compilers and operating systems, |
42 | // we implement the equivalent of std::normal_distribution using the so-called |
43 | // Marsaglia polar method. |
44 | template <typename T> |
45 | class normal_distribution_t { |
46 | public: |
47 | normal_distribution_t(T mean, T stddev) |
48 | : gen(-1.f, 1.f) |
49 | , is_odd_(false) |
50 | , odd_(1.f) |
51 | , mean_(mean) |
52 | , stddev_(stddev) { |
53 | static_assert(std::is_floating_point<T>::value, |
54 | "T must be a floating point type." ); |
55 | } |
56 | template <typename URNG> |
57 | T operator()(URNG &g) { |
58 | T r, r2, x, y; |
59 | if (is_odd_) { |
60 | is_odd_ = false; |
61 | return odd_; |
62 | } |
63 | is_odd_ = true; |
64 | do { |
65 | x = gen(g); // x E [-1, 1) |
66 | y = gen(g); // y E [-1, 1) |
67 | r2 = x * x + y * y; |
68 | } while (0.f == r2 || 1.f < r2); // r2 E (0, 1] |
69 | r = stddev_ * std::sqrt(-2.f * std::log(r2) / r2); |
70 | x = mean_ + x * r; |
71 | y = mean_ + y * r; |
72 | odd_ = x; |
73 | return y; |
74 | } |
75 | |
76 | private: |
77 | std::uniform_real_distribution<T> gen; |
78 | bool is_odd_; |
79 | T odd_; |
80 | const T mean_; |
81 | const T stddev_; |
82 | }; |
83 | |
84 | } // namespace |
85 | |
86 | namespace rnn { |
87 | |
88 | dnnl_primitive_attr_t create_dnnl_rnn_attr(const prb_t &prb) { |
89 | dnnl_primitive_attr_t dnnl_attr = nullptr; |
90 | DNN_SAFE_V(dnnl_primitive_attr_create(&dnnl_attr)); |
91 | |
92 | if (prb.skip_nonlinear) |
93 | DNN_SAFE_V(dnnl_primitive_attr_set_rnn_tparams(dnnl_attr, true, |
94 | prb.n_gates(), prb.linear_scales, prb.linear_cscale)); |
95 | |
96 | DNN_SAFE_V(dnnl_primitive_attr_set_rnn_weights_qparams( |
97 | dnnl_attr, prb.wei_nscales, prb.wei_scales_mask, prb.wei_scales)); |
98 | |
99 | if (prb.is_lstm_projection() && prb.is_int8()) |
100 | DNN_SAFE_V(dnnl_primitive_attr_set_rnn_weights_projection_qparams( |
101 | dnnl_attr, prb.wei_proj_nscales, prb.wei_proj_scales_mask, |
102 | prb.wei_proj_scales)); |
103 | |
104 | if (prb.data_scale != 1.0 || prb.data_shift != 0.0) |
105 | DNN_SAFE_V(dnnl_primitive_attr_set_rnn_data_qparams( |
106 | dnnl_attr, prb.data_scale, prb.data_shift)); |
107 | |
108 | DNN_SAFE_V(dnnl_primitive_attr_set_scratchpad_mode( |
109 | dnnl_attr, prb.attr.scratchpad_mode)); |
110 | |
111 | DNN_SAFE_V(dnnl_primitive_attr_set_fpmath_mode( |
112 | dnnl_attr, prb.attr.fpmath_mode)); |
113 | |
114 | return dnnl_attr; |
115 | } |
116 | |
117 | int check_s8s8_reorder(const prb_t &prb, rnn_data_kind_t kind, |
118 | const dnn_mem_t &mem_dt, const dnn_mem_t &mem_fp) { |
119 | // TODO: enable for all cpu_kind when supported |
120 | if (is_gpu()) return OK; |
121 | |
122 | #if DNNL_CPU_RUNTIME == DNNL_RUNTIME_DPCPP |
123 | // DPC++ does not provide a simple way to access the underlying |
124 | // buffer alignment. |
125 | return OK; |
126 | #endif |
127 | |
128 | // In the main test, we fill buffers with f32 and reorder to s8 |
129 | // with quantization. |
130 | |
131 | // The endgoal is to check that the reorder |
132 | // f32_plain_nonquantized --reorder--> s8_packed_quantized |
133 | // gives the same output as the sequence |
134 | // f32_plain --quant--> s8_plain_quantized --reorder--> s8_packed_quantized |
135 | |
136 | // Here, |
137 | // 1. we quantize the f32 plain memory to s8 plain memory, |
138 | // 2. we reorder the s8 plain to s8 packed (queried from rnn primitive desc) |
139 | // 3. we check that the two memory are bitwise identical. |
140 | |
141 | // Note: the two s8 packed memories need to have the same |
142 | // alignment as packed buffer is aligned internally and the offset |
143 | // is kept in the metadata. |
144 | // Works fine with dnn_mem_t as it is align to 2MB large page boundary |
145 | dnn_mem_t mem_s8_src(mem_fp.md_, dnnl_s8, tag::abx, get_cpu_engine()); |
146 | dnn_mem_t mem_s8_dst(mem_dt.md_, get_test_engine()); |
147 | |
148 | /* 1. compute f32_plain --quant--> s8_plain_quantized */ |
149 | /* Do fixed partitioning to have same filling for any number of threads */ |
150 | auto nelems = mem_fp.nelems(); |
151 | const int64_t n_chunks = 16; |
152 | const int64_t chunk_size = div_up(nelems, n_chunks); |
153 | const auto quantize = [&](const float *scales, int nscales, float shift, |
154 | int idx_chunk) { |
155 | int64_t idx_start = idx_chunk * chunk_size; |
156 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
157 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
158 | const float current_scale = scales[idx % nscales]; |
159 | float val_f32 = mem_fp.get_elem(idx); |
160 | int8_t val_s8 = saturate_and_round<dnnl_s8>( |
161 | val_f32 * current_scale + shift); |
162 | mem_s8_src.set_elem(idx, val_s8); |
163 | } |
164 | }; |
165 | switch (kind) { |
166 | case WEIGHTS_LAYER: |
167 | case WEIGHTS_ITER: |
168 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
169 | quantize(prb.wei_scales, prb.wei_nscales, 0, idx); |
170 | }); |
171 | break; |
172 | case WEIGHTS_PROJECTION: |
173 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
174 | quantize(prb.wei_proj_scales, prb.wei_proj_nscales, 0, idx); |
175 | }); |
176 | break; |
177 | case SRC_LAYER: |
178 | case SRC_ITER: |
179 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
180 | quantize(&(prb.data_scale), 1, prb.data_shift, idx); |
181 | }); |
182 | break; |
183 | default: assert(!"unsupported kind" ); |
184 | } |
185 | |
186 | /* 2. compute s8_plain_quantized --reorder--> s8_packed_quantized */ |
187 | mem_s8_dst.reorder(mem_s8_src); |
188 | |
189 | /* 3. we check that the two memory are bitwise identical. */ |
190 | auto sz = mem_dt.size(); |
191 | uint8_t *s8_dst_handle = (uint8_t *)mem_s8_dst; |
192 | uint8_t *mem_dt_handle = (uint8_t *)mem_dt; |
193 | |
194 | // check that both have the same size |
195 | assert(mem_dt.size() == mem_s8_dst.size()); |
196 | // check that both have the same alignment modulo align_data in gemm_pack_storage.hpp |
197 | assert((uint64_t)s8_dst_handle % 0x1000 |
198 | == (uint64_t)mem_dt_handle % 0x1000); |
199 | for (size_t i = 0; i < sz; ++i) { |
200 | if (s8_dst_handle[i] != mem_dt_handle[i]) { return FAIL; } |
201 | } |
202 | |
203 | return OK; |
204 | } |
205 | |
206 | int fill_memory(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt, |
207 | dnn_mem_t &mem_fp, dnnl_data_type_t dt, float mean, float stddev, |
208 | float min, float max, const_dnnl_primitive_attr_t attr = nullptr, |
209 | bool flip_sign = false) { |
210 | const auto nelems = mem_dt.nelems(); |
211 | if (nelems == 0) return OK; |
212 | assert(mem_dt.nelems() == mem_fp.nelems()); |
213 | |
214 | // For non-int8 RNN the data is filled according to cfg directly. |
215 | // However, for int8 RNN we have slightly obscure logic, at least for now: |
216 | // 1. cfg describes the quantized data; |
217 | // 2. We fill first f32 de-quantized data, by inverse-applying the scale |
218 | // and shift to the data generated by cfg distribution; |
219 | // 3. We reorder the data for the oneDNN RNN primitive |
220 | // 4. Q10n of the data for reference benchdnn RNN: |
221 | // 4.a. If the tensor is weights -- q10n it here; |
222 | // 4.b. If the tensor is data -- reference benchdnn RNN will quantize it. |
223 | |
224 | // pass rnn attributes to f32 -> int8 reorders only |
225 | const_dnnl_primitive_attr_t reorder_attr = nullptr; |
226 | if (prb.is_int8() && (dt != dnnl_f32)) reorder_attr = attr; |
227 | float default_scales[1] = {1.0f}; |
228 | float default_shift = 0.0f; |
229 | |
230 | /* Do fixed partitioning to have same filling for any number of threads */ |
231 | const int64_t n_chunks = 16; |
232 | const int64_t chunk_size = div_up(nelems, n_chunks); |
233 | |
234 | // 2. We fill first f32 de-quantized data, by inverse-applying the scale |
235 | // and shift to the data generated by cfg distribution; |
236 | auto fill_chunk = [&](const float *scales, int nscales, float shift, |
237 | int idx_chunk) { |
238 | int64_t idx_start = idx_chunk * chunk_size; |
239 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
240 | std::minstd_rand msr; |
241 | msr.seed(idx_start + kind); |
242 | normal_distribution_t<float> gen(mean, stddev); |
243 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
244 | float val = round_to_nearest_representable(dt, gen(msr)); |
245 | val = MAX2(MIN2(val, max), min); |
246 | val = (val - shift) |
247 | / scales[idx % nscales]; // change only int8-case |
248 | |
249 | // Vanilla RNN with RELU testing related only: flip the sign of |
250 | // inputs for `mb` == 0 to test RELU part |
251 | if (flip_sign) { |
252 | assert(kind == SRC_LAYER || kind == SRC_ITER); |
253 | auto ld = kind == SRC_LAYER ? prb.slc : prb.sic; |
254 | if (idx % (prb.mb * ld) < ld) val *= -1; |
255 | } |
256 | mem_fp.set_elem(idx, val); |
257 | } |
258 | }; |
259 | switch (kind) { |
260 | case WEIGHTS_PROJECTION: |
261 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
262 | fill_chunk( |
263 | prb.wei_proj_scales, prb.wei_proj_nscales, 0.0f, idx); |
264 | }); |
265 | break; |
266 | case WEIGHTS_LAYER: |
267 | case WEIGHTS_ITER: |
268 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
269 | fill_chunk(prb.wei_scales, prb.wei_nscales, 0.0f, idx); |
270 | }); |
271 | break; |
272 | case SRC_LAYER: |
273 | case SRC_ITER: |
274 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
275 | fill_chunk(&(prb.data_scale), 1, prb.data_shift, idx); |
276 | }); |
277 | break; |
278 | default: // we do no scale/shift |
279 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
280 | fill_chunk(default_scales, 1, default_shift, idx); |
281 | }); |
282 | } |
283 | |
284 | // 3. We reorder the data for the DNNL RNN primitive |
285 | mem_dt.reorder(mem_fp, reorder_attr); |
286 | if ((reorder_attr != nullptr) && (dt == dnnl_s8)) |
287 | if (check_s8s8_reorder(prb, kind, mem_dt, mem_fp) != OK) return FAIL; |
288 | |
289 | // Bullet 4.a holds: quantize weights for int8 benchdnn reference RNN |
290 | if (prb.is_int8()) { |
291 | auto quantize_chunk |
292 | = [&](const float *scales, int nscales, int idx_chunk) { |
293 | int64_t idx_start = idx_chunk * chunk_size; |
294 | int64_t idx_end = MIN2(idx_start + chunk_size, nelems); |
295 | for (int64_t idx = idx_start; idx < idx_end; ++idx) { |
296 | float current_scale = scales[idx % nscales]; |
297 | float val = ((float *)mem_fp)[idx]; |
298 | val = round(current_scale * val); |
299 | mem_fp.set_elem(idx, MAX2(MIN2(val, max), min)); |
300 | } |
301 | }; |
302 | switch (kind) { |
303 | case WEIGHTS_LAYER: |
304 | case WEIGHTS_ITER: |
305 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
306 | quantize_chunk(prb.wei_scales, prb.wei_nscales, idx); |
307 | }); |
308 | break; |
309 | case WEIGHTS_PROJECTION: |
310 | benchdnn_parallel_nd(n_chunks, [&](int64_t idx) { |
311 | quantize_chunk( |
312 | prb.wei_proj_scales, prb.wei_proj_nscales, idx); |
313 | }); |
314 | break; |
315 | default: // Nothing to do |
316 | break; |
317 | } |
318 | } |
319 | |
320 | return OK; |
321 | } |
322 | |
323 | int fill_memory(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt, |
324 | dnn_mem_t &mem_fp, const_dnnl_primitive_attr_t attr = nullptr, |
325 | bool flip_sign = false) { |
326 | const dt_conf_t::entry_t &c = prb.cfg[kind]; |
327 | return fill_memory(prb, kind, mem_dt, mem_fp, c.dt, c.f_mean, c.f_stddev, |
328 | c.f_min, c.f_max, attr, flip_sign); |
329 | } |
330 | |
331 | int fill_activation(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt, |
332 | dnn_mem_t &mem_fp, const_dnnl_primitive_attr_t attr = nullptr) { |
333 | // In general, we mostly want to use positive values to avoid |
334 | // cancellation from happening during computation. The only case |
335 | // where we actually want negative values to appear is for 1 layer |
336 | // 1 iteration tests using vanilla_rnn and non-zero alpha. In that |
337 | // case, we want to check that alpha is applied accordingly. Here |
338 | // skip_nonlinear is checked as we want to test relu with non-zero |
339 | // alpha, and not the linear function that would replace it under |
340 | // skip_nonlinear=true. |
341 | bool flip_sign = prb.skip_nonlinear == false && prb.alg == VANILLA_RNN |
342 | && prb.activation == RELU |
343 | && (kind == SRC_LAYER || kind == SRC_ITER); |
344 | return fill_memory(prb, kind, mem_dt, mem_fp, attr, flip_sign); |
345 | } |
346 | |
347 | int fill_src_iter_c(const prb_t &prb, dnn_mem_t &mem_dt, dnn_mem_t &mem_fp, |
348 | const_dnnl_primitive_attr_t attr = nullptr) { |
349 | const bool special_case = prb.prop == dnnl_backward && prb.skip_nonlinear; |
350 | if (!special_case) |
351 | return fill_memory(prb, SRC_ITER_C, mem_dt, mem_fp, attr); |
352 | |
353 | // The scaling factors in tparams when testing backward are common for |
354 | // for forward and backward passes, and computed as 1 over maximum of |
355 | // the accumulation chain: |
356 | // - ~n_gates on FWD |
357 | // - ~dhc * n_gates on BWD_D |
358 | // - ~mb * n_gates on BWD_W |
359 | // |
360 | // This makes tparam relatively small for the forward pass (compare to |
361 | // the forward pass when we test forward only). This in turn, makes |
362 | // src_iter_c converge relatively fast to the value ~ i_gate * c_gate, |
363 | // which is (typically) way smaller than the original distribution for |
364 | // src_iter_c. |
365 | // |
366 | // TODO: use different tparams for forward and |
367 | // backward passes when testing BWD_DW. |
368 | // |
369 | // The problem appears on backward pass. Consider diff_f_gate that |
370 | // contributes to backward weights when batch or number of iterations |
371 | // is big: |
372 | // diff_f_gate[iter] = src_iter_c[iter] * diff_dst[iter]. |
373 | // diff_weights += ~diff_f_gate[iter]. |
374 | // |
375 | // Assume, that diff_dst[iter] is always about the same for every iter. |
376 | // Since src_iter_c[0] >> src_iter_c[iter] for iter > 0, this makes the |
377 | // diff_weight be highly dependent on the order of accumulating the |
378 | // diff_f_gate[iter]. |
379 | // |
380 | // Originally we had something like: |
381 | // diff_weights = v + v * 10^-5 + ... + v * 10^-5 (n_iter * MB summands). |
382 | // Depending on the order of summation the difference might exceed the |
383 | // typical bound approximation: coefficient * log(number_of_summands). |
384 | // |
385 | // Anyways, the algorithm below tries to put the first src_iter_c[iter = 0] |
386 | // in the same ballpark as all the subsequent src_iter_c[iter > 0]. |
387 | // |
388 | // The estimation is based on the following rough assumptions: |
389 | // src_iter_c[iter+1] = f_gate * src_iter_c[iter] + i_gate * c_gate |
390 | // ~= f_gate * small_value + i_gate * c_gate |
391 | // ~= i_gate * c_gate. |
392 | // i_gate ~= tparams[i_gate] * ( |
393 | // 1 / ngates * mean_src_layer + |
394 | // 1 / ngates * mean_src_iter + |
395 | // mean_bias); |
396 | // |
397 | // Same for c_gate. |
398 | // The (1 / ngates) factor is taken from fill_weights(). |
399 | |
400 | float expect_gemm_output = (1.f / prb.n_gates()) * prb.cfg[SRC_LAYER].f_mean |
401 | + (1.f / prb.n_gates()) * prb.cfg[SRC_ITER].f_mean |
402 | + prb.cfg[BIAS].f_mean; |
403 | float expect_i_gate = (float)prb.linear_scales[LSTM_I] * expect_gemm_output; |
404 | float expect_c_gate = (float)prb.linear_scales[LSTM_C] * expect_gemm_output; |
405 | float expect_src_iter_c_mean = expect_i_gate * expect_c_gate; |
406 | |
407 | float adjust_factor = 1; |
408 | |
409 | const bool need_adjust = expect_src_iter_c_mean < prb.cfg[SRC_ITER_C].f_mean |
410 | && prb.cfg[SRC_ITER_C].f_mean != 0; |
411 | if (need_adjust) |
412 | adjust_factor = expect_src_iter_c_mean / prb.cfg[SRC_ITER_C].f_mean; |
413 | |
414 | const dt_conf_t::entry_t &c = prb.cfg[SRC_ITER_C]; |
415 | return fill_memory(prb, SRC_ITER_C, mem_dt, mem_fp, c.dt, |
416 | c.f_mean * adjust_factor, c.f_stddev * adjust_factor, |
417 | c.f_min * adjust_factor, c.f_max * adjust_factor, attr); |
418 | } |
419 | |
420 | int fill_weights(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt, |
421 | dnn_mem_t &mem_fp, const_dnnl_primitive_attr_t attr = nullptr) { |
422 | const auto nelems = mem_dt.nelems(); |
423 | if (nelems == 0) return OK; |
424 | const dt_conf_t::entry_t &c = prb.cfg[kind]; |
425 | |
426 | assert(kind == WEIGHTS_PROJECTION ? mem_fp.ndims() == 4 |
427 | : mem_fp.ndims() == 5); |
428 | |
429 | const auto &dims = mem_fp.dims(); |
430 | const int64_t L = dims[0]; |
431 | const int64_t D = dims[1]; |
432 | const int64_t I = dims[2]; |
433 | const int64_t G = (kind == WEIGHTS_PROJECTION) ? 1 : dims[3]; |
434 | const int64_t O = (kind == WEIGHTS_PROJECTION) ? dims[3] : dims[4]; |
435 | |
436 | float gate_factor |
437 | = (kind == WEIGHTS_PROJECTION) ? 1.f : 1.f / prb.n_gates(); |
438 | |
439 | const auto tag = tag::abx; |
440 | dnn_mem_t mem_pure_fp(mem_dt.md_, dnnl_f32, tag, get_cpu_engine()); |
441 | |
442 | for (int64_t i = 0; i < mem_fp.nelems(); i++) { |
443 | mem_fp.set_elem(i, 0); |
444 | mem_pure_fp.set_elem(i, 0); |
445 | } |
446 | |
447 | auto scales = (kind == WEIGHTS_PROJECTION) ? prb.wei_proj_scales |
448 | : prb.wei_scales; |
449 | auto n_scales = (kind == WEIGHTS_PROJECTION) ? prb.wei_proj_nscales |
450 | : prb.wei_nscales; |
451 | |
452 | // Fill weights sparsely to avoid accumulation errors. Using two memories: |
453 | // one is quantized for reference, another is for a reorder. |
454 | for_(int64_t l = 0; l < L; l++) |
455 | for_(int64_t d = 0; d < D; d++) |
456 | for_(int64_t g = 0; g < G; g++) |
457 | for (int64_t o = 0; o < O; o++) { |
458 | int64_t i_off = ((19 * o + 7 * g + 11 * d + 13 * l) % I); |
459 | int64_t off = (((l * D + d) * I + i_off) * G + g) * O + o; |
460 | float val = gate_factor; |
461 | mem_pure_fp.set_elem(off, val); |
462 | if (prb.is_int8()) val *= scales[off % n_scales]; |
463 | mem_fp.set_elem(off, round_to_nearest_representable(c.dt, val)); |
464 | } |
465 | |
466 | // Pass rnn attributes to f32 -> s8 reorders only |
467 | const_dnnl_primitive_attr_t reorder_attr = nullptr; |
468 | if (prb.is_int8()) reorder_attr = attr; |
469 | mem_dt.reorder(mem_pure_fp, reorder_attr); |
470 | |
471 | // Test that s8 -> s8 reorder works correctly |
472 | if ((reorder_attr != nullptr) && (c.dt == dnnl_s8)) |
473 | return check_s8s8_reorder(prb, kind, mem_dt, mem_pure_fp); |
474 | return OK; |
475 | } |
476 | |
477 | int fill_bias(const prb_t &prb, rnn_data_kind_t kind, dnn_mem_t &mem_dt, |
478 | dnn_mem_t &mem_fp) { |
479 | // To reduce likelihood of cancellation happening in bwd by bias, |
480 | // (especially for GRU), we want diff_bias to be sparse |
481 | const auto &dims = mem_fp.dims(); |
482 | auto L = dims[0]; |
483 | auto D = dims[1]; |
484 | auto G = dims[2]; |
485 | auto O = dims[3]; |
486 | |
487 | std::minstd_rand msr; |
488 | normal_distribution_t<float> gen( |
489 | prb.cfg[kind].f_mean, prb.cfg[kind].f_stddev); |
490 | msr.seed(kind); |
491 | |
492 | for_(int64_t l = 0; l < L; l++) |
493 | for_(int64_t d = 0; d < D; d++) |
494 | for_(int64_t g = 0; g < G; g++) |
495 | for (int64_t o = 0; o < O; o++) { |
496 | auto idx = l * D * G * O + d * G * O + g * O + o; |
497 | auto val = round_to_nearest_representable( |
498 | prb.cfg[kind].dt, gen(msr) * flip_coin(idx, 0.05f)); |
499 | mem_fp.set_elem(idx, val); |
500 | } |
501 | mem_dt.reorder(mem_fp); |
502 | return OK; |
503 | } |
504 | |
505 | void compute_ref( |
506 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
507 | const prb_t &prb_ = *prb; |
508 | if (prb_.prop != dnnl_backward) |
509 | compute_ref_fwd(prb_, args); |
510 | else |
511 | compute_ref_bwd(prb_, args); |
512 | } |
513 | |
514 | dnnl_status_t init_pd(init_pd_args_t<prb_t> &init_pd_args) { |
515 | const prb_t &prb = *init_pd_args.prb; |
516 | const dir_t dir = init_pd_args.dir; |
517 | |
518 | dnnl_prop_kind_t fwd_prop = dnnl_prop_kind_undef; |
519 | switch (prb.prop) { |
520 | case dnnl_forward_training: fwd_prop = dnnl_forward_training; break; |
521 | case dnnl_forward_inference: fwd_prop = dnnl_forward_inference; break; |
522 | // If we are testing backward, we have to run forward training first |
523 | // in order to generate a valid workspace. |
524 | case dnnl_backward: fwd_prop = dnnl_forward_training; break; |
525 | default: DNN_SAFE_STATUS(dnnl_invalid_arguments); |
526 | } |
527 | |
528 | const bool is_gru_lbr = prb.alg == LBR_GRU || prb.alg == LBR_AUGRU; |
529 | // Enable testing with non trivial strides |
530 | const int the_stride = prb.trivial_strides ? 0 : 1; |
531 | |
532 | // bidirectional = 2, s for lstm = 2, for all other = 1 |
533 | dnnl_dims_t weights_layer_dims |
534 | = {prb.n_layer, prb.n_dir(), prb.slc, prb.n_gates(), prb.dhc}; |
535 | dnnl_dims_t weights_iter_dims |
536 | = {prb.n_layer, prb.n_dir(), prb.sic, prb.n_gates(), prb.dhc}; |
537 | dnnl_dims_t attention_dims = {prb.n_iter, prb.mb, 1}; |
538 | dnnl_dims_t weights_peephole_dims = {prb.n_layer, prb.n_dir(), 3, prb.dhc}; |
539 | dnnl_dims_t weights_projection_dims |
540 | = {prb.n_layer, prb.n_dir(), prb.dhc, prb.dic}; |
541 | dnnl_dims_t bias_dims |
542 | = {prb.n_layer, prb.n_dir(), prb.n_gates() + is_gru_lbr, prb.dhc}; |
543 | // dnnl_tnc |
544 | dnnl_dims_t dst_layer_dims = {prb.n_iter, prb.mb, prb.dlc(PRIMITIVE)}; |
545 | |
546 | dnnl_dims_t src_layer_dims = {prb.n_iter, prb.mb, prb.slc}; |
547 | auto src_layer_d = dnn_mem_t::init_md( |
548 | 3, src_layer_dims, prb.cfg[SRC_LAYER].dt, tag::abx /* dnnl_tnc */); |
549 | dims_t src_layer_strides(query_md_ndims(src_layer_d)); |
550 | std::memcpy(src_layer_strides.data(), query_md_strides(src_layer_d), |
551 | src_layer_strides.size() * sizeof(dnnl_dim_t)); |
552 | src_layer_strides[0] += the_stride; |
553 | src_layer_d = dnn_mem_t::init_md(query_md_ndims(src_layer_d), |
554 | query_md_dims(src_layer_d), query_md_data_type(src_layer_d), "" , |
555 | src_layer_strides); |
556 | |
557 | dnnl_dims_t src_iter_dims = {prb.n_layer, prb.n_dir(), prb.mb, prb.sic}; |
558 | auto src_iter_d = dnn_mem_t::init_md( |
559 | 4, src_iter_dims, prb.cfg[SRC_ITER].dt, tag::abx /* dnnl_ldnc */); |
560 | // Adjust strides for src_iter_d. |
561 | dims_t src_iter_strides(query_md_ndims(src_iter_d)); |
562 | std::memcpy(src_iter_strides.data(), query_md_strides(src_iter_d), |
563 | src_iter_strides.size() * sizeof(dnnl_dim_t)); |
564 | src_iter_strides[2] = prb.sic + the_stride; |
565 | for (int d = 1; d >= 0; --d) |
566 | src_iter_strides[d] |
567 | = src_iter_strides[d + 1] * query_md_dims(src_iter_d)[d + 1]; |
568 | src_iter_d = dnn_mem_t::init_md(query_md_ndims(src_iter_d), |
569 | query_md_dims(src_iter_d), query_md_data_type(src_iter_d), "" , |
570 | src_iter_strides); |
571 | |
572 | dnnl_dims_t src_iter_c_dims = {prb.n_layer, prb.n_dir(), prb.mb, prb.dhc}; |
573 | auto src_iter_c_d = dnn_mem_t::init_md(4, src_iter_c_dims, |
574 | prb.cfg[SRC_ITER_C].dt, tag::abx /* dnnl_ldnc */); |
575 | // Adjust strides for src_iter_c_d. |
576 | dims_t src_iter_c_strides(query_md_ndims(src_iter_c_d)); |
577 | std::memcpy(src_iter_c_strides.data(), query_md_strides(src_iter_c_d), |
578 | src_iter_c_strides.size() * sizeof(dnnl_dim_t)); |
579 | src_iter_c_strides[2] = prb.dhc + the_stride; |
580 | for (int d = 1; d >= 0; --d) |
581 | src_iter_c_strides[d] = src_iter_c_strides[d + 1] |
582 | * query_md_dims(src_iter_c_d)[d + 1]; |
583 | src_iter_c_d = dnn_mem_t::init_md(query_md_ndims(src_iter_c_d), |
584 | query_md_dims(src_iter_c_d), query_md_data_type(src_iter_c_d), "" , |
585 | src_iter_c_strides); |
586 | |
587 | auto weights_layer_d = dnn_mem_t::init_md( |
588 | 5, weights_layer_dims, prb.cfg[WEIGHTS_LAYER].dt, tag::any); |
589 | auto weights_iter_d = dnn_mem_t::init_md( |
590 | 5, weights_iter_dims, prb.cfg[WEIGHTS_ITER].dt, tag::any); |
591 | |
592 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> attention_d {}; |
593 | if (prb.is_augru()) |
594 | attention_d = dnn_mem_t::init_md(3, attention_dims, |
595 | prb.cfg[AUGRU_ATTENTION].dt, tag::abx /* dnnl_tnc */); |
596 | |
597 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> weights_peephole_d {}; |
598 | if (prb.is_lstm_peephole()) |
599 | weights_peephole_d = dnn_mem_t::init_md(4, weights_peephole_dims, |
600 | prb.cfg[WEIGHTS_PEEPHOLE].dt, tag::abx /* dnnl_ldgo */); |
601 | |
602 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> weights_projection_d {}; |
603 | if (prb.is_lstm_projection()) |
604 | weights_projection_d = dnn_mem_t::init_md(4, weights_projection_dims, |
605 | prb.cfg[WEIGHTS_PROJECTION].dt, tag::any); |
606 | |
607 | auto bias_d = dnn_mem_t::init_md(4, bias_dims, prb.cfg[BIAS].dt, tag::any); |
608 | |
609 | auto dst_layer_d = dnn_mem_t::init_md( |
610 | 3, dst_layer_dims, prb.cfg[DST_LAYER].dt, tag::abx /* dnnl_tnc */); |
611 | dims_t dst_layer_strides(query_md_ndims(dst_layer_d)); |
612 | std::memcpy(dst_layer_strides.data(), query_md_strides(dst_layer_d), |
613 | dst_layer_strides.size() * sizeof(dnnl_dim_t)); |
614 | dst_layer_strides[0] += the_stride; |
615 | dst_layer_d = dnn_mem_t::init_md(query_md_ndims(dst_layer_d), |
616 | query_md_dims(dst_layer_d), query_md_data_type(dst_layer_d), "" , |
617 | dst_layer_strides); |
618 | |
619 | dnnl_dims_t dst_iter_dims = {prb.n_layer, prb.n_dir(), prb.mb, prb.dic}; |
620 | auto dst_iter_d = dnn_mem_t::init_md( |
621 | 4, dst_iter_dims, prb.cfg[DST_ITER].dt, tag::abx /* dnnl_ldnc */); |
622 | // Adjust strides for dst_iter_d. |
623 | dims_t dst_iter_strides(query_md_ndims(dst_iter_d)); |
624 | std::memcpy(dst_iter_strides.data(), query_md_strides(dst_iter_d), |
625 | dst_iter_strides.size() * sizeof(dnnl_dim_t)); |
626 | dst_iter_strides[2] = prb.dic + the_stride; |
627 | for (int d = 1; d >= 0; --d) |
628 | dst_iter_strides[d] |
629 | = dst_iter_strides[d + 1] * query_md_dims(dst_iter_d)[d + 1]; |
630 | dst_iter_d = dnn_mem_t::init_md(query_md_ndims(dst_iter_d), |
631 | query_md_dims(dst_iter_d), query_md_data_type(dst_iter_d), "" , |
632 | dst_iter_strides); |
633 | |
634 | dnnl_dims_t dst_iter_c_dims = {prb.n_layer, prb.n_dir(), prb.mb, prb.dhc}; |
635 | auto dst_iter_c_d = dnn_mem_t::init_md(4, dst_iter_c_dims, |
636 | prb.cfg[DST_ITER_C].dt, tag::abx /* dnnl_ldnc */); |
637 | // Adjust strides for dst_iter_c_d. |
638 | dims_t dst_iter_c_strides(query_md_ndims(dst_iter_c_d)); |
639 | std::memcpy(dst_iter_c_strides.data(), query_md_strides(dst_iter_c_d), |
640 | dst_iter_c_strides.size() * sizeof(dnnl_dim_t)); |
641 | dst_iter_c_strides[2] = prb.dhc + the_stride; |
642 | for (int d = 1; d >= 0; --d) |
643 | dst_iter_c_strides[d] = dst_iter_c_strides[d + 1] |
644 | * query_md_dims(dst_iter_c_d)[d + 1]; |
645 | dst_iter_c_d = dnn_mem_t::init_md(query_md_ndims(dst_iter_c_d), |
646 | query_md_dims(dst_iter_c_d), query_md_data_type(dst_iter_c_d), "" , |
647 | dst_iter_c_strides); |
648 | |
649 | auto dnnl_attr = make_benchdnn_dnnl_wrapper(create_dnnl_rnn_attr(prb)); |
650 | |
651 | // Initializing the forward pass |
652 | // When inference, we use forward_inference |
653 | // When training, we use forward_training |
654 | if (dir & FLAG_FWD) { |
655 | DNN_SAFE_STATUS(init_rnn_fwd_pd(&init_pd_args.pd, init_pd_args.engine, |
656 | prb, fwd_prop, src_layer_d, src_iter_d, src_iter_c_d, |
657 | attention_d, weights_layer_d, weights_iter_d, |
658 | weights_peephole_d, weights_projection_d, bias_d, dst_layer_d, |
659 | dst_iter_d, dst_iter_c_d, dnnl_attr)); |
660 | } else { |
661 | // TODO: add stride support for diff_* tensors |
662 | auto diff_src_layer_d = dnn_mem_t::init_md( |
663 | 3, src_layer_dims, prb.cfg[DIFF_SRC_LAYER].dt, tag::any); |
664 | auto diff_src_iter_d = dnn_mem_t::init_md( |
665 | 4, src_iter_dims, prb.cfg[DIFF_SRC_ITER].dt, tag::any); |
666 | auto diff_src_iter_c_d = dnn_mem_t::init_md( |
667 | 4, src_iter_c_dims, prb.cfg[DIFF_SRC_ITER_C].dt, tag::any); |
668 | auto diff_weights_layer_d = dnn_mem_t::init_md(5, weights_layer_dims, |
669 | prb.cfg[DIFF_WEIGHTS_LAYER].dt, tag::any); |
670 | auto diff_weights_iter_d = dnn_mem_t::init_md( |
671 | 5, weights_iter_dims, prb.cfg[DIFF_WEIGHTS_ITER].dt, tag::any); |
672 | |
673 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> diff_attention_d {}; |
674 | if (prb.is_augru()) |
675 | diff_attention_d = dnn_mem_t::init_md(3, attention_dims, |
676 | prb.cfg[DIFF_AUGRU_ATTENTION].dt, tag::abx /* dnnl_tnc */); |
677 | |
678 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> diff_weights_peephole_d {}; |
679 | if (prb.is_lstm_peephole()) |
680 | diff_weights_peephole_d = dnn_mem_t::init_md(4, |
681 | weights_peephole_dims, prb.cfg[DIFF_WEIGHTS_PEEPHOLE].dt, |
682 | tag::abx /* dnnl_ldgo */); |
683 | |
684 | benchdnn_dnnl_wrapper_t<dnnl_memory_desc_t> |
685 | diff_weights_projection_d {}; |
686 | if (prb.is_lstm_projection()) |
687 | diff_weights_projection_d |
688 | = dnn_mem_t::init_md(4, weights_projection_dims, |
689 | prb.cfg[DIFF_WEIGHTS_PROJECTION].dt, tag::any); |
690 | |
691 | auto diff_bias_d = dnn_mem_t::init_md( |
692 | 4, bias_dims, prb.cfg[DIFF_BIAS].dt, tag::any); |
693 | auto diff_dst_layer_d = dnn_mem_t::init_md( |
694 | 3, dst_layer_dims, prb.cfg[DIFF_DST_LAYER].dt, tag::any); |
695 | auto diff_dst_iter_d = dnn_mem_t::init_md( |
696 | 4, dst_iter_dims, prb.cfg[DIFF_DST_ITER].dt, tag::any); |
697 | auto diff_dst_iter_c_d = dnn_mem_t::init_md( |
698 | 4, dst_iter_c_dims, prb.cfg[DIFF_DST_ITER_C].dt, tag::any); |
699 | |
700 | DNN_SAFE_STATUS(init_rnn_bwd_pd(&init_pd_args.pd, init_pd_args.engine, |
701 | prb, prb.prop, src_layer_d, src_iter_d, src_iter_c_d, |
702 | attention_d, weights_layer_d, weights_iter_d, |
703 | weights_peephole_d, weights_projection_d, bias_d, dst_layer_d, |
704 | dst_iter_d, dst_iter_c_d, diff_src_layer_d, diff_src_iter_d, |
705 | diff_src_iter_c_d, diff_attention_d, diff_weights_layer_d, |
706 | diff_weights_iter_d, diff_weights_peephole_d, |
707 | diff_weights_projection_d, diff_bias_d, diff_dst_layer_d, |
708 | diff_dst_iter_d, diff_dst_iter_c_d, init_pd_args.hint, |
709 | dnnl_attr)); |
710 | } |
711 | |
712 | return dnnl_success; |
713 | } |
714 | |
715 | void skip_unimplemented_prb(const prb_t *prb_, res_t *res) { |
716 | const prb_t &prb = *prb_; |
717 | dir_t dir = str2dir(prop2str(prb.prop)); |
718 | skip_unimplemented_data_type({prb.cfg[SRC_LAYER].dt}, dir, res); |
719 | skip_unimplemented_sum_po(prb.attr, res); |
720 | |
721 | #if !defined(DNNL_X64) || DNNL_X64 == 0 \ |
722 | || DNNL_CPU_RUNTIME == DNNL_RUNTIME_THREADPOOL |
723 | // int8 is not supported altogether since RNN relies on packed IGEMM |
724 | // FIXME: this will disable int8 RNN testing if the library is built with |
725 | // Intel MKL that does have packed IGEMM |
726 | if (prb.is_int8()) { |
727 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
728 | return; |
729 | } |
730 | #endif |
731 | |
732 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
733 | static auto isa = dnnl_get_effective_cpu_isa(); |
734 | const bool is_f16_not_ok = prb.cfg[SRC_LAYER].dt == dnnl_f16 |
735 | && dnnl::is_superset(isa, dnnl_cpu_isa_avx512_core_fp16); |
736 | if (is_f16_not_ok) { |
737 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
738 | return; |
739 | } |
740 | #endif |
741 | |
742 | #ifdef DNNL_AARCH64_USE_ACL |
743 | const bool is_acl_f16_not_ok = prb.cfg[SRC_LAYER].dt == dnnl_f16 |
744 | && dnnl::impl::cpu::platform::has_data_type_support(dnnl_f16); |
745 | if (is_acl_f16_not_ok) { |
746 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
747 | return; |
748 | } |
749 | #endif |
750 | |
751 | // int8 weights reorder does not support non trivial strides; |
752 | // only LSTM and GRU cell kinds support int8 so far; |
753 | if (prb.is_int8()) { |
754 | if (!prb.trivial_strides) { |
755 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
756 | return; |
757 | } |
758 | if (prb.alg != VANILLA_LSTM && prb.alg != VANILLA_GRU) { |
759 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
760 | return; |
761 | } |
762 | } |
763 | |
764 | // LSTM w/ projection is not supported for bf16 |
765 | if (prb.is_lstm_projection() && prb.cfg[SRC_LAYER].dt == dnnl_bf16) { |
766 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
767 | return; |
768 | } |
769 | |
770 | // GPU limitations for RNN |
771 | if (is_gpu()) { |
772 | bool is_AUGRU = prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU; |
773 | if (is_AUGRU) { |
774 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
775 | return; |
776 | } |
777 | if (prb.is_lstm_projection() || prb.is_lstm_peephole()) { |
778 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
779 | return; |
780 | } |
781 | if (prb.is_int8() && prb.alg != VANILLA_LSTM) { |
782 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
783 | return; |
784 | } |
785 | if (prb.is_s8() && prb.alg == VANILLA_LSTM) { |
786 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
787 | return; |
788 | } |
789 | // Implemented only for CPU |
790 | if (prb.cfg[BIAS].dt == dnnl_bf16 || prb.cfg[SRC_ITER_C].dt == dnnl_bf16 |
791 | || prb.cfg[DST_ITER_C].dt == dnnl_bf16) { |
792 | res->state = SKIPPED, res->reason = CASE_NOT_SUPPORTED; |
793 | return; |
794 | } |
795 | } |
796 | } |
797 | |
798 | void skip_invalid_prb(const prb_t *prb_, res_t *res) { |
799 | const prb_t &prb = *prb_; |
800 | |
801 | // Consistency validation. |
802 | bool consistent_proj |
803 | = IMPLICATION(!prb.with_projection, prb.dhc == prb.dic); |
804 | bool consistent_L = IMPLICATION(prb.n_layer > 1, prb.slc == prb.dic); |
805 | bool consistent_T = IMPLICATION(prb.n_iter > 1, prb.sic == prb.dic); |
806 | bool is_GRU = prb.alg == VANILLA_GRU || prb.alg == LBR_GRU; |
807 | bool consistent_GRU = IMPLICATION(is_GRU, prb.sic == prb.dic); |
808 | bool is_AUGRU = prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU; |
809 | bool consistent_AUGRU = IMPLICATION(is_AUGRU, |
810 | prb.sic == prb.dic && prb.n_layer == 1 |
811 | && prb.direction == dnnl_unidirectional_left2right); |
812 | if (!consistent_proj || !consistent_L || !consistent_T || !consistent_GRU |
813 | || !consistent_AUGRU) { |
814 | res->state = SKIPPED, res->reason = INVALID_CASE; |
815 | return; |
816 | } |
817 | |
818 | // Only LSTM supports peephole and projection layer. |
819 | bool is_lstm_peephole |
820 | = IMPLICATION(prb.with_peephole, prb.alg == VANILLA_LSTM); |
821 | bool is_lstm_projection |
822 | = IMPLICATION(prb.with_projection, prb.alg == VANILLA_LSTM); |
823 | if (!is_lstm_peephole || !is_lstm_projection) { |
824 | res->state = SKIPPED, res->reason = INVALID_CASE; |
825 | return; |
826 | } |
827 | } |
828 | |
829 | void setup_cmp(compare::compare_t &cmp, const prb_t *prb, data_kind_t kind, |
830 | const args_t &ref_args) { |
831 | const auto rnn_kind = data_kind2rnn_data_kind(kind); |
832 | const auto &cfg = prb->cfg[rnn_kind]; |
833 | // factor 2 is because of the sum of 2 GEMMs |
834 | int64_t fwd_acc_dim = 2 * prb->n_gates() + 1; |
835 | if (prb->alg == VANILLA_GRU || prb->alg == VANILLA_AUGRU) |
836 | fwd_acc_dim *= prb->sic; |
837 | int64_t bwdd_acc_dim = prb->n_gates() * prb->dhc; |
838 | int64_t bwdw_acc_dim = prb->mb; |
839 | int64_t acc_dim = fwd_acc_dim; |
840 | if (prb->prop == dnnl_backward) acc_dim *= MAX2(bwdd_acc_dim, bwdw_acc_dim); |
841 | // Here the factor 4 just gives some wiggle room for fp32 testing |
842 | |
843 | float trh = 4 |
844 | * (1 + (prb->prop == dnnl_backward)) // double wiggle room for bwd |
845 | * ((prb->direction == dnnl_bidirectional_sum) |
846 | + 1) // double trh if bidir_sum |
847 | * ceilf(log2f(acc_dim * prb->n_iter)) * cfg.eps; |
848 | // expect exact value for int8 |
849 | if (cfg.dt == dnnl_u8 || cfg.dt == dnnl_s8) trh = 0.f; |
850 | cmp.set_threshold(trh); |
851 | |
852 | // Note: we do an eltwise comparison only when: |
853 | // - we use skip_nonlinear; |
854 | // - we do not use skip_nonlinear and we test only one cell execution; |
855 | // - for int8 computations the tensor is not DST_ITER_C; |
856 | // If the above conditions are not met, we check only L1, L2 and L8. |
857 | |
858 | // Rough rationale for the `DST_ITER_C` exception in int8 case: |
859 | // - The formula for one-step c-state is: |
860 | // c_t = f_t * c_{t−1} + i_t * c~_t. |
861 | // Here all computations happen in f32 (f_t, i_t, and c~_t are dequantized |
862 | // right before the computations + the corresponding bias added). |
863 | // - In int8 case we don't have much control over these components and |
864 | // cannot surmount potential cancellations, if any. |
865 | // In practice, I observed that the relative element-wise error of values |
866 | // in `DST_ITER_C` was bigger (up-to 8e-5) whenever the values |
867 | // themselves were smaller (which indirectly means the problem is exactly |
868 | // in the cancellation). Unfortunately, this even happened with only one |
869 | // layer and one time stamp. |
870 | // - So, for now the solution is to use l1- l2- and l_inf-norms to validate |
871 | // `DST_ITER_C`. When we switch testing on using precise |
872 | // integer arithmetic based on modulo operation in rnn_tparams (instead of |
873 | // current unreliable re-scaling), this testing weakness should go away. |
874 | // - Just an obvious side note: `DST_LAYER` and `DST_ITER` |
875 | // are immediate dequantization of the corresponding u8 tensors. Hence, |
876 | // as long as we get precise u8 intermediate results (and so far we do), |
877 | // the f32 result should be pretty accurate -- the dequantization is just |
878 | // two simple ops: f32 = scale * u8 + shift. |
879 | bool check_p2p = (prb->skip_nonlinear |
880 | || ((prb->n_layer == 1) && (prb->n_iter == 1))); |
881 | if (prb->is_int8() && rnn_kind == DST_ITER_C) check_p2p = false; |
882 | cmp.set_norm_validation_mode(!check_p2p); |
883 | |
884 | const auto rnn_add_check = |
885 | [&, prb](const compare::compare_t::driver_check_func_args_t &args) { |
886 | // Limitation from current filling. |
887 | // TODO: find a better filling to get rid of this... |
888 | if ((prb->alg == VANILLA_GRU || prb->alg == LBR_AUGRU |
889 | || prb->alg == VANILLA_RNN || prb->alg == LBR_GRU) |
890 | && prb->prop == dnnl_backward) { |
891 | return args.diff < args.trh; |
892 | } |
893 | return false; |
894 | }; |
895 | cmp.set_driver_check_function(rnn_add_check); |
896 | } |
897 | |
898 | int doit(const prb_t &prb, res_t *res) { |
899 | if (bench_mode == LIST) return res->state = LISTED, OK; |
900 | |
901 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> prim; |
902 | bool is_service_prim = prb.dir & FLAG_BWD; |
903 | SAFE(init_prim(prb.ctx_init, prim, init_pd, &prb, res, FLAG_FWD, nullptr, |
904 | is_service_prim), |
905 | WARN); |
906 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
907 | |
908 | auto const_fpd = query_pd(prim); |
909 | |
910 | const auto &src_layer_md = query_md(const_fpd, DNNL_ARG_SRC_LAYER); |
911 | const auto &src_layer_attention_md |
912 | = query_md(const_fpd, DNNL_ARG_AUGRU_ATTENTION); |
913 | const auto &src_iter_md = query_md(const_fpd, DNNL_ARG_SRC_ITER); |
914 | const auto &src_iter_c_md = query_md(const_fpd, DNNL_ARG_SRC_ITER_C); |
915 | const auto &weights_layer_md = query_md(const_fpd, DNNL_ARG_WEIGHTS_LAYER); |
916 | const auto &weights_iter_md = query_md(const_fpd, DNNL_ARG_WEIGHTS_ITER); |
917 | const auto &weights_peephole_md |
918 | = query_md(const_fpd, DNNL_ARG_WEIGHTS_PEEPHOLE); |
919 | const auto &weights_projection_md |
920 | = query_md(const_fpd, DNNL_ARG_WEIGHTS_PROJECTION); |
921 | const auto &bias_md = query_md(const_fpd, DNNL_ARG_BIAS); |
922 | const auto &dst_layer_md = query_md(const_fpd, DNNL_ARG_DST_LAYER); |
923 | const auto &dst_iter_md = query_md(const_fpd, DNNL_ARG_DST_ITER); |
924 | const auto &dst_iter_c_md = query_md(const_fpd, DNNL_ARG_DST_ITER_C); |
925 | const auto &workspace_md = query_md(const_fpd, DNNL_ARG_WORKSPACE); |
926 | const auto &scratchpad_md = query_md(const_fpd, DNNL_ARG_SCRATCHPAD); |
927 | |
928 | const auto &test_engine = get_test_engine(); |
929 | const auto &ref_engine = get_cpu_engine(); |
930 | |
931 | dnn_mem_t src_layer_dt(src_layer_md, test_engine); |
932 | dnn_mem_t src_layer_attention_dt(src_layer_attention_md, test_engine); |
933 | dnn_mem_t src_iter_dt(src_iter_md, test_engine); |
934 | dnn_mem_t src_iter_c_dt(src_iter_c_md, test_engine); |
935 | dnn_mem_t weights_layer_dt(weights_layer_md, test_engine); |
936 | dnn_mem_t weights_iter_dt(weights_iter_md, test_engine); |
937 | dnn_mem_t weights_peephole_dt(weights_peephole_md, test_engine); |
938 | dnn_mem_t weights_projection_dt(weights_projection_md, test_engine); |
939 | dnn_mem_t bias_dt(bias_md, test_engine); |
940 | dnn_mem_t dst_layer_dt(dst_layer_md, test_engine); |
941 | dnn_mem_t dst_iter_dt(dst_iter_md, test_engine); |
942 | dnn_mem_t dst_iter_c_dt(dst_iter_c_md, test_engine); |
943 | dnn_mem_t workspace_dt(workspace_md, test_engine); |
944 | dnn_mem_t scratchpad_dt(scratchpad_md, test_engine); |
945 | |
946 | dnn_mem_t src_layer_fp( |
947 | src_layer_md, dnnl_f32, tag::abx /*tnc*/, ref_engine); |
948 | dnn_mem_t src_layer_attention_fp( |
949 | src_layer_attention_md, dnnl_f32, tag::abx /*tnc*/, ref_engine); |
950 | dnn_mem_t src_iter_fp(src_iter_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine); |
951 | dnn_mem_t src_iter_c_fp( |
952 | src_iter_c_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine); |
953 | dnn_mem_t weights_layer_fp( |
954 | weights_layer_md, dnnl_f32, tag::abx /*ldigo*/, ref_engine); |
955 | dnn_mem_t weights_iter_fp( |
956 | weights_iter_md, dnnl_f32, tag::abx /*ldigo*/, ref_engine); |
957 | dnn_mem_t weights_peephole_fp( |
958 | weights_peephole_md, dnnl_f32, tag::abx /*ldgo*/, ref_engine); |
959 | dnn_mem_t weights_projection_fp( |
960 | weights_projection_md, dnnl_f32, tag::abx /*ldio*/, ref_engine); |
961 | dnn_mem_t bias_fp(bias_md, dnnl_f32, tag::abx /*ldgo*/, ref_engine); |
962 | dnn_mem_t dst_layer_fp( |
963 | dst_layer_md, dnnl_f32, tag::abx /*tnc*/, ref_engine); |
964 | dnn_mem_t dst_iter_fp(dst_iter_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine); |
965 | dnn_mem_t dst_iter_c_fp( |
966 | dst_iter_c_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine); |
967 | |
968 | dnn_mem_t bwd_weights_layer_dt; |
969 | dnn_mem_t bwd_weights_iter_dt; |
970 | dnn_mem_t bwd_weights_projection_dt; |
971 | dnn_mem_t diff_src_layer_dt; |
972 | dnn_mem_t diff_src_layer_attention_dt; |
973 | dnn_mem_t diff_src_iter_dt; |
974 | dnn_mem_t diff_src_iter_c_dt; |
975 | dnn_mem_t diff_weights_layer_dt; |
976 | dnn_mem_t diff_weights_iter_dt; |
977 | dnn_mem_t diff_weights_peephole_dt; |
978 | dnn_mem_t diff_weights_projection_dt; |
979 | dnn_mem_t diff_bias_dt; |
980 | dnn_mem_t diff_dst_layer_dt; |
981 | dnn_mem_t diff_dst_iter_dt; |
982 | dnn_mem_t diff_dst_iter_c_dt; |
983 | |
984 | // for int8 RNN we need pass attributes for data q10n |
985 | auto rnn_attr = query_attr(const_fpd); |
986 | SAFE(fill_activation(prb, SRC_LAYER, src_layer_dt, src_layer_fp, rnn_attr), |
987 | WARN); |
988 | if (prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU) |
989 | SAFE(fill_activation(prb, AUGRU_ATTENTION, src_layer_attention_dt, |
990 | src_layer_attention_fp, rnn_attr), |
991 | WARN); |
992 | SAFE(fill_activation(prb, SRC_ITER, src_iter_dt, src_iter_fp, rnn_attr), |
993 | WARN); |
994 | if (prb.alg == VANILLA_LSTM) |
995 | SAFE(fill_src_iter_c(prb, src_iter_c_dt, src_iter_c_fp, rnn_attr), |
996 | WARN); |
997 | SAFE(fill_weights(prb, WEIGHTS_LAYER, weights_layer_dt, weights_layer_fp, |
998 | rnn_attr), |
999 | WARN); |
1000 | SAFE(fill_weights( |
1001 | prb, WEIGHTS_ITER, weights_iter_dt, weights_iter_fp, rnn_attr), |
1002 | WARN); |
1003 | SAFE(fill_memory(prb, WEIGHTS_PEEPHOLE, weights_peephole_dt, |
1004 | weights_peephole_fp), |
1005 | WARN); |
1006 | SAFE(fill_weights(prb, WEIGHTS_PROJECTION, weights_projection_dt, |
1007 | weights_projection_fp, rnn_attr), |
1008 | WARN); |
1009 | SAFE(fill_memory(prb, BIAS, bias_dt, bias_fp), WARN); |
1010 | SAFE(fill_activation(prb, DST_LAYER, dst_layer_dt, dst_layer_fp), WARN); |
1011 | SAFE(fill_activation(prb, DST_ITER, dst_iter_dt, dst_iter_fp), WARN); |
1012 | if (prb.alg == VANILLA_LSTM) |
1013 | SAFE(fill_memory(prb, DST_ITER_C, dst_iter_c_dt, dst_iter_c_fp), WARN); |
1014 | |
1015 | args_t args, ref_args; |
1016 | |
1017 | // Running the forward pass |
1018 | args.set(DNNL_ARG_SRC_LAYER, src_layer_dt); |
1019 | args.set(DNNL_ARG_AUGRU_ATTENTION, src_layer_attention_dt); |
1020 | args.set(DNNL_ARG_SRC_ITER, src_iter_dt); |
1021 | args.set(DNNL_ARG_SRC_ITER_C, src_iter_c_dt); |
1022 | args.set(DNNL_ARG_WEIGHTS_LAYER, weights_layer_dt); |
1023 | args.set(DNNL_ARG_WEIGHTS_ITER, weights_iter_dt); |
1024 | args.set(DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_dt); |
1025 | args.set(DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_dt); |
1026 | args.set(DNNL_ARG_BIAS, bias_dt); |
1027 | args.set(DNNL_ARG_DST_LAYER, dst_layer_dt); |
1028 | args.set(DNNL_ARG_DST_ITER, dst_iter_dt); |
1029 | args.set(DNNL_ARG_DST_ITER_C, dst_iter_c_dt); |
1030 | args.set(DNNL_ARG_WORKSPACE, workspace_dt); |
1031 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
1032 | |
1033 | SAFE(execute_and_wait(prim, args, res), WARN); |
1034 | |
1035 | if (prb.prop != dnnl_backward) { |
1036 | if (is_bench_mode(CORR)) { |
1037 | ref_args.set(DNNL_ARG_SRC_LAYER, src_layer_fp); |
1038 | ref_args.set(DNNL_ARG_AUGRU_ATTENTION, src_layer_attention_fp); |
1039 | ref_args.set(DNNL_ARG_SRC_ITER, src_iter_fp); |
1040 | ref_args.set(DNNL_ARG_SRC_ITER_C, src_iter_c_fp); |
1041 | ref_args.set(DNNL_ARG_WEIGHTS_LAYER, weights_layer_fp); |
1042 | ref_args.set(DNNL_ARG_WEIGHTS_ITER, weights_iter_fp); |
1043 | ref_args.set(DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_fp); |
1044 | ref_args.set(DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_fp); |
1045 | ref_args.set(DNNL_ARG_BIAS, bias_fp); |
1046 | ref_args.set(DNNL_ARG_DST_LAYER, dst_layer_fp); |
1047 | ref_args.set(DNNL_ARG_DST_ITER, dst_iter_fp); |
1048 | ref_args.set(DNNL_ARG_DST_ITER_C, dst_iter_c_fp); |
1049 | |
1050 | std::vector<data_kind_t> kinds { |
1051 | data_kind_t::DST, data_kind_t::DST_ITER}; |
1052 | if (prb.alg == VANILLA_LSTM) { |
1053 | kinds.push_back(data_kind_t::DST_ITER_C); |
1054 | } |
1055 | |
1056 | check_correctness(&prb, kinds, args, ref_args, setup_cmp, res); |
1057 | } |
1058 | } else { |
1059 | benchdnn_dnnl_wrapper_t<dnnl_primitive_t> tmp_prim; |
1060 | SAFE(init_prim(tmp_prim, init_pd, &prb, res, FLAG_BWD), WARN); |
1061 | if (res->state == SKIPPED || res->state == UNIMPLEMENTED) return OK; |
1062 | prim.reset(tmp_prim.release()); |
1063 | |
1064 | auto const_bpd = query_pd(prim); |
1065 | |
1066 | const auto &bwd_weights_layer_md |
1067 | = query_md(const_bpd, DNNL_ARG_WEIGHTS_LAYER); |
1068 | const auto &bwd_weights_iter_md |
1069 | = query_md(const_bpd, DNNL_ARG_WEIGHTS_ITER); |
1070 | const auto &bwd_weights_projection_md |
1071 | = query_md(const_bpd, DNNL_ARG_WEIGHTS_PROJECTION); |
1072 | const auto &diff_src_layer_md |
1073 | = query_md(const_bpd, DNNL_ARG_DIFF_SRC_LAYER); |
1074 | const auto &diff_src_layer_attention_md |
1075 | = query_md(const_bpd, DNNL_ARG_DIFF_AUGRU_ATTENTION); |
1076 | const auto &diff_src_iter_md |
1077 | = query_md(const_bpd, DNNL_ARG_DIFF_SRC_ITER); |
1078 | const auto &diff_src_iter_c_md |
1079 | = query_md(const_bpd, DNNL_ARG_DIFF_SRC_ITER_C); |
1080 | const auto &diff_weights_layer_md |
1081 | = query_md(const_bpd, DNNL_ARG_DIFF_WEIGHTS_LAYER); |
1082 | const auto &diff_weights_iter_md |
1083 | = query_md(const_bpd, DNNL_ARG_DIFF_WEIGHTS_ITER); |
1084 | const auto &diff_weights_peephole_md |
1085 | = query_md(const_bpd, DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE); |
1086 | const auto &diff_weights_projection_md |
1087 | = query_md(const_bpd, DNNL_ARG_DIFF_WEIGHTS_PROJECTION); |
1088 | const auto &diff_bias_md = query_md(const_bpd, DNNL_ARG_DIFF_BIAS); |
1089 | const auto &diff_dst_layer_md |
1090 | = query_md(const_bpd, DNNL_ARG_DIFF_DST_LAYER); |
1091 | const auto &diff_dst_iter_md |
1092 | = query_md(const_bpd, DNNL_ARG_DIFF_DST_ITER); |
1093 | const auto &diff_dst_iter_c_md |
1094 | = query_md(const_bpd, DNNL_ARG_DIFF_DST_ITER_C); |
1095 | const auto &bwd_scratchpad_md |
1096 | = query_md(const_bpd, DNNL_ARG_SCRATCHPAD); |
1097 | |
1098 | bwd_weights_layer_dt = dnn_mem_t(bwd_weights_layer_md, test_engine); |
1099 | bwd_weights_iter_dt = dnn_mem_t(bwd_weights_iter_md, test_engine); |
1100 | bwd_weights_projection_dt |
1101 | = dnn_mem_t(bwd_weights_projection_md, test_engine); |
1102 | diff_src_layer_dt = dnn_mem_t(diff_src_layer_md, test_engine); |
1103 | diff_src_layer_attention_dt |
1104 | = dnn_mem_t(diff_src_layer_attention_md, test_engine); |
1105 | diff_src_iter_dt = dnn_mem_t(diff_src_iter_md, test_engine); |
1106 | diff_src_iter_c_dt = dnn_mem_t(diff_src_iter_c_md, test_engine); |
1107 | diff_weights_layer_dt = dnn_mem_t(diff_weights_layer_md, test_engine); |
1108 | diff_weights_iter_dt = dnn_mem_t(diff_weights_iter_md, test_engine); |
1109 | diff_weights_peephole_dt |
1110 | = dnn_mem_t(diff_weights_peephole_md, test_engine); |
1111 | diff_weights_projection_dt |
1112 | = dnn_mem_t(diff_weights_projection_md, test_engine); |
1113 | diff_bias_dt = dnn_mem_t(diff_bias_md, test_engine); |
1114 | diff_dst_layer_dt = dnn_mem_t(diff_dst_layer_md, test_engine); |
1115 | diff_dst_iter_dt = dnn_mem_t(diff_dst_iter_md, test_engine); |
1116 | diff_dst_iter_c_dt = dnn_mem_t(diff_dst_iter_c_md, test_engine); |
1117 | scratchpad_dt = dnn_mem_t(bwd_scratchpad_md, test_engine); |
1118 | |
1119 | dnn_mem_t diff_src_layer_fp( |
1120 | diff_src_layer_md, dnnl_f32, tag::abx /*tnc*/, ref_engine); |
1121 | dnn_mem_t diff_src_layer_attention_fp(diff_src_layer_attention_md, |
1122 | dnnl_f32, tag::abx /*tnc*/, ref_engine); |
1123 | dnn_mem_t diff_src_iter_fp( |
1124 | diff_src_iter_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine); |
1125 | dnn_mem_t diff_src_iter_c_fp( |
1126 | diff_src_iter_c_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine); |
1127 | dnn_mem_t diff_weights_layer_fp(diff_weights_layer_md, dnnl_f32, |
1128 | tag::abx /*ldigo*/, ref_engine); |
1129 | dnn_mem_t diff_weights_iter_fp( |
1130 | diff_weights_iter_md, dnnl_f32, tag::abx /*ldigo*/, ref_engine); |
1131 | dnn_mem_t diff_weights_peephole_fp(diff_weights_peephole_md, dnnl_f32, |
1132 | tag::abx /*ldgo*/, ref_engine); |
1133 | dnn_mem_t diff_weights_projection_fp(diff_weights_projection_md, |
1134 | dnnl_f32, tag::abx /*ldio*/, ref_engine); |
1135 | dnn_mem_t diff_bias_fp( |
1136 | diff_bias_md, dnnl_f32, tag::abx /*ldgo*/, ref_engine); |
1137 | dnn_mem_t diff_dst_layer_fp( |
1138 | diff_dst_layer_md, dnnl_f32, tag::abx /*tnc*/, ref_engine); |
1139 | dnn_mem_t diff_dst_iter_fp( |
1140 | diff_dst_iter_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine); |
1141 | dnn_mem_t diff_dst_iter_c_fp( |
1142 | diff_dst_iter_c_md, dnnl_f32, tag::abx /*ldnc*/, ref_engine); |
1143 | |
1144 | SAFE(bwd_weights_iter_dt.reorder(weights_iter_dt), WARN); |
1145 | SAFE(bwd_weights_layer_dt.reorder(weights_layer_dt), WARN); |
1146 | if (prb.is_lstm_projection()) |
1147 | SAFE(bwd_weights_projection_dt.reorder(weights_projection_dt), |
1148 | WARN); |
1149 | SAFE(fill_activation( |
1150 | prb, DIFF_SRC_LAYER, diff_src_layer_dt, diff_src_layer_fp), |
1151 | WARN); |
1152 | if (prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU) |
1153 | SAFE(fill_activation(prb, DIFF_AUGRU_ATTENTION, |
1154 | diff_src_layer_attention_dt, |
1155 | diff_src_layer_attention_fp), |
1156 | WARN); |
1157 | SAFE(fill_activation( |
1158 | prb, DIFF_SRC_ITER, diff_src_iter_dt, diff_src_iter_fp), |
1159 | WARN); |
1160 | if (prb.alg == VANILLA_LSTM) |
1161 | SAFE(fill_memory(prb, DIFF_SRC_ITER_C, diff_src_iter_c_dt, |
1162 | diff_src_iter_c_fp), |
1163 | WARN); |
1164 | SAFE(fill_weights(prb, DIFF_WEIGHTS_LAYER, diff_weights_layer_dt, |
1165 | diff_weights_layer_fp), |
1166 | WARN); |
1167 | SAFE(fill_weights(prb, DIFF_WEIGHTS_ITER, diff_weights_iter_dt, |
1168 | diff_weights_iter_fp), |
1169 | WARN); |
1170 | SAFE(fill_memory(prb, DIFF_WEIGHTS_PEEPHOLE, diff_weights_peephole_dt, |
1171 | diff_weights_peephole_fp), |
1172 | WARN); |
1173 | SAFE(fill_memory(prb, DIFF_WEIGHTS_PROJECTION, |
1174 | diff_weights_projection_dt, diff_weights_projection_fp), |
1175 | WARN); |
1176 | SAFE(fill_bias(prb, DIFF_BIAS, diff_bias_dt, diff_bias_fp), WARN); |
1177 | SAFE(fill_activation( |
1178 | prb, DIFF_DST_LAYER, diff_dst_layer_dt, diff_dst_layer_fp), |
1179 | WARN); |
1180 | SAFE(fill_activation( |
1181 | prb, DIFF_DST_ITER, diff_dst_iter_dt, diff_dst_iter_fp), |
1182 | WARN); |
1183 | if (prb.alg == VANILLA_LSTM) |
1184 | SAFE(fill_memory(prb, DIFF_DST_ITER_C, diff_dst_iter_c_dt, |
1185 | diff_dst_iter_c_fp), |
1186 | WARN); |
1187 | |
1188 | args.clear(); |
1189 | args.set(DNNL_ARG_SRC_LAYER, src_layer_dt); |
1190 | args.set(DNNL_ARG_AUGRU_ATTENTION, src_layer_attention_dt); |
1191 | args.set(DNNL_ARG_SRC_ITER, src_iter_dt); |
1192 | args.set(DNNL_ARG_SRC_ITER_C, src_iter_c_dt); |
1193 | args.set(DNNL_ARG_WEIGHTS_LAYER, bwd_weights_layer_dt); |
1194 | args.set(DNNL_ARG_WEIGHTS_ITER, bwd_weights_iter_dt); |
1195 | args.set(DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_dt); |
1196 | args.set(DNNL_ARG_WEIGHTS_PROJECTION, bwd_weights_projection_dt); |
1197 | args.set(DNNL_ARG_BIAS, bias_dt); |
1198 | args.set(DNNL_ARG_DST_LAYER, dst_layer_dt); |
1199 | args.set(DNNL_ARG_DST_ITER, dst_iter_dt); |
1200 | args.set(DNNL_ARG_DST_ITER_C, dst_iter_c_dt); |
1201 | args.set(DNNL_ARG_DIFF_DST_LAYER, diff_dst_layer_dt); |
1202 | args.set(DNNL_ARG_DIFF_DST_ITER, diff_dst_iter_dt); |
1203 | args.set(DNNL_ARG_DIFF_DST_ITER_C, diff_dst_iter_c_dt); |
1204 | args.set(DNNL_ARG_WORKSPACE, workspace_dt); |
1205 | args.set(DNNL_ARG_DIFF_SRC_LAYER, diff_src_layer_dt); |
1206 | args.set(DNNL_ARG_DIFF_AUGRU_ATTENTION, diff_src_layer_attention_dt); |
1207 | args.set(DNNL_ARG_DIFF_SRC_ITER, diff_src_iter_dt); |
1208 | args.set(DNNL_ARG_DIFF_SRC_ITER_C, diff_src_iter_c_dt); |
1209 | args.set(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_dt); |
1210 | args.set(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_dt); |
1211 | args.set(DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE, diff_weights_peephole_dt); |
1212 | args.set(DNNL_ARG_DIFF_WEIGHTS_PROJECTION, diff_weights_projection_dt); |
1213 | args.set(DNNL_ARG_DIFF_BIAS, diff_bias_dt); |
1214 | args.set(DNNL_ARG_SCRATCHPAD, scratchpad_dt); |
1215 | |
1216 | SAFE(execute_and_wait(prim, args, res), WARN); |
1217 | |
1218 | if (is_bench_mode(CORR)) { |
1219 | ref_args.set(DNNL_ARG_SRC_LAYER, src_layer_fp); |
1220 | ref_args.set(DNNL_ARG_AUGRU_ATTENTION, src_layer_attention_fp); |
1221 | ref_args.set(DNNL_ARG_SRC_ITER, src_iter_fp); |
1222 | ref_args.set(DNNL_ARG_SRC_ITER_C, src_iter_c_fp); |
1223 | ref_args.set(DNNL_ARG_WEIGHTS_LAYER, weights_layer_fp); |
1224 | ref_args.set(DNNL_ARG_WEIGHTS_ITER, weights_iter_fp); |
1225 | ref_args.set(DNNL_ARG_WEIGHTS_PEEPHOLE, weights_peephole_fp); |
1226 | ref_args.set(DNNL_ARG_WEIGHTS_PROJECTION, weights_projection_fp); |
1227 | ref_args.set(DNNL_ARG_BIAS, bias_fp); |
1228 | ref_args.set(DNNL_ARG_DST_LAYER, dst_layer_fp); |
1229 | ref_args.set(DNNL_ARG_DST_ITER, dst_iter_fp); |
1230 | ref_args.set(DNNL_ARG_DST_ITER_C, dst_iter_c_fp); |
1231 | ref_args.set(DNNL_ARG_DIFF_DST_LAYER, diff_dst_layer_fp); |
1232 | ref_args.set(DNNL_ARG_DIFF_DST_ITER, diff_dst_iter_fp); |
1233 | ref_args.set(DNNL_ARG_DIFF_DST_ITER_C, diff_dst_iter_c_fp); |
1234 | ref_args.set(DNNL_ARG_DIFF_SRC_LAYER, diff_src_layer_fp); |
1235 | ref_args.set( |
1236 | DNNL_ARG_DIFF_AUGRU_ATTENTION, diff_src_layer_attention_fp); |
1237 | ref_args.set(DNNL_ARG_DIFF_SRC_ITER, diff_src_iter_fp); |
1238 | ref_args.set(DNNL_ARG_DIFF_SRC_ITER_C, diff_src_iter_c_fp); |
1239 | ref_args.set(DNNL_ARG_DIFF_WEIGHTS_LAYER, diff_weights_layer_fp); |
1240 | ref_args.set(DNNL_ARG_DIFF_WEIGHTS_ITER, diff_weights_iter_fp); |
1241 | ref_args.set( |
1242 | DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE, diff_weights_peephole_fp); |
1243 | ref_args.set(DNNL_ARG_DIFF_WEIGHTS_PROJECTION, |
1244 | diff_weights_projection_fp); |
1245 | ref_args.set(DNNL_ARG_DIFF_BIAS, diff_bias_fp); |
1246 | |
1247 | std::vector<data_kind_t> kinds {data_kind_t::DST, |
1248 | data_kind_t::DST_ITER, data_kind_t::SRC, |
1249 | data_kind_t::SRC_ITER, data_kind_t::WEI, |
1250 | data_kind_t::WEI_ITER, data_kind_t::BIA}; |
1251 | if (prb.alg == VANILLA_LSTM) { |
1252 | kinds.push_back(data_kind_t::DST_ITER_C); |
1253 | kinds.push_back(data_kind_t::SRC_ITER_C); |
1254 | } |
1255 | if (prb.alg == VANILLA_AUGRU || prb.alg == LBR_AUGRU) |
1256 | kinds.push_back(data_kind_t::AUGRU_ATTENTION); |
1257 | if (prb.is_lstm_peephole()) |
1258 | kinds.push_back(data_kind_t::WEI_PEEPHOLE); |
1259 | if (prb.is_lstm_projection()) |
1260 | kinds.push_back(data_kind_t::WEI_PROJECTION); |
1261 | |
1262 | check_correctness(&prb, kinds, args, ref_args, setup_cmp, res); |
1263 | } |
1264 | } |
1265 | |
1266 | return measure_perf(prb.ctx_exe, res, prim, args); |
1267 | } |
1268 | |
1269 | } // namespace rnn |
1270 | |