1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/ir/core.hpp" |
18 | |
19 | #include <algorithm> |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace jit { |
25 | |
26 | expr_t const_fold_non_recursive(const expr_t &expr); |
27 | object_t const_fold(const object_t &obj); |
28 | |
29 | std::string to_string(type_kind_t kind) { |
30 | #define CASE(_kind) \ |
31 | case type_kind_t::_kind: return #_kind |
32 | switch (kind) { |
33 | CASE(undef); |
34 | CASE(u8); |
35 | CASE(s8); |
36 | CASE(u16); |
37 | CASE(s16); |
38 | CASE(u32); |
39 | CASE(s32); |
40 | CASE(u64); |
41 | CASE(s64); |
42 | CASE(bf16); |
43 | CASE(f16); |
44 | CASE(tf32); |
45 | CASE(f32); |
46 | CASE(f64); |
47 | CASE(byte); |
48 | CASE(dword); |
49 | CASE(qword); |
50 | CASE(oword); |
51 | CASE(hword); |
52 | case type_kind_t::_bool: return "bool" ; |
53 | default: ir_error_not_expected(); |
54 | } |
55 | #undef CASE |
56 | return {}; |
57 | } |
58 | |
59 | int type_t::size() const { |
60 | if (is_ptr()) return sizeof(uint64_t); |
61 | |
62 | if (is_bool()) return utils::div_up(elems(), 8); |
63 | |
64 | if (elems() != 1) return elems() * scalar().size(); |
65 | |
66 | switch (kind()) { |
67 | case type_kind_t::u8: |
68 | case type_kind_t::s8: |
69 | case type_kind_t::byte: return 1; |
70 | case type_kind_t::u16: |
71 | case type_kind_t::s16: |
72 | case type_kind_t::bf16: |
73 | case type_kind_t::f16: return 2; |
74 | case type_kind_t::u32: |
75 | case type_kind_t::s32: |
76 | case type_kind_t::tf32: |
77 | case type_kind_t::f32: |
78 | case type_kind_t::dword: return 4; |
79 | case type_kind_t::f64: |
80 | case type_kind_t::u64: |
81 | case type_kind_t::s64: |
82 | case type_kind_t::qword: return 8; |
83 | case type_kind_t::oword: return 16; |
84 | case type_kind_t::hword: return 32; |
85 | default: ir_error_not_expected(); |
86 | } |
87 | return 0; |
88 | } |
89 | |
90 | data_type_t to_dnnl(const type_t &type) { |
91 | ir_assert(type.elems() == 1) << type; |
92 | ir_assert(!type.is_ptr() == 1) << type; |
93 | switch (type.kind()) { |
94 | case type_kind_t::bf16: return data_type::bf16; |
95 | case type_kind_t::f16: return data_type::f16; |
96 | case type_kind_t::tf32: return data_type::tf32; |
97 | case type_kind_t::f32: return data_type::f32; |
98 | case type_kind_t::f64: return data_type::f64; |
99 | case type_kind_t::s32: return data_type::s32; |
100 | case type_kind_t::s8: return data_type::s8; |
101 | case type_kind_t::u8: return data_type::u8; |
102 | default: ir_error_not_expected(); |
103 | } |
104 | return data_type::undef; |
105 | } |
106 | |
107 | std::string to_string(op_kind_t kind) { |
108 | switch (kind) { |
109 | case op_kind_t::_minus: return "-" ; |
110 | |
111 | case op_kind_t::_add: return "+" ; |
112 | case op_kind_t::_sub: return "-" ; |
113 | case op_kind_t::_mul: return "*" ; |
114 | case op_kind_t::_div: return "/" ; |
115 | case op_kind_t::_mod: return "%" ; |
116 | case op_kind_t::_shl: return "<<" ; |
117 | case op_kind_t::_shr: return ">>" ; |
118 | case op_kind_t::_min: return "min" ; |
119 | case op_kind_t::_max: return "max" ; |
120 | |
121 | case op_kind_t::_lt: return "<" ; |
122 | case op_kind_t::_le: return "<=" ; |
123 | case op_kind_t::_gt: return ">" ; |
124 | case op_kind_t::_ge: return ">=" ; |
125 | case op_kind_t::_eq: return "==" ; |
126 | case op_kind_t::_ne: return "!=" ; |
127 | |
128 | case op_kind_t::_and: return "&&" ; |
129 | |
130 | case op_kind_t::_add3: return "add3" ; |
131 | case op_kind_t::_mad: return "mad" ; |
132 | case op_kind_t::_prelu: return "prelu" ; |
133 | |
134 | case op_kind_t::_dp4a: return "dp4a" ; |
135 | |
136 | default: ir_error_not_expected() << "Unknown op_kind_t value." ; |
137 | } |
138 | return "" ; |
139 | } |
140 | |
141 | bool is_cmp_op(op_kind_t op_kind) { |
142 | switch (op_kind) { |
143 | case op_kind_t::_ge: |
144 | case op_kind_t::_gt: |
145 | case op_kind_t::_le: |
146 | case op_kind_t::_lt: |
147 | case op_kind_t::_eq: |
148 | case op_kind_t::_ne: return true; |
149 | default: return false; |
150 | } |
151 | } |
152 | |
153 | op_kind_t negate_cmp_op(op_kind_t op_kind) { |
154 | switch (op_kind) { |
155 | case op_kind_t::_ge: return op_kind_t::_le; |
156 | case op_kind_t::_gt: return op_kind_t::_lt; |
157 | case op_kind_t::_le: return op_kind_t::_ge; |
158 | case op_kind_t::_lt: return op_kind_t::_gt; |
159 | case op_kind_t::_eq: return op_kind_t::_eq; |
160 | case op_kind_t::_ne: return op_kind_t::_ne; |
161 | default: ir_error_not_expected(); |
162 | } |
163 | return op_kind_t::undef; |
164 | } |
165 | |
166 | type_t unary_op_type(op_kind_t op_kind, const expr_t &a) { |
167 | switch (op_kind) { |
168 | case op_kind_t::_minus: { |
169 | auto &t = a.type(); |
170 | if (!t.is_int()) return t; |
171 | if (t.size() < int(sizeof(int32_t))) return type_t::s32(t.elems()); |
172 | return t; |
173 | } |
174 | default: |
175 | ir_error_not_expected() << "Unknown op_kind_t value: " << op_kind; |
176 | } |
177 | return type_t::undef(); |
178 | } |
179 | |
180 | type_t common_int_type(const type_t &_a, const type_t &_b) { |
181 | ir_assert(_a.is_int() && _b.is_int()) << "Unexpected types." ; |
182 | |
183 | int elems = _a.elems(); |
184 | |
185 | // Promote to s32 first. |
186 | type_t a = _a.size() < int(sizeof(int32_t)) ? type_t::s32() : _a; |
187 | type_t b = _b.size() < int(sizeof(int32_t)) ? type_t::s32() : _b; |
188 | a = a.scalar(); |
189 | b = b.scalar(); |
190 | |
191 | // Integer promotion, follow C++ rules. |
192 | int common_bits = 8 * std::max(a.size(), b.size()); |
193 | if (a.is_signed() == b.is_signed()) { |
194 | if (a.is_signed()) return type_t::s(common_bits, elems); |
195 | return type_t::u(common_bits, elems); |
196 | } |
197 | |
198 | if (a.size() >= b.size() && a.is_unsigned()) |
199 | return type_t::u(common_bits, elems); |
200 | if (b.size() >= a.size() && b.is_unsigned()) |
201 | return type_t::u(common_bits, elems); |
202 | if (a.size() > b.size() && a.is_signed()) |
203 | return type_t::s(common_bits, elems); |
204 | if (b.size() > a.size() && b.is_signed()) |
205 | return type_t::s(common_bits, elems); |
206 | |
207 | return type_t::u(common_bits, elems); |
208 | } |
209 | |
210 | type_t common_type(const type_t &a, const type_t &b) { |
211 | ir_assert(a.elems() == b.elems()) |
212 | << "Types must have the same number of components." ; |
213 | if (a.is_undef() || b.is_undef()) return type_t::undef(); |
214 | if (a.is_fp() && !b.is_fp()) return a; |
215 | if (!a.is_fp() && b.is_fp()) return b; |
216 | if (a.is_fp() && b.is_fp()) return (a.size() > b.size() ? a : b); |
217 | if (a.is_bool() && b.is_bool()) return a; |
218 | return common_int_type(a, b); |
219 | } |
220 | |
221 | type_t common_type(const expr_t &a, const expr_t &b) { |
222 | return common_type(a.type(), b.type()); |
223 | } |
224 | |
225 | type_t binary_op_type(op_kind_t op_kind, const type_t &a, const type_t &b, |
226 | const expr_t &a_expr = expr_t(), const expr_t &b_expr = expr_t()) { |
227 | if (a.is_undef() || b.is_undef()) return type_t::undef(); |
228 | ir_assert(a.elems() == b.elems()) |
229 | << "Types must have the same number of components." ; |
230 | if (is_cmp_op(op_kind)) return type_t::_bool(a.elems()); |
231 | if (utils::one_of(op_kind, op_kind_t::_shl, op_kind_t::_shr)) { |
232 | ir_assert(a.is_unsigned()) |
233 | << "a must be unsigned for shift left/right." ; |
234 | return type_t::u32(a.elems()); |
235 | } |
236 | if (op_kind == op_kind_t::_and) { |
237 | if (a == b) return a; |
238 | if (is_const(a_expr)) return b; |
239 | if (is_const(b_expr)) return a; |
240 | return (a.size() >= b.size()) ? a : b; |
241 | } |
242 | return common_type(a, b); |
243 | } |
244 | |
245 | type_t binary_op_type(op_kind_t op_kind, const expr_t &a, const expr_t &b) { |
246 | return binary_op_type(op_kind, a.type(), b.type(), a, b); |
247 | } |
248 | |
249 | type_t ternary_op_type( |
250 | op_kind_t op_kind, const expr_t &a, const expr_t &b, const expr_t &c) { |
251 | switch (op_kind) { |
252 | case op_kind_t::_add3: |
253 | return binary_op_type(op_kind_t::_add, a.type(), |
254 | binary_op_type(op_kind_t::_add, b, c)); |
255 | case op_kind_t::_mad: |
256 | return binary_op_type(op_kind_t::_add, a.type(), |
257 | binary_op_type(op_kind_t::_mul, b, c)); |
258 | default: ir_error_not_expected(); |
259 | } |
260 | return type_t::undef(); |
261 | } |
262 | |
263 | type_t nary_op_type(op_kind_t op_kind, const std::vector<expr_t> &args) { |
264 | ir_assert(!args.empty()); |
265 | if (args.size() == 1) return args[0].type(); |
266 | |
267 | auto type = args[0].type(); |
268 | for (size_t i = 1; i < args.size(); i++) |
269 | type = common_type(type, args[i].type()); |
270 | |
271 | return type; |
272 | } |
273 | |
274 | void ptr_t::normalize(expr_t &base, expr_t &off, op_kind_t op_kind) { |
275 | ir_assert(base.type().is_ptr()) << "base is not a pointer: " << base; |
276 | ir_assert(off.type().is_int()) << "off is not an integer: " << off; |
277 | ir_assert(utils::one_of(op_kind, op_kind_t::_add, op_kind_t::_sub)) |
278 | << "Can't apply this operation to pointer: " << to_string(op_kind); |
279 | |
280 | if (!base.is<ptr_t>()) { |
281 | if (op_kind == op_kind_t::_sub) off = const_fold(-off); |
282 | return; |
283 | } |
284 | |
285 | auto &base_off = base.as<ptr_t>().off; |
286 | base = base.as<ptr_t>().base; |
287 | off = const_fold_non_recursive(binary_op_t::make(op_kind, base_off, off)); |
288 | } |
289 | |
290 | expr_t shift_ptr(op_kind_t op_kind, const expr_t &a, const expr_t &b) { |
291 | expr_t base = a; |
292 | expr_t off = b; |
293 | ptr_t::normalize(base, off, op_kind); |
294 | return ptr_t::make(base, off); |
295 | } |
296 | |
297 | void normalize_ptr(const type_t &type, expr_t &base_expr, expr_t &off) { |
298 | if (base_expr.is<ptr_t>()) { |
299 | auto &base = base_expr.as<ptr_t>().base; |
300 | auto &base_off = base_expr.as<ptr_t>().off; |
301 | |
302 | base_expr = base; |
303 | off = const_fold_non_recursive(base_off + off); |
304 | } |
305 | ir_assert(to_cpp<int64_t>(off) % type.scalar().size() == 0) |
306 | << "Incompatible offset: " << off; |
307 | } |
308 | |
309 | expr_t expr_t::operator[](const expr_t &off) const { |
310 | if (is<shuffle_t>()) { |
311 | ir_assert(is_const(off)) << "Offset is not constant." ; |
312 | auto &shuffle = as<shuffle_t>(); |
313 | int idx = shuffle.idx[to_cpp<int>(off)]; |
314 | return shuffle.vec[idx]; |
315 | } |
316 | return shift_ptr(op_kind_t::_add, *this, off); |
317 | } |
318 | |
319 | expr_t::expr_t(bool value) : object_t(new bool_imm_t(value)) {} |
320 | expr_t::expr_t(float value) : object_t(new float_imm_t(value)) {} |
321 | expr_t::expr_t(double value) |
322 | : object_t(new float_imm_t(value, type_t::f64())) {} |
323 | expr_t::expr_t(int16_t value) : object_t(new int_imm_t(value)) {} |
324 | expr_t::expr_t(int32_t value) : object_t(new int_imm_t(value)) {} |
325 | expr_t::expr_t(int64_t value) : object_t(new int_imm_t(value)) {} |
326 | expr_t::expr_t(uint16_t value) : object_t(new int_imm_t(value)) {} |
327 | expr_t::expr_t(uint32_t value) : object_t(new int_imm_t(value)) {} |
328 | expr_t::expr_t(uint64_t value) : object_t(new int_imm_t(value)) {} |
329 | |
330 | expr_t operator-(const expr_t &a) { |
331 | return const_fold_non_recursive(unary_op_t::make(op_kind_t::_minus, a)); |
332 | } |
333 | |
334 | #define DEFINE_BINARY_OPERATOR(op, op_kind) \ |
335 | expr_t operator op(const expr_t &a, const expr_t &b) { \ |
336 | if (a.type().is_ptr()) return shift_ptr(op_kind, a, b); \ |
337 | return const_fold_non_recursive(binary_op_t::make(op_kind, a, b)); \ |
338 | } |
339 | |
340 | DEFINE_BINARY_OPERATOR(+, op_kind_t::_add) |
341 | DEFINE_BINARY_OPERATOR(-, op_kind_t::_sub) |
342 | DEFINE_BINARY_OPERATOR(*, op_kind_t::_mul) |
343 | DEFINE_BINARY_OPERATOR(/, op_kind_t::_div) |
344 | DEFINE_BINARY_OPERATOR(%, op_kind_t::_mod) |
345 | DEFINE_BINARY_OPERATOR(<<, op_kind_t::_shl) |
346 | DEFINE_BINARY_OPERATOR(>>, op_kind_t::_shr) |
347 | |
348 | DEFINE_BINARY_OPERATOR(==, op_kind_t::_eq) |
349 | DEFINE_BINARY_OPERATOR(!=, op_kind_t::_ne) |
350 | DEFINE_BINARY_OPERATOR(>, op_kind_t::_gt) |
351 | DEFINE_BINARY_OPERATOR(>=, op_kind_t::_ge) |
352 | DEFINE_BINARY_OPERATOR(<, op_kind_t::_lt) |
353 | DEFINE_BINARY_OPERATOR(<=, op_kind_t::_le) |
354 | |
355 | DEFINE_BINARY_OPERATOR(&, op_kind_t::_and) |
356 | |
357 | #undef DEFINE_BINARY_OPERATOR |
358 | |
359 | #define DEFINE_BINARY_ASSIGN_OPERATOR(op) \ |
360 | expr_t &expr_t::operator op##=(const expr_t &rhs) { \ |
361 | auto tmp = (*this)op rhs; \ |
362 | *this = tmp; \ |
363 | return *this; \ |
364 | } |
365 | |
366 | DEFINE_BINARY_ASSIGN_OPERATOR(+) |
367 | DEFINE_BINARY_ASSIGN_OPERATOR(-) |
368 | DEFINE_BINARY_ASSIGN_OPERATOR(*) |
369 | DEFINE_BINARY_ASSIGN_OPERATOR(/) |
370 | DEFINE_BINARY_ASSIGN_OPERATOR(%) |
371 | DEFINE_BINARY_ASSIGN_OPERATOR(&) |
372 | |
373 | #undef DEFINE_BINARY_ASSIGN_OPERATOR |
374 | |
375 | object_t object_impl_t::_mutate(ir_mutator_t &mutator) const { |
376 | return *this; |
377 | } |
378 | void object_impl_t::_visit(ir_visitor_t &visitor) const {} |
379 | |
380 | #define DECL_TRAVERSE_LEAF(name) \ |
381 | object_t ir_mutator_t::_mutate(const name &obj) { return obj; } \ |
382 | void ir_visitor_t::_visit(const name &obj) {} |
383 | |
384 | DECL_TRAVERSE_LEAF(bool_imm_t) |
385 | DECL_TRAVERSE_LEAF(float_imm_t) |
386 | DECL_TRAVERSE_LEAF(func_impl_t) |
387 | DECL_TRAVERSE_LEAF(int_imm_t) |
388 | DECL_TRAVERSE_LEAF(var_t) |
389 | |
390 | #undef DECL_TRAVERSE_LEAF |
391 | |
392 | object_t ir_mutator_t::_mutate(const alloc_t &obj) { |
393 | auto buf = mutate(obj.buf); |
394 | auto body = mutate(obj.body); |
395 | |
396 | if (buf.is_same(obj.buf) && body.is_same(obj.body)) return obj; |
397 | |
398 | return alloc_t::make(buf, obj.size, obj.kind, obj.attrs, body); |
399 | } |
400 | |
401 | void ir_visitor_t::_visit(const alloc_t &obj) { |
402 | visit(obj.buf); |
403 | visit(obj.body); |
404 | } |
405 | |
406 | object_t ir_mutator_t::_mutate(const binary_op_t &obj) { |
407 | auto a = mutate(obj.a); |
408 | auto b = mutate(obj.b); |
409 | |
410 | if (a.is_same(obj.a) && b.is_same(obj.b)) return obj; |
411 | |
412 | return binary_op_t::make(obj.op_kind, a, b); |
413 | } |
414 | |
415 | void ir_visitor_t::_visit(const binary_op_t &obj) { |
416 | visit(obj.a); |
417 | visit(obj.b); |
418 | } |
419 | |
420 | object_t ir_mutator_t::_mutate(const cast_t &obj) { |
421 | auto expr = mutate(obj.expr); |
422 | |
423 | if (expr.is_same(obj.expr)) return obj; |
424 | |
425 | return cast_t::make(obj.type, expr, obj.saturate); |
426 | } |
427 | |
428 | void ir_visitor_t::_visit(const cast_t &obj) { |
429 | visit(obj.expr); |
430 | } |
431 | |
432 | object_t ir_mutator_t::_mutate(const for_t &obj) { |
433 | auto var = mutate(obj.var); |
434 | auto init = mutate(obj.init); |
435 | auto bound = mutate(obj.bound); |
436 | auto body = mutate(obj.body); |
437 | |
438 | if (var.is_same(obj.var) && init.is_same(obj.init) |
439 | && bound.is_same(obj.bound) && body.is_same(obj.body)) |
440 | return obj; |
441 | |
442 | return for_t::make(var, init, bound, body, obj.unroll); |
443 | } |
444 | |
445 | void ir_visitor_t::_visit(const for_t &obj) { |
446 | visit(obj.var); |
447 | visit(obj.init); |
448 | visit(obj.bound); |
449 | visit(obj.body); |
450 | } |
451 | |
452 | object_t ir_mutator_t::_mutate(const func_call_t &obj) { |
453 | auto func = mutate(obj.func); |
454 | auto args = mutate(obj.args); |
455 | |
456 | if (func.is_same(obj.func) && ir_utils::is_same(args, obj.args)) return obj; |
457 | |
458 | return func_call_t::make(func, args, obj.attr); |
459 | } |
460 | |
461 | void ir_visitor_t::_visit(const func_call_t &obj) { |
462 | visit(obj.func); |
463 | visit(obj.args); |
464 | } |
465 | |
466 | object_t ir_mutator_t::_mutate(const if_t &obj) { |
467 | auto cond = mutate(obj.cond); |
468 | auto body = mutate(obj.body); |
469 | auto else_body = mutate(obj.else_body); |
470 | |
471 | if (cond.is_same(obj.cond) && body.is_same(obj.body) |
472 | && else_body.is_same(obj.else_body)) |
473 | return obj; |
474 | |
475 | return if_t::make(cond, body, else_body); |
476 | } |
477 | |
478 | void ir_visitor_t::_visit(const if_t &obj) { |
479 | visit(obj.cond); |
480 | visit(obj.body); |
481 | visit(obj.else_body); |
482 | } |
483 | |
484 | object_t ir_mutator_t::_mutate(const iif_t &obj) { |
485 | auto cond = mutate(obj.cond); |
486 | auto true_expr = mutate(obj.true_expr); |
487 | auto false_expr = mutate(obj.false_expr); |
488 | |
489 | if (cond.is_same(obj.cond) && true_expr.is_same(obj.true_expr) |
490 | && false_expr.is_same(obj.false_expr)) |
491 | return obj; |
492 | |
493 | return iif_t::make(cond, true_expr, false_expr); |
494 | } |
495 | |
496 | void ir_visitor_t::_visit(const iif_t &obj) { |
497 | visit(obj.cond); |
498 | visit(obj.true_expr); |
499 | visit(obj.false_expr); |
500 | } |
501 | |
502 | object_t ir_mutator_t::_mutate(const let_t &obj) { |
503 | auto var = mutate(obj.var); |
504 | auto value = mutate(obj.value); |
505 | auto body = mutate(obj.body); |
506 | |
507 | if (var.is_same(obj.var) && value.is_same(obj.value) |
508 | && body.is_same(obj.body)) |
509 | return obj; |
510 | |
511 | return let_t::make(var, value, body); |
512 | } |
513 | |
514 | void ir_visitor_t::_visit(const let_t &obj) { |
515 | visit(obj.var); |
516 | visit(obj.value); |
517 | visit(obj.body); |
518 | } |
519 | |
520 | object_t ir_mutator_t::_mutate(const load_t &obj) { |
521 | auto buf = mutate(obj.buf); |
522 | auto off = mutate(obj.off); |
523 | |
524 | if (buf.is_same(obj.buf) && off.is_same(obj.off)) return obj; |
525 | |
526 | return load_t::make(obj.type, buf, off, obj.stride); |
527 | } |
528 | |
529 | void ir_visitor_t::_visit(const load_t &obj) { |
530 | visit(obj.buf); |
531 | visit(obj.off); |
532 | } |
533 | |
534 | object_t ir_mutator_t::_mutate(const ptr_t &obj) { |
535 | auto base = mutate(obj.base); |
536 | auto off = mutate(obj.off); |
537 | |
538 | if (base.is_same(obj.base) && off.is_same(obj.off)) return obj; |
539 | |
540 | return ptr_t::make(base, off); |
541 | } |
542 | |
543 | void ir_visitor_t::_visit(const ptr_t &obj) { |
544 | visit(obj.base); |
545 | visit(obj.off); |
546 | } |
547 | |
548 | object_t ir_mutator_t::_mutate(const shuffle_t &obj) { |
549 | auto vec = mutate(obj.vec); |
550 | |
551 | if (ir_utils::is_same(vec, obj.vec)) return obj; |
552 | |
553 | return shuffle_t::make(vec, obj.idx); |
554 | } |
555 | |
556 | void ir_visitor_t::_visit(const shuffle_t &obj) { |
557 | visit(obj.vec); |
558 | } |
559 | |
560 | object_t ir_mutator_t::_mutate(const stmt_group_t &obj) { |
561 | auto body = mutate(obj.body); |
562 | |
563 | if (body.is_same(obj.body)) return obj; |
564 | |
565 | return stmt_group_t::make(obj.label, body); |
566 | } |
567 | |
568 | void ir_visitor_t::_visit(const stmt_group_t &obj) { |
569 | visit(obj.body); |
570 | } |
571 | |
572 | object_t ir_mutator_t::_mutate(const stmt_seq_t &obj) { |
573 | auto head = mutate(obj.head); |
574 | auto tail = mutate(obj.tail); |
575 | |
576 | if (head.is_same(obj.head) && tail.is_same(obj.tail)) return obj; |
577 | |
578 | return stmt_seq_t::make(head, tail); |
579 | } |
580 | |
581 | void ir_visitor_t::_visit(const stmt_seq_t &obj) { |
582 | visit(obj.head); |
583 | visit(obj.tail); |
584 | } |
585 | |
586 | object_t ir_mutator_t::_mutate(const store_t &obj) { |
587 | auto buf = mutate(obj.buf); |
588 | auto off = mutate(obj.off); |
589 | auto value = mutate(obj.value); |
590 | auto mask = mutate(obj.mask); |
591 | |
592 | if (buf.is_same(obj.buf) && off.is_same(obj.off) && value.is_same(obj.value) |
593 | && mask.is_same(obj.mask)) |
594 | return obj; |
595 | |
596 | return store_t::make(buf, off, value, obj.stride, mask, obj.fill_mask0); |
597 | } |
598 | |
599 | void ir_visitor_t::_visit(const store_t &obj) { |
600 | visit(obj.buf); |
601 | visit(obj.off); |
602 | visit(obj.value); |
603 | visit(obj.mask); |
604 | } |
605 | |
606 | object_t ir_mutator_t::_mutate(const ternary_op_t &obj) { |
607 | auto a = mutate(obj.a); |
608 | auto b = mutate(obj.b); |
609 | auto c = mutate(obj.c); |
610 | |
611 | if (a.is_same(obj.a) && b.is_same(obj.b) && c.is_same(obj.c)) return obj; |
612 | |
613 | return ternary_op_t::make(obj.op_kind, a, b, c); |
614 | } |
615 | |
616 | void ir_visitor_t::_visit(const ternary_op_t &obj) { |
617 | visit(obj.a); |
618 | visit(obj.b); |
619 | visit(obj.c); |
620 | } |
621 | |
622 | object_t ir_mutator_t::_mutate(const unary_op_t &obj) { |
623 | auto a = mutate(obj.a); |
624 | if (a.is_same(obj.a)) return obj; |
625 | return unary_op_t::make(obj.op_kind, a); |
626 | } |
627 | |
628 | void ir_visitor_t::_visit(const unary_op_t &obj) { |
629 | visit(obj.a); |
630 | } |
631 | |
632 | // Catch missing mutates that are not expected to dispatch to the base |
633 | // mutator |
634 | object_t ir_mutator_t::_mutate(const nary_op_t &obj) { |
635 | ir_error_not_expected() << "Can't handle type: nary_op_t" ; |
636 | return {}; |
637 | } |
638 | void ir_visitor_t::_visit(const nary_op_t &obj) { |
639 | ir_error_not_expected() << "Can't handle type: nary_op_t" ; |
640 | } |
641 | object_t ir_mutator_t::_mutate(const pexpr_t &obj) { |
642 | ir_error_not_expected() << "Can't handle type: pexpr_t" ; |
643 | return {}; |
644 | } |
645 | void ir_visitor_t::_visit(const pexpr_t &obj) { |
646 | ir_error_not_expected() << "Can't handle type: pexpr_t" ; |
647 | } |
648 | |
649 | } // namespace jit |
650 | } // namespace gpu |
651 | } // namespace impl |
652 | } // namespace dnnl |
653 | |