1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/peephole.hpp" |
18 | |
19 | #include "gpu/jit/pass/simplify.hpp" |
20 | #include "gpu/jit/utils/trace.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | class peephole_optimizer_t : public ir_mutator_t { |
28 | public: |
29 | object_t _mutate(const binary_op_t &obj) override { |
30 | auto old_obj = ir_mutator_t::_mutate(obj); |
31 | auto new_obj |
32 | = simplify_rewrite_with_ternary(old_obj, /*recursive=*/false); |
33 | auto *ternary = new_obj.as_ptr<ternary_op_t>(); |
34 | if (!ternary) return std::move(new_obj); |
35 | |
36 | switch (ternary->op_kind) { |
37 | case op_kind_t::_add3: { |
38 | bool ok = true; |
39 | // Allowed form: add3(dword/word, dword/word, dword/word). |
40 | ok &= add3_type_ok(ternary->a); |
41 | ok &= add3_type_ok(ternary->b); |
42 | ok &= add3_type_ok(ternary->c); |
43 | ok &= !is_const(ternary->a); |
44 | ok &= !is_const(ternary->b); |
45 | if (!ok) new_obj = old_obj; |
46 | break; |
47 | } |
48 | case op_kind_t::_mad: { |
49 | bool ok = false; |
50 | if (try_int_mad(ternary)) |
51 | ok = true; |
52 | else if (try_float_mad(ternary)) |
53 | ok = true; |
54 | if (!ok) new_obj = old_obj; |
55 | break; |
56 | } |
57 | default: ir_error_not_expected(); |
58 | } |
59 | return std::move(new_obj); |
60 | } |
61 | |
62 | private: |
63 | static type_t real_type(const expr_t &e) { |
64 | auto *imm = e.as_ptr<int_imm_t>(); |
65 | if (!imm) return e.type(); |
66 | if (int_imm_t::try_shrink_type<int16_t>(imm->value)) |
67 | return type_t::s16(); |
68 | if (int_imm_t::try_shrink_type<int32_t>(imm->value)) |
69 | return type_t::s32(); |
70 | return type_t::s64(); |
71 | } |
72 | |
73 | static bool try_int_mad(const ternary_op_t *ternary) { |
74 | auto a_type = real_type(ternary->a); |
75 | auto b_type = real_type(ternary->b); |
76 | auto c_type = real_type(ternary->c); |
77 | bool ok = true; |
78 | // Allowed form: mad(dword, dword, word). |
79 | ok &= utils::one_of(a_type, type_t::s32(), type_t::u32()); |
80 | ok &= utils::one_of(b_type, type_t::s32(), type_t::u32()); |
81 | ok &= utils::one_of(c_type, type_t::s16(), type_t::u16()); |
82 | return ok; |
83 | } |
84 | |
85 | static bool try_float_mad(const ternary_op_t *ternary) { |
86 | auto op_ok = [](const expr_t &e) { |
87 | if (is_const(e) || is_const_broadcast(e)) return false; |
88 | if (!e.type().is_f32()) return false; |
89 | return true; |
90 | }; |
91 | if (!op_ok(ternary->a)) return false; |
92 | if (!op_ok(ternary->b)) return false; |
93 | if (!op_ok(ternary->c)) return false; |
94 | return true; |
95 | } |
96 | |
97 | static bool add3_type_ok(const expr_t &e) { |
98 | auto t = real_type(e); |
99 | if (!t.is_scalar()) return false; |
100 | switch (t.kind()) { |
101 | case type_kind_t::s32: |
102 | case type_kind_t::u32: return !is_const(e); |
103 | case type_kind_t::s16: |
104 | case type_kind_t::u16: return true; |
105 | default: return false; |
106 | } |
107 | } |
108 | }; |
109 | |
110 | stmt_t optimize_peephole(const stmt_t &s, ir_context_t &ir_ctx) { |
111 | trace_start(); |
112 | auto ret = peephole_optimizer_t().mutate(s); |
113 | trace_pass("optimize_peephole" , ret, ir_ctx); |
114 | return ret; |
115 | } |
116 | |
117 | } // namespace jit |
118 | } // namespace gpu |
119 | } // namespace impl |
120 | } // namespace dnnl |
121 | |