1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/c_types_map.hpp"
18#include "common/dnnl_thread.hpp"
19#include "common/type_helpers.hpp"
20
21#include "cpu/ref_io_helper.hpp"
22#include "cpu/simple_q10n.hpp"
23
24#include "cpu/ref_inner_product_int8.hpp"
25#include "cpu/ref_inner_product_utils.hpp"
26
27#include "cpu/cpu_primitive.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32
33status_t ref_inner_product_int8_fwd_t::execute_forward(
34 const exec_ctx_t &ctx) const {
35 status_t status = status::success;
36 auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC);
37 auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS);
38 auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS);
39 auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status);
40 CHECK(status);
41
42 const memory_desc_wrapper src_d(pd()->src_md());
43 const memory_desc_wrapper dst_d(pd()->dst_md());
44 const memory_desc_wrapper weights_d(pd()->weights_md(0));
45 const memory_desc_wrapper bias_d(pd()->weights_md(1));
46
47 const auto MB = pd()->MB();
48 const auto OC = pd()->OC();
49 const auto IC = pd()->IC();
50
51 const auto ndims = pd()->ndims();
52
53 auto ker = [=](dim_t mb, dim_t oc) {
54 int d = 0;
55 const dim_t KD = pd()->KD();
56 const dim_t KH = pd()->KH();
57 const dim_t KW = pd()->KW();
58 for_(dim_t ic = 0; ic < IC; ++ic)
59 for_(dim_t kd = 0; kd < KD; ++kd)
60 for_(dim_t kh = 0; kh < KH; ++kh)
61 for (dim_t kw = 0; kw < KW; ++kw) {
62 const auto src_off = ref_ip_utils::get_data_off(
63 src_d, ndims, mb, ic, kd, kh, kw);
64 const auto wei_off = ref_ip_utils::get_weights_off(
65 weights_d, ndims, oc, ic, kd, kh, kw);
66 const int s = io::load_int_value(src_d.data_type(), src, src_off);
67 const int w = io::load_int_value(
68 weights_d.data_type(), weights, wei_off);
69 d += s * w;
70 }
71 return d;
72 };
73
74 DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC);
75 DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS);
76 DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST);
77
78 const auto &attr_scales = pd()->attr()->scales_;
79 const bool with_dst_scales
80 = !attr_scales.get(DNNL_ARG_DST).has_default_values();
81
82 auto maybe_oscale = [=](float &d, dim_t oc) {
83 // scale_idx_mult = 1 for per_oc scales and 0, otherwise
84 const int scale_idx_mult
85 = attr_scales.get(DNNL_ARG_WEIGHTS).mask_ == (1 << 0);
86 d *= src_scales[0] * wei_scales[oc * scale_idx_mult];
87 };
88
89 parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) {
90 int acc = ker(mb, oc);
91
92 float d = acc;
93 maybe_oscale(d, oc);
94
95 if (bias) {
96 const auto bias_off = bias_d.off(oc);
97 const float b
98 = io::load_float_value(bias_d.data_type(), bias, bias_off);
99 d += b;
100 }
101
102 dim_t dst_off = dst_d.off(mb, oc);
103 dim_t dst_l_off = (mb * OC + oc);
104
105 ref_post_ops_t::args_t args;
106 args.dst_val = io::load_float_value(dst_d.data_type(), dst, dst_off);
107 args.ctx = &ctx;
108 args.l_offset = dst_l_off;
109 args.dst_md = pd()->dst_md();
110 ref_post_ops->execute(d, args);
111
112 if (with_dst_scales) d *= dst_scales[0];
113 io::store_float_value(dst_d.data_type(), d, dst, dst_off);
114 });
115
116 return status::success;
117}
118
119} // namespace cpu
120} // namespace impl
121} // namespace dnnl
122
123// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
124