1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/dnnl_traits.hpp"
20#include "common/math_utils.hpp"
21#include "common/type_helpers.hpp"
22
23#include "cpu/cpu_primitive.hpp"
24#include "cpu/ref_io_helper.hpp"
25#include "cpu/simple_q10n.hpp"
26
27#include "cpu/ref_convolution_int8.hpp"
28#include "cpu/ref_convolution_utils.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33
34namespace {
35void dequantize(float &d, dim_t g, dim_t C, dim_t c, const float *wei_scales,
36 bool with_groups, int wei_mask, const float *src_scales) {
37 // scale_idx_mult = 1 for per_channel scales and 0, otherwise
38 const int wei_scale_idx_mult = wei_mask == (1 << (int)with_groups);
39 float scale = 1.0f;
40 if (src_scales) scale *= src_scales[0];
41 if (wei_scales) scale *= wei_scales[(g * C + c) * wei_scale_idx_mult];
42 d *= scale;
43}
44
45void quantize(float &d, dim_t g, dim_t C, dim_t c, const float *dst_scales) {
46 float scale = 1.0f;
47 if (dst_scales) scale *= dst_scales[0];
48 // dst_scale is inverted in DEFINE_ARG_SCALES_BUFFER
49 d *= scale;
50}
51} // namespace
52
53status_t ref_convolution_int8_fwd_t::execute_forward(
54 const exec_ctx_t &ctx) const {
55 status_t status = status::success;
56 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
57 auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
58 auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
59 auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
60 CHECK(status);
61
62 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
63 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
64 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
65
66 const int wei_scale_mask
67 = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
68
69 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
70 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
71
72 const memory_desc_wrapper src_d(pd()->src_md());
73 const memory_desc_wrapper dst_d(pd()->dst_md());
74 const memory_desc_wrapper weights_d(pd()->weights_md(0));
75 const memory_desc_wrapper bias_d(pd()->weights_md(1));
76
77 const bool with_groups = pd()->with_groups();
78
79 const auto G = pd()->G();
80 const auto MB = pd()->MB();
81 const auto OD = pd()->OD();
82 const auto OH = pd()->OH();
83 const auto OW = pd()->OW();
84 const auto ID = pd()->ID();
85 const auto IH = pd()->IH();
86 const auto IW = pd()->IW();
87
88 const auto OC = pd()->OC() / G;
89 const auto IC = pd()->IC() / G;
90 const auto KD = pd()->KD();
91 const auto KH = pd()->KH();
92 const auto KW = pd()->KW();
93
94 const auto KSD = pd()->KSD();
95 const auto KSH = pd()->KSH();
96 const auto KSW = pd()->KSW();
97
98 const auto KDD = pd()->KDD() + 1;
99 const auto KDH = pd()->KDH() + 1;
100 const auto KDW = pd()->KDW() + 1;
101
102 const auto padFront = pd()->padFront();
103 const auto padT = pd()->padT();
104 const auto padL = pd()->padL();
105
106 const auto ndims = pd()->desc()->src_desc.ndims;
107
108 // zp_idx_mult = 1 for per_dim1 zero points and 0, otherwise
109 const int src_zp_idx_mult
110 = !pd()->attr()->zero_points_.common(DNNL_ARG_SRC);
111 const int dst_zp_idx_mult
112 = !pd()->attr()->zero_points_.common(DNNL_ARG_DST);
113
114 auto ker = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
115 int d = 0;
116 for_(dim_t ic = 0; ic < IC; ++ic)
117 for_(dim_t kd = 0; kd < KD; ++kd)
118 for_(dim_t kh = 0; kh < KH; ++kh)
119 for (dim_t kw = 0; kw < KW; ++kw) {
120 const dim_t id = od * KSD - padFront + kd * KDD;
121 const dim_t ih = oh * KSH - padT + kh * KDH;
122 const dim_t iw = ow * KSW - padL + kw * KDW;
123
124 if (id < 0 || id >= ID) continue;
125 if (ih < 0 || ih >= IH) continue;
126 if (iw < 0 || iw >= IW) continue;
127
128 const auto src_off = ref_conv_utils::get_data_off(
129 src_d, ndims, mb, g * IC + ic, id, ih, iw);
130 const auto wei_off = ref_conv_utils::get_weights_off(
131 weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
132
133 const int s = io::load_int_value(src_d.data_type(), src, src_off);
134 const int src_zp = src_zero_point
135 ? io::load_int_value(data_type::s32, src_zero_point,
136 src_zp_idx_mult * (g * IC + ic))
137 : 0;
138 const int w = io::load_int_value(
139 weights_d.data_type(), weights, wei_off);
140 d += (s - src_zp) * w;
141 }
142 return d;
143 };
144
145 // help compiler optimize the code constants for plain layouts kernel
146 const dims_t &src_str = src_d.blocking_desc().strides;
147 const dim_t src_ic_stride = src_str[1];
148 const dim_t src_id_stride = (ndims == 5) ? src_str[2] : 0;
149 const dim_t src_ih_stride = (ndims >= 4) ? src_str[ndims - 2] : 0;
150 const dim_t src_iw_stride = (ndims >= 3) ? src_str[ndims - 1] : 0;
151 const dims_t &weights_str = weights_d.blocking_desc().strides;
152 const int gr_shift = with_groups ? 1 : 0;
153 const dim_t weights_ic_stride = weights_str[1 + gr_shift];
154 const dim_t weights_kd_stride
155 = (ndims == 5) ? weights_str[2 + gr_shift] : 0;
156 const dim_t weights_kh_stride
157 = (ndims >= 4) ? weights_str[ndims - 2 + gr_shift] : 0;
158 const dim_t weights_kw_stride
159 = (ndims >= 3) ? weights_str[ndims - 1 + gr_shift] : 0;
160
161 auto ker_plain = [=](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh,
162 dim_t ow) {
163 assert(3 <= ndims && ndims <= 5);
164 int d = 0;
165
166 const dim_t src_loc_off = ref_conv_utils::get_data_off(
167 src_d, ndims, mb, g * IC, 0, 0, 0);
168 const dim_t weights_loc_off = ref_conv_utils::get_weights_off(
169 weights_d, with_groups, ndims, g, oc, 0, 0, 0, 0);
170
171 const void *__restrict src_loc = src;
172 const void *__restrict weights_loc = weights;
173
174 if (IC > KW) {
175 for_(dim_t kd = 0; kd < KD; ++kd)
176 for_(dim_t kh = 0; kh < KH; ++kh)
177 for (dim_t kw = 0; kw < KW; ++kw) {
178 const dim_t id = od * KSD - padFront + kd * KDD;
179 const dim_t ih = oh * KSH - padT + kh * KDH;
180 const dim_t iw = ow * KSW - padL + kw * KDW;
181 if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0
182 || iw >= IW)
183 continue;
184
185 for (dim_t ic = 0; ic < IC; ++ic) {
186 const dim_t src_off = ic + id * src_id_stride
187 + ih * src_ih_stride + iw * src_iw_stride;
188 const dim_t weights_off = ic * weights_ic_stride
189 + kd * weights_kd_stride + kh * weights_kh_stride
190 + kw;
191 const int s = io::load_int_value(
192 src_d.data_type(), src_loc, src_off + src_loc_off);
193 const int src_zp = src_zero_point
194 ? io::load_int_value(data_type::s32, src_zero_point,
195 src_zp_idx_mult * (g * IC + ic))
196 : 0;
197 const int w = io::load_int_value(weights_d.data_type(),
198 weights_loc, weights_off + weights_loc_off);
199 d += (s - src_zp) * w;
200 }
201 }
202 } else {
203 for_(dim_t ic = 0; ic < IC; ++ic)
204 for_(dim_t kd = 0; kd < KD; ++kd)
205 for_(dim_t kh = 0; kh < KH; ++kh)
206 for (dim_t kw = 0; kw < KW; ++kw) {
207 const dim_t id = od * KSD - padFront + kd * KDD;
208 const dim_t ih = oh * KSH - padT + kh * KDH;
209 const dim_t iw = ow * KSW - padL + kw * KDW;
210 if (id < 0 || id >= ID || ih < 0 || ih >= IH || iw < 0
211 || iw >= IW)
212 continue;
213
214 const dim_t src_off = ic + id * src_id_stride
215 + ih * src_ih_stride + iw * src_iw_stride;
216 const dim_t weights_off = ic * weights_ic_stride
217 + kd * weights_kd_stride + kh * weights_kh_stride + kw;
218 const int s = io::load_int_value(
219 src_d.data_type(), src_loc, src_off + src_loc_off);
220 const int src_zp = src_zero_point
221 ? io::load_int_value(data_type::s32, src_zero_point,
222 src_zp_idx_mult * (g * IC + ic))
223 : 0;
224 const int w = io::load_int_value(weights_d.data_type(),
225 weights_loc, weights_off + weights_loc_off);
226 d += (s - src_zp) * w;
227 }
228 }
229 return d;
230 };
231
232 const auto sum_dt = pd()->attr()->post_ops_.get_sum_dt(dst_d.data_type());
233
234 parallel_nd(G, MB, OC, OD, OH, OW,
235 [&](dim_t g, dim_t mb, dim_t oc, dim_t od, dim_t oh, dim_t ow) {
236 int acc = 0;
237 if (src_d.is_plain() && weights_d.is_plain()
238 && src_ic_stride == 1 && weights_kw_stride == 1)
239 acc += ker_plain(g, mb, oc, od, oh, ow);
240 else
241 acc += ker(g, mb, oc, od, oh, ow);
242
243 float d = static_cast<float>(acc);
244
245 dequantize(d, g, OC, oc, wei_scales, with_groups,
246 wei_scale_mask, src_scales);
247
248 if (bias) {
249 const auto bias_off = bias_d.off(g * OC + oc);
250 const float b = io::load_float_value(
251 bias_d.data_type(), bias, bias_off);
252 d += b;
253 }
254
255 dim_t dst_off = ref_conv_utils::get_data_off(
256 dst_d, ndims, mb, g * OC + oc, od, oh, ow);
257
258 dim_t dst_l_off = (mb * OC * G + g * OC + oc) * OD * OH * OW
259 + od * OH * OW + oh * OW + ow;
260
261 ref_post_ops_t::args_t args;
262 args.dst_val = io::load_float_value(sum_dt, dst, dst_off);
263 args.ctx = &ctx;
264 args.l_offset = dst_l_off;
265 args.dst_md = pd()->dst_md();
266 ref_post_ops->execute(d, args);
267
268 if (dst_scales) quantize(d, g, OC, oc, dst_scales);
269
270 if (dst_zero_point) {
271 const int dst_zp = io::load_int_value(data_type::s32,
272 dst_zero_point, dst_zp_idx_mult * (g * OC + oc));
273 d += dst_zp;
274 }
275 io::store_float_value(dst_d.data_type(), d, dst, dst_off);
276 });
277
278 return status::success;
279}
280
281status_t ref_convolution_int8_bwd_data_t::execute_backward_data(
282 const exec_ctx_t &ctx) const {
283 status_t status = status::success;
284 auto diff_dst = CTX_IN_MEM(const void *, DNNL_ARG_DIFF_DST);
285 auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
286 auto diff_src = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DIFF_SRC, status);
287 CHECK(status);
288
289 DEFINE_ARG_SCALES_BUFFER(diff_src_scales, DNNL_ARG_SRC);
290 DEFINE_ARG_SCALES_BUFFER(diff_wei_scales, DNNL_ARG_WEIGHTS);
291 DEFINE_ARG_SCALES_BUFFER(diff_dst_scales, DNNL_ARG_DST);
292
293 const int diff_wei_scale_mask
294 = pd()->attr()->scales_.get(DNNL_ARG_WEIGHTS).mask_;
295
296 const memory_desc_wrapper diff_dst_d(pd()->diff_dst_md());
297 const memory_desc_wrapper diff_src_d(pd()->diff_src_md());
298 const memory_desc_wrapper weights_d(pd()->weights_md(0));
299
300 const bool with_groups = pd()->with_groups();
301
302 const auto G = pd()->G();
303 const auto MB = pd()->MB();
304 const auto OD = pd()->OD();
305 const auto OH = pd()->OH();
306 const auto OW = pd()->OW();
307 const auto ID = pd()->ID();
308 const auto IH = pd()->IH();
309 const auto IW = pd()->IW();
310
311 const auto OC = pd()->OC() / G;
312 const auto IC = pd()->IC() / G;
313 const auto KD = pd()->KD();
314 const auto KH = pd()->KH();
315 const auto KW = pd()->KW();
316
317 const auto KSD = pd()->KSD();
318 const auto KSH = pd()->KSH();
319 const auto KSW = pd()->KSW();
320
321 const auto KDD = pd()->KDD() + 1;
322 const auto KDH = pd()->KDH() + 1;
323 const auto KDW = pd()->KDW() + 1;
324
325 const auto padFront = pd()->padFront();
326 const auto padT = pd()->padT();
327 const auto padL = pd()->padL();
328
329 const auto ndims = pd()->desc()->diff_src_desc.ndims;
330
331 auto ker = [=](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) {
332 int ds = 0;
333 for_(dim_t oc = 0; oc < OC; ++oc)
334 for_(dim_t kd = 0; kd < KD; ++kd)
335 for_(dim_t kh = 0; kh < KH; ++kh)
336 for (dim_t kw = 0; kw < KW; ++kw) {
337 if (iw + padL < kw * KDW || ih + padT < kh * KDH
338 || id + padFront < kd * KDD)
339 continue;
340 dim_t ow = iw - kw * KDW + padL;
341 dim_t oh = ih - kh * KDH + padT;
342 dim_t od = id - kd * KDD + padFront;
343 if (ow % KSW != 0 || oh % KSH != 0 || od % KSD != 0) continue;
344
345 ow /= KSW;
346 oh /= KSH;
347 od /= KSD;
348
349 if (od < OD && oh < OH && ow < OW) {
350 const auto diff_dst_off = ref_conv_utils::get_data_off(
351 diff_dst_d, ndims, mb, g * OC + oc, od, oh, ow);
352 const auto weights_off = ref_conv_utils::get_weights_off(
353 weights_d, with_groups, ndims, g, oc, ic, kd, kh, kw);
354 const int dd = io::load_int_value(
355 diff_dst_d.data_type(), diff_dst, diff_dst_off);
356 const int w = io::load_int_value(
357 weights_d.data_type(), weights, weights_off);
358 ds += dd * w;
359 }
360 }
361 return ds;
362 };
363
364 // help compiler optimize the code constants for plain layouts kernel
365 const dims_t &diff_dst_str = diff_dst_d.blocking_desc().strides;
366 const dim_t diff_dst_oc_stride = diff_dst_str[1];
367 const dim_t diff_dst_ow_stride = diff_dst_str[ndims - 1];
368 const dim_t diff_dst_oh_stride = (ndims >= 4) ? diff_dst_str[ndims - 2] : 0;
369 const dim_t diff_dst_od_stride = (ndims >= 5) ? diff_dst_str[ndims - 3] : 0;
370
371 const dims_t &weights_str = weights_d.blocking_desc().strides;
372 const int gr_shift = with_groups ? 1 : 0;
373 const dim_t weights_oc_stride = weights_str[0 + gr_shift];
374 const dim_t weights_kw_stride = weights_str[ndims - 1 + gr_shift];
375 const dim_t weights_kh_stride
376 = (ndims >= 4) ? weights_str[ndims - 2 + gr_shift] : 0;
377 const dim_t weights_kd_stride
378 = (ndims >= 5) ? weights_str[ndims - 3 + gr_shift] : 0;
379
380 auto ker_plain = [=](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih,
381 dim_t iw) {
382 assert(3 <= ndims && ndims <= 5);
383 int ds = 0;
384 const dim_t diff_dst_loc_off = ref_conv_utils::get_data_off(
385 diff_dst_d, ndims, mb, g * OC, 0, 0, 0);
386 const dim_t weights_loc_off = ref_conv_utils::get_weights_off(
387 weights_d, with_groups, ndims, g, 0, ic, 0, 0, 0);
388
389 const void *__restrict diff_dst_loc = diff_dst;
390 const void *__restrict weights_loc = weights;
391
392 if (OC > KW) {
393 for_(dim_t kd = 0; kd < KD; ++kd)
394 for_(dim_t kh = 0; kh < KH; ++kh)
395 for (dim_t kw = 0; kw < KW; ++kw) {
396 dim_t ow = iw - kw * KDW + padL;
397 dim_t oh = ih - kh * KDH + padT;
398 dim_t od = id - kd * KDD + padFront;
399 if (ow < 0 || oh < 0 || od < 0 || ow % KSW != 0 || oh % KSH != 0
400 || od % KSD != 0)
401 continue;
402 ow /= KSW;
403 oh /= KSH;
404 od /= KSD;
405 if (od >= OD || oh >= OH || ow >= OW) continue;
406 for (dim_t oc = 0; oc < OC; ++oc) {
407 const dim_t diff_dst_off = oc + od * diff_dst_od_stride
408 + oh * diff_dst_oh_stride + ow * diff_dst_ow_stride;
409 const dim_t weights_off = oc * weights_oc_stride
410 + kd * weights_kd_stride + kh * weights_kh_stride
411 + kw;
412 const int dd = io::load_int_value(diff_dst_d.data_type(),
413 diff_dst_loc, diff_dst_off + diff_dst_loc_off);
414 const int w = io::load_int_value(weights_d.data_type(),
415 weights_loc, weights_off + weights_loc_off);
416 ds += dd * w;
417 }
418 }
419 } else {
420 for_(dim_t oc = 0; oc < OC; ++oc)
421 for_(dim_t kd = 0; kd < KD; ++kd)
422 for (dim_t kh = 0; kh < KH; ++kh) {
423 // Note: placing these 2 params outside the `kw-loop` because
424 // of a compiler-generated bug. Declaring 'od' as volatile
425 // fixes a recurring seg-fault.
426 const volatile dim_t od_ = id - kd * KDD + padFront;
427 const dim_t weights_off_ = oc * weights_oc_stride
428 + kd * weights_kd_stride + kh * weights_kh_stride;
429 for (dim_t kw = 0; kw < KW; ++kw) {
430 dim_t ow = iw - kw * KDW + padL;
431 dim_t oh = ih - kh * KDH + padT;
432 dim_t od = od_;
433 if (ow < 0 || oh < 0 || od < 0 || ow % KSW != 0
434 || oh % KSH != 0 || od % KSD != 0)
435 continue;
436 ow /= KSW;
437 oh /= KSH;
438 od /= KSD;
439 if (od >= OD || oh >= OH || ow >= OW) continue;
440 const dim_t diff_dst_off = oc + od * diff_dst_od_stride
441 + oh * diff_dst_oh_stride + ow * diff_dst_ow_stride;
442 const dim_t weights_off = weights_off_ + kw;
443 const int dd = io::load_int_value(diff_dst_d.data_type(),
444 diff_dst_loc, diff_dst_off + diff_dst_loc_off);
445 const int w = io::load_int_value(weights_d.data_type(),
446 weights_loc, weights_off + weights_loc_off);
447 ds += dd * w;
448 }
449 }
450 }
451 return ds;
452 };
453
454 parallel_nd(G, MB, IC, ID, IH, IW,
455 [&](dim_t g, dim_t mb, dim_t ic, dim_t id, dim_t ih, dim_t iw) {
456 int acc = 0;
457 if (diff_dst_d.is_plain() && weights_d.is_plain()
458 && diff_dst_oc_stride == 1 && weights_kw_stride == 1)
459 acc += ker_plain(g, mb, ic, id, ih, iw);
460 else
461 acc += ker(g, mb, ic, id, ih, iw);
462
463 float ds = static_cast<float>(acc);
464 dequantize(ds, g, IC, ic, diff_wei_scales, with_groups,
465 diff_wei_scale_mask, diff_dst_scales);
466 quantize(ds, g, IC, ic, diff_src_scales);
467
468 const auto diff_src_off = ref_conv_utils::get_data_off(
469 diff_src_d, ndims, mb, g * IC + ic, id, ih, iw);
470 io::store_float_value(
471 diff_src_d.data_type(), ds, diff_src, diff_src_off);
472 });
473
474 return status::success;
475}
476
477} // namespace cpu
478} // namespace impl
479} // namespace dnnl
480
481// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
482