1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/ocl/vectorized_resampling.hpp"
18#include "common/c_types_map.hpp"
19
20namespace dnnl {
21namespace impl {
22namespace gpu {
23namespace ocl {
24
25// -------- Common functions ----------- //
26
27static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx,
28 const resampling_conf_t &conf, const resampling_desc_t *desc) {
29
30 using namespace alg_kind;
31
32 status_t status = status::success;
33
34 kernel_ctx.define_int("IS_BWD", 1);
35
36 switch (desc->alg_kind) {
37 case resampling_nearest:
38 kernel_ctx.define_int("RESAMPLING_ALG_NEAREST", 1);
39 break;
40 case resampling_linear:
41 kernel_ctx.define_int("RESAMPLING_ALG_LINEAR", 1);
42 break;
43 default: status = status::unimplemented;
44 }
45 if (status != status::success) return status;
46
47 const memory_desc_wrapper diff_src_d(desc->diff_src_desc);
48 const memory_desc_wrapper diff_dst_d(desc->diff_dst_desc);
49 const int ndims = diff_dst_d.ndims();
50
51 // Compute strides and set variables
52 kernel_ctx.define_int("MB_STRIDE", conf.padded_strides[0]);
53 kernel_ctx.define_int("C_STRIDE", conf.padded_strides[1]);
54 kernel_ctx.define_int(
55 "ID_STRIDE", ndims < 5 ? 1 : conf.padded_strides[ndims - 3]);
56 kernel_ctx.define_int(
57 "IH_STRIDE", ndims < 4 ? 1 : conf.padded_strides[ndims - 2]);
58 kernel_ctx.define_int(
59 "IW_STRIDE", ndims < 3 ? 1 : conf.padded_strides[ndims - 1]);
60
61 // kernel_ctx.define_int("VECT_SIZE", conf.vect_size);
62 kernel_ctx.define_int("VECT_DT_N", conf.vect_size);
63 kernel_ctx.define_int("GWS_WITH_SG_DEFAULT", 1);
64 kernel_ctx.define_int("GWS_LWS0_DEFAULT", conf.lws[0]);
65 kernel_ctx.define_int("GWS_LWS1_DEFAULT", conf.lws[1]);
66 kernel_ctx.define_int("GWS_LWS2_DEFAULT", conf.lws[2]);
67 kernel_ctx.define_int("GWS_SGS_DEFAULT", conf.sub_group_size);
68
69 const size_t dst_size = types::data_type_size(diff_dst_d.data_type());
70 kernel_ctx.define_int("ALIGNED_READ",
71 conf.C * dst_size % (4 * conf.vect_size) == 0 ? 1 : 0);
72 const size_t src_size = types::data_type_size(diff_src_d.data_type());
73 kernel_ctx.define_int("ALIGNED_WRITE",
74 conf.C * src_size % (4 * conf.vect_size) == 0 ? 1 : 0);
75
76 kernel_ctx.define_int("NDIMS", ndims);
77 kernel_ctx.define_int("MB", conf.MB);
78 kernel_ctx.define_int("C", conf.C);
79 kernel_ctx.define_int("PADDED_C", conf.padded_c);
80 kernel_ctx.define_int("ID", conf.ID);
81 kernel_ctx.define_int("IH", conf.IH);
82 kernel_ctx.define_int("IW", conf.IW);
83 kernel_ctx.define_int("OD", conf.OD);
84 kernel_ctx.define_int("OH", conf.OH);
85 kernel_ctx.define_int("OW", conf.OW);
86 kernel_ctx.define_float("FD", conf.FD);
87 kernel_ctx.define_float("FH", conf.FH);
88 kernel_ctx.define_float("FW", conf.FW);
89
90 const int max_d = std::max(1, (int)std::ceil(conf.FD * 1.5 - 0.5));
91 const int max_h = std::max(1, (int)std::ceil(conf.FH * 1.5 - 0.5));
92 const int max_w = std::max(1, (int)std::ceil(conf.FW * 1.5 - 0.5));
93 kernel_ctx.define_int("MAX_NUM_D", max_d);
94 kernel_ctx.define_int("MAX_NUM_H", max_h);
95 kernel_ctx.define_int("MAX_NUM_W", max_w);
96
97 def_offsets(conf.off.src_off, kernel_ctx, "SRC", ndims);
98 def_offsets(conf.off.dst_off, kernel_ctx, "DST", ndims);
99 return status::success;
100}
101
102status_t vectorized_resampling_bwd_t::pd_t::init_conf(engine_t *engine) {
103 using namespace data_type;
104 assert(engine->kind() == engine_kind::gpu);
105 bool ok = !is_fwd() && set_default_params() == status::success
106 && attr()->has_default_values();
107 if (!ok) return status::unimplemented;
108
109 const memory_desc_wrapper diff_src_d(diff_src_md());
110 const memory_desc_wrapper diff_dst_d(diff_dst_md());
111
112 // Restriction: Only works for axb cases
113 using namespace dnnl::impl::format_tag;
114 const bool is_axb = (diff_src_d.matches_one_of_tag(acb, acdb, acdeb)
115 != format_tag::undef);
116 if (!is_axb) { return status::unimplemented; }
117
118 // ------- Heuristics -------- //
119 // Tuned for PVC
120 // TODO: Use hw config to determine optimal heuristics
121
122 conf.vect_size = 4;
123 conf.lws[0] = 512;
124 conf.sub_group_size = 32;
125
126 // For large cases where cache reuse is less likely, use smaller lws to increase parallelism via thread dispatching
127 dim_t num_wi_estimate = diff_src_md()->padded_dims[0]
128 * diff_src_md()->padded_dims[1] * ID() * IH() * IW()
129 / conf.vect_size;
130 if (num_wi_estimate > 1 >> 20) { conf.lws[0] = 256; }
131
132 // ------ End of Heuristics ------- //
133
134 // Padded C: multiple of sub_group_size and vect_size (subgroup padding), and at least vect_size * sub_group_size
135 const int c_divisor = math::lcm(conf.sub_group_size, conf.vect_size);
136 conf.padded_c = utils::rnd_up(diff_src_md()->padded_dims[1], c_divisor);
137 conf.padded_c
138 = std::max(conf.padded_c, conf.vect_size * conf.sub_group_size);
139
140 // lws: Multiple of sub_group_size
141 conf.lws[0] = utils::rnd_up(conf.lws[0], conf.sub_group_size);
142 conf.lws[1] = conf.lws[2] = 1;
143
144 // gws: multiple of lws and padded C, and each other dim
145 const int gws_divisor = math::lcm((int)conf.lws[0], (int)conf.padded_c);
146 conf.gws[0] = diff_src_md()->padded_dims[0] * conf.padded_c * ID() * IH()
147 * IW() / conf.vect_size;
148 conf.gws[0] = utils::rnd_up(conf.gws[0], gws_divisor);
149
150 conf.gws[1] = conf.gws[2] = 1;
151
152 // Copy src/dst shapes to conf
153 conf.MB = MB();
154 conf.C = C();
155 conf.ID = ID();
156 conf.IH = IH();
157 conf.IW = IW();
158 conf.OD = OD();
159 conf.OH = OH();
160 conf.OW = OW();
161 conf.FD = FD();
162 conf.FH = FH();
163 conf.FW = FW();
164
165 // Highly-upsampled cases are not supported
166 // TODO: Implement multiple linear calculations per work item
167 // to eliminate this requirement
168 const int max_d = std::max(1, (int)std::ceil(conf.FD * 1.5 - 0.5));
169 const int max_h = std::max(1, (int)std::ceil(conf.FH * 1.5 - 0.5));
170 const int max_w = std::max(1, (int)std::ceil(conf.FW * 1.5 - 0.5));
171 const int max_num_linear_calcs = 2 * (max_d + max_h + max_w);
172 if (max_num_linear_calcs > conf.sub_group_size) {
173 return status::unimplemented;
174 }
175
176 // Compute strides after vect_size is taken into account.
177 const blocking_desc_t &blocks = diff_src_md()->format_desc.blocking;
178 const dim_t c_dim = diff_src_d.padded_dims()[1];
179 const dim_t stride_c = blocks.strides[1];
180
181 for (int i = 0; i < ndims(); i++) {
182 if (blocks.strides[i] < stride_c || i == 1) {
183 conf.padded_strides[i] = blocks.strides[i];
184 } else {
185 conf.padded_strides[i] = blocks.strides[i] * conf.padded_c / c_dim
186 / conf.vect_size; // Scale up to the newly-padded size
187 }
188 }
189
190 set_offsets(diff_src_d, conf.off.src_off);
191 set_offsets(diff_dst_d, conf.off.dst_off);
192
193 return status::success;
194}
195
196status_t vectorized_resampling_bwd_t::pd_t::init_kernel_ctx(
197 compute::kernel_ctx_t &kernel_ctx) const {
198 kernel_ctx.set_data_type(diff_dst_md()->data_type);
199 kernel_ctx.define_int("IS_BWD", 1);
200
201 status_t status = init_kernel_ctx_common(kernel_ctx, conf, desc());
202
203 def_data_type(kernel_ctx, diff_dst_md()->data_type, "SRC");
204 def_data_type(kernel_ctx, diff_src_md()->data_type, "DST");
205
206 return status;
207}
208
209status_t vectorized_resampling_bwd_t::execute_backward(
210 const exec_ctx_t &ctx) const {
211 auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST);
212 auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC);
213
214 compute::kernel_arg_list_t arg_list;
215 arg_list.set(0, diff_src);
216 arg_list.set(1, diff_dst);
217
218 const resampling_conf_t &conf = pd()->conf;
219 compute::nd_range_t nd_range(conf.gws, conf.lws);
220
221 return parallel_for(ctx, nd_range, kernel_, arg_list);
222}
223
224} // namespace ocl
225} // namespace gpu
226} // namespace impl
227} // namespace dnnl
228