1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/peephole.hpp"
18
19#include "gpu/jit/pass/simplify.hpp"
20#include "gpu/jit/utils/trace.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27class peephole_optimizer_t : public ir_mutator_t {
28public:
29 object_t _mutate(const binary_op_t &obj) override {
30 auto old_obj = ir_mutator_t::_mutate(obj);
31 auto new_obj
32 = simplify_rewrite_with_ternary(old_obj, /*recursive=*/false);
33 auto *ternary = new_obj.as_ptr<ternary_op_t>();
34 if (!ternary) return std::move(new_obj);
35
36 switch (ternary->op_kind) {
37 case op_kind_t::_add3: {
38 bool ok = true;
39 // Allowed form: add3(dword/word, dword/word, dword/word).
40 ok &= add3_type_ok(ternary->a);
41 ok &= add3_type_ok(ternary->b);
42 ok &= add3_type_ok(ternary->c);
43 ok &= !is_const(ternary->a);
44 ok &= !is_const(ternary->b);
45 if (!ok) new_obj = old_obj;
46 break;
47 }
48 case op_kind_t::_mad: {
49 bool ok = false;
50 if (try_int_mad(ternary))
51 ok = true;
52 else if (try_float_mad(ternary))
53 ok = true;
54 if (!ok) new_obj = old_obj;
55 break;
56 }
57 default: ir_error_not_expected();
58 }
59 return std::move(new_obj);
60 }
61
62private:
63 static type_t real_type(const expr_t &e) {
64 auto *imm = e.as_ptr<int_imm_t>();
65 if (!imm) return e.type();
66 if (int_imm_t::try_shrink_type<int16_t>(imm->value))
67 return type_t::s16();
68 if (int_imm_t::try_shrink_type<int32_t>(imm->value))
69 return type_t::s32();
70 return type_t::s64();
71 }
72
73 static bool try_int_mad(const ternary_op_t *ternary) {
74 auto a_type = real_type(ternary->a);
75 auto b_type = real_type(ternary->b);
76 auto c_type = real_type(ternary->c);
77 bool ok = true;
78 // Allowed form: mad(dword, dword, word).
79 ok &= utils::one_of(a_type, type_t::s32(), type_t::u32());
80 ok &= utils::one_of(b_type, type_t::s32(), type_t::u32());
81 ok &= utils::one_of(c_type, type_t::s16(), type_t::u16());
82 return ok;
83 }
84
85 static bool try_float_mad(const ternary_op_t *ternary) {
86 auto op_ok = [](const expr_t &e) {
87 if (is_const(e) || is_const_broadcast(e)) return false;
88 if (!e.type().is_f32()) return false;
89 return true;
90 };
91 if (!op_ok(ternary->a)) return false;
92 if (!op_ok(ternary->b)) return false;
93 if (!op_ok(ternary->c)) return false;
94 return true;
95 }
96
97 static bool add3_type_ok(const expr_t &e) {
98 auto t = real_type(e);
99 if (!t.is_scalar()) return false;
100 switch (t.kind()) {
101 case type_kind_t::s32:
102 case type_kind_t::u32: return !is_const(e);
103 case type_kind_t::s16:
104 case type_kind_t::u16: return true;
105 default: return false;
106 }
107 }
108};
109
110stmt_t optimize_peephole(const stmt_t &s, ir_context_t &ir_ctx) {
111 trace_start();
112 auto ret = peephole_optimizer_t().mutate(s);
113 trace_pass("optimize_peephole", ret, ir_ctx);
114 return ret;
115}
116
117} // namespace jit
118} // namespace gpu
119} // namespace impl
120} // namespace dnnl
121