1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16#include <math.h>
17
18#include "utils/parallel.hpp"
19
20#include "resampling/resampling.hpp"
21
22namespace resampling {
23
24float linear_map(const int64_t y, const int64_t y_max, const int64_t x_max) {
25 const float s = (y + 0.5f) * x_max / y_max;
26 return s - 0.5f;
27}
28int64_t left(const int64_t y, const int64_t y_max, const int64_t x_max) {
29 return MAX2((int64_t)floorf(linear_map(y, y_max, x_max)), (int64_t)0);
30}
31int64_t right(const int64_t y, const int64_t y_max, const int64_t x_max) {
32 return MIN2((int64_t)ceilf(linear_map(y, y_max, x_max)), x_max - 1);
33}
34int64_t near(const int64_t y, const int64_t y_max, const int64_t x_max) {
35 return roundf(linear_map(y, y_max, x_max));
36}
37float weight(const int64_t y, const int64_t y_max, const int64_t x_max) {
38 return fabs(linear_map(y, y_max, x_max) - left(y, y_max, x_max));
39}
40
41void compute_ref_fwd(const prb_t *prb, const args_t &args) {
42 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
43 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
44
45 float *dst_ptr = (float *)dst;
46
47 int64_t MB = prb->mb;
48 int64_t IC = prb->ic;
49 int64_t ID = prb->id;
50 int64_t IH = prb->ih;
51 int64_t IW = prb->iw;
52 int64_t OD = prb->od;
53 int64_t OH = prb->oh;
54 int64_t OW = prb->ow;
55
56 auto ker_nearest = [&](float &result, int64_t mb, int64_t ic, int64_t od,
57 int64_t oh, int64_t ow) {
58 const int64_t id = near(od, OD, ID);
59 const int64_t ih = near(oh, OH, IH);
60 const int64_t iw = near(ow, OW, IW);
61 result = src.get_elem(src_off_f(prb, mb, ic, id, ih, iw));
62 };
63
64 auto ker_linear = [&](float &result, int64_t mb, int64_t ic, int64_t od,
65 int64_t oh, int64_t ow) {
66 const int64_t id[2] = {left(od, OD, ID), right(od, OD, ID)};
67 const int64_t ih[2] = {left(oh, OH, IH), right(oh, OH, IH)};
68 const int64_t iw[2] = {left(ow, OW, IW), right(ow, OW, IW)};
69 const float wd[2] = {1.f - weight(od, OD, ID), weight(od, OD, ID)};
70 const float wh[2] = {1.f - weight(oh, OH, IH), weight(oh, OH, IH)};
71 const float ww[2] = {1.f - weight(ow, OW, IW), weight(ow, OW, IW)};
72
73 float cd[2][2] = {{0}};
74 for_(int i = 0; i < 2; i++)
75 for (int j = 0; j < 2; j++)
76 cd[i][j] = src.get_elem(src_off_f(prb, mb, ic, id[0], ih[i], iw[j]))
77 * wd[0]
78 + src.get_elem(src_off_f(prb, mb, ic, id[1], ih[i], iw[j]))
79 * wd[1];
80
81 float ch[2] = {0};
82 for (int i = 0; i < 2; i++)
83 ch[i] = cd[0][i] * wh[0] + cd[1][i] * wh[1];
84
85 float cw = ch[0] * ww[0] + ch[1] * ww[1];
86
87 result = cw;
88 };
89
90 auto v_po_masks = prb->attr.post_ops.get_po_masks();
91 benchdnn_parallel_nd(MB, IC, OD, OH, OW,
92 [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) {
93 float result = 0.f;
94 if (prb->alg == nearest) {
95 ker_nearest(result, mb, ic, od, oh, ow);
96 } else {
97 ker_linear(result, mb, ic, od, oh, ow);
98 }
99 const auto dst_off = dst_off_f(prb, mb, ic, od, oh, ow);
100
101 const auto v_po_vals
102 = prepare_po_vals(dst, args, v_po_masks, dst_off);
103
104 maybe_post_ops(
105 prb->attr, result, dst.get_elem(dst_off), v_po_vals);
106 dst_ptr[dst_off] = result;
107 });
108}
109
110void compute_ref_bwd(const prb_t *prb, const args_t &args) {
111 const dnn_mem_t &d_dst = args.find(DNNL_ARG_DIFF_DST);
112 const dnn_mem_t &d_src = args.find(DNNL_ARG_DIFF_SRC);
113
114 float *d_src_ptr = (float *)d_src;
115
116 int64_t MB = prb->mb;
117 int64_t IC = prb->ic;
118 int64_t ID = prb->id;
119 int64_t IH = prb->ih;
120 int64_t IW = prb->iw;
121 int64_t OD = prb->od;
122 int64_t OH = prb->oh;
123 int64_t OW = prb->ow;
124
125 auto ker_nearest
126 = [&](int64_t mb, int64_t ic, int64_t od, int64_t oh, int64_t ow) {
127 const auto d_dst_off = dst_off_f(prb, mb, ic, od, oh, ow);
128 float d_dst_val = d_dst.get_elem(d_dst_off);
129 const int64_t id = near(od, OD, ID);
130 const int64_t ih = near(oh, OH, IH);
131 const int64_t iw = near(ow, OW, IW);
132 d_src_ptr[src_off_f(prb, mb, ic, id, ih, iw)] += d_dst_val;
133 };
134
135 auto ker_linear = [&](int64_t mb, int64_t ic, int64_t od, int64_t oh,
136 int64_t ow) {
137 const auto d_dst_off = dst_off_f(prb, mb, ic, od, oh, ow);
138 float d_dst_val = d_dst.get_elem(d_dst_off);
139 const int64_t id[2] = {left(od, OD, ID), right(od, OD, ID)};
140 const int64_t ih[2] = {left(oh, OH, IH), right(oh, OH, IH)};
141 const int64_t iw[2] = {left(ow, OW, IW), right(ow, OW, IW)};
142 const float wd[2] = {1.f - weight(od, OD, ID), weight(od, OD, ID)};
143 const float wh[2] = {1.f - weight(oh, OH, IH), weight(oh, OH, IH)};
144 const float ww[2] = {1.f - weight(ow, OW, IW), weight(ow, OW, IW)};
145 for_(int i = 0; i < 2; i++)
146 for_(int j = 0; j < 2; j++)
147 for (int k = 0; k < 2; k++) {
148 d_src_ptr[src_off_f(prb, mb, ic, id[i], ih[j], iw[k])]
149 += wd[i] * wh[j] * ww[k] * d_dst_val;
150 }
151 };
152
153 // zeroing d_src for correct result
154 benchdnn_parallel_nd(MB, IC, ID, IH, IW,
155 [&](int64_t mb, int64_t ic, int64_t id, int64_t ih, int64_t iw) {
156 d_src_ptr[src_off_f(prb, mb, ic, id, ih, iw)] = 0;
157 });
158
159 benchdnn_parallel_nd(MB, IC, [&](int64_t mb, int64_t ic) {
160 for_(int64_t od = 0; od < OD; ++od)
161 for_(int64_t oh = 0; oh < OH; ++oh)
162 for (int64_t ow = 0; ow < OW; ++ow)
163 if (prb->alg == nearest) {
164 ker_nearest(mb, ic, od, oh, ow);
165 } else {
166 ker_linear(mb, ic, od, oh, ow);
167 }
168 });
169}
170
171void compute_ref(
172 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
173 if (prb->dir & FLAG_FWD)
174 compute_ref_fwd(prb, args);
175 else
176 compute_ref_bwd(prb, args);
177}
178
179} // namespace resampling
180