1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18#include <cstdlib>
19#include <memory>
20
21#include "common/math_utils.hpp"
22
23#include "cpu/platform.hpp"
24#include "cpu/primitive_attr_postops.hpp"
25#include "cpu/ref_io_helper.hpp"
26#include "cpu/simple_q10n.hpp"
27
28#if DNNL_X64
29#include "cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp"
30#endif
31
32#include "cpu/gemm_x8s8s32x_convolution_utils.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace gemm_x8s8s32x_convolution_utils {
38
39template <typename dst_data_t>
40struct ref_pp_ker_t : pp_ker_t {
41 ref_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp)
42 : pp_ker_t(pd, jcp) {
43 if (jcp.with_eltwise || jcp.with_binary) {
44 ref_post_ops_.reset(new ref_post_ops_t(jcp.post_ops));
45 }
46 }
47
48 using acc_data_t = pp_ker_t::acc_data_t;
49
50 void operator()(void *dst, const acc_data_t *acc, const char *bias,
51 const float *scales, float dst_scale, float sum_scale,
52 float signed_scale, int g, size_t start, size_t end,
53 const zero_point_call_params_t &zp,
54 const void *post_ops_binary_rhs_arg_vec, const void *dst_orig,
55 const exec_ctx_t &ctx, const memory_desc_t &dst_md,
56 const single_gemm_conv_chunk_desc_t &chunk_desc) const override;
57
58private:
59 std::unique_ptr<ref_post_ops_t> ref_post_ops_;
60};
61
62template <typename dst_data_t>
63void ref_pp_ker_t<dst_data_t>::operator()(void *void_dst, const acc_data_t *acc,
64 const char *bias, const float *scales, float dst_scale, float sum_scale,
65 float signed_scale, int g, size_t start, size_t end,
66 const zero_point_call_params_t &zp,
67 const void * /* post_ops_binary_rhs_arg_vec */,
68 const void * /* dst_orig */, const exec_ctx_t &ctx,
69 const memory_desc_t &dst_md,
70 const single_gemm_conv_chunk_desc_t &chunk_desc) const {
71
72 if (end <= start) return;
73
74 assert(data_traits<dst_data_t>::data_type == jcp_.dst_data_type);
75
76 const lldiv_t dv_start = std::div((long long)start, (long long)jcp_.oc);
77 const lldiv_t dv_end = std::div((long long)(end - 1), (long long)jcp_.oc);
78 const size_t first_oc = dv_start.rem;
79 const size_t last_oc = dv_end.rem;
80 const size_t first_os = dv_start.quot;
81 const size_t last_os = dv_end.quot;
82 const int32_t zp_dst_val = jcp_.zp.dst_exists ? *(zp.dst) : 0;
83
84 ref_post_ops_t::args_t args;
85 args.ctx = &ctx;
86 args.dst_md = &dst_md;
87
88 for (size_t os = first_os; os <= last_os; os++) {
89 const size_t start_oc = (os == first_os) ? first_oc : 0;
90 const size_t end_oc = (os == last_os) ? last_oc : jcp_.oc - 1;
91 for (size_t oc = start_oc; oc <= end_oc; oc++) {
92 const size_t acc_off = os * jcp_.oc + oc;
93 const size_t dst_off = os * jcp_.dst_os_stride + oc;
94
95 int32_t data_s32 = acc[acc_off];
96
97 if (jcp_.zp.src_exists) {
98 const auto oc_offset = g * jcp_.oc + oc;
99 data_s32 += zp.src_comp[oc_offset];
100 }
101
102 float data = static_cast<float>(data_s32);
103
104 if (jcp_.signed_input) data *= signed_scale;
105
106 // dequantize data
107 data *= scales[(g * jcp_.oc + oc) * jcp_.scale_idx_mult];
108
109 if (jcp_.with_bias) {
110 const float b = io::load_float_value(
111 jcp_.bias_data_type, bias, g * jcp_.oc + oc);
112 data += b;
113 }
114
115 if (jcp_.with_sum)
116 data += sum_scale
117 * io::load_float_value(
118 jcp_.sum_data_type, void_dst, dst_off);
119 if (jcp_.with_eltwise || jcp_.with_binary) {
120 args.l_offset = (g * jcp_.oc + oc) * jcp_.os;
121 ref_post_ops_->execute(data, args);
122 }
123
124 // quantize data
125 if (jcp_.with_dst_scale) data *= dst_scale;
126 if (jcp_.zp.dst_exists) data += zp_dst_val;
127
128 io::store_float_value(jcp_.dst_data_type, data, void_dst, dst_off);
129 }
130 }
131}
132
133// Interface section
134
135pp_ker_t::pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp)
136 : jcp_(jcp) {}
137
138pp_ker_t *pp_ker_t::create(
139 const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) {
140#if DNNL_X64
141 auto *res
142 = x64::gemm_x8s8s32x_convolution_utils::jit_pp_ker_create(pd, jcp);
143 if (res) return res;
144#endif
145 switch (pd->dst_md()->data_type) {
146 case data_type::f32: return new ref_pp_ker_t<float>(pd, jcp);
147 case data_type::bf16: return new ref_pp_ker_t<bfloat16_t>(pd, jcp);
148 case data_type::s32: return new ref_pp_ker_t<int32_t>(pd, jcp);
149 case data_type::s8: return new ref_pp_ker_t<int8_t>(pd, jcp);
150 case data_type::u8: return new ref_pp_ker_t<uint8_t>(pd, jcp);
151 default: assert(!"unexpected data type");
152 }
153 return nullptr;
154}
155
156bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) {
157#if DNNL_X64
158 return x64::gemm_x8s8s32x_convolution_utils::post_ops_ok(post_ops, dst_d);
159#endif
160 return std::all_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(),
161 [](const dnnl_post_ops::entry_t &post_op) {
162 return post_op.is_eltwise() || post_op.is_sum()
163 || post_op.is_binary();
164 });
165}
166
167bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d) {
168 const auto dst_md = memory_desc_wrapper(dst_d);
169 return post_ops_ok(post_ops, &dst_md);
170}
171
172bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept {
173#if DNNL_X64
174 return x64::gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(dst_dt);
175#else
176 return false;
177#endif
178}
179
180} // namespace gemm_x8s8s32x_convolution_utils
181} // namespace cpu
182} // namespace impl
183} // namespace dnnl
184