1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18#include <cmath>
19
20#include "common/dnnl_thread.hpp"
21#include "common/memory_desc_wrapper.hpp"
22#include "common/type_helpers.hpp"
23#include "cpu/x64/cpu_isa_traits.hpp"
24#include "cpu/x64/prelu/jit_prelu_forward.hpp"
25#include "cpu/x64/prelu/jit_prelu_utils.hpp"
26#include "cpu/x64/prelu/jit_uni_prelu_forward_kernel.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33status_t jit_prelu_fwd_t::pd_t::init(engine_t *engine) {
34 const memory_desc_wrapper src_d {src_md(0)};
35 const memory_desc_wrapper weights_d {weights_md(0)};
36 const memory_desc_wrapper dst_d {dst_md(0)};
37
38 const bool ok = is_fwd()
39 && prelu::dt_supported({src_d.data_type(), weights_d.data_type(),
40 dst_d.data_type()})
41 && set_default_formats() && bcast_supported(src_d, weights_d, dst_d)
42 && !has_zero_dim_memory() && src_d.is_dense(true)
43 && weights_d.is_dense(true) && attr()->has_default_values()
44 && utils::one_of(prelu::get_supported_isa(), avx512_core_fp16,
45 avx512_core_bf16, avx512_core, avx2, avx, sse41)
46 && dst_d == src_d;
47
48 return ok ? status::success : status::unimplemented;
49}
50
51bool jit_prelu_fwd_t::pd_t::bcast_supported(const memory_desc_wrapper &src_d,
52 const memory_desc_wrapper &weights_d,
53 const memory_desc_wrapper &dst_d) const {
54
55 const auto bcast = prelu::get_bcast_type(src_d, weights_d);
56 if (bcast == prelu::bcast::full)
57 return true;
58 else if (bcast == prelu::bcast::unsupported)
59 return false;
60 else if (bcast == prelu::bcast::per_oc_blocked) {
61 const int simd_w = prelu::get_simd_w(
62 {src_d.data_type(), weights_d.data_type(), dst_d.data_type()});
63
64 const auto check_block_consistency
65 = [&](const memory_desc_wrapper &mdw) {
66 const auto &bd = mdw.blocking_desc();
67
68 return bd.inner_nblks == 1 && bd.inner_blks[0] == simd_w
69 && bd.inner_idxs[0] == 1;
70 };
71
72 return check_block_consistency(src_d)
73 && check_block_consistency(weights_d);
74 } else {
75 const auto &src_strides = src_d.blocking_desc().strides;
76 const auto &weights_strides = weights_d.blocking_desc().strides;
77 // C should be on second position in tag (example nchw or ncw) or on
78 // last postion (nhwc)
79 return src_strides[0] >= src_strides[1]
80 && IMPLICATION(
81 src_strides[1] > 1, src_strides[1] >= src_strides[2])
82 && weights_strides[0] >= weights_strides[1];
83 }
84
85 return true;
86}
87
88const jit_prelu_fwd_t::pd_t *jit_prelu_fwd_t::pd() const {
89 return static_cast<const pd_t *>(primitive_t::pd().get());
90}
91
92jit_prelu_fwd_t::jit_prelu_fwd_t(const pd_t *apd) : primitive_t(apd) {}
93jit_prelu_fwd_t::~jit_prelu_fwd_t() = default;
94
95status_t jit_prelu_fwd_t::init(engine_t *engine) {
96 CHECK(safe_ptr_assign(kernel_, jit_prelu_forward_kernel_t::create(pd())));
97 return kernel_->create_kernel();
98}
99
100status_t jit_prelu_fwd_t::execute(const exec_ctx_t &ctx) const {
101 using byte = unsigned char;
102 const byte *const src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC);
103 const byte *const weights = CTX_IN_MEM(const byte *, DNNL_ARG_WEIGHTS);
104 byte *const dst = CTX_OUT_MEM(byte *, DNNL_ARG_DST);
105 const memory_desc_wrapper src_d {pd()->src_md(0)};
106
107 const auto src_dt_size = types::data_type_size(src_d.data_type());
108 const auto weights_dt_size
109 = types::data_type_size(pd()->weights_md(0)->data_type);
110 const auto dst_dt_size = types::data_type_size(pd()->dst_md(0)->data_type);
111
112 const auto kernel = kernel_.get();
113 const auto bcast = kernel->get_bcast();
114 const auto ndims = src_d.ndims();
115 const dim_t MB = pd()->N();
116 const dim_t C = pd()->C();
117 const dim_t D = pd()->D();
118 const dim_t H = pd()->H();
119 const dim_t W = pd()->W();
120 const dim_t SP = D * H * W;
121
122 if (bcast == prelu::bcast::full) {
123 const auto nelems = src_d.nelems(true);
124 const auto simd_w = kernel->simd_w();
125 const auto res = std::div(nelems, simd_w);
126 const auto &nelems_simd = res.quot;
127 const auto &nelems_tail = res.rem;
128 const auto nelems_parallel = nelems_simd + (nelems_tail ? 1 : 0);
129
130 parallel(0, [&](const int ithr, const int nthr) {
131 dim_t start = 0, end = 0;
132 balance211(nelems_parallel, nthr, ithr, start, end);
133 if (start >= end) return;
134
135 const bool ithr_process_tail
136 = nelems_tail && end == nelems_parallel;
137 const auto n_simd_size = (end - start - ithr_process_tail) * simd_w;
138 const auto offset = start * simd_w;
139
140 jit_prelu_forward_kernel_t::call_params_t params;
141
142 params.compute_data_size
143 = (n_simd_size + (nelems_tail ? nelems_tail : 0));
144 params.src = src + (offset * src_dt_size);
145 params.weights = weights + (offset * weights_dt_size);
146 params.dst = dst + (offset * dst_dt_size);
147
148 (*kernel)(&params);
149 });
150 } else {
151
152 const dim_t nelems_single_mb
153 = utils::array_product(src_d.padded_dims() + 1, ndims - 1);
154
155 if (bcast == prelu::bcast::per_oc_n_spatial_c) {
156 parallel_nd(MB, SP, [&](dim_t mb, dim_t sp) {
157 const auto offset = (mb * nelems_single_mb + sp * C);
158 jit_prelu_forward_kernel_t::call_params_t params;
159 params.compute_data_size = C;
160 params.src = src + offset * src_dt_size;
161 params.weights = weights;
162 params.dst = dst + offset * dst_dt_size;
163 (*kernel)(&params);
164 });
165 } else if (bcast == prelu::bcast::per_oc_n_c_spatial) {
166 parallel_nd(MB, C, [&](dim_t mb, dim_t c) {
167 jit_prelu_forward_kernel_t::call_params_t params;
168 const auto offset = (mb * nelems_single_mb + c * SP);
169 params.compute_data_size = SP;
170 params.src = src + offset * src_dt_size;
171 params.weights = weights + c * weights_dt_size;
172 params.dst = dst + offset * dst_dt_size;
173 (*kernel)(&params);
174 });
175 } else if (bcast == prelu::bcast::per_oc_blocked) {
176 const auto simd_w = kernel->simd_w();
177 const dim_t C_blocks = std::ceil(static_cast<float>(C) / simd_w);
178
179 parallel_nd(MB, C_blocks, [&](dim_t mb, dim_t c_blk) {
180 jit_prelu_forward_kernel_t::call_params_t params;
181 params.compute_data_size = SP * simd_w;
182 const dim_t offset
183 = (mb * nelems_single_mb + c_blk * SP * simd_w);
184
185 params.src = src + offset * src_dt_size;
186 params.weights = weights + c_blk * simd_w * weights_dt_size;
187 params.dst = dst + offset * dst_dt_size;
188 (*kernel)(&params);
189 });
190 }
191 }
192 return status::success;
193}
194
195} // namespace x64
196} // namespace cpu
197} // namespace impl
198} // namespace dnnl
199