1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/conv/ir_builder.hpp"
18
19namespace dnnl {
20namespace impl {
21namespace gpu {
22namespace jit {
23
24namespace {
25
26bool need_src_or_dst_check(
27 bool is_fwd, int o, int i, int k, int p, int s, int d) {
28 if (is_fwd) {
29 int i_min = -p;
30 int i_max = (o - 1) * s - p + (k - 1) * (1 + d);
31 return (i_min < 0) || (i_max >= i);
32 }
33 // Backward.
34 int os_min = p - (k - 1) * (1 + d);
35 int os_max = (i - 1) + p;
36 return (os_min < 0) || (os_max >= o * s);
37}
38
39} // namespace
40
41// Represents hierarchy of tile levels and corresponding loop/grid indices.
42//
43// | Tile level | Nesting level | Maps to |
44// |------------|---------------|------------------------|
45// | grid_dim | 0 | Thread group |
46// | loop_dim | 1 | Loop in thread |
47// | tg_dim | 2 | Thread in thread group |
48// | iter_dim | 3 | Iteration in loop |
49class dim_tile_t {
50public:
51 const expr_t &grid_idx() const { return not_empty(grid_idx_); }
52 const expr_t &tg_idx() const { return not_empty(tg_idx_); }
53 const expr_t &loop_idx() const { return not_empty(loop_idx_); }
54 const expr_t &iter_idx() const { return not_empty(iter_idx_); }
55
56 void set_grid_idx(const expr_t &idx) { grid_idx_ = idx; }
57 void set_tg_idx(const expr_t &idx) { tg_idx_ = idx; }
58 void set_loop_idx(const expr_t &idx) { loop_idx_ = idx; }
59 void set_iter_idx(const expr_t &idx) { iter_idx_ = idx; }
60
61private:
62 static const expr_t &not_empty(const expr_t &v) {
63 ir_assert(!v.is_empty()) << "Queried empty index.";
64 return v;
65 }
66
67 expr_t grid_idx_;
68 expr_t tg_idx_;
69 expr_t loop_idx_;
70 expr_t iter_idx_;
71};
72
73dim_tile_t create_tile(gemm_schedule_t &gemm_schedule, const conv_config_t &cfg,
74 const expr_t &dim) {
75 dim_tile_t tile;
76 auto &name = dim.as<var_t>().name;
77 int loop_dim = cfg.loop_dim(name);
78 int tg_dim = cfg.thread_group_dim(name);
79 int iter_dim = cfg.iter_dim(name);
80
81 std::vector<int> dims = {1, loop_dim, tg_dim, iter_dim};
82 int ndims = (int)dims.size();
83 std::vector<expr_t> idxs(ndims);
84
85 static const char *suffixes[]
86 = {"_grid_idx", "_loop_idx", "_tg_idx", "_iter_idx"};
87 auto &dim_name = dim.as<var_t>().name;
88
89 auto has_block = [&](int dim_idx) {
90 bool is_thr = (dim_idx == 1);
91 bool is_tg = (dim_idx == 2);
92 bool is_iter = (dim_idx == 3);
93 if (is_thr || is_iter) return true;
94 for (int i = 0; i < 3; i++) {
95 auto **dd = is_tg ? get_thread_group_grid_conv_dims(cfg.prb(), i)
96 : get_kernel_grid_conv_dims(cfg.prb(), i);
97 for (auto **d = dd; *d; d++)
98 if (dim_name == *d) return true;
99 }
100 return false;
101 };
102
103 expr_t idx = dim;
104 for (int i = ndims - 1; i >= 1; i--) {
105 expr_t outer;
106 expr_t inner;
107 auto outer_name = (i == 1) ? dim_name + suffixes[i] : std::string();
108 auto inner_name = dim_name + suffixes[i];
109 gemm_schedule.split(idx, dims[i], outer, inner, outer_name, inner_name);
110 if (has_block(i)) idxs[i] = inner;
111 idx = outer;
112 }
113 idxs[0] = idx;
114
115 tile.set_grid_idx(idxs[0]);
116 tile.set_loop_idx(idxs[1]);
117 tile.set_tg_idx(idxs[2]);
118 tile.set_iter_idx(idxs[3]);
119
120 return tile;
121}
122
123void conv_ir_builder_t::init_fwd(gemm_schedule_t &gemm_schedule,
124 view_t &src_view, view_t &wei_view, view_t &dst_view, expr_t &src_buf,
125 expr_t &wei_buf, expr_t &dst_buf) {
126 auto &src_layout = cfg_.src_layout().compute();
127 auto &wei_layout = cfg_.wei_layout().compute();
128 auto &dst_layout = cfg_.dst_layout().compute();
129
130 // Initialize views.
131 auto mb = var_t::make(type_t::s32(), "mb");
132 auto ic = var_t::make(type_t::s32(), "ic");
133 auto oc = var_t::make(type_t::s32(), "oc");
134 auto kd = var_t::make(type_t::s32(), "kd");
135 auto kh = var_t::make(type_t::s32(), "kh");
136 auto kw = var_t::make(type_t::s32(), "kw");
137 auto g = var_t::make(type_t::s32(), "g");
138
139 expr_t ow, oh, od, osp;
140 bool check_od = false;
141 bool check_oh = false;
142 bool check_ow = false;
143 if (cfg_.fuse_spatial()) {
144 osp = var_t::make(type_t::s32(), "osp");
145 ow = osp;
146 oh = osp / prb_.ow;
147 od = osp / (prb_.oh * prb_.ow);
148
149 bool is_1d = (prb_.oh == 1 && prb_.od == 1);
150 bool is_2d = (prb_.oh != 1 && prb_.od == 1);
151 bool is_3d = !is_1d && !is_2d;
152
153 bool check_osp = (prb_.osp < cfg_.padded_dim("osp"));
154 check_ow = is_1d && check_osp;
155 check_oh = is_2d && check_osp;
156 check_od = is_3d && check_osp;
157
158 if (!is_1d) ow %= prb_.ow;
159 if (!is_2d) oh %= prb_.oh;
160 } else {
161 od = var_t::make(type_t::s32(), "od");
162 oh = var_t::make(type_t::s32(), "oh");
163 ow = var_t::make(type_t::s32(), "ow");
164 check_ow = (prb_.ow < cfg_.padded_dim("ow"));
165 }
166
167 // Initialize masks.
168 expr_t id_mask, ih_mask, iw_mask;
169 expr_t od_mask, oh_mask, ow_mask;
170
171 bool check_kw = (prb_.kw < cfg_.padded_dim("kw"));
172 bool check_iw = check_kw || check_ow
173 || need_src_or_dst_check(prb_.is_fwd, prb_.ow, prb_.iw, prb_.kw,
174 prb_.pw, prb_.sw, prb_.dw);
175 bool check_ih = check_oh
176 || need_src_or_dst_check(prb_.is_fwd, prb_.oh, prb_.ih, prb_.kh,
177 prb_.ph, prb_.sh, prb_.dh);
178 bool check_id = check_od
179 || need_src_or_dst_check(prb_.is_fwd, prb_.od, prb_.id, prb_.kd,
180 prb_.pd, prb_.sd, prb_.dd);
181
182 auto &x = view_t::placeholder_var();
183 if (check_id) id_mask = (x >= 0) & (x < prb_.id);
184 if (check_ih) ih_mask = (x >= 0) & (x < prb_.ih);
185 if (check_iw) iw_mask = (x >= 0) & (x < prb_.iw);
186 if (check_od) od_mask = (x >= 0) & (x < prb_.od);
187 if (check_oh) oh_mask = (x >= 0) & (x < prb_.oh);
188 if (check_ow) ow_mask = (x >= 0) & (x < prb_.ow);
189
190 // Source.
191 if (cfg_.fuse_spatial()) {
192 src_view = view_t({mb, g, ic, osp, kd, kh, kw}, 6);
193 } else {
194 src_view = view_t({mb, g, ic, od, oh, ow, kd, kh, kw}, 6);
195 }
196 src_view.set_vdim(mb, prb_.mb);
197 src_view.set_vdim(g, prb_.g);
198 src_view.set_vdim(ic, prb_.ic);
199 if (cfg_.fuse_spatial()) {
200 src_view.set_vdim(osp, prb_.osp);
201 } else {
202 src_view.set_vdim(od, prb_.od);
203 src_view.set_vdim(oh, prb_.oh);
204 src_view.set_vdim(ow, prb_.ow);
205 }
206 src_view.set_vdim(kd, prb_.kd);
207 src_view.set_vdim(kh, prb_.kh);
208 src_view.set_vdim(kw, prb_.kw);
209 src_view.set_tdim(0, mb);
210 src_view.set_tdim(1, g);
211 src_view.set_tdim(2, ic);
212 src_view.set_tdim(3, od * prb_.sd - prb_.pd + kd * (1 + prb_.dd), id_mask);
213 src_view.set_tdim(4, oh * prb_.sh - prb_.ph + kh * (1 + prb_.dh), ih_mask);
214 src_view.set_tdim(5, ow * prb_.sw - prb_.pw + kw * (1 + prb_.dw), iw_mask);
215 src_view.set_tlayout(src_layout);
216 src_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
217
218 // Weights.
219 wei_view = view_t({g, oc, ic, kd, kh, kw}, 6);
220 wei_view.set_vdim(g, prb_.g);
221 wei_view.set_vdim(oc, prb_.oc);
222 wei_view.set_vdim(ic, prb_.ic);
223 wei_view.set_vdim(kd, prb_.kd);
224 wei_view.set_vdim(kh, prb_.kh);
225 wei_view.set_vdim(kw, prb_.kw);
226 wei_view.set_tdim(0, g);
227 wei_view.set_tdim(1, oc);
228 wei_view.set_tdim(2, ic);
229 wei_view.set_tdim(3, kd);
230 wei_view.set_tdim(4, kh);
231 wei_view.set_tdim(5, kw);
232 wei_view.set_tlayout(wei_layout);
233 wei_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
234
235 // Destination.
236 if (cfg_.fuse_spatial()) {
237 dst_view = view_t({mb, g, oc, osp}, 6);
238 } else {
239 dst_view = view_t({mb, g, oc, od, oh, ow}, 6);
240 }
241 dst_view.set_vdim(mb, prb_.mb);
242 dst_view.set_vdim(g, prb_.g);
243 dst_view.set_vdim(oc, prb_.oc);
244 if (cfg_.fuse_spatial()) {
245 dst_view.set_vdim(osp, prb_.osp);
246 } else {
247 dst_view.set_vdim(od, prb_.od);
248 dst_view.set_vdim(oh, prb_.oh);
249 dst_view.set_vdim(ow, prb_.ow);
250 }
251 dst_view.set_tdim(0, mb);
252 dst_view.set_tdim(1, g);
253 dst_view.set_tdim(2, oc);
254 dst_view.set_tdim(3, od, od_mask);
255 dst_view.set_tdim(4, oh, oh_mask);
256 dst_view.set_tdim(5, ow, ow_mask);
257 dst_view.set_tlayout(dst_layout);
258 dst_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
259
260 // Initialize GEMM schedule.
261 gemm_schedule.set_a_view(src_view);
262 gemm_schedule.set_b_view(wei_view);
263 gemm_schedule.set_c_view(dst_view);
264 gemm_schedule.set_b_vars({g});
265 if (cfg_.fuse_spatial()) {
266 gemm_schedule.set_m_vars({mb, osp});
267 } else {
268 gemm_schedule.set_m_vars({mb, od, oh, ow});
269 }
270 gemm_schedule.set_n_vars({oc});
271 gemm_schedule.set_k_vars({ic, kd, kh, kw});
272
273 gemm_schedule.for_each_var([&](const expr_t &var) {
274 int bound = cfg_.padded_dim(var.as<var_t>().name);
275 gemm_schedule.set_var_bound(var, bound);
276 });
277
278 auto g_tile = create_tile(gemm_schedule, cfg_, g);
279 auto oc_tile = create_tile(gemm_schedule, cfg_, oc);
280 auto mb_tile = create_tile(gemm_schedule, cfg_, mb);
281 auto osp_tile = create_tile(gemm_schedule, cfg_, osp.is_empty() ? ow : osp);
282 auto ic_tile = create_tile(gemm_schedule, cfg_, ic);
283 auto kw_tile = create_tile(gemm_schedule, cfg_, kw);
284
285 auto g_osp_grid_idx = cfg_.fuse_spatial()
286 ? gemm_schedule.fuse({g_tile.grid_idx(), osp_tile.grid_idx()})
287 : gemm_schedule.fuse(
288 {g_tile.grid_idx(), od, oh, osp_tile.grid_idx()});
289 auto mb_osp_tg_idx
290 = gemm_schedule.fuse(mb_tile.tg_idx(), osp_tile.tg_idx());
291
292 gemm_schedule.bind(oc_tile.grid_idx(), cfg_.kernel_grid().idx(0));
293 gemm_schedule.bind(g_osp_grid_idx, cfg_.kernel_grid().idx(1));
294 gemm_schedule.bind(mb_tile.grid_idx(), cfg_.kernel_grid().idx(2));
295 gemm_schedule.bind(oc_tile.tg_idx(), cfg_.thread_group_grid().idx(0));
296 gemm_schedule.bind(mb_osp_tg_idx, cfg_.thread_group_grid().idx(1));
297 gemm_schedule.bind(ic_tile.tg_idx(), cfg_.thread_group_grid().idx(2));
298
299 gemm_schedule.tensorize(g_tile.iter_idx());
300 gemm_schedule.tensorize(oc_tile.iter_idx());
301 gemm_schedule.tensorize(mb_tile.iter_idx());
302 gemm_schedule.tensorize(osp_tile.iter_idx());
303 gemm_schedule.tensorize(kw_tile.iter_idx());
304 gemm_schedule.tensorize(ic_tile.iter_idx());
305
306 gemm_schedule.reorder({ic_tile.loop_idx(), kd, kh, kw_tile.loop_idx(),
307 oc_tile.tg_idx(), mb_osp_tg_idx, ic_tile.tg_idx()});
308
309 src_buf = kernel_info_.find_arg("src");
310 wei_buf = kernel_info_.find_arg("wei");
311 dst_buf = kernel_info_.find_arg("dst");
312}
313
314void conv_ir_builder_t::init_bwd_d(gemm_schedule_t &gemm_schedule,
315 view_t &dst_view, view_t &wei_view, view_t &src_view, expr_t &dst_buf,
316 expr_t &wei_buf, expr_t &src_buf) {
317 auto &src_layout = cfg_.src_layout().compute();
318 auto &wei_layout = cfg_.wei_layout().compute();
319 auto &dst_layout = cfg_.dst_layout().compute();
320
321 // Initialize views.
322 auto g = var_t::make(type_t::s32(), "g");
323 auto mb = var_t::make(type_t::s32(), "mb");
324 auto ic = var_t::make(type_t::s32(), "ic");
325 auto oc = var_t::make(type_t::s32(), "oc");
326 auto id = var_t::make(type_t::s32(), "id");
327 auto ih = var_t::make(type_t::s32(), "ih");
328 auto iw = var_t::make(type_t::s32(), "iw");
329 auto kd = var_t::make(type_t::s32(), "kd");
330 auto kh = var_t::make(type_t::s32(), "kh");
331 auto kw = var_t::make(type_t::s32(), "kw");
332
333 // Initialize masks.
334 expr_t od_mask(true), oh_mask(true), ow_mask(true);
335
336 bool check_iw = (prb_.iw < cfg_.padded_dim("iw"));
337 bool check_ow = check_iw
338 || need_src_or_dst_check(prb_.is_fwd, prb_.ow, prb_.iw, prb_.kw,
339 prb_.pw, prb_.sw, prb_.dw);
340 bool check_oh = need_src_or_dst_check(
341 prb_.is_fwd, prb_.oh, prb_.ih, prb_.kh, prb_.ph, prb_.sh, prb_.dh);
342 bool check_od = need_src_or_dst_check(
343 prb_.is_fwd, prb_.od, prb_.id, prb_.kd, prb_.pd, prb_.sd, prb_.dd);
344
345 auto &x = view_t::placeholder_var();
346 if (check_od) od_mask = (x >= 0) & (x < prb_.od);
347 if (check_oh) oh_mask = (x >= 0) & (x < prb_.oh);
348 if (check_ow) ow_mask = (x >= 0) & (x < prb_.ow);
349
350 std::function<expr_t(const expr_t &)> iw_mapping;
351 if (cfg_.bwd_d_optimize_strided_iw()) {
352 // Apply mapping to iw to ensure each thread group has the same
353 // stride condition when evaluating skip conditions.
354 iw_mapping = [&](const expr_t &e) {
355 int iw_tg_blk = cfg_.thread_group_dim("iw") * cfg_.iter_dim("iw");
356 int iw_bound = utils::rnd_up(prb_.iw, iw_tg_blk);
357 int iw_same_mod_blk = ir_utils::safe_divide(iw_bound, prb_.sw);
358 return (e % iw_same_mod_blk) * prb_.sw + (e / iw_same_mod_blk);
359 };
360 } else {
361 iw_mapping = [](const expr_t &e) { return e; };
362 }
363
364 // Destination.
365 dst_view = view_t({mb, g, oc, id, ih, iw, kd, kh, kw}, 6);
366 dst_view.set_vdim(mb, prb_.mb);
367 dst_view.set_vdim(g, prb_.g);
368 dst_view.set_vdim(oc, prb_.oc);
369 dst_view.set_vdim(id, prb_.id);
370 dst_view.set_vdim(ih, prb_.ih);
371 dst_view.set_vdim(iw, prb_.iw);
372 dst_view.set_vdim(kd, prb_.kd);
373 dst_view.set_vdim(kh, prb_.kh);
374 dst_view.set_vdim(kw, prb_.kw);
375 dst_view.set_tdim(0, mb);
376 dst_view.set_tdim(1, g);
377 dst_view.set_tdim(2, oc);
378
379 auto od = id - kd * (1 + prb_.dd) + prb_.pd;
380 auto oh = ih - kh * (1 + prb_.dh) + prb_.ph;
381 auto ow = iw_mapping(iw) - kw * (1 + prb_.dw) + prb_.pw;
382
383 // When stride optimization is enabled, stride conditions are handled by
384 // continue calls in the outer loops.
385 if (!cfg_.bwd_d_optimize_strided_iw()) {
386 od_mask &= (od % prb_.sd == 0);
387 oh_mask &= (oh % prb_.sh == 0);
388 ow_mask &= (ow % prb_.sw == 0);
389 }
390 dst_view.set_tdim(3, od / prb_.sd, od_mask);
391 dst_view.set_tdim(4, oh / prb_.sh, oh_mask);
392 dst_view.set_tdim(5, ow / prb_.sw, ow_mask);
393
394 dst_view.set_tlayout(dst_layout);
395 dst_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
396
397 // Weights.
398 wei_view = view_t({g, oc, ic, kd, kh, kw}, 6);
399 wei_view.set_vdim(g, prb_.g);
400 wei_view.set_vdim(ic, prb_.ic);
401 wei_view.set_vdim(oc, prb_.oc);
402 wei_view.set_vdim(kd, prb_.kd);
403 wei_view.set_vdim(kh, prb_.kh);
404 wei_view.set_vdim(kw, prb_.kw);
405 wei_view.set_tdim(0, g);
406 wei_view.set_tdim(1, oc);
407 wei_view.set_tdim(2, ic);
408 wei_view.set_tdim(3, kd);
409 wei_view.set_tdim(4, kh);
410 wei_view.set_tdim(5, kw);
411 wei_view.set_tlayout(wei_layout);
412 wei_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
413
414 // Source.
415 src_view = view_t({mb, g, ic, id, ih, iw}, 6);
416 src_view.set_vdim(mb, prb_.mb);
417 src_view.set_vdim(g, prb_.g);
418 src_view.set_vdim(ic, prb_.ic);
419 src_view.set_vdim(id, prb_.id);
420 src_view.set_vdim(ih, prb_.ih);
421 src_view.set_vdim(iw, prb_.iw);
422 src_view.set_tdim(0, mb);
423 src_view.set_tdim(1, g);
424 src_view.set_tdim(2, ic);
425 src_view.set_tdim(3, id);
426 src_view.set_tdim(4, ih);
427 src_view.set_tdim(5, iw_mapping(iw));
428 src_view.set_tlayout(src_layout);
429 src_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
430
431 // Initialize GEMM schedule.
432 gemm_schedule.set_a_view(dst_view);
433 gemm_schedule.set_b_view(wei_view);
434 gemm_schedule.set_c_view(src_view);
435 gemm_schedule.set_b_vars({g});
436 gemm_schedule.set_m_vars({mb, id, ih, iw});
437 gemm_schedule.set_n_vars({ic});
438 gemm_schedule.set_k_vars({oc, kd, kh, kw});
439
440 gemm_schedule.for_each_var([&](const expr_t &var) {
441 int bound = cfg_.padded_dim(var.as<var_t>().name);
442 gemm_schedule.set_var_bound(var, bound);
443 });
444
445 auto g_tile = create_tile(gemm_schedule, cfg_, g);
446 auto ic_tile = create_tile(gemm_schedule, cfg_, ic);
447 auto mb_tile = create_tile(gemm_schedule, cfg_, mb);
448 auto iw_tile = create_tile(gemm_schedule, cfg_, iw);
449 auto oc_tile = create_tile(gemm_schedule, cfg_, oc);
450
451 auto g_isp_grid_idx = gemm_schedule.fuse(
452 {g_tile.grid_idx(), id, ih, iw_tile.grid_idx()});
453 auto mb_iw_tg_idx = gemm_schedule.fuse(mb_tile.tg_idx(), iw_tile.tg_idx());
454
455 gemm_schedule.bind(ic_tile.grid_idx(), cfg_.kernel_grid().idx(0));
456 gemm_schedule.bind(g_isp_grid_idx, cfg_.kernel_grid().idx(1));
457 gemm_schedule.bind(mb_tile.grid_idx(), cfg_.kernel_grid().idx(2));
458 gemm_schedule.bind(ic_tile.tg_idx(), cfg_.thread_group_grid().idx(0));
459 gemm_schedule.bind(mb_iw_tg_idx, cfg_.thread_group_grid().idx(1));
460 gemm_schedule.bind(oc_tile.tg_idx(), cfg_.thread_group_grid().idx(2));
461
462 gemm_schedule.tensorize(g_tile.iter_idx());
463 gemm_schedule.tensorize(ic_tile.iter_idx());
464 gemm_schedule.tensorize(mb_tile.iter_idx());
465 gemm_schedule.tensorize(iw_tile.iter_idx());
466 gemm_schedule.tensorize(oc_tile.iter_idx());
467
468 if (cfg_.bwd_d_optimize_strided()) {
469 gemm_schedule.set_skip_condition(kd, od % prb_.sd != 0);
470 gemm_schedule.set_skip_condition(kh, oh % prb_.sh != 0);
471 if (cfg_.bwd_d_optimize_strided_iw())
472 gemm_schedule.set_skip_condition(kw, ow % prb_.sw != 0);
473 // Put kd/kh/kw outermost to allow pipelining in oc loop.
474 gemm_schedule.reorder({kd, kh, kw, oc_tile.loop_idx()});
475 } else {
476 gemm_schedule.reorder({oc_tile.loop_idx(), kd, kh, kw});
477 }
478
479 src_buf = kernel_info_.find_arg("src");
480 wei_buf = kernel_info_.find_arg("wei");
481 dst_buf = kernel_info_.find_arg("dst");
482}
483
484void conv_ir_builder_t::init_bwd_w(gemm_schedule_t &gemm_schedule,
485 view_t &src_view, view_t &dst_view, view_t &wei_view, view_t &bia_view,
486 expr_t &src_buf, expr_t &dst_buf, expr_t &wei_buf, expr_t &bia_buf,
487 expr_t &bia_reduction_condition) {
488 auto &src_layout = cfg_.src_layout().compute();
489 auto &wei_layout = cfg_.wei_layout().compute();
490 auto &dst_layout = cfg_.dst_layout().compute();
491 auto &bia_layout = cfg_.bia_layout().compute();
492
493 // Initialize thread group views.
494 auto g = var_t::make(type_t::s32(), "g");
495 auto mb = var_t::make(type_t::s32(), "mb");
496 auto ic = var_t::make(type_t::s32(), "ic");
497 auto oc = var_t::make(type_t::s32(), "oc");
498 auto od = var_t::make(type_t::s32(), "od");
499 auto oh = var_t::make(type_t::s32(), "oh");
500 auto ow = var_t::make(type_t::s32(), "ow");
501 auto kd = var_t::make(type_t::s32(), "kd");
502 auto kh = var_t::make(type_t::s32(), "kh");
503 auto kw = var_t::make(type_t::s32(), "kw");
504
505 // Initialize masks.
506 expr_t id_mask(true), ih_mask(true), iw_mask(true);
507
508 bool check_ow = (prb_.ow < cfg_.padded_dim("ow"));
509 bool check_oh = (prb_.oh < cfg_.padded_dim("oh"));
510 bool check_od = (prb_.od < cfg_.padded_dim("od"));
511 bool check_kw = (prb_.kw < cfg_.padded_dim("kw"));
512 bool check_iw = check_kw
513 || need_src_or_dst_check(/*is_fwd=*/true, prb_.ow, prb_.iw, prb_.kw,
514 prb_.pw, prb_.sw, prb_.dw);
515 bool check_ih = need_src_or_dst_check(/*is_fwd=*/true, prb_.oh, prb_.ih,
516 prb_.kh, prb_.ph, prb_.sh, prb_.dh);
517 bool check_id = need_src_or_dst_check(/*is_fwd=*/true, prb_.od, prb_.id,
518 prb_.kd, prb_.pd, prb_.sd, prb_.dd);
519 bool check_iw_min = check_iw;
520 bool check_ih_min = check_ih;
521 bool check_id_min = check_id;
522 bool check_iw_max = (check_iw || check_ow);
523 bool check_ih_max = (check_ih || check_oh);
524 bool check_id_max = (check_id || check_od);
525
526 auto &x = view_t::placeholder_var();
527 if (check_id_min) id_mask &= (x >= 0);
528 if (check_ih_min) ih_mask &= (x >= 0);
529 if (check_iw_min) iw_mask &= (x >= 0);
530 if (check_id_max) id_mask &= (x < prb_.id);
531 if (check_ih_max) ih_mask &= (x < prb_.ih);
532 if (check_iw_max) iw_mask &= (x < prb_.iw);
533
534 // Source.
535 src_view = view_t({mb, g, ic, od, oh, ow, kw}, 6);
536 src_view.set_vdim(mb, prb_.mb);
537 src_view.set_vdim(g, prb_.g);
538 src_view.set_vdim(ic, prb_.ic);
539 src_view.set_vdim(od, prb_.od);
540 src_view.set_vdim(oh, prb_.oh);
541 src_view.set_vdim(ow, prb_.ow);
542 src_view.set_vdim(kw, prb_.kw);
543 src_view.set_tdim(0, mb);
544 src_view.set_tdim(1, g);
545 src_view.set_tdim(2, ic);
546 src_view.set_tdim(3, od * prb_.sd - prb_.pd + kd * (1 + prb_.dd), id_mask);
547 src_view.set_tdim(4, oh * prb_.sh - prb_.ph + kh * (1 + prb_.dh), ih_mask);
548 src_view.set_tdim(5, ow * prb_.sw - prb_.pw + kw * (1 + prb_.dw), iw_mask);
549 src_view.set_tlayout(src_layout);
550 src_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
551
552 // Weights.
553 wei_view = view_t({g, oc, ic, kd, kh, kw}, 6);
554 wei_view.set_vdim(g, prb_.g);
555 wei_view.set_vdim(oc, prb_.oc);
556 wei_view.set_vdim(ic, prb_.ic);
557 wei_view.set_vdim(kd, prb_.kd);
558 wei_view.set_vdim(kh, prb_.kh);
559 wei_view.set_vdim(kw, prb_.kw);
560 wei_view.set_tdim(0, g);
561 wei_view.set_tdim(1, oc);
562 wei_view.set_tdim(2, ic);
563 wei_view.set_tdim(3, kd);
564 wei_view.set_tdim(4, kh);
565 wei_view.set_tdim(5, kw);
566 wei_view.set_tlayout(wei_layout);
567 wei_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
568
569 // Destination.
570 dst_view = view_t({mb, g, oc, od, oh, ow}, 6);
571 dst_view.set_vdim(mb, prb_.mb);
572 dst_view.set_vdim(g, prb_.g);
573 dst_view.set_vdim(oc, prb_.oc);
574 dst_view.set_vdim(od, prb_.od);
575 dst_view.set_vdim(oh, prb_.oh);
576 dst_view.set_vdim(ow, prb_.ow);
577 dst_view.set_tdim(0, mb);
578 dst_view.set_tdim(1, g);
579 dst_view.set_tdim(2, oc);
580 dst_view.set_tdim(3, od);
581 dst_view.set_tdim(4, oh);
582 dst_view.set_tdim(5, ow);
583 dst_view.set_tlayout(dst_layout);
584 dst_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
585
586 // Bias.
587 if (prb_.with_bias) {
588 bia_view = view_t({g, oc}, 2);
589 bia_view.set_vdim(g, prb_.g);
590 bia_view.set_vdim(oc, prb_.oc);
591 bia_view.set_tdim(0, g);
592 bia_view.set_tdim(1, oc);
593 bia_view.set_tlayout(bia_layout);
594 bia_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get());
595 }
596
597 // Initialize GEMM schedule.
598 gemm_schedule.set_a_view(src_view);
599 gemm_schedule.set_b_view(dst_view);
600 gemm_schedule.set_c_view(wei_view);
601 gemm_schedule.set_b_vars({g});
602 gemm_schedule.set_m_vars({ic, kw});
603 gemm_schedule.set_n_vars({oc});
604 gemm_schedule.set_k_vars({mb, od, oh, ow});
605
606 gemm_schedule.for_each_var([&](const expr_t &var) {
607 int bound = cfg_.padded_dim(var.as<var_t>().name);
608 gemm_schedule.set_var_bound(var, bound);
609 });
610
611 auto g_tile = create_tile(gemm_schedule, cfg_, g);
612 auto mb_tile = create_tile(gemm_schedule, cfg_, mb);
613 auto ic_tile = create_tile(gemm_schedule, cfg_, ic);
614 auto oc_tile = create_tile(gemm_schedule, cfg_, oc);
615 auto od_tile = create_tile(gemm_schedule, cfg_, od);
616 auto oh_tile = create_tile(gemm_schedule, cfg_, oh);
617 auto ow_tile = create_tile(gemm_schedule, cfg_, ow);
618 auto kw_tile = create_tile(gemm_schedule, cfg_, kw);
619
620 auto osp_ksp_ic_grid_idx = gemm_schedule.fuse(
621 {od_tile.grid_idx(), oh_tile.grid_idx(), ow_tile.grid_idx(), kd, kh,
622 kw_tile.grid_idx(), ic_tile.grid_idx()});
623
624 auto g_mb_grid_idx
625 = gemm_schedule.fuse({g_tile.grid_idx(), mb_tile.grid_idx()});
626
627 gemm_schedule.bind(oc_tile.grid_idx(), cfg_.kernel_grid().idx(0));
628 gemm_schedule.bind(osp_ksp_ic_grid_idx, cfg_.kernel_grid().idx(1));
629 gemm_schedule.bind(g_mb_grid_idx, cfg_.kernel_grid().idx(2));
630
631 gemm_schedule.bind(oc_tile.tg_idx(), cfg_.thread_group_grid().idx(0));
632 gemm_schedule.bind(ic_tile.tg_idx(), cfg_.thread_group_grid().idx(1));
633
634 gemm_schedule.reorder({od_tile.loop_idx(), oh_tile.loop_idx(),
635 ow_tile.loop_idx(), mb_tile.loop_idx()});
636
637 gemm_schedule.unroll(mb_tile.loop_idx(), cfg_.unroll("mb"));
638 gemm_schedule.unroll(ow_tile.loop_idx(), cfg_.unroll("ow"));
639
640 gemm_schedule.tensorize(g_tile.iter_idx());
641 gemm_schedule.tensorize(oc_tile.iter_idx());
642 gemm_schedule.tensorize(ic_tile.iter_idx());
643 gemm_schedule.tensorize(mb_tile.iter_idx());
644 gemm_schedule.tensorize(ow_tile.iter_idx());
645 gemm_schedule.tensorize(kw_tile.iter_idx());
646
647 src_buf = kernel_info_.find_arg("src");
648 wei_buf = kernel_info_.find_arg("wei");
649 dst_buf = kernel_info_.find_arg("dst");
650
651 if (prb_.with_bias) {
652 bia_buf = kernel_info_.find_arg("bia");
653 bia_reduction_condition = expr_t(true);
654 if (prb_.kd > 1) bia_reduction_condition &= (kd == 0);
655 if (prb_.kh > 1) bia_reduction_condition &= (kh == 0);
656 if (prb_.kw > 1) bia_reduction_condition &= (kw_tile.grid_idx() == 0);
657 if (cfg_.grid_dim("ic") > 1)
658 bia_reduction_condition &= (ic_tile.grid_idx() == 0);
659 if (!cfg_.slm().b() && cfg_.thread_group_grid().dim(1) > 1) {
660 bia_reduction_condition &= (cfg_.thread_group_grid().idx(1) == 0);
661 }
662 }
663}
664
665} // namespace jit
666} // namespace gpu
667} // namespace impl
668} // namespace dnnl
669