1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <array>
18#include <cassert>
19#include <cmath>
20
21#include "common/broadcast_strategy.hpp"
22#include "common/c_types_map.hpp"
23#include "common/dnnl_thread.hpp"
24#include "common/math_utils.hpp"
25#include "common/type_helpers.hpp"
26#include "common/utils.hpp"
27
28#include "cpu/ref_io_helper.hpp"
29#include "cpu/ref_prelu.hpp"
30
31namespace dnnl {
32namespace impl {
33namespace cpu {
34
35using namespace math;
36using namespace data_type;
37
38static constexpr int max_supported_ndims = 5;
39
40static dim_t offset(const memory_desc_wrapper &mem, dims_t dims) {
41 const int ndims = mem.ndims();
42 switch (ndims) {
43 case 1: return mem.off(dims[0]);
44 case 2: return mem.off(dims[0], dims[1]);
45 case 3: return mem.off(dims[0], dims[1], dims[2]);
46 case 4: return mem.off(dims[0], dims[1], dims[2], dims[3]);
47 case 5: return mem.off(dims[0], dims[1], dims[2], dims[3], dims[4]);
48 default: assert(!"Unsupported ndims count");
49 }
50 return -1;
51}
52
53static dim_t weights_offset(
54 const int mask, const memory_desc_wrapper &mem, dims_t &dims) {
55 dims_t dims_w {};
56 std::copy(dims, dims + max_supported_ndims, dims_w);
57 utils::apply_mask_on_dims(dims_w, mem.ndims(), mask);
58 return offset(mem, dims_w);
59}
60
61static bool is_padding(const memory_desc_wrapper &md) {
62 for (int i = 0; i < md.ndims(); i++)
63 if (md.dims()[i] != md.padded_dims()[i]) return true;
64 return false;
65}
66
67status_t ref_prelu_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
68 if (pd()->has_zero_dim_memory()) return status::success;
69
70 const auto src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC);
71 const auto weights = CTX_IN_MEM(const byte *, DNNL_ARG_WEIGHTS);
72 auto dst = CTX_OUT_MEM(byte *, DNNL_ARG_DST);
73
74 const memory_desc_wrapper data_d(pd()->src_md(0));
75 const memory_desc_wrapper weights_d(pd()->weights_md(0));
76 const auto is_inplace = (src == dst);
77 const auto has_padding = is_padding(data_d);
78 if (has_padding && !is_inplace) ctx.zero_pad_output(DNNL_ARG_TO);
79
80 const int mask = utils::get_dims_mask(
81 data_d.dims(), weights_d.dims(), data_d.ndims());
82 const dim_t work_amount = data_d.nelems();
83
84 parallel(0, [&](std::size_t ithr, std::size_t nthr) {
85 if ((dim_t)ithr >= work_amount) return;
86
87 dim_t start {0}, end {0};
88 dims_t dims_d, off;
89 for (int i = 0; i < max_supported_ndims; i++) {
90 off[i] = 0;
91 dims_d[i] = (data_d.dims()[i] != 0) ? data_d.dims()[i] : 1;
92 }
93
94 balance211(work_amount, nthr, ithr, start, end);
95 utils::nd_iterator_init(start, off[0], dims_d[0], off[1], dims_d[1],
96 off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]);
97
98 for (dim_t iwork = start; iwork < end; ++iwork) {
99 const auto data_off = offset(data_d, off);
100 const auto weight_off = weights_offset(mask, weights_d, off);
101 const float src_val
102 = io::load_float_value(data_d.data_type(), src, data_off);
103 const float weights_val = io::load_float_value(
104 weights_d.data_type(), weights, weight_off);
105
106 const float res = relu_fwd(src_val, weights_val);
107
108 io::store_float_value(data_d.data_type(), res, dst, data_off);
109 utils::nd_iterator_step(off[0], dims_d[0], off[1], dims_d[1],
110 off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]);
111 }
112 });
113 return status::success;
114}
115
116static float reduce(float *mem, dim_t size) {
117 bool tail = size % 2;
118 const auto reduce_iteration = [&](float *mem) {
119 const auto div_res = std::div(size, (dim_t)2);
120 tail = div_res.rem;
121 size = div_res.quot;
122 if (!tail && !size) {
123 mem[0] = 0;
124 return;
125 }
126 dim_t i {0}, off {0};
127 if (tail) {
128 if (size) mem[0] += mem[1 + off] + mem[2 + off];
129 ++off;
130 ++i;
131 }
132 for (; i < size; i++) {
133 mem[i] = mem[2 * i + off] + mem[(2 * i + 1) + off];
134 }
135 };
136 while (size > 1) {
137 reduce_iteration(mem);
138 }
139 return mem[0];
140}
141
142namespace prelu {
143void set_reduction_buffers(
144 const dim_t work_amount, dim_t &group_size, dim_t &buf_size) {
145 float sqrt = std::sqrt(work_amount);
146 group_size = std::ceil(sqrt);
147 buf_size = std::floor(sqrt);
148 if (group_size * buf_size < work_amount) group_size++;
149}
150
151dim_t get_scalar_scratchpad_offset(const std::size_t ithr,
152 const std::size_t nthr, const dim_t work_amount) {
153 dim_t offset {0}, group_size, buf_size;
154 for (std::size_t i = 0; i < ithr; i++) {
155 dim_t start {0}, end {0};
156 balance211(work_amount, nthr, i, start, end);
157 const dim_t workload = end - start;
158 set_reduction_buffers(workload, group_size, buf_size);
159 offset += buf_size;
160 offset += group_size;
161 }
162 return offset;
163}
164} // namespace prelu
165
166float ref_prelu_bwd_t::ker(const byte *src, const byte *weights,
167 const byte *diff_dst, byte *diff_src, dim_t data_off,
168 dim_t weight_off) const {
169
170 const auto dtype = pd()->src_md(0)->data_type;
171 const auto wtype = pd()->weights_md(0)->data_type;
172 const float src_val = io::load_float_value(dtype, src, data_off);
173 const float diff_dst_val = io::load_float_value(dtype, diff_dst, data_off);
174 const float weights_val = io::load_float_value(wtype, weights, weight_off);
175
176 const float diff_src_res
177 = relu_bwd_use_dst(diff_dst_val, src_val, weights_val);
178 const float diff_weight_res = src_val > 0 ? 0 : (diff_dst_val * src_val);
179
180 io::store_float_value(dtype, diff_src_res, diff_src, data_off);
181
182 return diff_weight_res;
183}
184
185void ref_prelu_bwd_t::calculate_scalar(const byte *src, const byte *weights,
186 byte *diff_weights, const byte *diff_dst, byte *diff_src,
187 float *scratchpad_buf) const {
188
189 const memory_desc_wrapper data_d(pd()->src_md(0));
190 const memory_desc_wrapper weights_d(pd()->weights_md(0));
191
192 const int nthr = pd()->nthr_;
193 const dim_t work_amount = data_d.nelems();
194 const int thread_count = nstl::min((dim_t)nthr, work_amount);
195
196 std::vector<float> buf_nthr_partial_results(nthr);
197
198 parallel(nthr, [&](std::size_t ithr, std::size_t nthr) {
199 if ((dim_t)ithr >= work_amount) return;
200
201 dim_t start {0}, end {0};
202 dims_t dims_d, off;
203 for (int i = 0; i < max_supported_ndims; i++) {
204 off[i] = 0;
205 dims_d[i] = (data_d.dims()[i] != 0) ? data_d.dims()[i] : 1;
206 }
207
208 balance211(work_amount, nthr, ithr, start, end);
209 const dim_t workload = end - start;
210
211 utils::nd_iterator_init(start, off[0], dims_d[0], off[1], dims_d[1],
212 off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]);
213
214 dim_t group_size, buf_size;
215 prelu::set_reduction_buffers(workload, group_size, buf_size);
216
217 const dim_t scratchpad_offset
218 = prelu::get_scalar_scratchpad_offset(ithr, nthr, work_amount);
219 auto *buf = &scratchpad_buf[scratchpad_offset];
220 auto *group_buf = &scratchpad_buf[scratchpad_offset + buf_size];
221
222 dim_t offset_buf {0}, group_off {0}, data_size {buf_size};
223 for (dim_t iwork = start; iwork < end; ++iwork) {
224 const auto data_off = offset(data_d, off);
225 const auto weight_off = 0;
226 buf[offset_buf] = ker(
227 src, weights, diff_dst, diff_src, data_off, weight_off);
228 if (++offset_buf == data_size) {
229 group_buf[group_off++] = reduce(buf, offset_buf);
230 offset_buf = 0;
231 data_size = ((group_off + 1) * buf_size <= workload)
232 ? buf_size
233 : workload - (group_off * buf_size);
234 }
235 utils::nd_iterator_step(off[0], dims_d[0], off[1], dims_d[1],
236 off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]);
237 }
238 buf_nthr_partial_results[ithr] = reduce(group_buf, group_size);
239 });
240 io::store_float_value(weights_d.data_type(),
241 reduce(&buf_nthr_partial_results[0], thread_count), diff_weights,
242 0);
243}
244
245void ref_prelu_bwd_t::calculate_no_broadcast(const byte *src,
246 const byte *weights, byte *diff_weights, const byte *diff_dst,
247 byte *diff_src, float *scratchpad_buf) const {
248
249 const memory_desc_wrapper data_d(pd()->src_md(0));
250 const memory_desc_wrapper weights_d(pd()->weights_md(0));
251
252 const int nthr = pd()->nthr_;
253 const dim_t work_amount = data_d.nelems();
254 const int mask = utils::get_dims_mask(
255 data_d.dims(), weights_d.dims(), data_d.ndims());
256
257 parallel(nthr, [&](std::size_t ithr, std::size_t nthr) {
258 if ((dim_t)ithr >= work_amount) return;
259
260 dim_t start {0}, end {0};
261 dims_t dims_d, off;
262 for (int i = 0; i < max_supported_ndims; i++) {
263 off[i] = 0;
264 dims_d[i] = (data_d.dims()[i] != 0) ? data_d.dims()[i] : 1;
265 }
266
267 balance211(work_amount, nthr, ithr, start, end);
268 utils::nd_iterator_init(start, off[0], dims_d[0], off[1], dims_d[1],
269 off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]);
270
271 for (dim_t iwork = start; iwork < end; ++iwork) {
272 const auto data_off = offset(data_d, off);
273 const auto weight_off = weights_offset(mask, weights_d, off);
274 const auto res = ker(
275 src, weights, diff_dst, diff_src, data_off, weight_off);
276
277 io::store_float_value(
278 weights_d.data_type(), res, diff_weights, weight_off);
279 utils::nd_iterator_step(off[0], dims_d[0], off[1], dims_d[1],
280 off[2], dims_d[2], off[3], dims_d[3], off[4], dims_d[4]);
281 }
282 });
283}
284
285void ref_prelu_bwd_t::calculate_shared_axes(const byte *src,
286 const byte *weights, byte *diff_weights, const byte *diff_dst,
287 byte *diff_src, float *scratchpad_buf) const {
288
289 const memory_desc_wrapper data_d(pd()->src_md(0));
290 const memory_desc_wrapper weights_d(pd()->weights_md(0));
291
292 dims_t dims_d, dims_w;
293 for (int i = 0; i < max_supported_ndims; i++) {
294 dims_d[i] = (data_d.dims()[i] != 0) ? data_d.dims()[i] : 1;
295 dims_w[i] = (weights_d.dims()[i] != 0) ? weights_d.dims()[i] : 1;
296 }
297
298 const int nthr = pd()->nthr_;
299 const dim_t work_amount = weights_d.nelems();
300
301 parallel(nthr, [&](std::size_t ithr, std::size_t nthr) {
302 if ((dim_t)ithr >= work_amount) return;
303
304 dim_t start {0}, end {0};
305 balance211(work_amount, nthr, ithr, start, end);
306
307 dim_t group_size, buf_size;
308 const dim_t workload = data_d.nelems() / weights_d.nelems();
309 prelu::set_reduction_buffers(workload, group_size, buf_size);
310 dim_t scratchpad_offset = (buf_size + group_size) * ithr;
311 auto *buf = &scratchpad_buf[scratchpad_offset];
312 auto *group_buf = &scratchpad_buf[scratchpad_offset + buf_size];
313
314 dims_t off_w, off_d, dims_start, dims_end;
315 utils::nd_iterator_init(start, off_w[0], dims_w[0], off_w[1], dims_w[1],
316 off_w[2], dims_w[2], off_w[3], dims_w[3], off_w[4], dims_w[4]);
317
318 for (dim_t iwork = start; iwork < end; ++iwork) {
319 const auto weight_off = offset(weights_d, off_w);
320 for (int i = 0; i < max_supported_ndims; i++) {
321 dims_start[i] = (dims_d[i] == dims_w[i]) ? off_w[i] : 0;
322 dims_end[i]
323 = (dims_d[i] == dims_w[i]) ? off_w[i] + 1 : dims_d[i];
324 }
325 dim_t buf_off {0}, group_off {0}, data_size {buf_size};
326 for_(off_d[0] = dims_start[0]; off_d[0] < dims_end[0]; ++off_d[0])
327 for_(off_d[1] = dims_start[1]; off_d[1] < dims_end[1]; ++off_d[1])
328 for_(off_d[2] = dims_start[2]; off_d[2] < dims_end[2]; ++off_d[2])
329 for_(off_d[3] = dims_start[3]; off_d[3] < dims_end[3]; ++off_d[3])
330 for (off_d[4] = dims_start[4]; off_d[4] < dims_end[4]; ++off_d[4]) {
331 const auto data_off = offset(data_d, off_d);
332 const auto diff_weight = ker(
333 src, weights, diff_dst, diff_src, data_off, weight_off);
334 buf[buf_off] = diff_weight;
335 if (++buf_off == data_size) {
336 group_buf[group_off++] = reduce(buf, buf_off);
337 buf_off = 0;
338 data_size = ((group_off + 1) * buf_size <= workload)
339 ? buf_size
340 : workload - (group_off * buf_size);
341 }
342 }
343 io::store_float_value(weights_d.data_type(),
344 reduce(group_buf, group_size), diff_weights, weight_off);
345 utils::nd_iterator_step(off_w[0], dims_w[0], off_w[1], dims_w[1],
346 off_w[2], dims_w[2], off_w[3], dims_w[3], off_w[4],
347 dims_w[4]);
348 }
349 });
350}
351
352status_t ref_prelu_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
353
354 if (pd()->has_zero_dim_memory()) return status::success;
355
356 const auto scratchpad = ctx.get_scratchpad_grantor();
357 auto scratchpad_buf = scratchpad.template get<float>(
358 memory_tracking::names::key_prelu_reduction);
359
360 const auto src = CTX_IN_MEM(const byte *, DNNL_ARG_SRC);
361 const auto weights = CTX_IN_MEM(const byte *, DNNL_ARG_WEIGHTS);
362 auto diff_weights = CTX_OUT_MEM(byte *, DNNL_ARG_DIFF_WEIGHTS);
363 const auto diff_dst = CTX_IN_MEM(const byte *, DNNL_ARG_DIFF_DST);
364 auto diff_src = CTX_OUT_MEM(byte *, DNNL_ARG_DIFF_SRC);
365
366 const memory_desc_wrapper weights_d(pd()->weights_md(0));
367 const memory_desc_wrapper data_d(pd()->src_md(0));
368 const memory_desc_wrapper diff_src_d(pd()->diff_src_md(0));
369 const memory_desc_wrapper diff_weights_d(pd()->diff_weights_md(0));
370 const auto bcast_type = dnnl::impl::get_rhs_arg_broadcasting_strategy(
371 *weights_d.md_, data_d);
372
373 const auto is_inplace = (diff_src == diff_dst);
374 if (is_padding(diff_src_d) && !is_inplace)
375 ctx.zero_pad_output(DNNL_ARG_DIFF_SRC);
376
377 if (is_padding(diff_weights_d)) ctx.zero_pad_output(DNNL_ARG_DIFF_WEIGHTS);
378
379 switch (bcast_type) {
380 case broadcasting_strategy_t::scalar:
381 calculate_scalar(src, weights, diff_weights, diff_dst, diff_src,
382 scratchpad_buf);
383 break;
384 case broadcasting_strategy_t::no_broadcast:
385 calculate_no_broadcast(src, weights, diff_weights, diff_dst,
386 diff_src, scratchpad_buf);
387 break;
388 case broadcasting_strategy_t::per_oc:
389 case broadcasting_strategy_t::per_oc_spatial:
390 case broadcasting_strategy_t::per_mb_spatial:
391 case broadcasting_strategy_t::per_mb_w:
392 case broadcasting_strategy_t::per_w:
393 case broadcasting_strategy_t::shared_axes:
394 calculate_shared_axes(src, weights, diff_weights, diff_dst,
395 diff_src, scratchpad_buf);
396 break;
397 default: assert(!"unsupported broadcast type");
398 }
399 return status::success;
400}
401
402} // namespace cpu
403} // namespace impl
404} // namespace dnnl
405
406// vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
407