1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/ocl/vectorized_resampling.hpp" |
18 | #include "common/c_types_map.hpp" |
19 | |
20 | namespace dnnl { |
21 | namespace impl { |
22 | namespace gpu { |
23 | namespace ocl { |
24 | |
25 | // -------- Common functions ----------- // |
26 | |
27 | static status_t init_kernel_ctx_common(compute::kernel_ctx_t &kernel_ctx, |
28 | const resampling_conf_t &conf, const resampling_desc_t *desc) { |
29 | |
30 | using namespace alg_kind; |
31 | |
32 | status_t status = status::success; |
33 | |
34 | kernel_ctx.define_int("IS_BWD" , 1); |
35 | |
36 | switch (desc->alg_kind) { |
37 | case resampling_nearest: |
38 | kernel_ctx.define_int("RESAMPLING_ALG_NEAREST" , 1); |
39 | break; |
40 | case resampling_linear: |
41 | kernel_ctx.define_int("RESAMPLING_ALG_LINEAR" , 1); |
42 | break; |
43 | default: status = status::unimplemented; |
44 | } |
45 | if (status != status::success) return status; |
46 | |
47 | const memory_desc_wrapper diff_src_d(desc->diff_src_desc); |
48 | const memory_desc_wrapper diff_dst_d(desc->diff_dst_desc); |
49 | const int ndims = diff_dst_d.ndims(); |
50 | |
51 | // Compute strides and set variables |
52 | kernel_ctx.define_int("MB_STRIDE" , conf.padded_strides[0]); |
53 | kernel_ctx.define_int("C_STRIDE" , conf.padded_strides[1]); |
54 | kernel_ctx.define_int( |
55 | "ID_STRIDE" , ndims < 5 ? 1 : conf.padded_strides[ndims - 3]); |
56 | kernel_ctx.define_int( |
57 | "IH_STRIDE" , ndims < 4 ? 1 : conf.padded_strides[ndims - 2]); |
58 | kernel_ctx.define_int( |
59 | "IW_STRIDE" , ndims < 3 ? 1 : conf.padded_strides[ndims - 1]); |
60 | |
61 | // kernel_ctx.define_int("VECT_SIZE", conf.vect_size); |
62 | kernel_ctx.define_int("VECT_DT_N" , conf.vect_size); |
63 | kernel_ctx.define_int("GWS_WITH_SG_DEFAULT" , 1); |
64 | kernel_ctx.define_int("GWS_LWS0_DEFAULT" , conf.lws[0]); |
65 | kernel_ctx.define_int("GWS_LWS1_DEFAULT" , conf.lws[1]); |
66 | kernel_ctx.define_int("GWS_LWS2_DEFAULT" , conf.lws[2]); |
67 | kernel_ctx.define_int("GWS_SGS_DEFAULT" , conf.sub_group_size); |
68 | |
69 | const size_t dst_size = types::data_type_size(diff_dst_d.data_type()); |
70 | kernel_ctx.define_int("ALIGNED_READ" , |
71 | conf.C * dst_size % (4 * conf.vect_size) == 0 ? 1 : 0); |
72 | const size_t src_size = types::data_type_size(diff_src_d.data_type()); |
73 | kernel_ctx.define_int("ALIGNED_WRITE" , |
74 | conf.C * src_size % (4 * conf.vect_size) == 0 ? 1 : 0); |
75 | |
76 | kernel_ctx.define_int("NDIMS" , ndims); |
77 | kernel_ctx.define_int("MB" , conf.MB); |
78 | kernel_ctx.define_int("C" , conf.C); |
79 | kernel_ctx.define_int("PADDED_C" , conf.padded_c); |
80 | kernel_ctx.define_int("ID" , conf.ID); |
81 | kernel_ctx.define_int("IH" , conf.IH); |
82 | kernel_ctx.define_int("IW" , conf.IW); |
83 | kernel_ctx.define_int("OD" , conf.OD); |
84 | kernel_ctx.define_int("OH" , conf.OH); |
85 | kernel_ctx.define_int("OW" , conf.OW); |
86 | kernel_ctx.define_float("FD" , conf.FD); |
87 | kernel_ctx.define_float("FH" , conf.FH); |
88 | kernel_ctx.define_float("FW" , conf.FW); |
89 | |
90 | const int max_d = std::max(1, (int)std::ceil(conf.FD * 1.5 - 0.5)); |
91 | const int max_h = std::max(1, (int)std::ceil(conf.FH * 1.5 - 0.5)); |
92 | const int max_w = std::max(1, (int)std::ceil(conf.FW * 1.5 - 0.5)); |
93 | kernel_ctx.define_int("MAX_NUM_D" , max_d); |
94 | kernel_ctx.define_int("MAX_NUM_H" , max_h); |
95 | kernel_ctx.define_int("MAX_NUM_W" , max_w); |
96 | |
97 | def_offsets(conf.off.src_off, kernel_ctx, "SRC" , ndims); |
98 | def_offsets(conf.off.dst_off, kernel_ctx, "DST" , ndims); |
99 | return status::success; |
100 | } |
101 | |
102 | status_t vectorized_resampling_bwd_t::pd_t::init_conf(engine_t *engine) { |
103 | using namespace data_type; |
104 | assert(engine->kind() == engine_kind::gpu); |
105 | bool ok = !is_fwd() && set_default_params() == status::success |
106 | && attr()->has_default_values(); |
107 | if (!ok) return status::unimplemented; |
108 | |
109 | const memory_desc_wrapper diff_src_d(diff_src_md()); |
110 | const memory_desc_wrapper diff_dst_d(diff_dst_md()); |
111 | |
112 | // Restriction: Only works for axb cases |
113 | using namespace dnnl::impl::format_tag; |
114 | const bool is_axb = (diff_src_d.matches_one_of_tag(acb, acdb, acdeb) |
115 | != format_tag::undef); |
116 | if (!is_axb) { return status::unimplemented; } |
117 | |
118 | // ------- Heuristics -------- // |
119 | // Tuned for PVC |
120 | // TODO: Use hw config to determine optimal heuristics |
121 | |
122 | conf.vect_size = 4; |
123 | conf.lws[0] = 512; |
124 | conf.sub_group_size = 32; |
125 | |
126 | // For large cases where cache reuse is less likely, use smaller lws to increase parallelism via thread dispatching |
127 | dim_t num_wi_estimate = diff_src_md()->padded_dims[0] |
128 | * diff_src_md()->padded_dims[1] * ID() * IH() * IW() |
129 | / conf.vect_size; |
130 | if (num_wi_estimate > 1 >> 20) { conf.lws[0] = 256; } |
131 | |
132 | // ------ End of Heuristics ------- // |
133 | |
134 | // Padded C: multiple of sub_group_size and vect_size (subgroup padding), and at least vect_size * sub_group_size |
135 | const int c_divisor = math::lcm(conf.sub_group_size, conf.vect_size); |
136 | conf.padded_c = utils::rnd_up(diff_src_md()->padded_dims[1], c_divisor); |
137 | conf.padded_c |
138 | = std::max(conf.padded_c, conf.vect_size * conf.sub_group_size); |
139 | |
140 | // lws: Multiple of sub_group_size |
141 | conf.lws[0] = utils::rnd_up(conf.lws[0], conf.sub_group_size); |
142 | conf.lws[1] = conf.lws[2] = 1; |
143 | |
144 | // gws: multiple of lws and padded C, and each other dim |
145 | const int gws_divisor = math::lcm((int)conf.lws[0], (int)conf.padded_c); |
146 | conf.gws[0] = diff_src_md()->padded_dims[0] * conf.padded_c * ID() * IH() |
147 | * IW() / conf.vect_size; |
148 | conf.gws[0] = utils::rnd_up(conf.gws[0], gws_divisor); |
149 | |
150 | conf.gws[1] = conf.gws[2] = 1; |
151 | |
152 | // Copy src/dst shapes to conf |
153 | conf.MB = MB(); |
154 | conf.C = C(); |
155 | conf.ID = ID(); |
156 | conf.IH = IH(); |
157 | conf.IW = IW(); |
158 | conf.OD = OD(); |
159 | conf.OH = OH(); |
160 | conf.OW = OW(); |
161 | conf.FD = FD(); |
162 | conf.FH = FH(); |
163 | conf.FW = FW(); |
164 | |
165 | // Highly-upsampled cases are not supported |
166 | // TODO: Implement multiple linear calculations per work item |
167 | // to eliminate this requirement |
168 | const int max_d = std::max(1, (int)std::ceil(conf.FD * 1.5 - 0.5)); |
169 | const int max_h = std::max(1, (int)std::ceil(conf.FH * 1.5 - 0.5)); |
170 | const int max_w = std::max(1, (int)std::ceil(conf.FW * 1.5 - 0.5)); |
171 | const int max_num_linear_calcs = 2 * (max_d + max_h + max_w); |
172 | if (max_num_linear_calcs > conf.sub_group_size) { |
173 | return status::unimplemented; |
174 | } |
175 | |
176 | // Compute strides after vect_size is taken into account. |
177 | const blocking_desc_t &blocks = diff_src_md()->format_desc.blocking; |
178 | const dim_t c_dim = diff_src_d.padded_dims()[1]; |
179 | const dim_t stride_c = blocks.strides[1]; |
180 | |
181 | for (int i = 0; i < ndims(); i++) { |
182 | if (blocks.strides[i] < stride_c || i == 1) { |
183 | conf.padded_strides[i] = blocks.strides[i]; |
184 | } else { |
185 | conf.padded_strides[i] = blocks.strides[i] * conf.padded_c / c_dim |
186 | / conf.vect_size; // Scale up to the newly-padded size |
187 | } |
188 | } |
189 | |
190 | set_offsets(diff_src_d, conf.off.src_off); |
191 | set_offsets(diff_dst_d, conf.off.dst_off); |
192 | |
193 | return status::success; |
194 | } |
195 | |
196 | status_t vectorized_resampling_bwd_t::pd_t::init_kernel_ctx( |
197 | compute::kernel_ctx_t &kernel_ctx) const { |
198 | kernel_ctx.set_data_type(diff_dst_md()->data_type); |
199 | kernel_ctx.define_int("IS_BWD" , 1); |
200 | |
201 | status_t status = init_kernel_ctx_common(kernel_ctx, conf, desc()); |
202 | |
203 | def_data_type(kernel_ctx, diff_dst_md()->data_type, "SRC" ); |
204 | def_data_type(kernel_ctx, diff_src_md()->data_type, "DST" ); |
205 | |
206 | return status; |
207 | } |
208 | |
209 | status_t vectorized_resampling_bwd_t::execute_backward( |
210 | const exec_ctx_t &ctx) const { |
211 | auto &diff_dst = CTX_IN_STORAGE(DNNL_ARG_DIFF_DST); |
212 | auto &diff_src = CTX_OUT_STORAGE(DNNL_ARG_DIFF_SRC); |
213 | |
214 | compute::kernel_arg_list_t arg_list; |
215 | arg_list.set(0, diff_src); |
216 | arg_list.set(1, diff_dst); |
217 | |
218 | const resampling_conf_t &conf = pd()->conf; |
219 | compute::nd_range_t nd_range(conf.gws, conf.lws); |
220 | |
221 | return parallel_for(ctx, nd_range, kernel_, arg_list); |
222 | } |
223 | |
224 | } // namespace ocl |
225 | } // namespace gpu |
226 | } // namespace impl |
227 | } // namespace dnnl |
228 | |