1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/ref_resampling.hpp"
18#include "common/c_types_map.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace gpu {
23namespace ocl {
24
25// -------- Common functions ----------- //
26
27static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
28 const resampling_conf_t &conf, const resampling_desc_t *desc) {
29 switch (desc->alg_kind) {
30 case alg_kind::resampling_nearest:
31 kernel_ctx.define_int("RESAMPLING_ALG_NEAREST", 1);
32 break;
33 case alg_kind::resampling_linear:
34 kernel_ctx.define_int("RESAMPLING_ALG_LINEAR", 1);
35 break;
36 default: return status::unimplemented;
37 }
38
39 kernel_ctx.define_int("NDIMS", conf.ndims);
40 kernel_ctx.define_int("MB", conf.MB);
41 kernel_ctx.define_int("C", conf.C);
42 kernel_ctx.define_int("ID", conf.ID);
43 kernel_ctx.define_int("IH", conf.IH);
44 kernel_ctx.define_int("IW", conf.IW);
45 kernel_ctx.define_int("OD", conf.OD);
46 kernel_ctx.define_int("OH", conf.OH);
47 kernel_ctx.define_int("OW", conf.OW);
48 kernel_ctx.define_float("FD", conf.FD);
49 kernel_ctx.define_float("FH", conf.FH);
50 kernel_ctx.define_float("FW", conf.FW);
51
52 def_offsets(conf.off.src_off, kernel_ctx, "SRC", conf.ndims);
53 def_offsets(conf.off.dst_off, kernel_ctx, "DST", conf.ndims);
54
55 def_dispatch(kernel_ctx, conf.dispatch);
56 return status::success;
57}
58
59// ---------- ref_resampling_fwd_t ------------ //
60
61status_t ref_resampling_fwd_t::pd_t::init_conf(engine_t *engine) {
62
63 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
64 conf.dispatch = compute_engine->create_dispatch(dst_md());
65
66 conf.dispatch.define_dim("MB", 0, dst_md()->padded_dims[0]);
67 conf.dispatch.define_dim("C", 1, dst_md()->padded_dims[1]);
68 conf.dispatch.define_dim("OD", nstl::max(2, dst_md()->ndims - 3), OD());
69 conf.dispatch.define_dim("OH", nstl::max(2, dst_md()->ndims - 2), OH());
70 conf.dispatch.define_dim("OW", nstl::max(2, dst_md()->ndims - 1), OW());
71 conf.dispatch.generate();
72
73 conf.ndims = dst_md()->ndims;
74
75 const memory_desc_wrapper src_d(src_md());
76 set_offsets(src_d, conf.off.src_off);
77
78 const memory_desc_wrapper dst_d(dst_md());
79 set_offsets(dst_d, conf.off.dst_off);
80
81 conf.MB = MB();
82 conf.C = C();
83 conf.ID = ID();
84 conf.IH = IH();
85 conf.IW = IW();
86 conf.OD = OD();
87 conf.OH = OH();
88 conf.OW = OW();
89 conf.FD = FD();
90 conf.FH = FH();
91 conf.FW = FW();
92
93 conf.attr_info = attr_info_t::create(attr());
94
95 return status::success;
96}
97
98status_t ref_resampling_fwd_t::pd_t::init_kernel_ctx(
99 compute::kernel_ctx_t &kernel_ctx) const {
100 kernel_ctx.set_data_type(src_md()->data_type);
101 kernel_ctx.define_int("IS_FWD", 1);
102
103 status_t status = init_kernel_ctx_common(kernel_ctx, conf, desc());
104
105 def_data_type(kernel_ctx, src_md()->data_type, "SRC");
106 def_data_type(kernel_ctx, dst_md()->data_type, "DST");
107
108 // Set post-op variables
109 def_attr_info(kernel_ctx, conf.attr_info, attr()->post_ops_);
110
111 return status;
112}
113
114status_t ref_resampling_fwd_t::execute_forward(const exec_ctx_t &ctx) const {
115
116 auto &src = CTX_IN_STORAGE(DNNL_ARG_SRC);
117 auto &dst = CTX_OUT_STORAGE(DNNL_ARG_DST);
118
119 compute::kernel_arg_list_t arg_list;
120 arg_list.set(0, src);
121 arg_list.set(1, dst);
122 append_post_ops_to_arg_list(ctx, arg_list, 2, pd()->attr()->post_ops_);
123
124 auto nd_range = pd()->conf.dispatch.nd_range();
125
126 return parallel_for(ctx, nd_range, kernel_, arg_list);
127}
128
129// -------- ref_resampling_bwd_t ---------- //
130
131status_t ref_resampling_bwd_t::pd_t::init_conf(engine_t *engine) {
132
133 auto *compute_engine = utils::downcast<compute::compute_engine_t *>(engine);
134 conf.dispatch = compute_engine->create_dispatch(diff_src_md());
135
136 conf.dispatch.define_dim("MB", 0, diff_src_md()->padded_dims[0]);
137 conf.dispatch.define_dim("C", 1, diff_src_md()->padded_dims[1]);
138 conf.dispatch.define_dim(
139 "ID", nstl::max(2, diff_src_md()->ndims - 3), ID());
140 conf.dispatch.define_dim(
141 "IH", nstl::max(2, diff_src_md()->ndims - 2), IH());
142 conf.dispatch.define_dim(
143 "IW", nstl::max(2, diff_src_md()->ndims - 1), IW());
144 conf.dispatch.generate();
145
146 conf.ndims = diff_dst_md()->ndims;
147
148 const memory_desc_wrapper diff_src_d(diff_src_md());
149 set_offsets(diff_src_d, conf.off.src_off);
150
151 const memory_desc_wrapper diff_dst_d(diff_dst_md());
152 set_offsets(diff_dst_d, conf.off.dst_off);
153
154 conf.MB = MB();
155 conf.C = C();
156 conf.ID = ID();
157 conf.IH = IH();
158 conf.IW = IW();
159 conf.OD = OD();
160 conf.OH = OH();
161 conf.OW = OW();
162 conf.FD = FD();
163 conf.FH = FH();
164 conf.FW = FW();
165
166 conf.attr_info = attr_info_t::create(attr());
167
168 return status::success;
169}
170
171status_t ref_resampling_bwd_t::pd_t::init_kernel_ctx(
172 compute::kernel_ctx_t &kernel_ctx) const {
173 kernel_ctx.set_data_type(diff_src_md()->data_type);
174 kernel_ctx.define_int("IS_BWD", 1);
175
176 status_t status = init_kernel_ctx_common(kernel_ctx, conf, desc());
177
178 def_data_type(kernel_ctx, diff_src_md()->data_type, "SRC");
179 def_data_type(kernel_ctx, diff_dst_md()->data_type, "DST");
180
181 return status;
182}
183
184status_t ref_resampling_bwd_t::execute_backward(const exec_ctx_t &ctx) const {
185 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
186 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
187
188 compute::kernel_arg_list_t arg_list;
189 arg_list.set(0, diff_src);
190 arg_list.set(1, diff_dst);
191
192 auto nd_range = pd()->conf.dispatch.nd_range();
193
194 return parallel_for(ctx, nd_range, kernel_, arg_list);
195}
196
197} // namespace ocl
198} // namespace gpu
199} // namespace impl
200} // namespace dnnl
201