1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "common/dnnl_thread.hpp"
18
19#include "cpu/x64/jit_uni_reduction.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace cpu {
24namespace x64 {
25
26static cpu_isa_t get_supported_isa() {
27 if (mayiuse(avx512_core_fp16)) return avx512_core_fp16;
28 if (mayiuse(avx512_core_bf16)) return avx512_core_bf16;
29 if (mayiuse(avx512_core)) return avx512_core;
30 if (mayiuse(avx2)) return avx2;
31 if (mayiuse(avx)) return avx;
32 if (mayiuse(sse41)) return sse41;
33
34 return isa_undef;
35}
36
37static bool impl_supports_datatype(data_type_t data_type) {
38 switch (data_type) {
39 case data_type::bf16: return x64::mayiuse(x64::avx512_core);
40 case data_type::f16: return x64::mayiuse(x64::avx512_core_fp16);
41 case data_type::f32:
42 case data_type::s32:
43 case data_type::s8:
44 case data_type::u8: return true;
45 default: return false;
46 }
47}
48
49status_t jit_uni_reduction_t::pd_t::init(engine_t *engine) {
50 using namespace alg_kind;
51 using namespace data_type;
52 using namespace format_tag;
53 using sm = primitive_attr_t::skip_mask_t;
54
55 conf_.isa = get_supported_isa();
56
57 conf_.src_type = src_md()->data_type;
58 conf_.dst_type = dst_md()->data_type;
59 conf_.acc_type
60 = types::default_accum_data_type(conf_.src_type, conf_.dst_type);
61 conf_.src_dt_size = types::data_type_size(conf_.src_type);
62 conf_.dst_dt_size = types::data_type_size(conf_.dst_type);
63 conf_.acc_dt_size = types::data_type_size(conf_.acc_type);
64
65 const bool ok = impl_supports_datatype(conf_.src_type)
66 && impl_supports_datatype(conf_.dst_type)
67 && set_default_params() == status::success
68 && attr()->has_default_values(sm::post_ops)
69 && attr_.set_default_formats(dst_md(0)) == status::success;
70 if (!ok) return status::unimplemented;
71
72 const auto src_mdw = memory_desc_wrapper(src_md());
73 const auto dst_mdw = memory_desc_wrapper(dst_md());
74
75 const std::vector<injector::post_op_type> accepted_post_ops
76 = {injector::sum, injector::eltwise, injector::binary};
77 static constexpr bool sum_at_0_pos_only = false;
78 static constexpr bool sum_requires_scale_one = false;
79 static constexpr bool sum_requires_zp_zero = true;
80 const bcast_set_t accepted_broadcasts
81 = {broadcasting_strategy_t::scalar, broadcasting_strategy_t::per_oc,
82 broadcasting_strategy_t::per_oc_spatial,
83 broadcasting_strategy_t::no_broadcast};
84 injector::post_ops_ok_args_t post_ops_args(conf_.isa, accepted_post_ops,
85 attr()->post_ops_, &dst_mdw, sum_at_0_pos_only,
86 sum_requires_scale_one, sum_requires_zp_zero, accepted_broadcasts);
87 if (!post_ops_ok(post_ops_args)) return status::unimplemented;
88
89 conf_.post_ops = attr()->post_ops_;
90
91 static constexpr bool require_scale_one = false;
92 conf_.with_eltwise = conf_.with_binary = conf_.with_sum = false;
93 for (const auto &entry : conf_.post_ops.entry_) {
94 if (entry.is_eltwise()) {
95 conf_.with_eltwise = true;
96 } else if (entry.is_binary()) {
97 conf_.with_binary = true;
98 } else if (entry.is_sum(require_scale_one) && entry.sum.scale != 0.f) {
99 conf_.with_sum = true;
100 conf_.sum_scales.push(entry.sum.scale);
101 }
102 }
103 conf_.with_postops
104 = conf_.with_eltwise || conf_.with_binary || conf_.with_sum;
105
106 const format_tag_t src_md_desired_format = memory_desc_matches_one_of_tag(
107 *src_md(), x, nc, ncw, nchw, ncdhw);
108 const format_tag_t dst_md_desired_format = memory_desc_matches_one_of_tag(
109 *dst_md(), x, nc, ncw, nchw, ncdhw);
110 if (src_md_desired_format != dst_md_desired_format
111 || src_md_desired_format == format_tag::undef)
112 return status::unimplemented;
113
114 const int ndims = src_mdw.ndims();
115 const auto &src_dims = src_mdw.dims();
116 const auto &dst_dims = dst_mdw.dims();
117
118 conf_.is_saturation_needed = utils::one_of(conf_.dst_type, s32, s8, u8);
119
120 int num_of_reduced_dims = 0;
121 conf_.idle_size = dst_mdw.nelems();
122 conf_.reduce_size = 1;
123 for (int d = ndims - 1; d >= 0; --d) {
124 if (src_dims[d] != dst_dims[d]) {
125 num_of_reduced_dims++;
126 conf_.reduce_size *= src_dims[d];
127 } else
128 break;
129 }
130
131 if (num_of_reduced_dims == 0) return status::unimplemented;
132
133 for (int d = 0; d < ndims - num_of_reduced_dims; ++d)
134 if (src_dims[d] != dst_dims[d]) return status::unimplemented;
135
136 conf_.alg = desc()->alg_kind;
137 if (utils::one_of(conf_.alg, reduction_norm_lp_max, reduction_norm_lp_sum,
138 reduction_norm_lp_power_p_max, reduction_norm_lp_power_p_sum))
139 return status::unimplemented;
140
141 return status::success;
142}
143
144status_t jit_uni_reduction_t::init(engine_t *engine) {
145 using namespace format_tag;
146
147 const memory_desc_t *dst_md = pd()->dst_md();
148 const jit_reduction_conf_t &conf = pd()->get_conf();
149
150 CHECK(get_proper_kernel(dst_md, conf));
151 CHECK(kernel_->create_kernel());
152
153 return status::success;
154}
155
156status_t jit_uni_reduction_t::execute(const exec_ctx_t &ctx) const {
157 const auto src = CTX_IN_MEM(const uint8_t *, DNNL_ARG_SRC);
158 auto dst = CTX_OUT_MEM(uint8_t *, DNNL_ARG_DST);
159
160 const dim_t idle_size = pd()->get_conf().idle_size;
161 const dim_t reduce_size = pd()->get_conf().reduce_size;
162 const std::size_t src_dt_size = pd()->get_conf().src_dt_size;
163 const std::size_t dst_dt_size = pd()->get_conf().dst_dt_size;
164 const auto &post_ops = pd()->attr()->post_ops_;
165 const auto &post_ops_binary_rhs_arg_vec
166 = binary_injector::prepare_binary_args(post_ops, ctx);
167
168 parallel_nd(idle_size, [&](dim_t i) {
169 const dim_t src_off = i * reduce_size * src_dt_size;
170 const dim_t dst_off = i * dst_dt_size;
171
172 jit_reduction_call_s args = jit_reduction_call_s();
173 args.src = src + src_off;
174 args.dst = dst + dst_off;
175 args.dst_orig = dst;
176 args.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec.data();
177
178 (*kernel_)(&args);
179 });
180
181 return status::success;
182}
183
184status_t jit_uni_reduction_t::get_proper_kernel(
185 const memory_desc_t *dst_md, const jit_reduction_conf_t &conf) {
186 using namespace data_type;
187
188 if (conf.isa == avx512_core_fp16)
189 return safe_ptr_assign(kernel_,
190 new jit_uni_reduction_kernel_t<avx512_core_fp16>(conf, dst_md));
191 if (conf.isa == avx512_core_bf16)
192 return safe_ptr_assign(kernel_,
193 new jit_uni_reduction_kernel_t<avx512_core_bf16>(conf, dst_md));
194 else if (conf.isa == avx512_core)
195 return safe_ptr_assign(kernel_,
196 new jit_uni_reduction_kernel_t<avx512_core>(conf, dst_md));
197 else if (is_superset(conf.isa, avx)) {
198 const bool is_src_i8 = utils::one_of(conf.src_type, s8, u8);
199 const bool is_dst_i8 = utils::one_of(conf.dst_type, s8, u8);
200 if (conf.isa == avx2) {
201 if (is_src_i8 || is_dst_i8)
202 return safe_ptr_assign(kernel_,
203 new jit_uni_reduction_kernel_t<avx2, Xbyak::Xmm>(
204 conf, dst_md));
205 else
206 return safe_ptr_assign(kernel_,
207 new jit_uni_reduction_kernel_t<avx2>(conf, dst_md));
208 } else {
209 if (is_src_i8 || is_dst_i8)
210 return safe_ptr_assign(kernel_,
211 new jit_uni_reduction_kernel_t<avx, Xbyak::Xmm>(
212 conf, dst_md));
213 else
214 return safe_ptr_assign(kernel_,
215 new jit_uni_reduction_kernel_t<avx>(conf, dst_md));
216 }
217 } else if (conf.isa == sse41)
218 return safe_ptr_assign(
219 kernel_, new jit_uni_reduction_kernel_t<sse41>(conf, dst_md));
220 else
221 return status::runtime_error;
222}
223
224} // namespace x64
225} // namespace cpu
226} // namespace impl
227} // namespace dnnl
228