1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <algorithm>
18
19#include <assert.h>
20
21#include "prelu/prelu.hpp"
22#include "utils/parallel.hpp"
23
24namespace prelu {
25
26void compute_ref_fwd(const prb_t *prb, const args_t &args) {
27 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
28 const dnn_mem_t &wei = args.find(DNNL_ARG_WEIGHTS);
29 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
30
31 float *dst_ptr = (float *)dst;
32
33 const auto nelems = src.nelems();
34 const auto weights_broadcast_mask = prb->get_broadcast_mask();
35
36 benchdnn_parallel_nd(nelems, [&](int64_t i) {
37 const auto wei_idx = src.get_scale_idx(i, weights_broadcast_mask);
38 const float s = src.get_elem(i);
39 float res = s * (s > 0 ? 1.f : wei.get_elem(wei_idx));
40 maybe_saturate(prb->sdt[0], res);
41 dst_ptr[i] = res;
42 });
43}
44
45void compute_ref_bwd(const prb_t *prb, const args_t &args) {
46 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
47 const dnn_mem_t &wei = args.find(DNNL_ARG_WEIGHTS);
48 const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST);
49 const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC);
50 const dnn_mem_t &d_wei = args.find(DNNL_ARG_DIFF_WEIGHTS);
51
52 float *d_src_ptr = (float *)d_src;
53 float *d_wei_ptr = (float *)d_wei;
54 float *d_wei_buf = d_wei_ptr;
55
56 const auto src_nelems = d_src.nelems();
57 const auto wei_nelems = d_wei.nelems();
58
59 const auto ker = [&](int64_t i, int64_t wei_idx, int64_t d_wei_idx) {
60 float s = src.get_elem(i);
61 float dd = d_dst.get_elem(i);
62 float d_src = dd * (s > 0 ? 1.f : wei.get_elem(wei_idx));
63 maybe_saturate(prb->sdt[0], d_src);
64 d_src_ptr[i] = d_src;
65 d_wei_buf[d_wei_idx] += MIN2(0.f, s) * dd;
66 };
67
68 benchdnn_parallel_nd(wei_nelems, [&](int64_t i) { d_wei_ptr[i] = 0; });
69
70 if (wei_nelems == 1) {
71 const int reduce_dim = 0;
72 const int64_t N = d_src.dims()[reduce_dim];
73 const int64_t nelems_per_thr = src_nelems / N;
74 d_wei_buf = new float[N];
75 benchdnn_parallel_nd(N, [&](int64_t n) {
76 d_wei_buf[n] = 0;
77
78 for (int64_t ithr_i = 0; ithr_i < nelems_per_thr; ++ithr_i) {
79 int64_t idx = nelems_per_thr * n + ithr_i;
80 ker(idx, 0, n);
81 }
82 });
83
84 for (int64_t i = 0; i < N; i++)
85 d_wei_ptr[0] += d_wei_buf[i];
86 delete[] d_wei_buf;
87
88 } else if (src_nelems == wei_nelems) {
89 benchdnn_parallel_nd(src_nelems, [&](int64_t i) { ker(i, i, i); });
90 } else {
91 const int64_t reduce_size = src_nelems / wei_nelems;
92
93 // Re-used from ref_reduction.cpp
94 // TODO: make a common reduction kernel to avoid duplication.
95 const auto &src_dims = prb->vdims[0];
96 const auto &wei_dims = prb->vdims[1];
97 dims_t reduce_dims(prb->ndims, 1);
98 for (int d = 0; d < prb->ndims; ++d)
99 if (src_dims[d] != wei_dims[d]) reduce_dims[d] = src_dims[d];
100
101 benchdnn_parallel_nd(wei_nelems, [&](int64_t f) {
102 dims_t wei_pos = off2dims_idx(wei_dims, f);
103 const int64_t wei_off = md_off_v(wei, wei_pos.data());
104 const int64_t src_wei_off = md_off_v(src, wei_pos.data());
105
106 for (int64_t r = 0; r < reduce_size; ++r) {
107 dims_t reduce_pos = off2dims_idx(reduce_dims, r);
108 const int64_t src_reduce_off = md_off_v(src, reduce_pos.data());
109 const int64_t src_off = src_wei_off + src_reduce_off;
110 ker(src_off, wei_off, wei_off);
111 }
112 });
113 }
114}
115
116void compute_ref(
117 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
118 if (prb->dir & FLAG_FWD)
119 compute_ref_fwd(prb, args);
120 else
121 compute_ref_bwd(prb, args);
122}
123
124} // namespace prelu
125