1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/conv/ir_builder.hpp" |
18 | |
19 | namespace dnnl { |
20 | namespace impl { |
21 | namespace gpu { |
22 | namespace jit { |
23 | |
24 | namespace { |
25 | |
26 | bool need_src_or_dst_check( |
27 | bool is_fwd, int o, int i, int k, int p, int s, int d) { |
28 | if (is_fwd) { |
29 | int i_min = -p; |
30 | int i_max = (o - 1) * s - p + (k - 1) * (1 + d); |
31 | return (i_min < 0) || (i_max >= i); |
32 | } |
33 | // Backward. |
34 | int os_min = p - (k - 1) * (1 + d); |
35 | int os_max = (i - 1) + p; |
36 | return (os_min < 0) || (os_max >= o * s); |
37 | } |
38 | |
39 | } // namespace |
40 | |
41 | // Represents hierarchy of tile levels and corresponding loop/grid indices. |
42 | // |
43 | // | Tile level | Nesting level | Maps to | |
44 | // |------------|---------------|------------------------| |
45 | // | grid_dim | 0 | Thread group | |
46 | // | loop_dim | 1 | Loop in thread | |
47 | // | tg_dim | 2 | Thread in thread group | |
48 | // | iter_dim | 3 | Iteration in loop | |
49 | class dim_tile_t { |
50 | public: |
51 | const expr_t &grid_idx() const { return not_empty(grid_idx_); } |
52 | const expr_t &tg_idx() const { return not_empty(tg_idx_); } |
53 | const expr_t &loop_idx() const { return not_empty(loop_idx_); } |
54 | const expr_t &iter_idx() const { return not_empty(iter_idx_); } |
55 | |
56 | void set_grid_idx(const expr_t &idx) { grid_idx_ = idx; } |
57 | void set_tg_idx(const expr_t &idx) { tg_idx_ = idx; } |
58 | void set_loop_idx(const expr_t &idx) { loop_idx_ = idx; } |
59 | void set_iter_idx(const expr_t &idx) { iter_idx_ = idx; } |
60 | |
61 | private: |
62 | static const expr_t ¬_empty(const expr_t &v) { |
63 | ir_assert(!v.is_empty()) << "Queried empty index." ; |
64 | return v; |
65 | } |
66 | |
67 | expr_t grid_idx_; |
68 | expr_t tg_idx_; |
69 | expr_t loop_idx_; |
70 | expr_t iter_idx_; |
71 | }; |
72 | |
73 | dim_tile_t create_tile(gemm_schedule_t &gemm_schedule, const conv_config_t &cfg, |
74 | const expr_t &dim) { |
75 | dim_tile_t tile; |
76 | auto &name = dim.as<var_t>().name; |
77 | int loop_dim = cfg.loop_dim(name); |
78 | int tg_dim = cfg.thread_group_dim(name); |
79 | int iter_dim = cfg.iter_dim(name); |
80 | |
81 | std::vector<int> dims = {1, loop_dim, tg_dim, iter_dim}; |
82 | int ndims = (int)dims.size(); |
83 | std::vector<expr_t> idxs(ndims); |
84 | |
85 | static const char *suffixes[] |
86 | = {"_grid_idx" , "_loop_idx" , "_tg_idx" , "_iter_idx" }; |
87 | auto &dim_name = dim.as<var_t>().name; |
88 | |
89 | auto has_block = [&](int dim_idx) { |
90 | bool is_thr = (dim_idx == 1); |
91 | bool is_tg = (dim_idx == 2); |
92 | bool is_iter = (dim_idx == 3); |
93 | if (is_thr || is_iter) return true; |
94 | for (int i = 0; i < 3; i++) { |
95 | auto **dd = is_tg ? get_thread_group_grid_conv_dims(cfg.prb(), i) |
96 | : get_kernel_grid_conv_dims(cfg.prb(), i); |
97 | for (auto **d = dd; *d; d++) |
98 | if (dim_name == *d) return true; |
99 | } |
100 | return false; |
101 | }; |
102 | |
103 | expr_t idx = dim; |
104 | for (int i = ndims - 1; i >= 1; i--) { |
105 | expr_t outer; |
106 | expr_t inner; |
107 | auto outer_name = (i == 1) ? dim_name + suffixes[i] : std::string(); |
108 | auto inner_name = dim_name + suffixes[i]; |
109 | gemm_schedule.split(idx, dims[i], outer, inner, outer_name, inner_name); |
110 | if (has_block(i)) idxs[i] = inner; |
111 | idx = outer; |
112 | } |
113 | idxs[0] = idx; |
114 | |
115 | tile.set_grid_idx(idxs[0]); |
116 | tile.set_loop_idx(idxs[1]); |
117 | tile.set_tg_idx(idxs[2]); |
118 | tile.set_iter_idx(idxs[3]); |
119 | |
120 | return tile; |
121 | } |
122 | |
123 | void conv_ir_builder_t::init_fwd(gemm_schedule_t &gemm_schedule, |
124 | view_t &src_view, view_t &wei_view, view_t &dst_view, expr_t &src_buf, |
125 | expr_t &wei_buf, expr_t &dst_buf) { |
126 | auto &src_layout = cfg_.src_layout().compute(); |
127 | auto &wei_layout = cfg_.wei_layout().compute(); |
128 | auto &dst_layout = cfg_.dst_layout().compute(); |
129 | |
130 | // Initialize views. |
131 | auto mb = var_t::make(type_t::s32(), "mb" ); |
132 | auto ic = var_t::make(type_t::s32(), "ic" ); |
133 | auto oc = var_t::make(type_t::s32(), "oc" ); |
134 | auto kd = var_t::make(type_t::s32(), "kd" ); |
135 | auto kh = var_t::make(type_t::s32(), "kh" ); |
136 | auto kw = var_t::make(type_t::s32(), "kw" ); |
137 | auto g = var_t::make(type_t::s32(), "g" ); |
138 | |
139 | expr_t ow, oh, od, osp; |
140 | bool check_od = false; |
141 | bool check_oh = false; |
142 | bool check_ow = false; |
143 | if (cfg_.fuse_spatial()) { |
144 | osp = var_t::make(type_t::s32(), "osp" ); |
145 | ow = osp; |
146 | oh = osp / prb_.ow; |
147 | od = osp / (prb_.oh * prb_.ow); |
148 | |
149 | bool is_1d = (prb_.oh == 1 && prb_.od == 1); |
150 | bool is_2d = (prb_.oh != 1 && prb_.od == 1); |
151 | bool is_3d = !is_1d && !is_2d; |
152 | |
153 | bool check_osp = (prb_.osp < cfg_.padded_dim("osp" )); |
154 | check_ow = is_1d && check_osp; |
155 | check_oh = is_2d && check_osp; |
156 | check_od = is_3d && check_osp; |
157 | |
158 | if (!is_1d) ow %= prb_.ow; |
159 | if (!is_2d) oh %= prb_.oh; |
160 | } else { |
161 | od = var_t::make(type_t::s32(), "od" ); |
162 | oh = var_t::make(type_t::s32(), "oh" ); |
163 | ow = var_t::make(type_t::s32(), "ow" ); |
164 | check_ow = (prb_.ow < cfg_.padded_dim("ow" )); |
165 | } |
166 | |
167 | // Initialize masks. |
168 | expr_t id_mask, ih_mask, iw_mask; |
169 | expr_t od_mask, oh_mask, ow_mask; |
170 | |
171 | bool check_kw = (prb_.kw < cfg_.padded_dim("kw" )); |
172 | bool check_iw = check_kw || check_ow |
173 | || need_src_or_dst_check(prb_.is_fwd, prb_.ow, prb_.iw, prb_.kw, |
174 | prb_.pw, prb_.sw, prb_.dw); |
175 | bool check_ih = check_oh |
176 | || need_src_or_dst_check(prb_.is_fwd, prb_.oh, prb_.ih, prb_.kh, |
177 | prb_.ph, prb_.sh, prb_.dh); |
178 | bool check_id = check_od |
179 | || need_src_or_dst_check(prb_.is_fwd, prb_.od, prb_.id, prb_.kd, |
180 | prb_.pd, prb_.sd, prb_.dd); |
181 | |
182 | auto &x = view_t::placeholder_var(); |
183 | if (check_id) id_mask = (x >= 0) & (x < prb_.id); |
184 | if (check_ih) ih_mask = (x >= 0) & (x < prb_.ih); |
185 | if (check_iw) iw_mask = (x >= 0) & (x < prb_.iw); |
186 | if (check_od) od_mask = (x >= 0) & (x < prb_.od); |
187 | if (check_oh) oh_mask = (x >= 0) & (x < prb_.oh); |
188 | if (check_ow) ow_mask = (x >= 0) & (x < prb_.ow); |
189 | |
190 | // Source. |
191 | if (cfg_.fuse_spatial()) { |
192 | src_view = view_t({mb, g, ic, osp, kd, kh, kw}, 6); |
193 | } else { |
194 | src_view = view_t({mb, g, ic, od, oh, ow, kd, kh, kw}, 6); |
195 | } |
196 | src_view.set_vdim(mb, prb_.mb); |
197 | src_view.set_vdim(g, prb_.g); |
198 | src_view.set_vdim(ic, prb_.ic); |
199 | if (cfg_.fuse_spatial()) { |
200 | src_view.set_vdim(osp, prb_.osp); |
201 | } else { |
202 | src_view.set_vdim(od, prb_.od); |
203 | src_view.set_vdim(oh, prb_.oh); |
204 | src_view.set_vdim(ow, prb_.ow); |
205 | } |
206 | src_view.set_vdim(kd, prb_.kd); |
207 | src_view.set_vdim(kh, prb_.kh); |
208 | src_view.set_vdim(kw, prb_.kw); |
209 | src_view.set_tdim(0, mb); |
210 | src_view.set_tdim(1, g); |
211 | src_view.set_tdim(2, ic); |
212 | src_view.set_tdim(3, od * prb_.sd - prb_.pd + kd * (1 + prb_.dd), id_mask); |
213 | src_view.set_tdim(4, oh * prb_.sh - prb_.ph + kh * (1 + prb_.dh), ih_mask); |
214 | src_view.set_tdim(5, ow * prb_.sw - prb_.pw + kw * (1 + prb_.dw), iw_mask); |
215 | src_view.set_tlayout(src_layout); |
216 | src_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
217 | |
218 | // Weights. |
219 | wei_view = view_t({g, oc, ic, kd, kh, kw}, 6); |
220 | wei_view.set_vdim(g, prb_.g); |
221 | wei_view.set_vdim(oc, prb_.oc); |
222 | wei_view.set_vdim(ic, prb_.ic); |
223 | wei_view.set_vdim(kd, prb_.kd); |
224 | wei_view.set_vdim(kh, prb_.kh); |
225 | wei_view.set_vdim(kw, prb_.kw); |
226 | wei_view.set_tdim(0, g); |
227 | wei_view.set_tdim(1, oc); |
228 | wei_view.set_tdim(2, ic); |
229 | wei_view.set_tdim(3, kd); |
230 | wei_view.set_tdim(4, kh); |
231 | wei_view.set_tdim(5, kw); |
232 | wei_view.set_tlayout(wei_layout); |
233 | wei_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
234 | |
235 | // Destination. |
236 | if (cfg_.fuse_spatial()) { |
237 | dst_view = view_t({mb, g, oc, osp}, 6); |
238 | } else { |
239 | dst_view = view_t({mb, g, oc, od, oh, ow}, 6); |
240 | } |
241 | dst_view.set_vdim(mb, prb_.mb); |
242 | dst_view.set_vdim(g, prb_.g); |
243 | dst_view.set_vdim(oc, prb_.oc); |
244 | if (cfg_.fuse_spatial()) { |
245 | dst_view.set_vdim(osp, prb_.osp); |
246 | } else { |
247 | dst_view.set_vdim(od, prb_.od); |
248 | dst_view.set_vdim(oh, prb_.oh); |
249 | dst_view.set_vdim(ow, prb_.ow); |
250 | } |
251 | dst_view.set_tdim(0, mb); |
252 | dst_view.set_tdim(1, g); |
253 | dst_view.set_tdim(2, oc); |
254 | dst_view.set_tdim(3, od, od_mask); |
255 | dst_view.set_tdim(4, oh, oh_mask); |
256 | dst_view.set_tdim(5, ow, ow_mask); |
257 | dst_view.set_tlayout(dst_layout); |
258 | dst_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
259 | |
260 | // Initialize GEMM schedule. |
261 | gemm_schedule.set_a_view(src_view); |
262 | gemm_schedule.set_b_view(wei_view); |
263 | gemm_schedule.set_c_view(dst_view); |
264 | gemm_schedule.set_b_vars({g}); |
265 | if (cfg_.fuse_spatial()) { |
266 | gemm_schedule.set_m_vars({mb, osp}); |
267 | } else { |
268 | gemm_schedule.set_m_vars({mb, od, oh, ow}); |
269 | } |
270 | gemm_schedule.set_n_vars({oc}); |
271 | gemm_schedule.set_k_vars({ic, kd, kh, kw}); |
272 | |
273 | gemm_schedule.for_each_var([&](const expr_t &var) { |
274 | int bound = cfg_.padded_dim(var.as<var_t>().name); |
275 | gemm_schedule.set_var_bound(var, bound); |
276 | }); |
277 | |
278 | auto g_tile = create_tile(gemm_schedule, cfg_, g); |
279 | auto oc_tile = create_tile(gemm_schedule, cfg_, oc); |
280 | auto mb_tile = create_tile(gemm_schedule, cfg_, mb); |
281 | auto osp_tile = create_tile(gemm_schedule, cfg_, osp.is_empty() ? ow : osp); |
282 | auto ic_tile = create_tile(gemm_schedule, cfg_, ic); |
283 | auto kw_tile = create_tile(gemm_schedule, cfg_, kw); |
284 | |
285 | auto g_osp_grid_idx = cfg_.fuse_spatial() |
286 | ? gemm_schedule.fuse({g_tile.grid_idx(), osp_tile.grid_idx()}) |
287 | : gemm_schedule.fuse( |
288 | {g_tile.grid_idx(), od, oh, osp_tile.grid_idx()}); |
289 | auto mb_osp_tg_idx |
290 | = gemm_schedule.fuse(mb_tile.tg_idx(), osp_tile.tg_idx()); |
291 | |
292 | gemm_schedule.bind(oc_tile.grid_idx(), cfg_.kernel_grid().idx(0)); |
293 | gemm_schedule.bind(g_osp_grid_idx, cfg_.kernel_grid().idx(1)); |
294 | gemm_schedule.bind(mb_tile.grid_idx(), cfg_.kernel_grid().idx(2)); |
295 | gemm_schedule.bind(oc_tile.tg_idx(), cfg_.thread_group_grid().idx(0)); |
296 | gemm_schedule.bind(mb_osp_tg_idx, cfg_.thread_group_grid().idx(1)); |
297 | gemm_schedule.bind(ic_tile.tg_idx(), cfg_.thread_group_grid().idx(2)); |
298 | |
299 | gemm_schedule.tensorize(g_tile.iter_idx()); |
300 | gemm_schedule.tensorize(oc_tile.iter_idx()); |
301 | gemm_schedule.tensorize(mb_tile.iter_idx()); |
302 | gemm_schedule.tensorize(osp_tile.iter_idx()); |
303 | gemm_schedule.tensorize(kw_tile.iter_idx()); |
304 | gemm_schedule.tensorize(ic_tile.iter_idx()); |
305 | |
306 | gemm_schedule.reorder({ic_tile.loop_idx(), kd, kh, kw_tile.loop_idx(), |
307 | oc_tile.tg_idx(), mb_osp_tg_idx, ic_tile.tg_idx()}); |
308 | |
309 | src_buf = kernel_info_.find_arg("src" ); |
310 | wei_buf = kernel_info_.find_arg("wei" ); |
311 | dst_buf = kernel_info_.find_arg("dst" ); |
312 | } |
313 | |
314 | void conv_ir_builder_t::init_bwd_d(gemm_schedule_t &gemm_schedule, |
315 | view_t &dst_view, view_t &wei_view, view_t &src_view, expr_t &dst_buf, |
316 | expr_t &wei_buf, expr_t &src_buf) { |
317 | auto &src_layout = cfg_.src_layout().compute(); |
318 | auto &wei_layout = cfg_.wei_layout().compute(); |
319 | auto &dst_layout = cfg_.dst_layout().compute(); |
320 | |
321 | // Initialize views. |
322 | auto g = var_t::make(type_t::s32(), "g" ); |
323 | auto mb = var_t::make(type_t::s32(), "mb" ); |
324 | auto ic = var_t::make(type_t::s32(), "ic" ); |
325 | auto oc = var_t::make(type_t::s32(), "oc" ); |
326 | auto id = var_t::make(type_t::s32(), "id" ); |
327 | auto ih = var_t::make(type_t::s32(), "ih" ); |
328 | auto iw = var_t::make(type_t::s32(), "iw" ); |
329 | auto kd = var_t::make(type_t::s32(), "kd" ); |
330 | auto kh = var_t::make(type_t::s32(), "kh" ); |
331 | auto kw = var_t::make(type_t::s32(), "kw" ); |
332 | |
333 | // Initialize masks. |
334 | expr_t od_mask(true), oh_mask(true), ow_mask(true); |
335 | |
336 | bool check_iw = (prb_.iw < cfg_.padded_dim("iw" )); |
337 | bool check_ow = check_iw |
338 | || need_src_or_dst_check(prb_.is_fwd, prb_.ow, prb_.iw, prb_.kw, |
339 | prb_.pw, prb_.sw, prb_.dw); |
340 | bool check_oh = need_src_or_dst_check( |
341 | prb_.is_fwd, prb_.oh, prb_.ih, prb_.kh, prb_.ph, prb_.sh, prb_.dh); |
342 | bool check_od = need_src_or_dst_check( |
343 | prb_.is_fwd, prb_.od, prb_.id, prb_.kd, prb_.pd, prb_.sd, prb_.dd); |
344 | |
345 | auto &x = view_t::placeholder_var(); |
346 | if (check_od) od_mask = (x >= 0) & (x < prb_.od); |
347 | if (check_oh) oh_mask = (x >= 0) & (x < prb_.oh); |
348 | if (check_ow) ow_mask = (x >= 0) & (x < prb_.ow); |
349 | |
350 | std::function<expr_t(const expr_t &)> iw_mapping; |
351 | if (cfg_.bwd_d_optimize_strided_iw()) { |
352 | // Apply mapping to iw to ensure each thread group has the same |
353 | // stride condition when evaluating skip conditions. |
354 | iw_mapping = [&](const expr_t &e) { |
355 | int iw_tg_blk = cfg_.thread_group_dim("iw" ) * cfg_.iter_dim("iw" ); |
356 | int iw_bound = utils::rnd_up(prb_.iw, iw_tg_blk); |
357 | int iw_same_mod_blk = ir_utils::safe_divide(iw_bound, prb_.sw); |
358 | return (e % iw_same_mod_blk) * prb_.sw + (e / iw_same_mod_blk); |
359 | }; |
360 | } else { |
361 | iw_mapping = [](const expr_t &e) { return e; }; |
362 | } |
363 | |
364 | // Destination. |
365 | dst_view = view_t({mb, g, oc, id, ih, iw, kd, kh, kw}, 6); |
366 | dst_view.set_vdim(mb, prb_.mb); |
367 | dst_view.set_vdim(g, prb_.g); |
368 | dst_view.set_vdim(oc, prb_.oc); |
369 | dst_view.set_vdim(id, prb_.id); |
370 | dst_view.set_vdim(ih, prb_.ih); |
371 | dst_view.set_vdim(iw, prb_.iw); |
372 | dst_view.set_vdim(kd, prb_.kd); |
373 | dst_view.set_vdim(kh, prb_.kh); |
374 | dst_view.set_vdim(kw, prb_.kw); |
375 | dst_view.set_tdim(0, mb); |
376 | dst_view.set_tdim(1, g); |
377 | dst_view.set_tdim(2, oc); |
378 | |
379 | auto od = id - kd * (1 + prb_.dd) + prb_.pd; |
380 | auto oh = ih - kh * (1 + prb_.dh) + prb_.ph; |
381 | auto ow = iw_mapping(iw) - kw * (1 + prb_.dw) + prb_.pw; |
382 | |
383 | // When stride optimization is enabled, stride conditions are handled by |
384 | // continue calls in the outer loops. |
385 | if (!cfg_.bwd_d_optimize_strided_iw()) { |
386 | od_mask &= (od % prb_.sd == 0); |
387 | oh_mask &= (oh % prb_.sh == 0); |
388 | ow_mask &= (ow % prb_.sw == 0); |
389 | } |
390 | dst_view.set_tdim(3, od / prb_.sd, od_mask); |
391 | dst_view.set_tdim(4, oh / prb_.sh, oh_mask); |
392 | dst_view.set_tdim(5, ow / prb_.sw, ow_mask); |
393 | |
394 | dst_view.set_tlayout(dst_layout); |
395 | dst_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
396 | |
397 | // Weights. |
398 | wei_view = view_t({g, oc, ic, kd, kh, kw}, 6); |
399 | wei_view.set_vdim(g, prb_.g); |
400 | wei_view.set_vdim(ic, prb_.ic); |
401 | wei_view.set_vdim(oc, prb_.oc); |
402 | wei_view.set_vdim(kd, prb_.kd); |
403 | wei_view.set_vdim(kh, prb_.kh); |
404 | wei_view.set_vdim(kw, prb_.kw); |
405 | wei_view.set_tdim(0, g); |
406 | wei_view.set_tdim(1, oc); |
407 | wei_view.set_tdim(2, ic); |
408 | wei_view.set_tdim(3, kd); |
409 | wei_view.set_tdim(4, kh); |
410 | wei_view.set_tdim(5, kw); |
411 | wei_view.set_tlayout(wei_layout); |
412 | wei_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
413 | |
414 | // Source. |
415 | src_view = view_t({mb, g, ic, id, ih, iw}, 6); |
416 | src_view.set_vdim(mb, prb_.mb); |
417 | src_view.set_vdim(g, prb_.g); |
418 | src_view.set_vdim(ic, prb_.ic); |
419 | src_view.set_vdim(id, prb_.id); |
420 | src_view.set_vdim(ih, prb_.ih); |
421 | src_view.set_vdim(iw, prb_.iw); |
422 | src_view.set_tdim(0, mb); |
423 | src_view.set_tdim(1, g); |
424 | src_view.set_tdim(2, ic); |
425 | src_view.set_tdim(3, id); |
426 | src_view.set_tdim(4, ih); |
427 | src_view.set_tdim(5, iw_mapping(iw)); |
428 | src_view.set_tlayout(src_layout); |
429 | src_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
430 | |
431 | // Initialize GEMM schedule. |
432 | gemm_schedule.set_a_view(dst_view); |
433 | gemm_schedule.set_b_view(wei_view); |
434 | gemm_schedule.set_c_view(src_view); |
435 | gemm_schedule.set_b_vars({g}); |
436 | gemm_schedule.set_m_vars({mb, id, ih, iw}); |
437 | gemm_schedule.set_n_vars({ic}); |
438 | gemm_schedule.set_k_vars({oc, kd, kh, kw}); |
439 | |
440 | gemm_schedule.for_each_var([&](const expr_t &var) { |
441 | int bound = cfg_.padded_dim(var.as<var_t>().name); |
442 | gemm_schedule.set_var_bound(var, bound); |
443 | }); |
444 | |
445 | auto g_tile = create_tile(gemm_schedule, cfg_, g); |
446 | auto ic_tile = create_tile(gemm_schedule, cfg_, ic); |
447 | auto mb_tile = create_tile(gemm_schedule, cfg_, mb); |
448 | auto iw_tile = create_tile(gemm_schedule, cfg_, iw); |
449 | auto oc_tile = create_tile(gemm_schedule, cfg_, oc); |
450 | |
451 | auto g_isp_grid_idx = gemm_schedule.fuse( |
452 | {g_tile.grid_idx(), id, ih, iw_tile.grid_idx()}); |
453 | auto mb_iw_tg_idx = gemm_schedule.fuse(mb_tile.tg_idx(), iw_tile.tg_idx()); |
454 | |
455 | gemm_schedule.bind(ic_tile.grid_idx(), cfg_.kernel_grid().idx(0)); |
456 | gemm_schedule.bind(g_isp_grid_idx, cfg_.kernel_grid().idx(1)); |
457 | gemm_schedule.bind(mb_tile.grid_idx(), cfg_.kernel_grid().idx(2)); |
458 | gemm_schedule.bind(ic_tile.tg_idx(), cfg_.thread_group_grid().idx(0)); |
459 | gemm_schedule.bind(mb_iw_tg_idx, cfg_.thread_group_grid().idx(1)); |
460 | gemm_schedule.bind(oc_tile.tg_idx(), cfg_.thread_group_grid().idx(2)); |
461 | |
462 | gemm_schedule.tensorize(g_tile.iter_idx()); |
463 | gemm_schedule.tensorize(ic_tile.iter_idx()); |
464 | gemm_schedule.tensorize(mb_tile.iter_idx()); |
465 | gemm_schedule.tensorize(iw_tile.iter_idx()); |
466 | gemm_schedule.tensorize(oc_tile.iter_idx()); |
467 | |
468 | if (cfg_.bwd_d_optimize_strided()) { |
469 | gemm_schedule.set_skip_condition(kd, od % prb_.sd != 0); |
470 | gemm_schedule.set_skip_condition(kh, oh % prb_.sh != 0); |
471 | if (cfg_.bwd_d_optimize_strided_iw()) |
472 | gemm_schedule.set_skip_condition(kw, ow % prb_.sw != 0); |
473 | // Put kd/kh/kw outermost to allow pipelining in oc loop. |
474 | gemm_schedule.reorder({kd, kh, kw, oc_tile.loop_idx()}); |
475 | } else { |
476 | gemm_schedule.reorder({oc_tile.loop_idx(), kd, kh, kw}); |
477 | } |
478 | |
479 | src_buf = kernel_info_.find_arg("src" ); |
480 | wei_buf = kernel_info_.find_arg("wei" ); |
481 | dst_buf = kernel_info_.find_arg("dst" ); |
482 | } |
483 | |
484 | void conv_ir_builder_t::init_bwd_w(gemm_schedule_t &gemm_schedule, |
485 | view_t &src_view, view_t &dst_view, view_t &wei_view, view_t &bia_view, |
486 | expr_t &src_buf, expr_t &dst_buf, expr_t &wei_buf, expr_t &bia_buf, |
487 | expr_t &bia_reduction_condition) { |
488 | auto &src_layout = cfg_.src_layout().compute(); |
489 | auto &wei_layout = cfg_.wei_layout().compute(); |
490 | auto &dst_layout = cfg_.dst_layout().compute(); |
491 | auto &bia_layout = cfg_.bia_layout().compute(); |
492 | |
493 | // Initialize thread group views. |
494 | auto g = var_t::make(type_t::s32(), "g" ); |
495 | auto mb = var_t::make(type_t::s32(), "mb" ); |
496 | auto ic = var_t::make(type_t::s32(), "ic" ); |
497 | auto oc = var_t::make(type_t::s32(), "oc" ); |
498 | auto od = var_t::make(type_t::s32(), "od" ); |
499 | auto oh = var_t::make(type_t::s32(), "oh" ); |
500 | auto ow = var_t::make(type_t::s32(), "ow" ); |
501 | auto kd = var_t::make(type_t::s32(), "kd" ); |
502 | auto kh = var_t::make(type_t::s32(), "kh" ); |
503 | auto kw = var_t::make(type_t::s32(), "kw" ); |
504 | |
505 | // Initialize masks. |
506 | expr_t id_mask(true), ih_mask(true), iw_mask(true); |
507 | |
508 | bool check_ow = (prb_.ow < cfg_.padded_dim("ow" )); |
509 | bool check_oh = (prb_.oh < cfg_.padded_dim("oh" )); |
510 | bool check_od = (prb_.od < cfg_.padded_dim("od" )); |
511 | bool check_kw = (prb_.kw < cfg_.padded_dim("kw" )); |
512 | bool check_iw = check_kw |
513 | || need_src_or_dst_check(/*is_fwd=*/true, prb_.ow, prb_.iw, prb_.kw, |
514 | prb_.pw, prb_.sw, prb_.dw); |
515 | bool check_ih = need_src_or_dst_check(/*is_fwd=*/true, prb_.oh, prb_.ih, |
516 | prb_.kh, prb_.ph, prb_.sh, prb_.dh); |
517 | bool check_id = need_src_or_dst_check(/*is_fwd=*/true, prb_.od, prb_.id, |
518 | prb_.kd, prb_.pd, prb_.sd, prb_.dd); |
519 | bool check_iw_min = check_iw; |
520 | bool check_ih_min = check_ih; |
521 | bool check_id_min = check_id; |
522 | bool check_iw_max = (check_iw || check_ow); |
523 | bool check_ih_max = (check_ih || check_oh); |
524 | bool check_id_max = (check_id || check_od); |
525 | |
526 | auto &x = view_t::placeholder_var(); |
527 | if (check_id_min) id_mask &= (x >= 0); |
528 | if (check_ih_min) ih_mask &= (x >= 0); |
529 | if (check_iw_min) iw_mask &= (x >= 0); |
530 | if (check_id_max) id_mask &= (x < prb_.id); |
531 | if (check_ih_max) ih_mask &= (x < prb_.ih); |
532 | if (check_iw_max) iw_mask &= (x < prb_.iw); |
533 | |
534 | // Source. |
535 | src_view = view_t({mb, g, ic, od, oh, ow, kw}, 6); |
536 | src_view.set_vdim(mb, prb_.mb); |
537 | src_view.set_vdim(g, prb_.g); |
538 | src_view.set_vdim(ic, prb_.ic); |
539 | src_view.set_vdim(od, prb_.od); |
540 | src_view.set_vdim(oh, prb_.oh); |
541 | src_view.set_vdim(ow, prb_.ow); |
542 | src_view.set_vdim(kw, prb_.kw); |
543 | src_view.set_tdim(0, mb); |
544 | src_view.set_tdim(1, g); |
545 | src_view.set_tdim(2, ic); |
546 | src_view.set_tdim(3, od * prb_.sd - prb_.pd + kd * (1 + prb_.dd), id_mask); |
547 | src_view.set_tdim(4, oh * prb_.sh - prb_.ph + kh * (1 + prb_.dh), ih_mask); |
548 | src_view.set_tdim(5, ow * prb_.sw - prb_.pw + kw * (1 + prb_.dw), iw_mask); |
549 | src_view.set_tlayout(src_layout); |
550 | src_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
551 | |
552 | // Weights. |
553 | wei_view = view_t({g, oc, ic, kd, kh, kw}, 6); |
554 | wei_view.set_vdim(g, prb_.g); |
555 | wei_view.set_vdim(oc, prb_.oc); |
556 | wei_view.set_vdim(ic, prb_.ic); |
557 | wei_view.set_vdim(kd, prb_.kd); |
558 | wei_view.set_vdim(kh, prb_.kh); |
559 | wei_view.set_vdim(kw, prb_.kw); |
560 | wei_view.set_tdim(0, g); |
561 | wei_view.set_tdim(1, oc); |
562 | wei_view.set_tdim(2, ic); |
563 | wei_view.set_tdim(3, kd); |
564 | wei_view.set_tdim(4, kh); |
565 | wei_view.set_tdim(5, kw); |
566 | wei_view.set_tlayout(wei_layout); |
567 | wei_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
568 | |
569 | // Destination. |
570 | dst_view = view_t({mb, g, oc, od, oh, ow}, 6); |
571 | dst_view.set_vdim(mb, prb_.mb); |
572 | dst_view.set_vdim(g, prb_.g); |
573 | dst_view.set_vdim(oc, prb_.oc); |
574 | dst_view.set_vdim(od, prb_.od); |
575 | dst_view.set_vdim(oh, prb_.oh); |
576 | dst_view.set_vdim(ow, prb_.ow); |
577 | dst_view.set_tdim(0, mb); |
578 | dst_view.set_tdim(1, g); |
579 | dst_view.set_tdim(2, oc); |
580 | dst_view.set_tdim(3, od); |
581 | dst_view.set_tdim(4, oh); |
582 | dst_view.set_tdim(5, ow); |
583 | dst_view.set_tlayout(dst_layout); |
584 | dst_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
585 | |
586 | // Bias. |
587 | if (prb_.with_bias) { |
588 | bia_view = view_t({g, oc}, 2); |
589 | bia_view.set_vdim(g, prb_.g); |
590 | bia_view.set_vdim(oc, prb_.oc); |
591 | bia_view.set_tdim(0, g); |
592 | bia_view.set_tdim(1, oc); |
593 | bia_view.set_tlayout(bia_layout); |
594 | bia_view.set_tmasks(cfg_.padded_dims().get(), cfg_.iter_dims().get()); |
595 | } |
596 | |
597 | // Initialize GEMM schedule. |
598 | gemm_schedule.set_a_view(src_view); |
599 | gemm_schedule.set_b_view(dst_view); |
600 | gemm_schedule.set_c_view(wei_view); |
601 | gemm_schedule.set_b_vars({g}); |
602 | gemm_schedule.set_m_vars({ic, kw}); |
603 | gemm_schedule.set_n_vars({oc}); |
604 | gemm_schedule.set_k_vars({mb, od, oh, ow}); |
605 | |
606 | gemm_schedule.for_each_var([&](const expr_t &var) { |
607 | int bound = cfg_.padded_dim(var.as<var_t>().name); |
608 | gemm_schedule.set_var_bound(var, bound); |
609 | }); |
610 | |
611 | auto g_tile = create_tile(gemm_schedule, cfg_, g); |
612 | auto mb_tile = create_tile(gemm_schedule, cfg_, mb); |
613 | auto ic_tile = create_tile(gemm_schedule, cfg_, ic); |
614 | auto oc_tile = create_tile(gemm_schedule, cfg_, oc); |
615 | auto od_tile = create_tile(gemm_schedule, cfg_, od); |
616 | auto oh_tile = create_tile(gemm_schedule, cfg_, oh); |
617 | auto ow_tile = create_tile(gemm_schedule, cfg_, ow); |
618 | auto kw_tile = create_tile(gemm_schedule, cfg_, kw); |
619 | |
620 | auto osp_ksp_ic_grid_idx = gemm_schedule.fuse( |
621 | {od_tile.grid_idx(), oh_tile.grid_idx(), ow_tile.grid_idx(), kd, kh, |
622 | kw_tile.grid_idx(), ic_tile.grid_idx()}); |
623 | |
624 | auto g_mb_grid_idx |
625 | = gemm_schedule.fuse({g_tile.grid_idx(), mb_tile.grid_idx()}); |
626 | |
627 | gemm_schedule.bind(oc_tile.grid_idx(), cfg_.kernel_grid().idx(0)); |
628 | gemm_schedule.bind(osp_ksp_ic_grid_idx, cfg_.kernel_grid().idx(1)); |
629 | gemm_schedule.bind(g_mb_grid_idx, cfg_.kernel_grid().idx(2)); |
630 | |
631 | gemm_schedule.bind(oc_tile.tg_idx(), cfg_.thread_group_grid().idx(0)); |
632 | gemm_schedule.bind(ic_tile.tg_idx(), cfg_.thread_group_grid().idx(1)); |
633 | |
634 | gemm_schedule.reorder({od_tile.loop_idx(), oh_tile.loop_idx(), |
635 | ow_tile.loop_idx(), mb_tile.loop_idx()}); |
636 | |
637 | gemm_schedule.unroll(mb_tile.loop_idx(), cfg_.unroll("mb" )); |
638 | gemm_schedule.unroll(ow_tile.loop_idx(), cfg_.unroll("ow" )); |
639 | |
640 | gemm_schedule.tensorize(g_tile.iter_idx()); |
641 | gemm_schedule.tensorize(oc_tile.iter_idx()); |
642 | gemm_schedule.tensorize(ic_tile.iter_idx()); |
643 | gemm_schedule.tensorize(mb_tile.iter_idx()); |
644 | gemm_schedule.tensorize(ow_tile.iter_idx()); |
645 | gemm_schedule.tensorize(kw_tile.iter_idx()); |
646 | |
647 | src_buf = kernel_info_.find_arg("src" ); |
648 | wei_buf = kernel_info_.find_arg("wei" ); |
649 | dst_buf = kernel_info_.find_arg("dst" ); |
650 | |
651 | if (prb_.with_bias) { |
652 | bia_buf = kernel_info_.find_arg("bia" ); |
653 | bia_reduction_condition = expr_t(true); |
654 | if (prb_.kd > 1) bia_reduction_condition &= (kd == 0); |
655 | if (prb_.kh > 1) bia_reduction_condition &= (kh == 0); |
656 | if (prb_.kw > 1) bia_reduction_condition &= (kw_tile.grid_idx() == 0); |
657 | if (cfg_.grid_dim("ic" ) > 1) |
658 | bia_reduction_condition &= (ic_tile.grid_idx() == 0); |
659 | if (!cfg_.slm().b() && cfg_.thread_group_grid().dim(1) > 1) { |
660 | bia_reduction_condition &= (cfg_.thread_group_grid().idx(1) == 0); |
661 | } |
662 | } |
663 | } |
664 | |
665 | } // namespace jit |
666 | } // namespace gpu |
667 | } // namespace impl |
668 | } // namespace dnnl |
669 | |