1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <cmath> |
17 | |
18 | #include "common/dnnl_thread.hpp" |
19 | #include "common/memory_desc_wrapper.hpp" |
20 | #include "common/type_helpers.hpp" |
21 | |
22 | #include "cpu/x64/cpu_isa_traits.hpp" |
23 | #include "cpu/x64/prelu/jit_prelu_backward.hpp" |
24 | #include "cpu/x64/prelu/jit_prelu_reduction_kernel.hpp" |
25 | #include "cpu/x64/prelu/jit_prelu_utils.hpp" |
26 | #include "cpu/x64/prelu/jit_uni_prelu_backward_kernel.hpp" |
27 | |
28 | namespace dnnl { |
29 | namespace impl { |
30 | namespace cpu { |
31 | namespace x64 { |
32 | |
33 | static constexpr dim_t alignment = platform::get_cache_line_size() |
34 | / sizeof(float); // align to cache line size to avoid false sharing |
35 | |
36 | status_t jit_prelu_bwd_t::pd_t::init(engine_t *engine) { |
37 | const memory_desc_wrapper src_d {src_md(0)}; |
38 | const memory_desc_wrapper weights_d {weights_md(0)}; |
39 | const memory_desc_wrapper src_diff_d {diff_src_md(0)}; |
40 | const memory_desc_wrapper weights_diff_d {diff_weights_md(0)}; |
41 | const memory_desc_wrapper dst_diff_d {diff_dst_md(0)}; |
42 | |
43 | bool ok = !is_fwd() && !has_zero_dim_memory() |
44 | && prelu::dt_supported({src_d.data_type(), weights_d.data_type(), |
45 | src_diff_d.data_type(), weights_diff_d.data_type(), |
46 | dst_diff_d.data_type()}) |
47 | && set_default_formats() && src_d.is_dense(true) |
48 | && weights_d.is_dense(true) && src_diff_d.is_dense(true) |
49 | && weights_diff_d.is_dense(true) && dst_diff_d.is_dense(true) |
50 | && attr()->has_default_values() |
51 | && utils::one_of(prelu::get_supported_isa(), avx512_core_fp16, |
52 | avx512_core_bf16, avx512_core, avx2, avx, sse41) |
53 | && dst_diff_d == src_diff_d; |
54 | if (!ok) return status::unimplemented; |
55 | |
56 | const auto bcast = prelu::get_bcast_type(src_diff_d, weights_diff_d); |
57 | |
58 | ok = ok |
59 | && bcast_supported(bcast, src_diff_d, weights_diff_d, |
60 | prelu::get_simd_w({src_d.data_type(), weights_d.data_type(), |
61 | src_diff_d.data_type(), weights_diff_d.data_type(), |
62 | dst_diff_d.data_type()})); |
63 | |
64 | if (ok) { |
65 | nthr_ = dnnl_get_max_threads(); |
66 | if (utils::one_of(bcast, prelu::bcast::per_oc_blocked, |
67 | prelu::bcast::per_oc_n_spatial_c, |
68 | prelu::bcast::per_oc_n_c_spatial)) { |
69 | auto scratchpad = scratchpad_registry().registrar(); |
70 | const dim_t C = src_diff_d.ndims() >= 2 ? src_diff_d.dims()[1] : 1; |
71 | scratchpad.book<float>(memory_tracking::names::key_prelu_reduction, |
72 | nthr_ * utils::rnd_up(C, alignment)); |
73 | } |
74 | |
75 | return status::success; |
76 | } |
77 | |
78 | return status::unimplemented; |
79 | } |
80 | |
81 | bool jit_prelu_bwd_t::pd_t::bcast_supported(const prelu::bcast &bcast, |
82 | const memory_desc_wrapper &src_diff_d, |
83 | const memory_desc_wrapper &weights_diff_d, int simd_w) const { |
84 | |
85 | if (bcast == prelu::bcast::full) |
86 | return true; |
87 | else if (bcast == prelu::bcast::unsupported) |
88 | return false; |
89 | else if (bcast == prelu::bcast::per_oc_blocked) { |
90 | const auto check_block_consistency |
91 | = [&](const memory_desc_wrapper &mdw) { |
92 | const auto &bd = mdw.blocking_desc(); |
93 | |
94 | return bd.inner_nblks == 1 && bd.inner_blks[0] == simd_w |
95 | && bd.inner_idxs[0] == 1; |
96 | }; |
97 | |
98 | return check_block_consistency(src_diff_d) |
99 | && check_block_consistency(weights_diff_d); |
100 | } else { |
101 | const auto &src_strides = src_diff_d.blocking_desc().strides; |
102 | const auto &weights_strides = weights_diff_d.blocking_desc().strides; |
103 | // C should be on second position in tag (example nchw or ncw) or on |
104 | // last postion (nhwc) |
105 | return src_strides[0] >= src_strides[1] |
106 | && IMPLICATION( |
107 | src_strides[1] > 1, src_strides[1] >= src_strides[2]) |
108 | && weights_strides[0] >= weights_strides[1]; |
109 | } |
110 | |
111 | return true; |
112 | } |
113 | |
114 | const jit_prelu_bwd_t::pd_t *jit_prelu_bwd_t::pd() const { |
115 | return static_cast<const pd_t *>(primitive_t::pd().get()); |
116 | } |
117 | |
118 | jit_prelu_bwd_t::jit_prelu_bwd_t(const pd_t *apd) : primitive_t(apd) {} |
119 | jit_prelu_bwd_t::~jit_prelu_bwd_t() = default; |
120 | |
121 | status_t jit_prelu_bwd_t::init(engine_t *engine) { |
122 | const memory_desc_wrapper weights_diff_d {pd()->diff_weights_md(0)}; |
123 | const memory_desc_wrapper src_diff_d {pd()->diff_src_md(0)}; |
124 | |
125 | const auto bcast = prelu::get_bcast_type(src_diff_d, weights_diff_d); |
126 | |
127 | CHECK(safe_ptr_assign(kernel_, jit_prelu_backward_kernel_t::create(pd()))); |
128 | if (utils::one_of(bcast, prelu::bcast::per_oc_blocked, |
129 | prelu::bcast::per_oc_n_spatial_c, |
130 | prelu::bcast::per_oc_n_c_spatial)) { |
131 | |
132 | CHECK(safe_ptr_assign( |
133 | reduction_kernel_, jit_prelu_reduction_kernel_t::create(pd()))); |
134 | CHECK(reduction_kernel_->create_kernel()); |
135 | } |
136 | |
137 | return kernel_->create_kernel(); |
138 | } |
139 | |
140 | status_t jit_prelu_bwd_t::execute(const exec_ctx_t &ctx) const { |
141 | const byte *const src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC); |
142 | const byte *const weights = CTX_IN_MEM(const byte *, DNNL_ARG_WEIGHTS); |
143 | const byte *const dst_diff = CTX_IN_MEM(const byte *, DNNL_ARG_DIFF_DST); |
144 | byte *const weights_diff = CTX_OUT_MEM(const byte *, DNNL_ARG_DIFF_WEIGHTS); |
145 | byte *const src_diff = CTX_OUT_MEM(byte *, DNNL_ARG_DIFF_SRC); |
146 | const memory_desc_wrapper src_d {pd()->src_md(0)}; |
147 | const auto src_dt_size = types::data_type_size(src_d.data_type()); |
148 | const auto wei_dt_size |
149 | = types::data_type_size(pd()->weights_md(0)->data_type); |
150 | const auto diff_wei_dt_size |
151 | = types::data_type_size(pd()->diff_weights_md(0)->data_type); |
152 | const auto diff_src_dt_size |
153 | = types::data_type_size(pd()->diff_src_md(0)->data_type); |
154 | const auto diff_dst_dt_size |
155 | = types::data_type_size(pd()->diff_dst_md(0)->data_type); |
156 | |
157 | const auto kernel = kernel_.get(); |
158 | const auto &bcast = kernel->get_bcast(); |
159 | const auto &simd_w = kernel->simd_w(); |
160 | int nthr = pd()->nthr_; |
161 | |
162 | if (bcast == prelu::bcast::full) { |
163 | const auto nelems = src_d.nelems(true); |
164 | const auto res = std::div(nelems, simd_w); |
165 | const auto &nelems_simd = res.quot; |
166 | const auto &nelems_tail = res.rem; |
167 | const auto nelems_parallel = nelems_simd + (nelems_tail ? 1 : 0); |
168 | |
169 | parallel(nthr, [&](const int ithr, const int nthr) { |
170 | dim_t start = 0, end = 0; |
171 | balance211(nelems_parallel, nthr, ithr, start, end); |
172 | if (start >= end) return; |
173 | |
174 | const bool ithr_process_tail |
175 | = nelems_tail && end == nelems_parallel; |
176 | const auto n_simd_size = (end - start - ithr_process_tail) * simd_w; |
177 | const auto offset = start * simd_w; |
178 | |
179 | jit_prelu_backward_kernel_t::call_params_t params; |
180 | |
181 | params.compute_data_size |
182 | = (n_simd_size + (nelems_tail ? nelems_tail : 0)); |
183 | params.src = src + offset * src_dt_size; |
184 | params.weights = weights + offset * wei_dt_size; |
185 | params.dst_diff = dst_diff + offset * diff_dst_dt_size; |
186 | params.src_diff = src_diff + offset * diff_src_dt_size; |
187 | params.weights_diff = weights_diff + offset * diff_wei_dt_size; |
188 | (*kernel)(¶ms); |
189 | }); |
190 | } else { |
191 | const auto ndims = src_d.ndims(); |
192 | const auto &dims = src_d.dims(); |
193 | const dim_t MB = dims[0]; |
194 | const dim_t C = ndims >= 2 ? dims[1] : 1; |
195 | const dim_t D = ndims >= 5 ? dims[ndims - 3] : 1; |
196 | const dim_t H = ndims >= 4 ? dims[ndims - 2] : 1; |
197 | const dim_t W = ndims >= 3 ? dims[ndims - 1] : 1; |
198 | const dim_t SP = D * H * W; |
199 | const dim_t nelems_single_mb |
200 | = utils::array_product(src_d.padded_dims() + 1, ndims - 1); |
201 | |
202 | auto scratchpad = ctx.get_scratchpad_grantor(); |
203 | float *const weights_diff_scratchpad = scratchpad.template get<float>( |
204 | memory_tracking::names::key_prelu_reduction); |
205 | const auto C_cache_line_aligned = utils::rnd_up(C, alignment); |
206 | size_t work_amount = 0; |
207 | |
208 | fill_scratchpad_zeros( |
209 | weights_diff_scratchpad, C_cache_line_aligned, nthr); |
210 | |
211 | if (bcast == prelu::bcast::per_oc_blocked) { |
212 | const dim_t C_blocks = std::ceil(static_cast<float>(C) / simd_w); |
213 | work_amount = MB * C_blocks; |
214 | parallel_nd_ext(nthr, MB, C_blocks, |
215 | [&](int ithr, int, dim_t mb, dim_t c_blk) { |
216 | jit_prelu_backward_kernel_t::call_params_t params; |
217 | params.compute_data_size = SP * simd_w; |
218 | const dim_t offset |
219 | = (mb * nelems_single_mb + c_blk * SP * simd_w); |
220 | params.src = src + offset * src_dt_size; |
221 | params.dst_diff = dst_diff + offset * diff_dst_dt_size; |
222 | params.src_diff = src_diff + offset * diff_src_dt_size; |
223 | params.weights = weights + c_blk * simd_w * wei_dt_size; |
224 | params.weights_diff = reinterpret_cast<void *>( |
225 | weights_diff_scratchpad |
226 | + ithr * C_cache_line_aligned + c_blk * simd_w); |
227 | |
228 | (*kernel)(¶ms); |
229 | }); |
230 | } else if (bcast == prelu::bcast::per_oc_n_c_spatial) { |
231 | work_amount = MB * C; |
232 | |
233 | parallel_nd_ext(nthr, MB, C, [&](int ithr, int, dim_t mb, dim_t c) { |
234 | jit_prelu_backward_kernel_t::call_params_t params; |
235 | const auto offset = (mb * nelems_single_mb + c * SP); |
236 | params.compute_data_size = SP; |
237 | params.src = src + offset * src_dt_size; |
238 | params.dst_diff = dst_diff + offset * diff_dst_dt_size; |
239 | params.src_diff = src_diff + offset * diff_src_dt_size; |
240 | params.weights = weights + c * wei_dt_size; |
241 | params.weights_diff |
242 | = reinterpret_cast<void *>(weights_diff_scratchpad |
243 | + ithr * C_cache_line_aligned + c); |
244 | (*kernel)(¶ms); |
245 | }); |
246 | } else if (bcast == prelu::bcast::per_oc_n_spatial_c) { |
247 | work_amount = MB * SP; |
248 | |
249 | parallel_nd_ext( |
250 | nthr, MB, SP, [&](int ithr, int, dim_t mb, dim_t sp) { |
251 | jit_prelu_backward_kernel_t::call_params_t params; |
252 | const auto offset = (mb * nelems_single_mb + sp * C); |
253 | params.compute_data_size = C; |
254 | params.src = src + offset * src_dt_size; |
255 | params.dst_diff = dst_diff + offset * diff_dst_dt_size; |
256 | params.src_diff = src_diff + offset * diff_src_dt_size; |
257 | params.weights = weights; |
258 | params.weights_diff = reinterpret_cast<void *>( |
259 | weights_diff_scratchpad |
260 | + ithr * C_cache_line_aligned); |
261 | (*kernel)(¶ms); |
262 | }); |
263 | } |
264 | |
265 | const size_t reduction_blocks = nstl::min(work_amount, (size_t)nthr); |
266 | scratchpad_to_diff_weights_reduction(weights_diff_scratchpad, |
267 | weights_diff, diff_wei_dt_size, C, reduction_blocks); |
268 | } |
269 | |
270 | return status::success; |
271 | } |
272 | |
273 | void jit_prelu_bwd_t::scratchpad_to_diff_weights_reduction(float *scratchpad, |
274 | byte *weights_diff, size_t weights_diff_dt, dim_t C, |
275 | size_t reduction_blocks) const { |
276 | const auto reduction_kernel = reduction_kernel_.get(); |
277 | const auto &simd_w = reduction_kernel_->simd_w(); |
278 | const bool tail_exists = C % simd_w; |
279 | const dim_t C_blocks = std::ceil(static_cast<float>(C) / simd_w); |
280 | |
281 | parallel_nd(C_blocks, [&](dim_t c_blk) { |
282 | const auto blk_offset = c_blk * simd_w; |
283 | jit_prelu_reduction_kernel_t::call_params_t params; |
284 | params.reduction_blocks = reduction_blocks; |
285 | params.weights_diff_scratch |
286 | = reinterpret_cast<void *>(scratchpad + blk_offset); |
287 | params.weights_diff = weights_diff + blk_offset * weights_diff_dt; |
288 | params.tail = tail_exists && c_blk == C_blocks - 1; |
289 | params.is_last_c_blk = c_blk == C_blocks - 1; |
290 | (*reduction_kernel)(¶ms); |
291 | }); |
292 | } |
293 | |
294 | void jit_prelu_bwd_t::fill_scratchpad_zeros(float *const scratchpad, |
295 | size_t thread_scratchpad_size, int nthr) const { |
296 | |
297 | parallel(nthr, [&](std::size_t ithr, std::size_t) { |
298 | float *scratchpad_ithr = scratchpad + ithr * thread_scratchpad_size; |
299 | std::memset(scratchpad_ithr, 0, thread_scratchpad_size * sizeof(float)); |
300 | }); |
301 | } |
302 | |
303 | } // namespace x64 |
304 | } // namespace cpu |
305 | } // namespace impl |
306 | } // namespace dnnl |
307 | |