1 | /******************************************************************************* |
2 | * Copyright 2019-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | #include <math.h> |
17 | |
18 | #include "utils/parallel.hpp" |
19 | |
20 | #include "resampling/resampling.hpp" |
21 | |
22 | namespace resampling { |
23 | |
24 | float linear_map(const int64_t y, const int64_t y_max, const int64_t x_max) { |
25 | const float s = (y + 0.5f) * x_max / y_max; |
26 | return s - 0.5f; |
27 | } |
28 | int64_t left(const int64_t y, const int64_t y_max, const int64_t x_max) { |
29 | return MAX2((int64_t)floorf(linear_map(y, y_max, x_max)), (int64_t)0); |
30 | } |
31 | int64_t right(const int64_t y, const int64_t y_max, const int64_t x_max) { |
32 | return MIN2((int64_t)ceilf(linear_map(y, y_max, x_max)), x_max - 1); |
33 | } |
34 | int64_t near(const int64_t y, const int64_t y_max, const int64_t x_max) { |
35 | return roundf(linear_map(y, y_max, x_max)); |
36 | } |
37 | float weight(const int64_t y, const int64_t y_max, const int64_t x_max) { |
38 | return fabs(linear_map(y, y_max, x_max) - left(y, y_max, x_max)); |
39 | } |
40 | |
41 | void compute_ref_fwd(const prb_t *prb, const args_t &args) { |
42 | const dnn_mem_t &src = args.find(DNNL_ARG_SRC); |
43 | const dnn_mem_t &dst = args.find(DNNL_ARG_DST); |
44 | |
45 | float *dst_ptr = (float *)dst; |
46 | |
47 | int64_t MB = prb->mb; |
48 | int64_t IC = prb->ic; |
49 | int64_t ID = prb->id; |
50 | int64_t IH = prb->ih; |
51 | int64_t IW = prb->iw; |
52 | int64_t OD = prb->od; |
53 | int64_t OH = prb->oh; |
54 | int64_t OW = prb->ow; |
55 | |
56 | auto ker_nearest = [&](float &result, int64_t mb, int64_t ic, int64_t od, |
57 | int64_t oh, int64_t ow) { |
58 | const int64_t id = near(od, OD, ID); |
59 | const int64_t ih = near(oh, OH, IH); |
60 | const int64_t iw = near(ow, OW, IW); |
61 | result = src.get_elem(src_off_f(prb, mb, ic, id, ih, iw)); |
62 | }; |
63 | |
64 | auto ker_linear = [&](float &result, int64_t mb, int64_t ic, int64_t od, |
65 | int64_t oh, int64_t ow) { |
66 | const int64_t id[2] = {left(od, OD, ID), right(od, OD, ID)}; |
67 | const int64_t ih[2] = {left(oh, OH, IH), right(oh, OH, IH)}; |
68 | const int64_t iw[2] = {left(ow, OW, IW), right(ow, OW, IW)}; |
69 | const float wd[2] = {1.f - weight(od, OD, ID), weight(od, OD, ID)}; |
70 | const float wh[2] = {1.f - weight(oh, OH, IH), weight(oh, OH, IH)}; |
71 | const float ww[2] = {1.f - weight(ow, OW, IW), weight(ow, OW, IW)}; |
72 | |
73 | float cd[2][2] = {{0}}; |
74 | for_(int i = 0; i < 2; i++) |
75 | for (int j = 0; j < 2; j++) |
76 | cd[i][j] = src.get_elem(src_off_f(prb, mb, ic, id[0], ih[i], iw[j])) |
77 | * wd[0] |
78 | + src.get_elem(src_off_f(prb, mb, ic, id[1], ih[i], iw[j])) |
79 | * wd[1]; |
80 | |
81 | float ch[2] = {0}; |
82 | for (int i = 0; i < 2; i++) |
83 | ch[i] = cd[0][i] * wh[0] + cd[1][i] * wh[1]; |
84 | |
85 | float cw = ch[0] * ww[0] + ch[1] * ww[1]; |
86 | |
87 | result = cw; |
88 | }; |
89 | |
90 | auto v_po_masks = prb->attr.post_ops.get_po_masks(); |
91 | benchdnn_parallel_nd(MB, IC, OD, OH, OW, |
92 | [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) { |
93 | float result = 0.f; |
94 | if (prb->alg == nearest) { |
95 | ker_nearest(result, mb, ic, od, oh, ow); |
96 | } else { |
97 | ker_linear(result, mb, ic, od, oh, ow); |
98 | } |
99 | const auto dst_off = dst_off_f(prb, mb, ic, od, oh, ow); |
100 | |
101 | const auto v_po_vals |
102 | = prepare_po_vals(dst, args, v_po_masks, dst_off); |
103 | |
104 | maybe_post_ops( |
105 | prb->attr, result, dst.get_elem(dst_off), v_po_vals); |
106 | dst_ptr[dst_off] = result; |
107 | }); |
108 | } |
109 | |
110 | void compute_ref_bwd(const prb_t *prb, const args_t &args) { |
111 | const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST); |
112 | const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC); |
113 | |
114 | float *d_src_ptr = (float *)d_src; |
115 | |
116 | int64_t MB = prb->mb; |
117 | int64_t IC = prb->ic; |
118 | int64_t ID = prb->id; |
119 | int64_t IH = prb->ih; |
120 | int64_t IW = prb->iw; |
121 | int64_t OD = prb->od; |
122 | int64_t OH = prb->oh; |
123 | int64_t OW = prb->ow; |
124 | |
125 | auto ker_nearest |
126 | = [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) { |
127 | const auto d_dst_off = dst_off_f(prb, mb, ic, od, oh, ow); |
128 | float d_dst_val = d_dst.get_elem(d_dst_off); |
129 | const int64_t id = near(od, OD, ID); |
130 | const int64_t ih = near(oh, OH, IH); |
131 | const int64_t iw = near(ow, OW, IW); |
132 | d_src_ptr[src_off_f(prb, mb, ic, id, ih, iw)] += d_dst_val; |
133 | }; |
134 | |
135 | auto ker_linear = [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, |
136 | int64_t ow) { |
137 | const auto d_dst_off = dst_off_f(prb, mb, ic, od, oh, ow); |
138 | float d_dst_val = d_dst.get_elem(d_dst_off); |
139 | const int64_t id[2] = {left(od, OD, ID), right(od, OD, ID)}; |
140 | const int64_t ih[2] = {left(oh, OH, IH), right(oh, OH, IH)}; |
141 | const int64_t iw[2] = {left(ow, OW, IW), right(ow, OW, IW)}; |
142 | const float wd[2] = {1.f - weight(od, OD, ID), weight(od, OD, ID)}; |
143 | const float wh[2] = {1.f - weight(oh, OH, IH), weight(oh, OH, IH)}; |
144 | const float ww[2] = {1.f - weight(ow, OW, IW), weight(ow, OW, IW)}; |
145 | for_(int i = 0; i < 2; i++) |
146 | for_(int j = 0; j < 2; j++) |
147 | for (int k = 0; k < 2; k++) { |
148 | d_src_ptr[src_off_f(prb, mb, ic, id[i], ih[j], iw[k])] |
149 | += wd[i] * wh[j] * ww[k] * d_dst_val; |
150 | } |
151 | }; |
152 | |
153 | // zeroing d_src for correct result |
154 | benchdnn_parallel_nd(MB, IC, ID, IH, IW, |
155 | [&](int64_t mb, int64_t ic, int64_t id, int64_t ih, int64_t iw) { |
156 | d_src_ptr[src_off_f(prb, mb, ic, id, ih, iw)] = 0; |
157 | }); |
158 | |
159 | benchdnn_parallel_nd(MB, IC, [&](int64_t mb, int64_t ic) { |
160 | for_(int64_t od = 0; od < OD; ++od) |
161 | for_(int64_t oh = 0; oh < OH; ++oh) |
162 | for (int64_t ow = 0; ow < OW; ++ow) |
163 | if (prb->alg == nearest) { |
164 | ker_nearest(mb, ic, od, oh, ow); |
165 | } else { |
166 | ker_linear(mb, ic, od, oh, ow); |
167 | } |
168 | }); |
169 | } |
170 | |
171 | void compute_ref( |
172 | const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) { |
173 | if (prb->dir & FLAG_FWD) |
174 | compute_ref_fwd(prb, args); |
175 | else |
176 | compute_ref_bwd(prb, args); |
177 | } |
178 | |
179 | } // namespace resampling |
180 | |