1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17/*
18 * Cell execution LSTM
19 */
20
21#include "common/bit_cast.hpp"
22#include "common/dnnl_thread.hpp"
23#include "common/math_utils.hpp"
24
25#include "cpu/simple_q10n.hpp"
26
27#include "cpu/rnn/postgemm_dispatcher.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32
33using namespace dnnl::impl::utils;
34using namespace dnnl::impl::math;
35using namespace rnn_utils;
36#define AOC array_offset_calculator
37
38template <typename T1, typename T2, typename T3, typename T4, typename T5,
39 typename src_data_t, typename scratch_data_t>
40void gru_fwd_part1_postgemm_template(T1 func1, T2 to_src, T3 acc_to_float,
41 T4 src_to_float, T5 reinterpret_as_acc, const float *scales,
42 const rnn_utils::rnn_conf_t &rnn,
43 rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
44 scratch_data_t *scratch_gates_, const src_data_t *augru_attention_,
45 src_data_t *dst_layer_, src_data_t *dst_iter_,
46 const src_data_t *src_iter_, const void *bias_, int block_step) {
47 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
48 const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
49 const auto bias_aoc = rnn_utils::make_raw_aoc(
50 bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
51 const auto bias = [&](int gate_id, int dhc_id) {
52 return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
53 };
54
55 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
56 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
57 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
58
59 const ws_states_layer_aoc<src_data_t> dst_layer(
60 rnn, dst_layer_, dst_layer_ld);
61 const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
62 const ws_states_iter_aoc<const src_data_t> src_iter(
63 rnn, src_iter_, src_iter_ld);
64
65 const float *scales_G1 = scales ? scales + 1 : nullptr;
66
67 const auto postgemm_call = [&](int i) {
68 const int n_elem = block_step;
69 PRAGMA_OMP_SIMD()
70 for (int j = 0; j < n_elem; j++) {
71 const auto G0 // default func1 is sigmoid
72 = func1(scales,
73 acc_to_float(scratch_gates(i, 0, j), 0, j)
74 + bias(0, j));
75 const auto G1 // default func1 is sigmoid
76 = func1(scales_G1,
77 acc_to_float(scratch_gates(i, 1, j), 1, j)
78 + bias(1, j));
79 /* TODO: Can be optimized for fwd_training by using ws_gates instead of scratch_gates in p2 */
80 scratch_gates(i, 0, j) = reinterpret_as_acc(G0);
81 const auto t = to_src(src_to_float(src_iter(i, j)) * G1);
82 if (dst_layer_) dst_layer(i, j) = t;
83 if (dst_iter_) dst_iter(i, j) = t;
84
85 if (rnn.is_training) {
86 ws_gates(i, 0, j) = to_src(G0);
87 ws_gates(i, 1, j) = to_src(G1);
88 }
89 }
90 };
91
92 if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
93 for (int i = 0; i < rnn.m_block; i++)
94 postgemm_call(i);
95 } else {
96 parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); });
97 }
98}
99
100template <typename T1, typename T2, typename T3, typename T4, typename T5,
101 typename src_data_t, typename scratch_data_t>
102void gru_fwd_part2_postgemm_template(T1 func1, T2 to_src, T3 acc_to_float,
103 T4 src_to_float, T5 reinterpret_as_float, const float *scales,
104 const rnn_utils::rnn_conf_t &rnn,
105 rnn_utils::cell_position_t cell_position, src_data_t *ws_gates_,
106 scratch_data_t *scratch_gates_, const src_data_t *augru_attention_,
107 src_data_t *dst_layer_, src_data_t *dst_iter_,
108 const src_data_t *src_iter_, const void *bias_, int block_step) {
109 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
110 const scratch_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
111 const auto bias_aoc = rnn_utils::make_raw_aoc(
112 bias_, types::data_type_size(rnn.bias_dt), rnn.n_bias, rnn.dhc);
113 const auto bias = [&](int gate_id, int dhc_id) {
114 return to_float(bias_aoc(gate_id, dhc_id), rnn.bias_dt);
115 };
116
117 const auto dst_layer_ld = rnn.dst_layer_ld(cell_position);
118 const auto dst_iter_ld = rnn.dst_iter_ld(cell_position);
119 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
120 const augru_attention_aoc<const src_data_t> augru_attention(
121 rnn, augru_attention_);
122 const ws_states_layer_aoc<src_data_t> dst_layer(
123 rnn, dst_layer_, dst_layer_ld);
124 const ws_states_iter_aoc<src_data_t> dst_iter(rnn, dst_iter_, dst_iter_ld);
125 const ws_states_iter_aoc<const src_data_t> src_iter(
126 rnn, src_iter_, src_iter_ld);
127
128 const float *scales_G2 = scales ? scales + 2 : nullptr;
129
130 const auto postgemm_call = [&](int i) {
131 const int n_elem = block_step;
132 PRAGMA_OMP_SIMD()
133 for (int j = 0; j < n_elem; j++) {
134 auto G0 = reinterpret_as_float(scratch_gates(i, 0, j));
135 const auto G2 // default func1 is tanh
136 = func1(scales_G2,
137 acc_to_float(scratch_gates(i, 2, j), 2, j)
138 + bias(2, j));
139
140 if (rnn.is_augru) {
141 const auto a = reinterpret_as_float(augru_attention(i));
142 G0 = (1.0f - a) * G0;
143 }
144
145 const auto tmp = to_src(
146 src_to_float(src_iter(i, j)) * G0 + (1.0f - G0) * G2);
147 if (dst_layer_ != nullptr) dst_layer(i, j) = tmp;
148 if (dst_iter_ != nullptr) dst_iter(i, j) = tmp;
149
150 if (rnn.is_training) { ws_gates(i, 2, j) = to_src(G2); }
151 }
152 };
153
154 if (rnn.is_brgemm && !rnn.unfused_post_gemm) {
155 for (int i = 0; i < rnn.m_block; i++)
156 postgemm_call(i);
157 } else {
158 parallel_nd(rnn.mb, [&](dim_t i) { postgemm_call(i); });
159 }
160}
161
162template <>
163rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part1_postgemm) {
164 const float *scales = pd_->attr()->rnn_tparams_.scales_;
165 const auto linear_f
166 = [](const float *scale, float a) { return *scale * a; };
167 const auto logistic_f = [](const float *scale, float a) {
168 return logistic_fwd<float>(a);
169 };
170
171 const auto deq_id = [](float f, int i, int j) { return f; };
172 const auto id = [](float f) { return f; };
173
174 if (!pd_->attr()->rnn_tparams_.test_mode_)
175 gru_fwd_part1_postgemm_template(logistic_f, id, deq_id, id, id, scales,
176 rnn, cell_position, ws_gates_, scratch_gates_, augru_attention_,
177 dst_layer_, dst_iter_, src_iter_, bias_, block_step);
178 else
179 gru_fwd_part1_postgemm_template(linear_f, id, deq_id, id, id, scales,
180 rnn, cell_position, ws_gates_, scratch_gates_, augru_attention_,
181 dst_layer_, dst_iter_, src_iter_, bias_, block_step);
182}
183
184template <>
185rnn_postgemm_sig(rnn_postgemm_fwd_f32_t::gru_part2_postgemm) {
186 const float *scales = pd_->attr()->rnn_tparams_.scales_;
187 const auto linear_f
188 = [](const float *scale, float a) { return *scale * a; };
189 const auto tanh_f
190 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
191
192 const auto deq_id = [](float f, int i, int j) { return f; };
193 const auto id = [](float f) { return f; };
194
195 if (!pd_->attr()->rnn_tparams_.test_mode_)
196 gru_fwd_part2_postgemm_template(tanh_f, id, deq_id, id, id, scales, rnn,
197 cell_position, ws_gates_, scratch_gates_, augru_attention_,
198 dst_layer_, dst_iter_, src_iter_, bias_, block_step);
199 else
200 gru_fwd_part2_postgemm_template(linear_f, id, deq_id, id, id, scales,
201 rnn, cell_position, ws_gates_, scratch_gates_, augru_attention_,
202 dst_layer_, dst_iter_, src_iter_, bias_, block_step);
203}
204
205template <>
206rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part1_postgemm) {
207 const float *scales = pd_->attr()->rnn_tparams_.scales_;
208 const auto linear_f
209 = [](const float *scale, float a) { return *scale * a; };
210 const auto logistic_f = [](const float *scale, float a) {
211 return logistic_fwd<float>(a);
212 };
213
214 const auto dn_cvt_f32_bf16 = [](float f) { return bfloat16_t(f); };
215 const auto up_cvt_bf16_f32 = [](bfloat16_t b) { return float(b); };
216 const auto deq_id = [](float f, int i, int j) { return f; };
217 const auto id = [](float f) { return f; };
218
219 if (!pd_->attr()->rnn_tparams_.test_mode_)
220 gru_fwd_part1_postgemm_template(logistic_f, dn_cvt_f32_bf16, deq_id,
221 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
222 scratch_gates_, augru_attention_, dst_layer_, dst_iter_,
223 src_iter_, bias_, block_step);
224 else
225 gru_fwd_part1_postgemm_template(linear_f, dn_cvt_f32_bf16, deq_id,
226 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
227 scratch_gates_, augru_attention_, dst_layer_, dst_iter_,
228 src_iter_, bias_, block_step);
229}
230template <>
231rnn_postgemm_sig(rnn_postgemm_fwd_bf16_t::gru_part2_postgemm) {
232 const float *scales = pd_->attr()->rnn_tparams_.scales_;
233 const auto linear_f
234 = [](const float *scale, float a) { return *scale * a; };
235 const auto tanh_f
236 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
237
238 const auto dn_cvt_f32_bf16 = [](float f) { return bfloat16_t(f); };
239 const auto up_cvt_bf16_f32 = [](bfloat16_t b) { return float(b); };
240 const auto deq_id = [](float f, int i, int j) { return f; };
241 const auto id = [](float f) { return f; };
242
243 if (!pd_->attr()->rnn_tparams_.test_mode_)
244 gru_fwd_part2_postgemm_template(tanh_f, dn_cvt_f32_bf16, deq_id,
245 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
246 scratch_gates_, augru_attention_, dst_layer_, dst_iter_,
247 src_iter_, bias_, block_step);
248 else
249 gru_fwd_part2_postgemm_template(linear_f, dn_cvt_f32_bf16, deq_id,
250 up_cvt_bf16_f32, id, scales, rnn, cell_position, ws_gates_,
251 scratch_gates_, augru_attention_, dst_layer_, dst_iter_,
252 src_iter_, bias_, block_step);
253}
254
255template <>
256rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part1_postgemm) {
257 const float *scales = pd_->attr()->rnn_tparams_.scales_;
258 const auto linear_f
259 = [](const float *scale, float a) { return *scale * a; };
260 const auto logistic_f = [](const float *scale, float a) {
261 return logistic_fwd<float>(a);
262 };
263
264 const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
265 const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
266
267 const auto quantize_f32_u8 = [&](float f) {
268 float qf = f * data_scale + data_shift;
269 qf = nstl::min(qf, 255.0f);
270 qf = nstl::max(qf, 0.0f);
271 return (dst_layer_t)mxcsr_cvt(qf);
272 };
273
274 const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) {
275 const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0
276 ? weights_scales_[0]
277 : weights_scales_[gate * rnn.dhc + j];
278 return saturate<float>(s) * (1.f / (wscale * data_scale));
279 };
280
281 const auto dequantize_u8_f32 = [&](src_iter_t s) {
282 return (static_cast<float>(s) - data_shift) * (1.f / data_scale);
283 };
284
285 const auto reinterpret_f32_s32
286 = [](float a) { return bit_cast<gemm_acc_t>(a); };
287
288 if (!pd_->attr()->rnn_tparams_.test_mode_)
289 gru_fwd_part1_postgemm_template(logistic_f, quantize_f32_u8,
290 dequantize_s32_f32, dequantize_u8_f32, reinterpret_f32_s32,
291 scales, rnn, cell_position, ws_gates_, scratch_gates_,
292 augru_attention_, dst_layer_, dst_iter_, src_iter_, bias_,
293 block_step);
294 else
295 gru_fwd_part1_postgemm_template(linear_f, quantize_f32_u8,
296 dequantize_s32_f32, dequantize_u8_f32, reinterpret_f32_s32,
297 scales, rnn, cell_position, ws_gates_, scratch_gates_,
298 augru_attention_, dst_layer_, dst_iter_, src_iter_, bias_,
299 block_step);
300}
301
302template <>
303rnn_postgemm_sig(rnn_postgemm_fwd_u8_t::gru_part2_postgemm) {
304 const float *scales = pd_->attr()->rnn_tparams_.scales_;
305 const auto linear_f
306 = [](const float *scale, float a) { return *scale * a; };
307 const auto tanh_f
308 = [](const float *scale, float a) { return tanh_fwd<float>(a); };
309
310 const float data_shift = pd_->attr()->rnn_data_qparams_.shift_;
311 const float data_scale = pd_->attr()->rnn_data_qparams_.scale_;
312
313 const auto quantize_f32_u8 = [&](float f) {
314 float qf = f * data_scale + data_shift;
315 qf = nstl::min(qf, 255.0f);
316 qf = nstl::max(qf, 0.0f);
317 return (dst_layer_t)mxcsr_cvt(qf);
318 };
319
320 const auto dequantize_s32_f32 = [&](gemm_acc_t s, int gate, int j) {
321 const float wscale = pd_->attr()->rnn_weights_qparams_.mask_ == 0
322 ? weights_scales_[0]
323 : weights_scales_[gate * rnn.dhc + j];
324 return saturate<float>(s) * (1.f / (wscale * data_scale));
325 };
326
327 const auto dequantize_u8_f32 = [&](src_iter_t s) {
328 return (static_cast<float>(s) - data_shift) * (1.f / data_scale);
329 };
330
331 const auto reinterpret_s32_f32
332 = [](gemm_acc_t a) { return bit_cast<float>(a); };
333
334 if (!pd_->attr()->rnn_tparams_.test_mode_)
335 gru_fwd_part2_postgemm_template(tanh_f, quantize_f32_u8,
336 dequantize_s32_f32, dequantize_u8_f32, reinterpret_s32_f32,
337 scales, rnn, cell_position, ws_gates_, scratch_gates_,
338 augru_attention_, dst_layer_, dst_iter_, src_iter_, bias_,
339 block_step);
340 else
341 gru_fwd_part2_postgemm_template(linear_f, quantize_f32_u8,
342 dequantize_s32_f32, dequantize_u8_f32, reinterpret_s32_f32,
343 scales, rnn, cell_position, ws_gates_, scratch_gates_,
344 augru_attention_, dst_layer_, dst_iter_, src_iter_, bias_,
345 block_step);
346}
347
348template <>
349rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part1_postgemm) {
350 assert(!"GRU signed int8 is not supported");
351}
352
353template <>
354rnn_postgemm_sig(rnn_postgemm_fwd_s8_t::gru_part2_postgemm) {
355 assert(!"GRU signed int8 is not supported");
356}
357
358template <typename T, typename src_data_t, typename acc_data_t,
359 typename scratch_data_t>
360void gru_bwd_part1_postgemm_template(T to_src, const rnn_utils::rnn_conf_t &rnn,
361 cell_position_t cell_position, src_data_t *ws_gates_,
362 scratch_data_t *scratch_gates_, const src_data_t *augru_attention_,
363 src_data_t *dst_layer_, const src_data_t *src_iter_,
364 acc_data_t *diff_src_iter_, acc_data_t *diff_dst_iter_,
365 acc_data_t *diff_augru_attention_, acc_data_t *diff_dst_layer_) {
366 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
367
368 const augru_attention_aoc<const src_data_t> augru_attention(
369 rnn, augru_attention_);
370 const augru_attention_aoc<acc_data_t> diff_augru_attention(
371 rnn, diff_augru_attention_);
372
373 const ws_states_iter_aoc<const src_data_t> src_iter(
374 rnn, src_iter_, src_iter_ld);
375 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
376 const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
377 const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter(
378 rnn, diff_src_iter_);
379 const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
380 rnn, diff_dst_iter_);
381 const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
382 rnn, diff_dst_layer_);
383
384 // dG2^ = dh * (1 - G0) * (1 - G2^2)
385 // dG0^ = dh * (ht-1 - G2) * u * (1 - G0)
386 // dht-1 (part) = dh * G0
387 parallel_nd(rnn.mb, [&](dim_t i) {
388 acc_data_t diff_attention = 0.0f;
389 PRAGMA_OMP_SIMD(reduction(+ : diff_attention))
390 for (int j = 0; j < rnn.dhc; j++) {
391 const float h = src_iter(i, j);
392 const float dHt = diff_dst_iter(i, j) + diff_dst_layer(i, j);
393 const float dG2 = (1.0f - ws_gates(i, 0, j)) * dHt
394 * one_m_square(ws_gates(i, 2, j));
395 float dG0 = (h - ws_gates(i, 2, j)) * dHt
396 * x_m_square(ws_gates(i, 0, j));
397
398 if (rnn.is_augru) {
399 diff_attention -= dG0 * ws_gates(i, 0, j);
400 dG0 *= 1.0f - augru_attention(i);
401 }
402
403 diff_src_iter(i, j) = dHt * ws_gates(i, 0, j);
404 scratch_gates(i, 0, j) = to_src(dG0);
405 scratch_gates(i, 2, j) = to_src(dG2);
406 }
407 if (rnn.is_augru) diff_augru_attention(i) = diff_attention;
408 });
409}
410
411template <typename T, typename src_data_t, typename acc_data_t,
412 typename scratch_data_t>
413void gru_bwd_part2_postgemm_template(T to_src, const rnn_utils::rnn_conf_t &rnn,
414 cell_position_t cell_position, src_data_t *ws_gates_,
415 scratch_data_t *scratch_gates_, src_data_t *dst_layer_,
416 const src_data_t *src_iter_, acc_data_t *diff_src_layer_,
417 acc_data_t *diff_src_iter_, acc_data_t *diff_dst_iter_,
418 acc_data_t *diff_dst_layer_, scratch_data_t *scratch_cell_) {
419 const auto src_iter_ld = rnn.src_iter_ld(cell_position);
420 // auto dst_ld = rnn.dst_ld(cell_position);
421 // ws_states_layer_aoc<src_data_t> dst_layer(rnn, dst_layer_, dst_ld);
422 const ws_states_iter_aoc<const src_data_t> src_iter(
423 rnn, src_iter_, src_iter_ld);
424 const ws_gates_aoc<src_data_t> ws_gates(rnn, ws_gates_);
425 const ws_gates_aoc<scratch_data_t> scratch_gates(rnn, scratch_gates_);
426 const ws_diff_states_layer_aoc<acc_data_t> diff_dst_layer(
427 rnn, diff_dst_layer_);
428 const ws_diff_states_iter_aoc<acc_data_t> diff_dst_iter(
429 rnn, diff_dst_iter_);
430
431 const ws_diff_states_layer_aoc<acc_data_t> dhG1(rnn, diff_src_layer_);
432 const ws_diff_states_iter_aoc<acc_data_t> diff_src_iter(
433 rnn, diff_src_iter_);
434 const AOC<scratch_data_t, 2> hG1(
435 scratch_cell_, rnn.ws_states_layer_nld, rnn.ws_states_layer_ld);
436
437 // dG1^ = d(hG1) * h * G1 * (1 - G1)
438 // dht-1 (part) += d(hG1) * G1
439 // h * G1 (required for dWh)
440 parallel_nd(rnn.mb, [&](dim_t i) {
441 PRAGMA_OMP_SIMD()
442 for (int j = 0; j < rnn.dhc; j++) {
443 const float h = src_iter(i, j);
444 const float G1 = ws_gates(i, 1, j);
445 diff_src_iter(i, j) += dhG1(i, j) * G1;
446 scratch_gates(i, 1, j) = to_src(dhG1(i, j) * h * x_m_square(G1));
447 hG1(i, j) = to_src(G1 * h);
448 }
449 });
450}
451
452template <>
453rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part1_postgemm) {
454 const auto to_src = [](float a) { return a; };
455
456 gru_bwd_part1_postgemm_template(to_src, rnn, cell_position, ws_gates_,
457 scratch_gates_, augru_attention_, dst_layer_, src_iter_,
458 diff_src_iter_, diff_dst_iter_, diff_augru_attention_,
459 diff_dst_layer_);
460}
461
462template <>
463rnn_postgemm_sig(rnn_postgemm_bwd_f32_t::gru_part2_postgemm) {
464 const auto to_src = [](float a) { return a; };
465
466 gru_bwd_part2_postgemm_template(to_src, rnn, cell_position, ws_gates_,
467 scratch_gates_, dst_layer_, src_iter_, diff_src_layer_,
468 diff_src_iter_, diff_dst_iter_, diff_dst_layer_, scratch_cell_);
469}
470
471template <>
472rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part1_postgemm) {
473 const auto to_src = [](float a) { return bfloat16_t(a); };
474
475 gru_bwd_part1_postgemm_template(to_src, rnn, cell_position, ws_gates_,
476 scratch_gates_, augru_attention_, dst_layer_, src_iter_,
477 diff_src_iter_, diff_dst_iter_, diff_augru_attention_,
478 diff_dst_layer_);
479}
480
481template <>
482rnn_postgemm_sig(rnn_postgemm_bwd_bf16_t::gru_part2_postgemm) {
483 const auto to_src = [](float a) { return bfloat16_t(a); };
484
485 gru_bwd_part2_postgemm_template(to_src, rnn, cell_position, ws_gates_,
486 scratch_gates_, dst_layer_, src_iter_, diff_src_layer_,
487 diff_src_iter_, diff_dst_iter_, diff_dst_layer_, scratch_cell_);
488}
489
490#undef AOC
491} // namespace cpu
492} // namespace impl
493} // namespace dnnl
494