1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_CONV_POST_OPS_HPP
18#define GPU_JIT_CONV_POST_OPS_HPP
19
20#include <string>
21#include <vector>
22
23#include "common/convolution_pd.hpp"
24#include "common/eltwise_pd.hpp"
25#include "gpu/jit/ir/eltwise.hpp"
26#include "gpu/jit/ir/gemm_schedule.hpp"
27#include "gpu/jit/ir/ir.hpp"
28#include "gpu/jit/ir/kernel_info.hpp"
29#include "gpu/jit/ir/post_ops.hpp"
30#include "gpu/jit/ir/tensor.hpp"
31#include "gpu/jit/utils/utils.hpp"
32
33#include "gpu/jit/conv/config.hpp"
34#include "gpu/jit/conv/normalization.hpp"
35
36namespace dnnl {
37namespace impl {
38namespace gpu {
39namespace jit {
40
41class post_op_context_t {
42public:
43 post_op_context_t() = default;
44
45 post_op_context_t(const conv_config_t &cfg,
46 const gemm_schedule_t &gemm_schedule,
47 const kernel_info_t &kernel_info)
48 : prb_(&cfg.prb()), cfg_(&cfg), cp_view_(gemm_schedule.c_view()) {
49
50 auto *pd = prb_->conv_pd;
51 auto *attr = prb_->attr;
52
53 auto c = add_tensor(/*is_input=*/false, /*is_output=*/false, cp_view_,
54 expr_t(), var_t::make(type_t::f32(), "c"));
55
56 // Prepare src/weights/dst scale expressions.
57 std::vector<expr_t> scales(3, expr_t(1.0f));
58 auto &src_scales = scales[0];
59 auto &wei_scales = scales[1];
60 auto &dst_scales = scales[2];
61 if ((prb_->is_fwd || prb_->is_bwd_d)
62 && !attr->scales_.has_default_values()) {
63 const char *names[] = {"src_scales", "wei_scales", "dst_scales"};
64 expr_t c_scaled = c;
65 for (int i = 0; i < 3; i++) {
66 auto buf = kernel_info.find_arg(names[i], /*allow_empty=*/true);
67 if (buf.is_empty()) continue;
68 int key = kernel_info.key(names[i]) & ~DNNL_ARG_ATTR_SCALES;
69 int mask = attr->scales_.get(key).mask_;
70 if (i == 1) {
71 // Convert o/i weights mask to src/dst.
72 // XXX: per_oc for BWD_D is treated as per_ic assuming it's called from
73 // deconvolution.
74 int c_idx = prb_->with_groups;
75 ir_assert(utils::one_of(mask, 0, 1 << c_idx));
76 if (mask != 0) mask = (1 << 1);
77 } else {
78 ir_assert(mask == 0);
79 }
80 auto view = create_view(type_t::f32(), normalize_mask(mask));
81 scales[i] = add_input_tensor(view, buf);
82 }
83 }
84
85 // Handle input and weights scales.
86 if (!is_one(src_scales) || !is_one(wei_scales)) {
87 auto c_scaled = c * src_scales * wei_scales;
88 post_ops_.emplace_back(c, c_scaled);
89 }
90
91 // Handle bias.
92 if ((pd->is_fwd() || pd->is_bwd_d()) && pd->with_bias()) {
93 uint32_t mask = normalize_mask(1 << 1); // Per-channel mask.
94 auto view = create_view(pd->invariant_bia_md()->data_type, mask);
95 auto buf = kernel_info.find_arg("bia");
96 auto bia = add_input_tensor(view, buf);
97 post_ops_.emplace_back(c, c + bia);
98 }
99
100 // Handle post-ops.
101 for (int i = 0; i < attr->post_ops_.len(); i++) {
102 auto &po = attr->post_ops_.entry_[i];
103 if (po.is_eltwise()) {
104 auto func = eltwise_t::make(po.eltwise.alg, po.eltwise.scale,
105 po.eltwise.alpha, po.eltwise.beta);
106 post_ops_.emplace_back(c, c, func);
107 } else if (po.is_sum(/*require_scale_one=*/false,
108 /*require_zp_zero=*/false)) {
109 float scale = po.sum.scale;
110 int32_t zp = po.sum.zero_point;
111 if (pd->is_bwd_w()) {
112 ir_assert(scale == 1) << "BWD_W doesn't support "
113 "non-default scale for sum.";
114 continue;
115 }
116 auto view = cp_view_;
117 if (po.sum.dt != data_type::undef)
118 view = view.retype(po.sum.dt);
119 auto buf = kernel_info.find_arg(pd->is_fwd() ? "dst" : "src");
120 auto c_old = add_input_tensor(view, buf);
121 post_ops_.emplace_back(c, c + scale * (c_old - zp));
122 } else if (po.is_prelu()) {
123 uint32_t rhs_mask = normalize_mask(po.prelu.mask);
124 auto rhs_view = create_view(type_t::f32(), rhs_mask);
125 auto buf_name = "prelu_rhs_" + std::to_string(i);
126 auto rhs_buf = kernel_info.find_arg(buf_name);
127 auto rhs = add_input_tensor(rhs_view, rhs_buf);
128 post_ops_.emplace_back(
129 c, binary_op_t::make(op_kind_t::_prelu, c, rhs));
130 } else if (po.is_binary()) {
131 auto buf_name = "binary_rhs_" + std::to_string(i);
132 auto view = create_view(po.binary.src1_desc);
133 auto buf = kernel_info.find_arg(buf_name);
134 auto rhs = add_input_tensor(view, buf);
135 auto op_kind = alg_kind_to_op_kind(po.binary.alg);
136 post_ops_.emplace_back(c, binary_op_t::make(op_kind, c, rhs));
137 } else {
138 ir_error_not_expected();
139 }
140 }
141
142 // Handle dst scale.
143 if (!is_one(dst_scales)) {
144 auto c_scaled = c / dst_scales;
145 post_ops_.emplace_back(c, c_scaled);
146 }
147
148 // Handle dst zero points.
149 auto &zp_cfg = prb_->zp_cfg;
150 if (zp_cfg.do_dst_compensation) {
151 if (zp_cfg.is_runtime_dst_zero_points) {
152 uint32_t mask = normalize_mask(
153 zp_cfg.is_common_dst_zero_point ? 0 : 1 << 1);
154 auto view = create_view(type_t::s32(), mask);
155 auto buf = kernel_info.find_arg("dst_zero_points");
156 auto in = add_input_tensor(view, buf);
157 post_ops_.emplace_back(c, c + in);
158 } else {
159 auto func = eltwise_t::make(alg_kind::eltwise_linear,
160 /*scale=*/1.f,
161 /*alpha=*/1.f,
162 /*beta=*/float(zp_cfg.common_dst_zero_point));
163 post_ops_.emplace_back(c, c, func);
164 }
165 }
166
167 need_to_restore_zero_padding_ = init_need_to_restore_zero_padding();
168
169 // Require masked updates when needed.
170 for (auto &info : tensor_infos_) {
171 if (!info.is_output()) continue;
172
173 if (need_to_restore_zero_padding_) {
174 info.require_masked_update();
175 continue;
176 }
177
178 for (int i = 0; i < cp_ndims(); i++) {
179 if ((info.mask() & (1 << i)) != 0) continue;
180 if (is_spurious_spatial(gemm_schedule, i)) {
181 info.require_masked_update();
182 break;
183 }
184 }
185 }
186 }
187
188 const view_t &cp_view() const { return cp_view_; }
189
190 const std::vector<post_op_t> &post_ops() const { return post_ops_; }
191
192 const std::vector<post_op_tensor_info_t> &post_op_tensor_infos() const {
193 return tensor_infos_;
194 }
195
196 bool need_to_restore_zero_padding() const {
197 return need_to_restore_zero_padding_;
198 }
199
200private:
201 bool init_need_to_restore_zero_padding() const {
202 auto *pd = prb_->conv_pd;
203 auto *attr = prb_->attr;
204 if (prb_->with_bias) return true;
205 for (int i = 0; i < attr->post_ops_.len(); i++) {
206 auto &po = attr->post_ops_.entry_[i];
207 if (po.is_eltwise()) {
208 if (!eltwise_fwd_pd_t::eltwise_preserves_zero(po.eltwise))
209 return true;
210 } else if (po.is_sum(/*require_scale_one=*/false,
211 /*require_zp_zero=*/false)) {
212 if (po.sum.zero_point != 0) return true;
213 for (int j = 0; j < cp_ndims(); j++) {
214 if (!is_cp_dim_zero_padded(j)) continue;
215 // Size one dimensions are treated as broadcast which does
216 // not preserve zero padding with block updates.
217 if (cp_view_.vdims()[j] == 1) return true;
218 }
219 } else if (po.is_binary()) {
220 for (int j = 0; j < cp_ndims(); j++) {
221 if (!is_cp_dim_zero_padded(j)) continue;
222 // Check if binary preserves zeros: (0 op X == 0) or (0 op 0 == 0).
223 bool zero_op_x_ok = (po.binary.alg == alg_kind::binary_mul);
224 bool zero_op_zero_ok = zero_op_x_ok
225 || utils::one_of(po.binary.alg,
226 alg_kind::binary_add, alg_kind::binary_sub,
227 alg_kind::binary_min, alg_kind::binary_max,
228 alg_kind::binary_gt, alg_kind::binary_lt,
229 alg_kind::binary_ne);
230
231 uint32_t rhs_mask
232 = utils::get_dims_mask(cp_view_.vdims().data(),
233 po.binary.src1_desc.dims, cp_ndims());
234 if ((rhs_mask & (1 << j)) == 0 && !zero_op_x_ok)
235 return true;
236 if (!zero_op_zero_ok) return true;
237 }
238 } else if (po.is_prelu()) {
239 return false;
240 } else {
241 ir_error_not_expected();
242 }
243 }
244 if (prb_->zp_cfg.do_src_compensation
245 && pd->dst_md()->dims[0] != pd->dst_md()->padded_dims[0])
246 return true;
247 if (prb_->zp_cfg.do_dst_compensation
248 && prb_->zp_cfg.is_common_dst_zero_point
249 && pd->dst_md()->dims[1] != pd->dst_md()->padded_dims[1])
250 return true;
251 return false;
252 }
253
254 // Checks if convolution computes output elements that are out of bound in
255 // the output tensor. This can happen due to spatial padding.
256 //
257 // For example for forward convolution OW is padded to OW_PADDED. Then if
258 // ow >= OW (out of bounds) and iw = ow * SW - PW + kw * (DW + 1) < IW (in
259 // bounds) convolution computes an out-of-bound element which is not
260 // generally zero. This requires special handling if there are post-ops
261 // followed the convolution.
262 bool is_spurious_spatial(
263 const gemm_schedule_t &gemm_schedule, int dim_idx) const {
264 auto &var = cp_view_.vvars()[dim_idx].as<var_t>();
265
266 int sp_idx = -1;
267 if (utils::one_of(var.name, "od", "id")) {
268 sp_idx = 0;
269 } else if (utils::one_of(var.name, "oh", "ih")) {
270 sp_idx = 1;
271 } else if (utils::one_of(var.name, "ow", "iw")) {
272 sp_idx = 2;
273 } else {
274 return false;
275 }
276
277 int p = utils::pick(sp_idx, prb_->pd, prb_->ph, prb_->pw);
278 int s = utils::pick(sp_idx, prb_->sd, prb_->sh, prb_->sw);
279 int k = utils::pick(sp_idx, prb_->kd, prb_->kh, prb_->kw);
280 int d = utils::pick(sp_idx, prb_->dd, prb_->dh, prb_->dw);
281
282 if (prb_->is_fwd) {
283 int o_value = utils::pick(sp_idx, prb_->od, prb_->oh, prb_->ow);
284 int o_bound = gemm_schedule.var_bound(var);
285 int i = utils::pick(sp_idx, prb_->id, prb_->ih, prb_->iw);
286
287 for (int o = o_value; o < o_bound; o++) {
288 int i_min = o * s - p;
289 if (i_min < i) return true;
290 }
291 return false;
292 }
293
294 if (prb_->is_bwd_d) {
295 int i_value = utils::pick(sp_idx, prb_->id, prb_->ih, prb_->iw);
296 int i_bound = gemm_schedule.var_bound(var);
297 int o = utils::pick(sp_idx, prb_->od, prb_->oh, prb_->ow);
298
299 for (int i = i_value; i < i_bound; i++) {
300 int os_min = i - (k - 1) * (d + 1) + p;
301 if (os_min < o * s) return true;
302 }
303 return false;
304 }
305
306 return false;
307 }
308
309 int cp_ndims() const { return cp_view_.nvdims(); }
310
311 dim_t cp_dim(int idx) const { return cp_view_.vdims()[idx]; }
312
313 dim_t cp_padded_dim(int idx) const { return cp_view_.tlayout().dim(idx); }
314
315 bool has_cp_mask(int idx) const { return cp_view_.has_tmask(idx); }
316
317 bool is_cp_dim_zero_padded(int idx) const {
318 return cp_view_.is_masked_vdim(idx);
319 }
320
321 const expr_t &add_input_tensor(const view_t &view, const expr_t &buf) {
322 return add_tensor(/*is_input=*/true, /*is_output=*/false, view, buf);
323 }
324
325 const expr_t &add_output_tensor(
326 const view_t &view, const expr_t &buf, float scale = 1.0f) {
327 return add_tensor(/*is_input=*/false, /*is_output=*/true, view, buf,
328 expr_t(), scale);
329 }
330
331 const expr_t &add_tensor(bool is_input, bool is_output, const view_t &view,
332 const expr_t &buf, const expr_t &op_var = expr_t(),
333 float scale = 1.0f) {
334 ir_assert(view.nvdims() == cp_view_.nvdims());
335 uint32_t mask
336 = (buf.is_empty() ? ~(1u << cp_ndims()) : compute_mask(view));
337 tensor_infos_.emplace_back(
338 is_input, is_output, view, buf, mask, op_var, scale);
339 return tensor_infos_.back().op_var();
340 }
341
342 uint32_t compute_mask(const view_t &view) const {
343 ir_assert(cp_view_.nvdims() == view.nvdims());
344 uint32_t mask = 0;
345 for (int i = 0; i < view.nvdims(); i++) {
346 if (view.vdims()[i] != 1) mask |= (1 << i);
347 }
348 return mask;
349 }
350
351 // rhs tensor has plain layout.
352 view_t create_view(const type_t &type, uint32_t rhs_mask) const {
353 std::vector<dim_t> rhs_dims = cp_view_.vdims();
354 uint32_t bound_check_mask = 0;
355 for (int i = 0; i < cp_ndims(); i++) {
356 if ((rhs_mask & (1 << i)) == 0) {
357 // Broadcast dimension.
358 rhs_dims[i] = 1;
359 } else if (is_cp_dim_zero_padded(i)) {
360 bound_check_mask |= (1 << i);
361 }
362 }
363 return view_t(layout_t(type, 0, rhs_dims, /*do_normalize=*/false),
364 cp_view_.vvars(), bound_check_mask);
365 }
366
367 // rhs tensor layout is defined by md memory descriptor.
368 view_t create_view(const memory_desc_t &md) const {
369 ir_assert(cp_ndims() >= 3);
370 // Add groups to match ngcdhw layout.
371 bool add_groups = (cp_view_.vvars()[1].as<var_t>().name == "g");
372 layout_t layout(md, /*do_normalize=*/false);
373 std::vector<dim_t> dims(md.dims, md.dims + md.ndims);
374 std::vector<dim_t> padded_dims(
375 md.padded_dims, md.padded_dims + md.ndims);
376 maybe_reshape_dims(prb_->ndims, layout, dims, padded_dims);
377 layout = normalize_conv_layout(layout, /*with_groups=*/false, prb_->g,
378 prb_->is_dw, prb_->reduced_dim, cfg_->fuse_spatial(),
379 add_groups,
380 /*is_wei=*/false);
381 dims = normalize_conv_dims(dims, /*with_groups=*/false, prb_->g,
382 prb_->is_dw, prb_->reduced_dim, cfg_->fuse_spatial(),
383 add_groups,
384 /*is_wei=*/false);
385 padded_dims = normalize_conv_dims(padded_dims,
386 /*with_groups=*/false, prb_->g, prb_->is_dw, prb_->reduced_dim,
387 cfg_->fuse_spatial(), add_groups, /*is_wei=*/false);
388 ir_assert(layout.ndims() == cp_ndims()) << "Incompatible dimensions.";
389 uint32_t bound_check_mask = 0;
390 for (int i = 0; i < cp_ndims(); i++) {
391 if (dims[i] == 1) continue; // Broadcast, no bound check needed.
392 if (padded_dims[i] != cp_padded_dim(i)) {
393 bound_check_mask |= (1 << i);
394 } else if (has_cp_mask(i)) {
395 bound_check_mask |= (1 << i);
396 }
397 }
398 return view_t(layout, cp_view_.vvars(), dims, bound_check_mask);
399 }
400
401 static void maybe_reshape_dims(int ndims, layout_t &layout,
402 std::vector<dim_t> &dims, std::vector<dim_t> &padded_dims) {
403 ir_assert(layout.ndims() == int(dims.size()));
404 if (layout.ndims() < ndims) {
405 layout = layout_t(layout.type(), ndims, layout.offset(),
406 layout.blocks(), /*do_normalize=*/false);
407 dims.resize(ndims, 1);
408 padded_dims.resize(ndims, 1);
409 }
410 }
411
412 uint32_t normalize_mask(uint32_t orig_mask) const {
413 ir_assert(cp_ndims() >= 3);
414 // Add groups to match ngcdhw layout.
415 bool add_groups = (cp_view_.vvars()[1].as<var_t>().name == "g");
416 // Number of dimensions before normalization.
417 int orig_ndims = 2 + prb_->ndims;
418 std::vector<dim_t> dummy_dims(orig_ndims, 1);
419 dim_t mask_set_value = 2;
420 for (int i = 0; i < orig_ndims; i++) {
421 if ((orig_mask & (1 << i)) != 0) dummy_dims[i] = mask_set_value;
422 }
423 auto cvt_dims = normalize_conv_dims(dummy_dims, /*with_groups=*/false,
424 prb_->g, prb_->is_dw, prb_->reduced_dim, cfg_->fuse_spatial(),
425 /*add_groups=*/false,
426 /*is_wei=*/false);
427 // Split channels into groups and channels to match ngcdhw layout.
428 if (add_groups) cvt_dims.insert(cvt_dims.begin() + 1, cvt_dims[1]);
429 ir_assert(int(cvt_dims.size()) == cp_ndims());
430
431 uint32_t mask = 0;
432 for (int i = 0; i < cp_ndims(); i++) {
433 if (cvt_dims[i] == mask_set_value) mask = mask | (1 << i);
434 }
435 return mask;
436 }
437
438 const conv_problem_t *prb_;
439 const conv_config_t *cfg_;
440
441 bool need_to_restore_zero_padding_ = false;
442
443 view_t cp_view_;
444 std::vector<post_op_t> post_ops_;
445 std::vector<post_op_tensor_info_t> tensor_infos_;
446};
447
448} // namespace jit
449} // namespace gpu
450} // namespace impl
451} // namespace dnnl
452
453#endif
454