1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include <algorithm> |
18 | |
19 | #include <assert.h> |
20 | |
21 | #include "prelu/prelu.hpp" |
22 | #include "utils/parallel.hpp" |
23 | |
24 | namespace prelu { |
25 | |
26 | void compute_ref_fwd(const prb_t *prb, const args_t &args) { |
27 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
28 | const dnn_mem_t &wei = args.find(DNNL_ARG_WEIGHTS); |
29 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
30 | |
31 | float *dst_ptr = (float *)dst; |
32 | |
33 | const auto nelems = src.nelems(); |
34 | const auto weights_broadcast_mask = prb->get_broadcast_mask(); |
35 | |
36 | benchdnn_parallel_nd(nelems, [&](int64_t i) { |
37 | const auto wei_idx = src.get_scale_idx(i, weights_broadcast_mask); |
38 | const float s = src.get_elem(i); |
39 | float res = s * (s > 0 ? 1.f : wei.get_elem(wei_idx)); |
40 | maybe_saturate(prb->sdt[0], res); |
41 | dst_ptr[i] = res; |
42 | }); |
43 | } |
44 | |
45 | void compute_ref_bwd(const prb_t *prb, const args_t &args) { |
46 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
47 | const dnn_mem_t &wei = args.find(DNNL_ARG_WEIGHTS); |
48 | const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST); |
49 | const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC); |
50 | const dnn_mem_t &d_wei = args.find(DNNL_ARG_DIFF_WEIGHTS); |
51 | |
52 | float *d_src_ptr = (float *)d_src; |
53 | float *d_wei_ptr = (float *)d_wei; |
54 | float *d_wei_buf = d_wei_ptr; |
55 | |
56 | const auto src_nelems = d_src.nelems(); |
57 | const auto wei_nelems = d_wei.nelems(); |
58 | |
59 | const auto ker = [&](int64_t i, int64_t wei_idx, int64_t d_wei_idx) { |
60 | float s = src.get_elem(i); |
61 | float dd = d_dst.get_elem(i); |
62 | float d_src = dd * (s > 0 ? 1.f : wei.get_elem(wei_idx)); |
63 | maybe_saturate(prb->sdt[0], d_src); |
64 | d_src_ptr[i] = d_src; |
65 | d_wei_buf[d_wei_idx] += MIN2(0.f, s) * dd; |
66 | }; |
67 | |
68 | benchdnn_parallel_nd(wei_nelems, [&](int64_t i) { d_wei_ptr[i] = 0; }); |
69 | |
70 | if (wei_nelems == 1) { |
71 | const int reduce_dim = 0; |
72 | const int64_t N = d_src.dims()[reduce_dim]; |
73 | const int64_t nelems_per_thr = src_nelems / N; |
74 | d_wei_buf = new float[N]; |
75 | benchdnn_parallel_nd(N, [&](int64_t n) { |
76 | d_wei_buf[n] = 0; |
77 | |
78 | for (int64_t ithr_i = 0; ithr_i < nelems_per_thr; ++ithr_i) { |
79 | int64_t idx = nelems_per_thr * n + ithr_i; |
80 | ker(idx, 0, n); |
81 | } |
82 | }); |
83 | |
84 | for (int64_t i = 0; i < N; i++) |
85 | d_wei_ptr[0] += d_wei_buf[i]; |
86 | delete[] d_wei_buf; |
87 | |
88 | } else if (src_nelems == wei_nelems) { |
89 | benchdnn_parallel_nd(src_nelems, [&](int64_t i) { ker(i, i, i); }); |
90 | } else { |
91 | const int64_t reduce_size = src_nelems / wei_nelems; |
92 | |
93 | // Re-used from ref_reduction.cpp |
94 | // TODO: make a common reduction kernel to avoid duplication. |
95 | const auto &src_dims = prb->vdims[0]; |
96 | const auto &wei_dims = prb->vdims[1]; |
97 | dims_t reduce_dims(prb->ndims, 1); |
98 | for (int d = 0; d < prb->ndims; ++d) |
99 | if (src_dims[d] != wei_dims[d]) reduce_dims[d] = src_dims[d]; |
100 | |
101 | benchdnn_parallel_nd(wei_nelems, [&](int64_t f) { |
102 | dims_t wei_pos = off2dims_idx(wei_dims, f); |
103 | const int64_t wei_off = md_off_v(wei, wei_pos.data()); |
104 | const int64_t src_wei_off = md_off_v(src, wei_pos.data()); |
105 | |
106 | for (int64_t r = 0; r < reduce_size; ++r) { |
107 | dims_t reduce_pos = off2dims_idx(reduce_dims, r); |
108 | const int64_t src_reduce_off = md_off_v(src, reduce_pos.data()); |
109 | const int64_t src_off = src_wei_off + src_reduce_off; |
110 | ker(src_off, wei_off, wei_off); |
111 | } |
112 | }); |
113 | } |
114 | } |
115 | |
116 | void compute_ref( |
117 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
118 | if (prb->dir & FLAG_FWD) |
119 | compute_ref_fwd(prb, args); |
120 | else |
121 | compute_ref_bwd(prb, args); |
122 | } |
123 | |
124 | } // namespace prelu |
125 | |