1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "common/dnnl_thread.hpp" |
18 | |
19 | #include "cpu/x64/jit_uni_reduction.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace cpu { |
24 | namespace x64 { |
25 | |
26 | static cpu_isa_t get_supported_isa() { |
27 | if (mayiuse(avx512_core_fp16)) return avx512_core_fp16; |
28 | if (mayiuse(avx512_core_bf16)) return avx512_core_bf16; |
29 | if (mayiuse(avx512_core)) return avx512_core; |
30 | if (mayiuse(avx2)) return avx2; |
31 | if (mayiuse(avx)) return avx; |
32 | if (mayiuse(sse41)) return sse41; |
33 | |
34 | return isa_undef; |
35 | } |
36 | |
37 | static bool impl_supports_datatype(data_type_t data_type) { |
38 | switch (data_type) { |
39 | case data_type::bf16: return x64::mayiuse(x64::avx512_core); |
40 | case data_type::f16: return x64::mayiuse(x64::avx512_core_fp16); |
41 | case data_type::f32: |
42 | case data_type::s32: |
43 | case data_type::s8: |
44 | case data_type::u8: return true; |
45 | default: return false; |
46 | } |
47 | } |
48 | |
49 | status_t jit_uni_reduction_t::pd_t::init(engine_t *engine) { |
50 | using namespace alg_kind; |
51 | using namespace data_type; |
52 | using namespace format_tag; |
53 | using sm = primitive_attr_t::skip_mask_t; |
54 | |
55 | conf_.isa = get_supported_isa(); |
56 | |
57 | conf_.src_type = src_md()->data_type; |
58 | conf_.dst_type = dst_md()->data_type; |
59 | conf_.acc_type |
60 | = types::default_accum_data_type(conf_.src_type, conf_.dst_type); |
61 | conf_.src_dt_size = types::data_type_size(conf_.src_type); |
62 | conf_.dst_dt_size = types::data_type_size(conf_.dst_type); |
63 | conf_.acc_dt_size = types::data_type_size(conf_.acc_type); |
64 | |
65 | const bool ok = impl_supports_datatype(conf_.src_type) |
66 | && impl_supports_datatype(conf_.dst_type) |
67 | && set_default_params() == status::success |
68 | && attr()->has_default_values(sm::post_ops) |
69 | && attr_.set_default_formats(dst_md(0)) == status::success; |
70 | if (!ok) return status::unimplemented; |
71 | |
72 | const auto src_mdw = memory_desc_wrapper(src_md()); |
73 | const auto dst_mdw = memory_desc_wrapper(dst_md()); |
74 | |
75 | const std::vector<injector::post_op_type> accepted_post_ops |
76 | = {injector::sum, injector::eltwise, injector::binary}; |
77 | static constexpr bool sum_at_0_pos_only = false; |
78 | static constexpr bool sum_requires_scale_one = false; |
79 | static constexpr bool sum_requires_zp_zero = true; |
80 | const bcast_set_t accepted_broadcasts |
81 | = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc, |
82 | broadcasting_strategy_t::per_oc_spatial, |
83 | broadcasting_strategy_t::no_broadcast}; |
84 | injector::post_ops_ok_args_t post_ops_args(conf_.isa, accepted_post_ops, |
85 | attr()->post_ops_, &dst_mdw, sum_at_0_pos_only, |
86 | sum_requires_scale_one, sum_requires_zp_zero, accepted_broadcasts); |
87 | if (!post_ops_ok(post_ops_args)) return status::unimplemented; |
88 | |
89 | conf_.post_ops = attr()->post_ops_; |
90 | |
91 | static constexpr bool require_scale_one = false; |
92 | conf_.with_eltwise = conf_.with_binary = conf_.with_sum = false; |
93 | for (const auto &entry : conf_.post_ops.entry_) { |
94 | if (entry.is_eltwise()) { |
95 | conf_.with_eltwise = true; |
96 | } else if (entry.is_binary()) { |
97 | conf_.with_binary = true; |
98 | } else if (entry.is_sum(require_scale_one) && entry.sum.scale != 0.f) { |
99 | conf_.with_sum = true; |
100 | conf_.sum_scales.push(entry.sum.scale); |
101 | } |
102 | } |
103 | conf_.with_postops |
104 | = conf_.with_eltwise || conf_.with_binary || conf_.with_sum; |
105 | |
106 | const format_tag_t src_md_desired_format = memory_desc_matches_one_of_tag( |
107 | *src_md(), x, nc, ncw, nchw, ncdhw); |
108 | const format_tag_t dst_md_desired_format = memory_desc_matches_one_of_tag( |
109 | *dst_md(), x, nc, ncw, nchw, ncdhw); |
110 | if (src_md_desired_format != dst_md_desired_format |
111 | || src_md_desired_format == format_tag::undef) |
112 | return status::unimplemented; |
113 | |
114 | const int ndims = src_mdw.ndims(); |
115 | const auto &src_dims = src_mdw.dims(); |
116 | const auto &dst_dims = dst_mdw.dims(); |
117 | |
118 | conf_.is_saturation_needed = utils::one_of(conf_.dst_type, s32, s8, u8); |
119 | |
120 | int num_of_reduced_dims = 0; |
121 | conf_.idle_size = dst_mdw.nelems(); |
122 | conf_.reduce_size = 1; |
123 | for (int d = ndims - 1; d >= 0; --d) { |
124 | if (src_dims[d] != dst_dims[d]) { |
125 | num_of_reduced_dims++; |
126 | conf_.reduce_size *= src_dims[d]; |
127 | } else |
128 | break; |
129 | } |
130 | |
131 | if (num_of_reduced_dims == 0) return status::unimplemented; |
132 | |
133 | for (int d = 0; d < ndims - num_of_reduced_dims; ++d) |
134 | if (src_dims[d] != dst_dims[d]) return status::unimplemented; |
135 | |
136 | conf_.alg = desc()->alg_kind; |
137 | if (utils::one_of(conf_.alg, reduction_norm_lp_max, reduction_norm_lp_sum, |
138 | reduction_norm_lp_power_p_max, reduction_norm_lp_power_p_sum)) |
139 | return status::unimplemented; |
140 | |
141 | return status::success; |
142 | } |
143 | |
144 | status_t jit_uni_reduction_t::init(engine_t *engine) { |
145 | using namespace format_tag; |
146 | |
147 | const memory_desc_t *dst_md = pd()->dst_md(); |
148 | const jit_reduction_conf_t &conf = pd()->get_conf(); |
149 | |
150 | CHECK(get_proper_kernel(dst_md, conf)); |
151 | CHECK(kernel_->create_kernel()); |
152 | |
153 | return status::success; |
154 | } |
155 | |
156 | status_t jit_uni_reduction_t::execute(const exec_ctx_t &ctx) const { |
157 | const auto src = CTX_IN_MEM(const uint8_t *, DNNL_ARG_SRC); |
158 | auto dst = CTX_OUT_MEM(uint8_t *, DNNL_ARG_DST); |
159 | |
160 | const dim_t idle_size = pd()->get_conf().idle_size; |
161 | const dim_t reduce_size = pd()->get_conf().reduce_size; |
162 | const std::size_t src_dt_size = pd()->get_conf().src_dt_size; |
163 | const std::size_t dst_dt_size = pd()->get_conf().dst_dt_size; |
164 | const auto &post_ops = pd()->attr()->post_ops_; |
165 | const auto &post_ops_binary_rhs_arg_vec |
166 | = binary_injector::prepare_binary_args(post_ops, ctx); |
167 | |
168 | parallel_nd(idle_size, [&](dim_t i) { |
169 | const dim_t src_off = i * reduce_size * src_dt_size; |
170 | const dim_t dst_off = i * dst_dt_size; |
171 | |
172 | jit_reduction_call_s args = jit_reduction_call_s(); |
173 | args.src = src + src_off; |
174 | args.dst = dst + dst_off; |
175 | args.dst_orig = dst; |
176 | args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data(); |
177 | |
178 | (*kernel_)(&args); |
179 | }); |
180 | |
181 | return status::success; |
182 | } |
183 | |
184 | status_t jit_uni_reduction_t::get_proper_kernel( |
185 | const memory_desc_t *dst_md, const jit_reduction_conf_t &conf) { |
186 | using namespace data_type; |
187 | |
188 | if (conf.isa == avx512_core_fp16) |
189 | return safe_ptr_assign(kernel_, |
190 | new jit_uni_reduction_kernel_t<avx512_core_fp16>(conf, dst_md)); |
191 | if (conf.isa == avx512_core_bf16) |
192 | return safe_ptr_assign(kernel_, |
193 | new jit_uni_reduction_kernel_t<avx512_core_bf16>(conf, dst_md)); |
194 | else if (conf.isa == avx512_core) |
195 | return safe_ptr_assign(kernel_, |
196 | new jit_uni_reduction_kernel_t<avx512_core>(conf, dst_md)); |
197 | else if (is_superset(conf.isa, avx)) { |
198 | const bool is_src_i8 = utils::one_of(conf.src_type, s8, u8); |
199 | const bool is_dst_i8 = utils::one_of(conf.dst_type, s8, u8); |
200 | if (conf.isa == avx2) { |
201 | if (is_src_i8 || is_dst_i8) |
202 | return safe_ptr_assign(kernel_, |
203 | new jit_uni_reduction_kernel_t<avx2, Xbyak::Xmm>( |
204 | conf, dst_md)); |
205 | else |
206 | return safe_ptr_assign(kernel_, |
207 | new jit_uni_reduction_kernel_t<avx2>(conf, dst_md)); |
208 | } else { |
209 | if (is_src_i8 || is_dst_i8) |
210 | return safe_ptr_assign(kernel_, |
211 | new jit_uni_reduction_kernel_t<avx, Xbyak::Xmm>( |
212 | conf, dst_md)); |
213 | else |
214 | return safe_ptr_assign(kernel_, |
215 | new jit_uni_reduction_kernel_t<avx>(conf, dst_md)); |
216 | } |
217 | } else if (conf.isa == sse41) |
218 | return safe_ptr_assign( |
219 | kernel_, new jit_uni_reduction_kernel_t<sse41>(conf, dst_md)); |
220 | else |
221 | return status::runtime_error; |
222 | } |
223 | |
224 | } // namespace x64 |
225 | } // namespace cpu |
226 | } // namespace impl |
227 | } // namespace dnnl |
228 | |