1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/dnnl_traits.hpp" |
20 | #include "common/math_utils.hpp" |
21 | #include "common/type_helpers.hpp" |
22 | |
23 | #include "cpu/cpu_primitive.hpp" |
24 | #include "cpu/ref_io_helper.hpp" |
25 | #include "cpu/simple_q10n.hpp" |
26 | |
27 | #include "cpu/ref_convolution_int8.hpp" |
28 | #include "cpu/ref_convolution_utils.hpp" |
29 | |
30 | namespace dnnl { |
31 | namespace impl { |
32 | namespace cpu { |
33 | |
34 | namespace { |
35 | void dequantize(float &d, dim_t g, dim_t C, dim_t c, const float *wei_scales, |
36 | bool with_groups, int wei_mask, const float *src_scales) { |
37 | // scale_idx_mult = 1 for per_channel scales and 0, otherwise |
38 | const int wei_scale_idx_mult = wei_mask == (1 << (int)with_groups); |
39 | float scale = 1.0f; |
40 | if (src_scales) scale *= src_scales[0]; |
41 | if (wei_scales) scale *= wei_scales[(g * C + c) * wei_scale_idx_mult]; |
42 | d *= scale; |
43 | } |
44 | |
45 | void quantize(float &d, dim_t g, dim_t C, dim_t c, const float *dst_scales) { |
46 | float scale = 1.0f; |
47 | if (dst_scales) scale *= dst_scales[0]; |
48 | // dst_scale is inverted in DEFINE_ARG_SCALES_BUFFER |
49 | d *= scale; |
50 | } |
51 | } // namespace |
52 | |
53 | status_t ref_convolution_int8_fwd_t::execute_forward( |
54 | const exec_ctx_t &ctx) const { |
55 | status_t status = status::success; |
56 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
57 | auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); |
58 | auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
59 | auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); |
60 | CHECK(status); |
61 | |
62 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
63 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
64 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
65 | |
66 | const int wei_scale_mask |
67 | = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
68 | |
69 | DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); |
70 | DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); |
71 | |
72 | const memory_desc_wrapper src_d(pd()->src_md()); |
73 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
74 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
75 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
76 | |
77 | const bool with_groups = pd()->with_groups(); |
78 | |
79 | const auto G = pd()->G(); |
80 | const auto MB = pd()->MB(); |
81 | const auto OD = pd()->OD(); |
82 | const auto OH = pd()->OH(); |
83 | const auto OW = pd()->OW(); |
84 | const auto ID = pd()->ID(); |
85 | const auto IH = pd()->IH(); |
86 | const auto IW = pd()->IW(); |
87 | |
88 | const auto OC = pd()->OC() / G; |
89 | const auto IC = pd()->IC() / G; |
90 | const auto KD = pd()->KD(); |
91 | const auto KH = pd()->KH(); |
92 | const auto KW = pd()->KW(); |
93 | |
94 | const auto KSD = pd()->KSD(); |
95 | const auto KSH = pd()->KSH(); |
96 | const auto KSW = pd()->KSW(); |
97 | |
98 | const auto KDD = pd()->KDD() + 1; |
99 | const auto KDH = pd()->KDH() + 1; |
100 | const auto KDW = pd()->KDW() + 1; |
101 | |
102 | const auto padFront = pd()->padFront(); |
103 | const auto padT = pd()->padT(); |
104 | const auto padL = pd()->padL(); |
105 | |
106 | const auto ndims = pd()->desc()->src_desc.ndims; |
107 | |
108 | // zp_idx_mult = 1 for per_dim1 zero points and 0, otherwise |
109 | const int src_zp_idx_mult |
110 | = !pd()->attr()->zero_points_.common(DNNL_ARG_SRC); |
111 | const int dst_zp_idx_mult |
112 | = !pd()->attr()->zero_points_.common(DNNL_ARG_DST); |
113 | |
114 | auto ker = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) { |
115 | int d = 0; |
116 | for_(dim_t ic = 0; ic < IC; ++ic) |
117 | for_(dim_t kd = 0; kd < KD; ++kd) |
118 | for_(dim_t kh = 0; kh < KH; ++kh) |
119 | for (dim_t kw = 0; kw < KW; ++kw) { |
120 | const dim_t id = od * KSD - padFront + kd * KDD; |
121 | const dim_t ih = oh * KSH - padT + kh * KDH; |
122 | const dim_t iw = ow * KSW - padL + kw * KDW; |
123 | |
124 | if (id < 0 || id >= ID) continue; |
125 | if (ih < 0 || ih >= IH) continue; |
126 | if (iw < 0 || iw >= IW) continue; |
127 | |
128 | const auto src_off = ref_conv_utils::get_data_off( |
129 | src_d, ndims, mb, g * IC + ic, id, ih, iw); |
130 | const auto wei_off = ref_conv_utils::get_weights_off( |
131 | weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw); |
132 | |
133 | const int s = io::load_int_value(src_d.data_type(), src, src_off); |
134 | const int src_zp = src_zero_point |
135 | ? io::load_int_value(data_type::s32, src_zero_point, |
136 | src_zp_idx_mult * (g * IC + ic)) |
137 | : 0; |
138 | const int w = io::load_int_value( |
139 | weights_d.data_type(), weights, wei_off); |
140 | d += (s - src_zp) * w; |
141 | } |
142 | return d; |
143 | }; |
144 | |
145 | // help compiler optimize the code constants for plain layouts kernel |
146 | const dims_t &src_str = src_d.blocking_desc().strides; |
147 | const dim_t src_ic_stride = src_str[1]; |
148 | const dim_t src_id_stride = (ndims == 5) ? src_str[2] : 0; |
149 | const dim_t src_ih_stride = (ndims >= 4) ? src_str[ndims - 2] : 0; |
150 | const dim_t src_iw_stride = (ndims >= 3) ? src_str[ndims - 1] : 0; |
151 | const dims_t &weights_str = weights_d.blocking_desc().strides; |
152 | const int gr_shift = with_groups ? 1 : 0; |
153 | const dim_t weights_ic_stride = weights_str[1 + gr_shift]; |
154 | const dim_t weights_kd_stride |
155 | = (ndims == 5) ? weights_str[2 + gr_shift] : 0; |
156 | const dim_t weights_kh_stride |
157 | = (ndims >= 4) ? weights_str[ndims - 2 + gr_shift] : 0; |
158 | const dim_t weights_kw_stride |
159 | = (ndims >= 3) ? weights_str[ndims - 1 + gr_shift] : 0; |
160 | |
161 | auto ker_plain = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, |
162 | dim_t ow) { |
163 | assert(3 <= ndims && ndims <= 5); |
164 | int d = 0; |
165 | |
166 | const dim_t src_loc_off = ref_conv_utils::get_data_off( |
167 | src_d, ndims, mb, g * IC, 0, 0, 0); |
168 | const dim_t weights_loc_off = ref_conv_utils::get_weights_off( |
169 | weights_d, with_groups, ndims, g, oc, 0, 0, 0, 0); |
170 | |
171 | const void *__restrict src_loc = src; |
172 | const void *__restrict weights_loc = weights; |
173 | |
174 | if (IC > KW) { |
175 | for_(dim_t kd = 0; kd < KD; ++kd) |
176 | for_(dim_t kh = 0; kh < KH; ++kh) |
177 | for (dim_t kw = 0; kw < KW; ++kw) { |
178 | const dim_t id = od * KSD - padFront + kd * KDD; |
179 | const dim_t ih = oh * KSH - padT + kh * KDH; |
180 | const dim_t iw = ow * KSW - padL + kw * KDW; |
181 | if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 |
182 | || iw >= IW) |
183 | continue; |
184 | |
185 | for (dim_t ic = 0; ic < IC; ++ic) { |
186 | const dim_t src_off = ic + id * src_id_stride |
187 | + ih * src_ih_stride + iw * src_iw_stride; |
188 | const dim_t weights_off = ic * weights_ic_stride |
189 | + kd * weights_kd_stride + kh * weights_kh_stride |
190 | + kw; |
191 | const int s = io::load_int_value( |
192 | src_d.data_type(), src_loc, src_off + src_loc_off); |
193 | const int src_zp = src_zero_point |
194 | ? io::load_int_value(data_type::s32, src_zero_point, |
195 | src_zp_idx_mult * (g * IC + ic)) |
196 | : 0; |
197 | const int w = io::load_int_value(weights_d.data_type(), |
198 | weights_loc, weights_off + weights_loc_off); |
199 | d += (s - src_zp) * w; |
200 | } |
201 | } |
202 | } else { |
203 | for_(dim_t ic = 0; ic < IC; ++ic) |
204 | for_(dim_t kd = 0; kd < KD; ++kd) |
205 | for_(dim_t kh = 0; kh < KH; ++kh) |
206 | for (dim_t kw = 0; kw < KW; ++kw) { |
207 | const dim_t id = od * KSD - padFront + kd * KDD; |
208 | const dim_t ih = oh * KSH - padT + kh * KDH; |
209 | const dim_t iw = ow * KSW - padL + kw * KDW; |
210 | if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0 |
211 | || iw >= IW) |
212 | continue; |
213 | |
214 | const dim_t src_off = ic + id * src_id_stride |
215 | + ih * src_ih_stride + iw * src_iw_stride; |
216 | const dim_t weights_off = ic * weights_ic_stride |
217 | + kd * weights_kd_stride + kh * weights_kh_stride + kw; |
218 | const int s = io::load_int_value( |
219 | src_d.data_type(), src_loc, src_off + src_loc_off); |
220 | const int src_zp = src_zero_point |
221 | ? io::load_int_value(data_type::s32, src_zero_point, |
222 | src_zp_idx_mult * (g * IC + ic)) |
223 | : 0; |
224 | const int w = io::load_int_value(weights_d.data_type(), |
225 | weights_loc, weights_off + weights_loc_off); |
226 | d += (s - src_zp) * w; |
227 | } |
228 | } |
229 | return d; |
230 | }; |
231 | |
232 | const auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type()); |
233 | |
234 | parallel_nd(G, MB, OC, OD, OH, OW, |
235 | [&](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) { |
236 | int acc = 0; |
237 | if (src_d.is_plain() && weights_d.is_plain() |
238 | && src_ic_stride == 1 && weights_kw_stride == 1) |
239 | acc += ker_plain(g, mb, oc, od, oh, ow); |
240 | else |
241 | acc += ker(g, mb, oc, od, oh, ow); |
242 | |
243 | float d = static_cast<float>(acc); |
244 | |
245 | dequantize(d, g, OC, oc, wei_scales, with_groups, |
246 | wei_scale_mask, src_scales); |
247 | |
248 | if (bias) { |
249 | const auto bias_off = bias_d.off(g * OC + oc); |
250 | const float b = io::load_float_value( |
251 | bias_d.data_type(), bias, bias_off); |
252 | d += b; |
253 | } |
254 | |
255 | dim_t dst_off = ref_conv_utils::get_data_off( |
256 | dst_d, ndims, mb, g * OC + oc, od, oh, ow); |
257 | |
258 | dim_t dst_l_off = (mb * OC * G + g * OC + oc) * OD * OH * OW |
259 | + od * OH * OW + oh * OW + ow; |
260 | |
261 | ref_post_ops_t::args_t args; |
262 | args.dst_val = io::load_float_value(sum_dt, dst, dst_off); |
263 | args.ctx = &ctx; |
264 | args.l_offset = dst_l_off; |
265 | args.dst_md = pd()->dst_md(); |
266 | ref_post_ops->execute(d, args); |
267 | |
268 | if (dst_scales) quantize(d, g, OC, oc, dst_scales); |
269 | |
270 | if (dst_zero_point) { |
271 | const int dst_zp = io::load_int_value(data_type::s32, |
272 | dst_zero_point, dst_zp_idx_mult * (g * OC + oc)); |
273 | d += dst_zp; |
274 | } |
275 | io::store_float_value(dst_d.data_type(), d, dst, dst_off); |
276 | }); |
277 | |
278 | return status::success; |
279 | } |
280 | |
281 | status_t ref_convolution_int8_bwd_data_t::execute_backward_data( |
282 | const exec_ctx_t &ctx) const { |
283 | status_t status = status::success; |
284 | auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST); |
285 | auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); |
286 | auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status); |
287 | CHECK(status); |
288 | |
289 | DEFINE_ARG_SCALES_BUFFER(diff_src_scales, DNNL_ARG_SRC); |
290 | DEFINE_ARG_SCALES_BUFFER(diff_wei_scales, DNNL_ARG_WEIGHTS); |
291 | DEFINE_ARG_SCALES_BUFFER(diff_dst_scales, DNNL_ARG_DST); |
292 | |
293 | const int diff_wei_scale_mask |
294 | = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_; |
295 | |
296 | const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md()); |
297 | const memory_desc_wrapper diff_src_d(pd()->diff_src_md()); |
298 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
299 | |
300 | const bool with_groups = pd()->with_groups(); |
301 | |
302 | const auto G = pd()->G(); |
303 | const auto MB = pd()->MB(); |
304 | const auto OD = pd()->OD(); |
305 | const auto OH = pd()->OH(); |
306 | const auto OW = pd()->OW(); |
307 | const auto ID = pd()->ID(); |
308 | const auto IH = pd()->IH(); |
309 | const auto IW = pd()->IW(); |
310 | |
311 | const auto OC = pd()->OC() / G; |
312 | const auto IC = pd()->IC() / G; |
313 | const auto KD = pd()->KD(); |
314 | const auto KH = pd()->KH(); |
315 | const auto KW = pd()->KW(); |
316 | |
317 | const auto KSD = pd()->KSD(); |
318 | const auto KSH = pd()->KSH(); |
319 | const auto KSW = pd()->KSW(); |
320 | |
321 | const auto KDD = pd()->KDD() + 1; |
322 | const auto KDH = pd()->KDH() + 1; |
323 | const auto KDW = pd()->KDW() + 1; |
324 | |
325 | const auto padFront = pd()->padFront(); |
326 | const auto padT = pd()->padT(); |
327 | const auto padL = pd()->padL(); |
328 | |
329 | const auto ndims = pd()->desc()->diff_src_desc.ndims; |
330 | |
331 | auto ker = [=](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) { |
332 | int ds = 0; |
333 | for_(dim_t oc = 0; oc < OC; ++oc) |
334 | for_(dim_t kd = 0; kd < KD; ++kd) |
335 | for_(dim_t kh = 0; kh < KH; ++kh) |
336 | for (dim_t kw = 0; kw < KW; ++kw) { |
337 | if (iw + padL < kw * KDW || ih + padT < kh * KDH |
338 | || id + padFront < kd * KDD) |
339 | continue; |
340 | dim_t ow = iw - kw * KDW + padL; |
341 | dim_t oh = ih - kh * KDH + padT; |
342 | dim_t od = id - kd * KDD + padFront; |
343 | if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) continue; |
344 | |
345 | ow /= KSW; |
346 | oh /= KSH; |
347 | od /= KSD; |
348 | |
349 | if (od < OD && oh < OH && ow < OW) { |
350 | const auto diff_dst_off = ref_conv_utils::get_data_off( |
351 | diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow); |
352 | const auto weights_off = ref_conv_utils::get_weights_off( |
353 | weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw); |
354 | const int dd = io::load_int_value( |
355 | diff_dst_d.data_type(), diff_dst, diff_dst_off); |
356 | const int w = io::load_int_value( |
357 | weights_d.data_type(), weights, weights_off); |
358 | ds += dd * w; |
359 | } |
360 | } |
361 | return ds; |
362 | }; |
363 | |
364 | // help compiler optimize the code constants for plain layouts kernel |
365 | const dims_t &diff_dst_str = diff_dst_d.blocking_desc().strides; |
366 | const dim_t diff_dst_oc_stride = diff_dst_str[1]; |
367 | const dim_t diff_dst_ow_stride = diff_dst_str[ndims - 1]; |
368 | const dim_t diff_dst_oh_stride = (ndims >= 4) ? diff_dst_str[ndims - 2] : 0; |
369 | const dim_t diff_dst_od_stride = (ndims >= 5) ? diff_dst_str[ndims - 3] : 0; |
370 | |
371 | const dims_t &weights_str = weights_d.blocking_desc().strides; |
372 | const int gr_shift = with_groups ? 1 : 0; |
373 | const dim_t weights_oc_stride = weights_str[0 + gr_shift]; |
374 | const dim_t weights_kw_stride = weights_str[ndims - 1 + gr_shift]; |
375 | const dim_t weights_kh_stride |
376 | = (ndims >= 4) ? weights_str[ndims - 2 + gr_shift] : 0; |
377 | const dim_t weights_kd_stride |
378 | = (ndims >= 5) ? weights_str[ndims - 3 + gr_shift] : 0; |
379 | |
380 | auto ker_plain = [=](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, |
381 | dim_t iw) { |
382 | assert(3 <= ndims && ndims <= 5); |
383 | int ds = 0; |
384 | const dim_t diff_dst_loc_off = ref_conv_utils::get_data_off( |
385 | diff_dst_d, ndims, mb, g * OC, 0, 0, 0); |
386 | const dim_t weights_loc_off = ref_conv_utils::get_weights_off( |
387 | weights_d, with_groups, ndims, g, 0, ic, 0, 0, 0); |
388 | |
389 | const void *__restrict diff_dst_loc = diff_dst; |
390 | const void *__restrict weights_loc = weights; |
391 | |
392 | if (OC > KW) { |
393 | for_(dim_t kd = 0; kd < KD; ++kd) |
394 | for_(dim_t kh = 0; kh < KH; ++kh) |
395 | for (dim_t kw = 0; kw < KW; ++kw) { |
396 | dim_t ow = iw - kw * KDW + padL; |
397 | dim_t oh = ih - kh * KDH + padT; |
398 | dim_t od = id - kd * KDD + padFront; |
399 | if (ow < 0 || oh < 0 || od < 0 || ow % KSW != 0 || oh % KSH != 0 |
400 | || od % KSD != 0) |
401 | continue; |
402 | ow /= KSW; |
403 | oh /= KSH; |
404 | od /= KSD; |
405 | if (od >= OD || oh >= OH || ow >= OW) continue; |
406 | for (dim_t oc = 0; oc < OC; ++oc) { |
407 | const dim_t diff_dst_off = oc + od * diff_dst_od_stride |
408 | + oh * diff_dst_oh_stride + ow * diff_dst_ow_stride; |
409 | const dim_t weights_off = oc * weights_oc_stride |
410 | + kd * weights_kd_stride + kh * weights_kh_stride |
411 | + kw; |
412 | const int dd = io::load_int_value(diff_dst_d.data_type(), |
413 | diff_dst_loc, diff_dst_off + diff_dst_loc_off); |
414 | const int w = io::load_int_value(weights_d.data_type(), |
415 | weights_loc, weights_off + weights_loc_off); |
416 | ds += dd * w; |
417 | } |
418 | } |
419 | } else { |
420 | for_(dim_t oc = 0; oc < OC; ++oc) |
421 | for_(dim_t kd = 0; kd < KD; ++kd) |
422 | for (dim_t kh = 0; kh < KH; ++kh) { |
423 | // Note: placing these 2 params outside the `kw-loop` because |
424 | // of a compiler-generated bug. Declaring 'od' as volatile |
425 | // fixes a recurring seg-fault. |
426 | const volatile dim_t od_ = id - kd * KDD + padFront; |
427 | const dim_t weights_off_ = oc * weights_oc_stride |
428 | + kd * weights_kd_stride + kh * weights_kh_stride; |
429 | for (dim_t kw = 0; kw < KW; ++kw) { |
430 | dim_t ow = iw - kw * KDW + padL; |
431 | dim_t oh = ih - kh * KDH + padT; |
432 | dim_t od = od_; |
433 | if (ow < 0 || oh < 0 || od < 0 || ow % KSW != 0 |
434 | || oh % KSH != 0 || od % KSD != 0) |
435 | continue; |
436 | ow /= KSW; |
437 | oh /= KSH; |
438 | od /= KSD; |
439 | if (od >= OD || oh >= OH || ow >= OW) continue; |
440 | const dim_t diff_dst_off = oc + od * diff_dst_od_stride |
441 | + oh * diff_dst_oh_stride + ow * diff_dst_ow_stride; |
442 | const dim_t weights_off = weights_off_ + kw; |
443 | const int dd = io::load_int_value(diff_dst_d.data_type(), |
444 | diff_dst_loc, diff_dst_off + diff_dst_loc_off); |
445 | const int w = io::load_int_value(weights_d.data_type(), |
446 | weights_loc, weights_off + weights_loc_off); |
447 | ds += dd * w; |
448 | } |
449 | } |
450 | } |
451 | return ds; |
452 | }; |
453 | |
454 | parallel_nd(G, MB, IC, ID, IH, IW, |
455 | [&](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) { |
456 | int acc = 0; |
457 | if (diff_dst_d.is_plain() && weights_d.is_plain() |
458 | && diff_dst_oc_stride == 1 && weights_kw_stride == 1) |
459 | acc += ker_plain(g, mb, ic, id, ih, iw); |
460 | else |
461 | acc += ker(g, mb, ic, id, ih, iw); |
462 | |
463 | float ds = static_cast<float>(acc); |
464 | dequantize(ds, g, IC, ic, diff_wei_scales, with_groups, |
465 | diff_wei_scale_mask, diff_dst_scales); |
466 | quantize(ds, g, IC, ic, diff_src_scales); |
467 | |
468 | const auto diff_src_off = ref_conv_utils::get_data_off( |
469 | diff_src_d, ndims, mb, g * IC + ic, id, ih, iw); |
470 | io::store_float_value( |
471 | diff_src_d.data_type(), ds, diff_src, diff_src_off); |
472 | }); |
473 | |
474 | return status::success; |
475 | } |
476 | |
477 | } // namespace cpu |
478 | } // namespace impl |
479 | } // namespace dnnl |
480 | |
481 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
482 | |