1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <cmath>
17
18#include "common/dnnl_thread.hpp"
19#include "common/memory_desc_wrapper.hpp"
20#include "common/type_helpers.hpp"
21
22#include "cpu/x64/cpu_isa_traits.hpp"
23#include "cpu/x64/prelu/jit_prelu_backward.hpp"
24#include "cpu/x64/prelu/jit_prelu_reduction_kernel.hpp"
25#include "cpu/x64/prelu/jit_prelu_utils.hpp"
26#include "cpu/x64/prelu/jit_uni_prelu_backward_kernel.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33static constexpr dim_t alignment = platform::get_cache_line_size()
34 / sizeof(float); // align to cache line size to avoid false sharing
35
36status_t jit_prelu_bwd_t::pd_t::init(engine_t *engine) {
37 const memory_desc_wrapper src_d {src_md(0)};
38 const memory_desc_wrapper weights_d {weights_md(0)};
39 const memory_desc_wrapper src_diff_d {diff_src_md(0)};
40 const memory_desc_wrapper weights_diff_d {diff_weights_md(0)};
41 const memory_desc_wrapper dst_diff_d {diff_dst_md(0)};
42
43 bool ok = !is_fwd() && !has_zero_dim_memory()
44 && prelu::dt_supported({src_d.data_type(), weights_d.data_type(),
45 src_diff_d.data_type(), weights_diff_d.data_type(),
46 dst_diff_d.data_type()})
47 && set_default_formats() && src_d.is_dense(true)
48 && weights_d.is_dense(true) && src_diff_d.is_dense(true)
49 && weights_diff_d.is_dense(true) && dst_diff_d.is_dense(true)
50 && attr()->has_default_values()
51 && utils::one_of(prelu::get_supported_isa(), avx512_core_fp16,
52 avx512_core_bf16, avx512_core, avx2, avx, sse41)
53 && dst_diff_d == src_diff_d;
54 if (!ok) return status::unimplemented;
55
56 const auto bcast = prelu::get_bcast_type(src_diff_d, weights_diff_d);
57
58 ok = ok
59 && bcast_supported(bcast, src_diff_d, weights_diff_d,
60 prelu::get_simd_w({src_d.data_type(), weights_d.data_type(),
61 src_diff_d.data_type(), weights_diff_d.data_type(),
62 dst_diff_d.data_type()}));
63
64 if (ok) {
65 nthr_ = dnnl_get_max_threads();
66 if (utils::one_of(bcast, prelu::bcast::per_oc_blocked,
67 prelu::bcast::per_oc_n_spatial_c,
68 prelu::bcast::per_oc_n_c_spatial)) {
69 auto scratchpad = scratchpad_registry().registrar();
70 const dim_t C = src_diff_d.ndims() >= 2 ? src_diff_d.dims()[1] : 1;
71 scratchpad.book<float>(memory_tracking::names::key_prelu_reduction,
72 nthr_ * utils::rnd_up(C, alignment));
73 }
74
75 return status::success;
76 }
77
78 return status::unimplemented;
79}
80
81bool jit_prelu_bwd_t::pd_t::bcast_supported(const prelu::bcast &bcast,
82 const memory_desc_wrapper &src_diff_d,
83 const memory_desc_wrapper &weights_diff_d, int simd_w) const {
84
85 if (bcast == prelu::bcast::full)
86 return true;
87 else if (bcast == prelu::bcast::unsupported)
88 return false;
89 else if (bcast == prelu::bcast::per_oc_blocked) {
90 const auto check_block_consistency
91 = [&](const memory_desc_wrapper &mdw) {
92 const auto &bd = mdw.blocking_desc();
93
94 return bd.inner_nblks == 1 && bd.inner_blks[0] == simd_w
95 && bd.inner_idxs[0] == 1;
96 };
97
98 return check_block_consistency(src_diff_d)
99 && check_block_consistency(weights_diff_d);
100 } else {
101 const auto &src_strides = src_diff_d.blocking_desc().strides;
102 const auto &weights_strides = weights_diff_d.blocking_desc().strides;
103 // C should be on second position in tag (example nchw or ncw) or on
104 // last postion (nhwc)
105 return src_strides[0] >= src_strides[1]
106 && IMPLICATION(
107 src_strides[1] > 1, src_strides[1] >= src_strides[2])
108 && weights_strides[0] >= weights_strides[1];
109 }
110
111 return true;
112}
113
114const jit_prelu_bwd_t::pd_t *jit_prelu_bwd_t::pd() const {
115 return static_cast<const pd_t *>(primitive_t::pd().get());
116}
117
118jit_prelu_bwd_t::jit_prelu_bwd_t(const pd_t *apd) : primitive_t(apd) {}
119jit_prelu_bwd_t::~jit_prelu_bwd_t() = default;
120
121status_t jit_prelu_bwd_t::init(engine_t *engine) {
122 const memory_desc_wrapper weights_diff_d {pd()->diff_weights_md(0)};
123 const memory_desc_wrapper src_diff_d {pd()->diff_src_md(0)};
124
125 const auto bcast = prelu::get_bcast_type(src_diff_d, weights_diff_d);
126
127 CHECK(safe_ptr_assign(kernel_, jit_prelu_backward_kernel_t::create(pd())));
128 if (utils::one_of(bcast, prelu::bcast::per_oc_blocked,
129 prelu::bcast::per_oc_n_spatial_c,
130 prelu::bcast::per_oc_n_c_spatial)) {
131
132 CHECK(safe_ptr_assign(
133 reduction_kernel_, jit_prelu_reduction_kernel_t::create(pd())));
134 CHECK(reduction_kernel_->create_kernel());
135 }
136
137 return kernel_->create_kernel();
138}
139
140status_t jit_prelu_bwd_t::execute(const exec_ctx_t &ctx) const {
141 const byte *const src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC);
142 const byte *const weights = CTX_IN_MEM(const byte *, DNNL_ARG_WEIGHTS);
143 const byte *const dst_diff = CTX_IN_MEM(const byte *, DNNL_ARG_DIFF_DST);
144 byte *const weights_diff = CTX_OUT_MEM(const byte *, DNNL_ARG_DIFF_WEIGHTS);
145 byte *const src_diff = CTX_OUT_MEM(byte *, DNNL_ARG_DIFF_SRC);
146 const memory_desc_wrapper src_d {pd()->src_md(0)};
147 const auto src_dt_size = types::data_type_size(src_d.data_type());
148 const auto wei_dt_size
149 = types::data_type_size(pd()->weights_md(0)->data_type);
150 const auto diff_wei_dt_size
151 = types::data_type_size(pd()->diff_weights_md(0)->data_type);
152 const auto diff_src_dt_size
153 = types::data_type_size(pd()->diff_src_md(0)->data_type);
154 const auto diff_dst_dt_size
155 = types::data_type_size(pd()->diff_dst_md(0)->data_type);
156
157 const auto kernel = kernel_.get();
158 const auto &bcast = kernel->get_bcast();
159 const auto &simd_w = kernel->simd_w();
160 int nthr = pd()->nthr_;
161
162 if (bcast == prelu::bcast::full) {
163 const auto nelems = src_d.nelems(true);
164 const auto res = std::div(nelems, simd_w);
165 const auto &nelems_simd = res.quot;
166 const auto &nelems_tail = res.rem;
167 const auto nelems_parallel = nelems_simd + (nelems_tail ? 1 : 0);
168
169 parallel(nthr, [&](const int ithr, const int nthr) {
170 dim_t start = 0, end = 0;
171 balance211(nelems_parallel, nthr, ithr, start, end);
172 if (start >= end) return;
173
174 const bool ithr_process_tail
175 = nelems_tail && end == nelems_parallel;
176 const auto n_simd_size = (end - start - ithr_process_tail) * simd_w;
177 const auto offset = start * simd_w;
178
179 jit_prelu_backward_kernel_t::call_params_t params;
180
181 params.compute_data_size
182 = (n_simd_size + (nelems_tail ? nelems_tail : 0));
183 params.src = src + offset * src_dt_size;
184 params.weights = weights + offset * wei_dt_size;
185 params.dst_diff = dst_diff + offset * diff_dst_dt_size;
186 params.src_diff = src_diff + offset * diff_src_dt_size;
187 params.weights_diff = weights_diff + offset * diff_wei_dt_size;
188 (*kernel)(&params);
189 });
190 } else {
191 const auto ndims = src_d.ndims();
192 const auto &dims = src_d.dims();
193 const dim_t MB = dims[0];
194 const dim_t C = ndims >= 2 ? dims[1] : 1;
195 const dim_t D = ndims >= 5 ? dims[ndims - 3] : 1;
196 const dim_t H = ndims >= 4 ? dims[ndims - 2] : 1;
197 const dim_t W = ndims >= 3 ? dims[ndims - 1] : 1;
198 const dim_t SP = D * H * W;
199 const dim_t nelems_single_mb
200 = utils::array_product(src_d.padded_dims() + 1, ndims - 1);
201
202 auto scratchpad = ctx.get_scratchpad_grantor();
203 float *const weights_diff_scratchpad = scratchpad.template get<float>(
204 memory_tracking::names::key_prelu_reduction);
205 const auto C_cache_line_aligned = utils::rnd_up(C, alignment);
206 size_t work_amount = 0;
207
208 fill_scratchpad_zeros(
209 weights_diff_scratchpad, C_cache_line_aligned, nthr);
210
211 if (bcast == prelu::bcast::per_oc_blocked) {
212 const dim_t C_blocks = std::ceil(static_cast<float>(C) / simd_w);
213 work_amount = MB * C_blocks;
214 parallel_nd_ext(nthr, MB, C_blocks,
215 [&](int ithr, int, dim_t mb, dim_t c_blk) {
216 jit_prelu_backward_kernel_t::call_params_t params;
217 params.compute_data_size = SP * simd_w;
218 const dim_t offset
219 = (mb * nelems_single_mb + c_blk * SP * simd_w);
220 params.src = src + offset * src_dt_size;
221 params.dst_diff = dst_diff + offset * diff_dst_dt_size;
222 params.src_diff = src_diff + offset * diff_src_dt_size;
223 params.weights = weights + c_blk * simd_w * wei_dt_size;
224 params.weights_diff = reinterpret_cast<void *>(
225 weights_diff_scratchpad
226 + ithr * C_cache_line_aligned + c_blk * simd_w);
227
228 (*kernel)(&params);
229 });
230 } else if (bcast == prelu::bcast::per_oc_n_c_spatial) {
231 work_amount = MB * C;
232
233 parallel_nd_ext(nthr, MB, C, [&](int ithr, int, dim_t mb, dim_t c) {
234 jit_prelu_backward_kernel_t::call_params_t params;
235 const auto offset = (mb * nelems_single_mb + c * SP);
236 params.compute_data_size = SP;
237 params.src = src + offset * src_dt_size;
238 params.dst_diff = dst_diff + offset * diff_dst_dt_size;
239 params.src_diff = src_diff + offset * diff_src_dt_size;
240 params.weights = weights + c * wei_dt_size;
241 params.weights_diff
242 = reinterpret_cast<void *>(weights_diff_scratchpad
243 + ithr * C_cache_line_aligned + c);
244 (*kernel)(&params);
245 });
246 } else if (bcast == prelu::bcast::per_oc_n_spatial_c) {
247 work_amount = MB * SP;
248
249 parallel_nd_ext(
250 nthr, MB, SP, [&](int ithr, int, dim_t mb, dim_t sp) {
251 jit_prelu_backward_kernel_t::call_params_t params;
252 const auto offset = (mb * nelems_single_mb + sp * C);
253 params.compute_data_size = C;
254 params.src = src + offset * src_dt_size;
255 params.dst_diff = dst_diff + offset * diff_dst_dt_size;
256 params.src_diff = src_diff + offset * diff_src_dt_size;
257 params.weights = weights;
258 params.weights_diff = reinterpret_cast<void *>(
259 weights_diff_scratchpad
260 + ithr * C_cache_line_aligned);
261 (*kernel)(&params);
262 });
263 }
264
265 const size_t reduction_blocks = nstl::min(work_amount, (size_t)nthr);
266 scratchpad_to_diff_weights_reduction(weights_diff_scratchpad,
267 weights_diff, diff_wei_dt_size, C, reduction_blocks);
268 }
269
270 return status::success;
271}
272
273void jit_prelu_bwd_t::scratchpad_to_diff_weights_reduction(float *scratchpad,
274 byte *weights_diff, size_t weights_diff_dt, dim_t C,
275 size_t reduction_blocks) const {
276 const auto reduction_kernel = reduction_kernel_.get();
277 const auto &simd_w = reduction_kernel_->simd_w();
278 const bool tail_exists = C % simd_w;
279 const dim_t C_blocks = std::ceil(static_cast<float>(C) / simd_w);
280
281 parallel_nd(C_blocks, [&](dim_t c_blk) {
282 const auto blk_offset = c_blk * simd_w;
283 jit_prelu_reduction_kernel_t::call_params_t params;
284 params.reduction_blocks = reduction_blocks;
285 params.weights_diff_scratch
286 = reinterpret_cast<void *>(scratchpad + blk_offset);
287 params.weights_diff = weights_diff + blk_offset * weights_diff_dt;
288 params.tail = tail_exists && c_blk == C_blocks - 1;
289 params.is_last_c_blk = c_blk == C_blocks - 1;
290 (*reduction_kernel)(&params);
291 });
292}
293
294void jit_prelu_bwd_t::fill_scratchpad_zeros(float *const scratchpad,
295 size_t thread_scratchpad_size, int nthr) const {
296
297 parallel(nthr, [&](std::size_t ithr, std::size_t) {
298 float *scratchpad_ithr = scratchpad + ithr * thread_scratchpad_size;
299 std::memset(scratchpad_ithr, 0, thread_scratchpad_size * sizeof(float));
300 });
301}
302
303} // namespace x64
304} // namespace cpu
305} // namespace impl
306} // namespace dnnl
307