1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | #include <cstdlib> |
19 | #include <memory> |
20 | |
21 | #include "common/math_utils.hpp" |
22 | |
23 | #include "cpu/platform.hpp" |
24 | #include "cpu/primitive_attr_postops.hpp" |
25 | #include "cpu/ref_io_helper.hpp" |
26 | #include "cpu/simple_q10n.hpp" |
27 | |
28 | #if DNNL_X64 |
29 | #include "cpu/x64/jit_gemm_x8s8s32x_convolution_utils.hpp" |
30 | #endif |
31 | |
32 | #include "cpu/gemm_x8s8s32x_convolution_utils.hpp" |
33 | |
34 | namespace dnnl { |
35 | namespace impl { |
36 | namespace cpu { |
37 | namespace gemm_x8s8s32x_convolution_utils { |
38 | |
39 | template <typename dst_data_t> |
40 | struct ref_pp_ker_t : pp_ker_t { |
41 | ref_pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) |
42 | : pp_ker_t(pd, jcp) { |
43 | if (jcp.with_eltwise || jcp.with_binary) { |
44 | ref_post_ops_.reset(new ref_post_ops_t(jcp.post_ops)); |
45 | } |
46 | } |
47 | |
48 | using acc_data_t = pp_ker_t::acc_data_t; |
49 | |
50 | void operator()(void *dst, const acc_data_t *acc, const char *bias, |
51 | const float *scales, float dst_scale, float sum_scale, |
52 | float signed_scale, int g, size_t start, size_t end, |
53 | const zero_point_call_params_t &zp, |
54 | const void *post_ops_binary_rhs_arg_vec, const void *dst_orig, |
55 | const exec_ctx_t &ctx, const memory_desc_t &dst_md, |
56 | const single_gemm_conv_chunk_desc_t &chunk_desc) const override; |
57 | |
58 | private: |
59 | std::unique_ptr<ref_post_ops_t> ref_post_ops_; |
60 | }; |
61 | |
62 | template <typename dst_data_t> |
63 | void ref_pp_ker_t<dst_data_t>::operator()(void *void_dst, const acc_data_t *acc, |
64 | const char *bias, const float *scales, float dst_scale, float sum_scale, |
65 | float signed_scale, int g, size_t start, size_t end, |
66 | const zero_point_call_params_t &zp, |
67 | const void * /* post_ops_binary_rhs_arg_vec */, |
68 | const void * /* dst_orig */, const exec_ctx_t &ctx, |
69 | const memory_desc_t &dst_md, |
70 | const single_gemm_conv_chunk_desc_t &chunk_desc) const { |
71 | |
72 | if (end <= start) return; |
73 | |
74 | assert(data_traits<dst_data_t>::data_type == jcp_.dst_data_type); |
75 | |
76 | const lldiv_t dv_start = std::div((long long)start, (long long)jcp_.oc); |
77 | const lldiv_t dv_end = std::div((long long)(end - 1), (long long)jcp_.oc); |
78 | const size_t first_oc = dv_start.rem; |
79 | const size_t last_oc = dv_end.rem; |
80 | const size_t first_os = dv_start.quot; |
81 | const size_t last_os = dv_end.quot; |
82 | const int32_t zp_dst_val = jcp_.zp.dst_exists ? *(zp.dst) : 0; |
83 | |
84 | ref_post_ops_t::args_t args; |
85 | args.ctx = &ctx; |
86 | args.dst_md = &dst_md; |
87 | |
88 | for (size_t os = first_os; os <= last_os; os++) { |
89 | const size_t start_oc = (os == first_os) ? first_oc : 0; |
90 | const size_t end_oc = (os == last_os) ? last_oc : jcp_.oc - 1; |
91 | for (size_t oc = start_oc; oc <= end_oc; oc++) { |
92 | const size_t acc_off = os * jcp_.oc + oc; |
93 | const size_t dst_off = os * jcp_.dst_os_stride + oc; |
94 | |
95 | int32_t data_s32 = acc[acc_off]; |
96 | |
97 | if (jcp_.zp.src_exists) { |
98 | const auto oc_offset = g * jcp_.oc + oc; |
99 | data_s32 += zp.src_comp[oc_offset]; |
100 | } |
101 | |
102 | float data = static_cast<float>(data_s32); |
103 | |
104 | if (jcp_.signed_input) data *= signed_scale; |
105 | |
106 | // dequantize data |
107 | data *= scales[(g * jcp_.oc + oc) * jcp_.scale_idx_mult]; |
108 | |
109 | if (jcp_.with_bias) { |
110 | const float b = io::load_float_value( |
111 | jcp_.bias_data_type, bias, g * jcp_.oc + oc); |
112 | data += b; |
113 | } |
114 | |
115 | if (jcp_.with_sum) |
116 | data += sum_scale |
117 | * io::load_float_value( |
118 | jcp_.sum_data_type, void_dst, dst_off); |
119 | if (jcp_.with_eltwise || jcp_.with_binary) { |
120 | args.l_offset = (g * jcp_.oc + oc) * jcp_.os; |
121 | ref_post_ops_->execute(data, args); |
122 | } |
123 | |
124 | // quantize data |
125 | if (jcp_.with_dst_scale) data *= dst_scale; |
126 | if (jcp_.zp.dst_exists) data += zp_dst_val; |
127 | |
128 | io::store_float_value(jcp_.dst_data_type, data, void_dst, dst_off); |
129 | } |
130 | } |
131 | } |
132 | |
133 | // Interface section |
134 | |
135 | pp_ker_t::pp_ker_t(const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) |
136 | : jcp_(jcp) {} |
137 | |
138 | pp_ker_t *pp_ker_t::create( |
139 | const convolution_pd_t *pd, const conv_gemm_conf_t &jcp) { |
140 | #if DNNL_X64 |
141 | auto *res |
142 | = x64::gemm_x8s8s32x_convolution_utils::jit_pp_ker_create(pd, jcp); |
143 | if (res) return res; |
144 | #endif |
145 | switch (pd->dst_md()->data_type) { |
146 | case data_type::f32: return new ref_pp_ker_t<float>(pd, jcp); |
147 | case data_type::bf16: return new ref_pp_ker_t<bfloat16_t>(pd, jcp); |
148 | case data_type::s32: return new ref_pp_ker_t<int32_t>(pd, jcp); |
149 | case data_type::s8: return new ref_pp_ker_t<int8_t>(pd, jcp); |
150 | case data_type::u8: return new ref_pp_ker_t<uint8_t>(pd, jcp); |
151 | default: assert(!"unexpected data type" ); |
152 | } |
153 | return nullptr; |
154 | } |
155 | |
156 | bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_wrapper *dst_d) { |
157 | #if DNNL_X64 |
158 | return x64::gemm_x8s8s32x_convolution_utils::post_ops_ok(post_ops, dst_d); |
159 | #endif |
160 | return std::all_of(post_ops.entry_.cbegin(), post_ops.entry_.cend(), |
161 | [](const dnnl_post_ops::entry_t &post_op) { |
162 | return post_op.is_eltwise() || post_op.is_sum() |
163 | || post_op.is_binary(); |
164 | }); |
165 | } |
166 | |
167 | bool post_ops_ok(const post_ops_t &post_ops, const memory_desc_t *dst_d) { |
168 | const auto dst_md = memory_desc_wrapper(dst_d); |
169 | return post_ops_ok(post_ops, &dst_md); |
170 | } |
171 | |
172 | bool mayiuse_jit_pp_kernel(data_type_t dst_dt) noexcept { |
173 | #if DNNL_X64 |
174 | return x64::gemm_x8s8s32x_convolution_utils::mayiuse_jit_pp_kernel(dst_dt); |
175 | #else |
176 | return false; |
177 | #endif |
178 | } |
179 | |
180 | } // namespace gemm_x8s8s32x_convolution_utils |
181 | } // namespace cpu |
182 | } // namespace impl |
183 | } // namespace dnnl |
184 | |