1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_CONV_POST_OPS_HPP |
18 | #define GPU_JIT_CONV_POST_OPS_HPP |
19 | |
20 | #include <string> |
21 | #include <vector> |
22 | |
23 | #include "common/convolution_pd.hpp" |
24 | #include "common/eltwise_pd.hpp" |
25 | #include "gpu/jit/ir/eltwise.hpp" |
26 | #include "gpu/jit/ir/gemm_schedule.hpp" |
27 | #include "gpu/jit/ir/ir.hpp" |
28 | #include "gpu/jit/ir/kernel_info.hpp" |
29 | #include "gpu/jit/ir/post_ops.hpp" |
30 | #include "gpu/jit/ir/tensor.hpp" |
31 | #include "gpu/jit/utils/utils.hpp" |
32 | |
33 | #include "gpu/jit/conv/config.hpp" |
34 | #include "gpu/jit/conv/normalization.hpp" |
35 | |
36 | namespace dnnl { |
37 | namespace impl { |
38 | namespace gpu { |
39 | namespace jit { |
40 | |
41 | class post_op_context_t { |
42 | public: |
43 | post_op_context_t() = default; |
44 | |
45 | post_op_context_t(const conv_config_t &cfg, |
46 | const gemm_schedule_t &gemm_schedule, |
47 | const kernel_info_t &kernel_info) |
48 | : prb_(&cfg.prb()), cfg_(&cfg), cp_view_(gemm_schedule.c_view()) { |
49 | |
50 | auto *pd = prb_->conv_pd; |
51 | auto *attr = prb_->attr; |
52 | |
53 | auto c = add_tensor(/*is_input=*/false, /*is_output=*/false, cp_view_, |
54 | expr_t(), var_t::make(type_t::f32(), "c" )); |
55 | |
56 | // Prepare src/weights/dst scale expressions. |
57 | std::vector<expr_t> scales(3, expr_t(1.0f)); |
58 | auto &src_scales = scales[0]; |
59 | auto &wei_scales = scales[1]; |
60 | auto &dst_scales = scales[2]; |
61 | if ((prb_->is_fwd || prb_->is_bwd_d) |
62 | && !attr->scales_.has_default_values()) { |
63 | const char *names[] = {"src_scales" , "wei_scales" , "dst_scales" }; |
64 | expr_t c_scaled = c; |
65 | for (int i = 0; i < 3; i++) { |
66 | auto buf = kernel_info.find_arg(names[i], /*allow_empty=*/true); |
67 | if (buf.is_empty()) continue; |
68 | int key = kernel_info.key(names[i]) & ~DNNL_ARG_ATTR_SCALES; |
69 | int mask = attr->scales_.get(key).mask_; |
70 | if (i == 1) { |
71 | // Convert o/i weights mask to src/dst. |
72 | // XXX: per_oc for BWD_D is treated as per_ic assuming it's called from |
73 | // deconvolution. |
74 | int c_idx = prb_->with_groups; |
75 | ir_assert(utils::one_of(mask, 0, 1 << c_idx)); |
76 | if (mask != 0) mask = (1 << 1); |
77 | } else { |
78 | ir_assert(mask == 0); |
79 | } |
80 | auto view = create_view(type_t::f32(), normalize_mask(mask)); |
81 | scales[i] = add_input_tensor(view, buf); |
82 | } |
83 | } |
84 | |
85 | // Handle input and weights scales. |
86 | if (!is_one(src_scales) || !is_one(wei_scales)) { |
87 | auto c_scaled = c * src_scales * wei_scales; |
88 | post_ops_.emplace_back(c, c_scaled); |
89 | } |
90 | |
91 | // Handle bias. |
92 | if ((pd->is_fwd() || pd->is_bwd_d()) && pd->with_bias()) { |
93 | uint32_t mask = normalize_mask(1 << 1); // Per-channel mask. |
94 | auto view = create_view(pd->invariant_bia_md()->data_type, mask); |
95 | auto buf = kernel_info.find_arg("bia" ); |
96 | auto bia = add_input_tensor(view, buf); |
97 | post_ops_.emplace_back(c, c + bia); |
98 | } |
99 | |
100 | // Handle post-ops. |
101 | for (int i = 0; i < attr->post_ops_.len(); i++) { |
102 | auto &po = attr->post_ops_.entry_[i]; |
103 | if (po.is_eltwise()) { |
104 | auto func = eltwise_t::make(po.eltwise.alg, po.eltwise.scale, |
105 | po.eltwise.alpha, po.eltwise.beta); |
106 | post_ops_.emplace_back(c, c, func); |
107 | } else if (po.is_sum(/*require_scale_one=*/false, |
108 | /*require_zp_zero=*/false)) { |
109 | float scale = po.sum.scale; |
110 | int32_t zp = po.sum.zero_point; |
111 | if (pd->is_bwd_w()) { |
112 | ir_assert(scale == 1) << "BWD_W doesn't support " |
113 | "non-default scale for sum." ; |
114 | continue; |
115 | } |
116 | auto view = cp_view_; |
117 | if (po.sum.dt != data_type::undef) |
118 | view = view.retype(po.sum.dt); |
119 | auto buf = kernel_info.find_arg(pd->is_fwd() ? "dst" : "src" ); |
120 | auto c_old = add_input_tensor(view, buf); |
121 | post_ops_.emplace_back(c, c + scale * (c_old - zp)); |
122 | } else if (po.is_prelu()) { |
123 | uint32_t rhs_mask = normalize_mask(po.prelu.mask); |
124 | auto rhs_view = create_view(type_t::f32(), rhs_mask); |
125 | auto buf_name = "prelu_rhs_" + std::to_string(i); |
126 | auto rhs_buf = kernel_info.find_arg(buf_name); |
127 | auto rhs = add_input_tensor(rhs_view, rhs_buf); |
128 | post_ops_.emplace_back( |
129 | c, binary_op_t::make(op_kind_t::_prelu, c, rhs)); |
130 | } else if (po.is_binary()) { |
131 | auto buf_name = "binary_rhs_" + std::to_string(i); |
132 | auto view = create_view(po.binary.src1_desc); |
133 | auto buf = kernel_info.find_arg(buf_name); |
134 | auto rhs = add_input_tensor(view, buf); |
135 | auto op_kind = alg_kind_to_op_kind(po.binary.alg); |
136 | post_ops_.emplace_back(c, binary_op_t::make(op_kind, c, rhs)); |
137 | } else { |
138 | ir_error_not_expected(); |
139 | } |
140 | } |
141 | |
142 | // Handle dst scale. |
143 | if (!is_one(dst_scales)) { |
144 | auto c_scaled = c / dst_scales; |
145 | post_ops_.emplace_back(c, c_scaled); |
146 | } |
147 | |
148 | // Handle dst zero points. |
149 | auto &zp_cfg = prb_->zp_cfg; |
150 | if (zp_cfg.do_dst_compensation) { |
151 | if (zp_cfg.is_runtime_dst_zero_points) { |
152 | uint32_t mask = normalize_mask( |
153 | zp_cfg.is_common_dst_zero_point ? 0 : 1 << 1); |
154 | auto view = create_view(type_t::s32(), mask); |
155 | auto buf = kernel_info.find_arg("dst_zero_points" ); |
156 | auto in = add_input_tensor(view, buf); |
157 | post_ops_.emplace_back(c, c + in); |
158 | } else { |
159 | auto func = eltwise_t::make(alg_kind::eltwise_linear, |
160 | /*scale=*/1.f, |
161 | /*alpha=*/1.f, |
162 | /*beta=*/float(zp_cfg.common_dst_zero_point)); |
163 | post_ops_.emplace_back(c, c, func); |
164 | } |
165 | } |
166 | |
167 | need_to_restore_zero_padding_ = init_need_to_restore_zero_padding(); |
168 | |
169 | // Require masked updates when needed. |
170 | for (auto &info : tensor_infos_) { |
171 | if (!info.is_output()) continue; |
172 | |
173 | if (need_to_restore_zero_padding_) { |
174 | info.require_masked_update(); |
175 | continue; |
176 | } |
177 | |
178 | for (int i = 0; i < cp_ndims(); i++) { |
179 | if ((info.mask() & (1 << i)) != 0) continue; |
180 | if (is_spurious_spatial(gemm_schedule, i)) { |
181 | info.require_masked_update(); |
182 | break; |
183 | } |
184 | } |
185 | } |
186 | } |
187 | |
188 | const view_t &cp_view() const { return cp_view_; } |
189 | |
190 | const std::vector<post_op_t> &post_ops() const { return post_ops_; } |
191 | |
192 | const std::vector<post_op_tensor_info_t> &post_op_tensor_infos() const { |
193 | return tensor_infos_; |
194 | } |
195 | |
196 | bool need_to_restore_zero_padding() const { |
197 | return need_to_restore_zero_padding_; |
198 | } |
199 | |
200 | private: |
201 | bool init_need_to_restore_zero_padding() const { |
202 | auto *pd = prb_->conv_pd; |
203 | auto *attr = prb_->attr; |
204 | if (prb_->with_bias) return true; |
205 | for (int i = 0; i < attr->post_ops_.len(); i++) { |
206 | auto &po = attr->post_ops_.entry_[i]; |
207 | if (po.is_eltwise()) { |
208 | if (!eltwise_fwd_pd_t::eltwise_preserves_zero(po.eltwise)) |
209 | return true; |
210 | } else if (po.is_sum(/*require_scale_one=*/false, |
211 | /*require_zp_zero=*/false)) { |
212 | if (po.sum.zero_point != 0) return true; |
213 | for (int j = 0; j < cp_ndims(); j++) { |
214 | if (!is_cp_dim_zero_padded(j)) continue; |
215 | // Size one dimensions are treated as broadcast which does |
216 | // not preserve zero padding with block updates. |
217 | if (cp_view_.vdims()[j] == 1) return true; |
218 | } |
219 | } else if (po.is_binary()) { |
220 | for (int j = 0; j < cp_ndims(); j++) { |
221 | if (!is_cp_dim_zero_padded(j)) continue; |
222 | // Check if binary preserves zeros: (0 op X == 0) or (0 op 0 == 0). |
223 | bool zero_op_x_ok = (po.binary.alg == alg_kind::binary_mul); |
224 | bool zero_op_zero_ok = zero_op_x_ok |
225 | || utils::one_of(po.binary.alg, |
226 | alg_kind::binary_add, alg_kind::binary_sub, |
227 | alg_kind::binary_min, alg_kind::binary_max, |
228 | alg_kind::binary_gt, alg_kind::binary_lt, |
229 | alg_kind::binary_ne); |
230 | |
231 | uint32_t rhs_mask |
232 | = utils::get_dims_mask(cp_view_.vdims().data(), |
233 | po.binary.src1_desc.dims, cp_ndims()); |
234 | if ((rhs_mask & (1 << j)) == 0 && !zero_op_x_ok) |
235 | return true; |
236 | if (!zero_op_zero_ok) return true; |
237 | } |
238 | } else if (po.is_prelu()) { |
239 | return false; |
240 | } else { |
241 | ir_error_not_expected(); |
242 | } |
243 | } |
244 | if (prb_->zp_cfg.do_src_compensation |
245 | && pd->dst_md()->dims[0] != pd->dst_md()->padded_dims[0]) |
246 | return true; |
247 | if (prb_->zp_cfg.do_dst_compensation |
248 | && prb_->zp_cfg.is_common_dst_zero_point |
249 | && pd->dst_md()->dims[1] != pd->dst_md()->padded_dims[1]) |
250 | return true; |
251 | return false; |
252 | } |
253 | |
254 | // Checks if convolution computes output elements that are out of bound in |
255 | // the output tensor. This can happen due to spatial padding. |
256 | // |
257 | // For example for forward convolution OW is padded to OW_PADDED. Then if |
258 | // ow >= OW (out of bounds) and iw = ow * SW - PW + kw * (DW + 1) < IW (in |
259 | // bounds) convolution computes an out-of-bound element which is not |
260 | // generally zero. This requires special handling if there are post-ops |
261 | // followed the convolution. |
262 | bool is_spurious_spatial( |
263 | const gemm_schedule_t &gemm_schedule, int dim_idx) const { |
264 | auto &var = cp_view_.vvars()[dim_idx].as<var_t>(); |
265 | |
266 | int sp_idx = -1; |
267 | if (utils::one_of(var.name, "od" , "id" )) { |
268 | sp_idx = 0; |
269 | } else if (utils::one_of(var.name, "oh" , "ih" )) { |
270 | sp_idx = 1; |
271 | } else if (utils::one_of(var.name, "ow" , "iw" )) { |
272 | sp_idx = 2; |
273 | } else { |
274 | return false; |
275 | } |
276 | |
277 | int p = utils::pick(sp_idx, prb_->pd, prb_->ph, prb_->pw); |
278 | int s = utils::pick(sp_idx, prb_->sd, prb_->sh, prb_->sw); |
279 | int k = utils::pick(sp_idx, prb_->kd, prb_->kh, prb_->kw); |
280 | int d = utils::pick(sp_idx, prb_->dd, prb_->dh, prb_->dw); |
281 | |
282 | if (prb_->is_fwd) { |
283 | int o_value = utils::pick(sp_idx, prb_->od, prb_->oh, prb_->ow); |
284 | int o_bound = gemm_schedule.var_bound(var); |
285 | int i = utils::pick(sp_idx, prb_->id, prb_->ih, prb_->iw); |
286 | |
287 | for (int o = o_value; o < o_bound; o++) { |
288 | int i_min = o * s - p; |
289 | if (i_min < i) return true; |
290 | } |
291 | return false; |
292 | } |
293 | |
294 | if (prb_->is_bwd_d) { |
295 | int i_value = utils::pick(sp_idx, prb_->id, prb_->ih, prb_->iw); |
296 | int i_bound = gemm_schedule.var_bound(var); |
297 | int o = utils::pick(sp_idx, prb_->od, prb_->oh, prb_->ow); |
298 | |
299 | for (int i = i_value; i < i_bound; i++) { |
300 | int os_min = i - (k - 1) * (d + 1) + p; |
301 | if (os_min < o * s) return true; |
302 | } |
303 | return false; |
304 | } |
305 | |
306 | return false; |
307 | } |
308 | |
309 | int cp_ndims() const { return cp_view_.nvdims(); } |
310 | |
311 | dim_t cp_dim(int idx) const { return cp_view_.vdims()[idx]; } |
312 | |
313 | dim_t cp_padded_dim(int idx) const { return cp_view_.tlayout().dim(idx); } |
314 | |
315 | bool has_cp_mask(int idx) const { return cp_view_.has_tmask(idx); } |
316 | |
317 | bool is_cp_dim_zero_padded(int idx) const { |
318 | return cp_view_.is_masked_vdim(idx); |
319 | } |
320 | |
321 | const expr_t &add_input_tensor(const view_t &view, const expr_t &buf) { |
322 | return add_tensor(/*is_input=*/true, /*is_output=*/false, view, buf); |
323 | } |
324 | |
325 | const expr_t &add_output_tensor( |
326 | const view_t &view, const expr_t &buf, float scale = 1.0f) { |
327 | return add_tensor(/*is_input=*/false, /*is_output=*/true, view, buf, |
328 | expr_t(), scale); |
329 | } |
330 | |
331 | const expr_t &add_tensor(bool is_input, bool is_output, const view_t &view, |
332 | const expr_t &buf, const expr_t &op_var = expr_t(), |
333 | float scale = 1.0f) { |
334 | ir_assert(view.nvdims() == cp_view_.nvdims()); |
335 | uint32_t mask |
336 | = (buf.is_empty() ? ~(1u << cp_ndims()) : compute_mask(view)); |
337 | tensor_infos_.emplace_back( |
338 | is_input, is_output, view, buf, mask, op_var, scale); |
339 | return tensor_infos_.back().op_var(); |
340 | } |
341 | |
342 | uint32_t compute_mask(const view_t &view) const { |
343 | ir_assert(cp_view_.nvdims() == view.nvdims()); |
344 | uint32_t mask = 0; |
345 | for (int i = 0; i < view.nvdims(); i++) { |
346 | if (view.vdims()[i] != 1) mask |= (1 << i); |
347 | } |
348 | return mask; |
349 | } |
350 | |
351 | // rhs tensor has plain layout. |
352 | view_t create_view(const type_t &type, uint32_t rhs_mask) const { |
353 | std::vector<dim_t> rhs_dims = cp_view_.vdims(); |
354 | uint32_t bound_check_mask = 0; |
355 | for (int i = 0; i < cp_ndims(); i++) { |
356 | if ((rhs_mask & (1 << i)) == 0) { |
357 | // Broadcast dimension. |
358 | rhs_dims[i] = 1; |
359 | } else if (is_cp_dim_zero_padded(i)) { |
360 | bound_check_mask |= (1 << i); |
361 | } |
362 | } |
363 | return view_t(layout_t(type, 0, rhs_dims, /*do_normalize=*/false), |
364 | cp_view_.vvars(), bound_check_mask); |
365 | } |
366 | |
367 | // rhs tensor layout is defined by md memory descriptor. |
368 | view_t create_view(const memory_desc_t &md) const { |
369 | ir_assert(cp_ndims() >= 3); |
370 | // Add groups to match ngcdhw layout. |
371 | bool add_groups = (cp_view_.vvars()[1].as<var_t>().name == "g" ); |
372 | layout_t layout(md, /*do_normalize=*/false); |
373 | std::vector<dim_t> dims(md.dims, md.dims + md.ndims); |
374 | std::vector<dim_t> padded_dims( |
375 | md.padded_dims, md.padded_dims + md.ndims); |
376 | maybe_reshape_dims(prb_->ndims, layout, dims, padded_dims); |
377 | layout = normalize_conv_layout(layout, /*with_groups=*/false, prb_->g, |
378 | prb_->is_dw, prb_->reduced_dim, cfg_->fuse_spatial(), |
379 | add_groups, |
380 | /*is_wei=*/false); |
381 | dims = normalize_conv_dims(dims, /*with_groups=*/false, prb_->g, |
382 | prb_->is_dw, prb_->reduced_dim, cfg_->fuse_spatial(), |
383 | add_groups, |
384 | /*is_wei=*/false); |
385 | padded_dims = normalize_conv_dims(padded_dims, |
386 | /*with_groups=*/false, prb_->g, prb_->is_dw, prb_->reduced_dim, |
387 | cfg_->fuse_spatial(), add_groups, /*is_wei=*/false); |
388 | ir_assert(layout.ndims() == cp_ndims()) << "Incompatible dimensions." ; |
389 | uint32_t bound_check_mask = 0; |
390 | for (int i = 0; i < cp_ndims(); i++) { |
391 | if (dims[i] == 1) continue; // Broadcast, no bound check needed. |
392 | if (padded_dims[i] != cp_padded_dim(i)) { |
393 | bound_check_mask |= (1 << i); |
394 | } else if (has_cp_mask(i)) { |
395 | bound_check_mask |= (1 << i); |
396 | } |
397 | } |
398 | return view_t(layout, cp_view_.vvars(), dims, bound_check_mask); |
399 | } |
400 | |
401 | static void maybe_reshape_dims(int ndims, layout_t &layout, |
402 | std::vector<dim_t> &dims, std::vector<dim_t> &padded_dims) { |
403 | ir_assert(layout.ndims() == int(dims.size())); |
404 | if (layout.ndims() < ndims) { |
405 | layout = layout_t(layout.type(), ndims, layout.offset(), |
406 | layout.blocks(), /*do_normalize=*/false); |
407 | dims.resize(ndims, 1); |
408 | padded_dims.resize(ndims, 1); |
409 | } |
410 | } |
411 | |
412 | uint32_t normalize_mask(uint32_t orig_mask) const { |
413 | ir_assert(cp_ndims() >= 3); |
414 | // Add groups to match ngcdhw layout. |
415 | bool add_groups = (cp_view_.vvars()[1].as<var_t>().name == "g" ); |
416 | // Number of dimensions before normalization. |
417 | int orig_ndims = 2 + prb_->ndims; |
418 | std::vector<dim_t> dummy_dims(orig_ndims, 1); |
419 | dim_t mask_set_value = 2; |
420 | for (int i = 0; i < orig_ndims; i++) { |
421 | if ((orig_mask & (1 << i)) != 0) dummy_dims[i] = mask_set_value; |
422 | } |
423 | auto cvt_dims = normalize_conv_dims(dummy_dims, /*with_groups=*/false, |
424 | prb_->g, prb_->is_dw, prb_->reduced_dim, cfg_->fuse_spatial(), |
425 | /*add_groups=*/false, |
426 | /*is_wei=*/false); |
427 | // Split channels into groups and channels to match ngcdhw layout. |
428 | if (add_groups) cvt_dims.insert(cvt_dims.begin() + 1, cvt_dims[1]); |
429 | ir_assert(int(cvt_dims.size()) == cp_ndims()); |
430 | |
431 | uint32_t mask = 0; |
432 | for (int i = 0; i < cp_ndims(); i++) { |
433 | if (cvt_dims[i] == mask_set_value) mask = mask | (1 << i); |
434 | } |
435 | return mask; |
436 | } |
437 | |
438 | const conv_problem_t *prb_; |
439 | const conv_config_t *cfg_; |
440 | |
441 | bool need_to_restore_zero_padding_ = false; |
442 | |
443 | view_t cp_view_; |
444 | std::vector<post_op_t> post_ops_; |
445 | std::vector<post_op_tensor_info_t> tensor_infos_; |
446 | }; |
447 | |
448 | } // namespace jit |
449 | } // namespace gpu |
450 | } // namespace impl |
451 | } // namespace dnnl |
452 | |
453 | #endif |
454 | |