1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/c_types_map.hpp" |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/type_helpers.hpp" |
20 | |
21 | #include "cpu/ref_io_helper.hpp" |
22 | #include "cpu/simple_q10n.hpp" |
23 | |
24 | #include "cpu/ref_inner_product_int8.hpp" |
25 | #include "cpu/ref_inner_product_utils.hpp" |
26 | |
27 | #include "cpu/cpu_primitive.hpp" |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | namespace cpu { |
32 | |
33 | status_t ref_inner_product_int8_fwd_t::execute_forward( |
34 | const exec_ctx_t &ctx) const { |
35 | status_t status = status::success; |
36 | auto src = CTX_IN_MEM(const void *, DNNL_ARG_SRC); |
37 | auto weights = CTX_IN_MEM(const void *, DNNL_ARG_WEIGHTS); |
38 | auto bias = CTX_IN_MEM(const void *, DNNL_ARG_BIAS); |
39 | auto dst = CTX_OUT_CLEAN_MEM(void *, DNNL_ARG_DST, status); |
40 | CHECK(status); |
41 | |
42 | const memory_desc_wrapper src_d(pd()->src_md()); |
43 | const memory_desc_wrapper dst_d(pd()->dst_md()); |
44 | const memory_desc_wrapper weights_d(pd()->weights_md(0)); |
45 | const memory_desc_wrapper bias_d(pd()->weights_md(1)); |
46 | |
47 | const auto MB = pd()->MB(); |
48 | const auto OC = pd()->OC(); |
49 | const auto IC = pd()->IC(); |
50 | |
51 | const auto ndims = pd()->ndims(); |
52 | |
53 | auto ker = [=](dim_t mb, dim_t oc) { |
54 | int d = 0; |
55 | const dim_t KD = pd()->KD(); |
56 | const dim_t KH = pd()->KH(); |
57 | const dim_t KW = pd()->KW(); |
58 | for_(dim_t ic = 0; ic < IC; ++ic) |
59 | for_(dim_t kd = 0; kd < KD; ++kd) |
60 | for_(dim_t kh = 0; kh < KH; ++kh) |
61 | for (dim_t kw = 0; kw < KW; ++kw) { |
62 | const auto src_off = ref_ip_utils::get_data_off( |
63 | src_d, ndims, mb, ic, kd, kh, kw); |
64 | const auto wei_off = ref_ip_utils::get_weights_off( |
65 | weights_d, ndims, oc, ic, kd, kh, kw); |
66 | const int s = io::load_int_value(src_d.data_type(), src, src_off); |
67 | const int w = io::load_int_value( |
68 | weights_d.data_type(), weights, wei_off); |
69 | d += s * w; |
70 | } |
71 | return d; |
72 | }; |
73 | |
74 | DEFINE_ARG_SCALES_BUFFER(src_scales, DNNL_ARG_SRC); |
75 | DEFINE_ARG_SCALES_BUFFER(wei_scales, DNNL_ARG_WEIGHTS); |
76 | DEFINE_ARG_SCALES_BUFFER(dst_scales, DNNL_ARG_DST); |
77 | |
78 | const auto &attr_scales = pd()->attr()->scales_; |
79 | const bool with_dst_scales |
80 | = !attr_scales.get(DNNL_ARG_DST).has_default_values(); |
81 | |
82 | auto maybe_oscale = [=](float &d, dim_t oc) { |
83 | // scale_idx_mult = 1 for per_oc scales and 0, otherwise |
84 | const int scale_idx_mult |
85 | = attr_scales.get(DNNL_ARG_WEIGHTS).mask_ == (1 << 0); |
86 | d *= src_scales[0] * wei_scales[oc * scale_idx_mult]; |
87 | }; |
88 | |
89 | parallel_nd(MB, OC, [&](dim_t mb, dim_t oc) { |
90 | int acc = ker(mb, oc); |
91 | |
92 | float d = acc; |
93 | maybe_oscale(d, oc); |
94 | |
95 | if (bias) { |
96 | const auto bias_off = bias_d.off(oc); |
97 | const float b |
98 | = io::load_float_value(bias_d.data_type(), bias, bias_off); |
99 | d += b; |
100 | } |
101 | |
102 | dim_t dst_off = dst_d.off(mb, oc); |
103 | dim_t dst_l_off = (mb * OC + oc); |
104 | |
105 | ref_post_ops_t::args_t args; |
106 | args.dst_val = io::load_float_value(dst_d.data_type(), dst, dst_off); |
107 | args.ctx = &ctx; |
108 | args.l_offset = dst_l_off; |
109 | args.dst_md = pd()->dst_md(); |
110 | ref_post_ops->execute(d, args); |
111 | |
112 | if (with_dst_scales) d *= dst_scales[0]; |
113 | io::store_float_value(dst_d.data_type(), d, dst, dst_off); |
114 | }); |
115 | |
116 | return status::success; |
117 | } |
118 | |
119 | } // namespace cpu |
120 | } // namespace impl |
121 | } // namespace dnnl |
122 | |
123 | // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s |
124 | |