1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/overflow.hpp" |
18 | |
19 | #include "gpu/jit/pass/expr_scalarizer.hpp" |
20 | #include "gpu/jit/utils/trace.hpp" |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | class overflow_bound_finder_t : public bound_finder_base_t { |
28 | public: |
29 | bool has_var(const expr_t &e) const { |
30 | ir_assert(is_var(e)) << "Expected variable, found: " << e; |
31 | auto it = var_bounds_.find(e); |
32 | return it != var_bounds_.end(); |
33 | } |
34 | |
35 | std::pair<int64_t, int64_t> find_bounds(const expr_t &e) const { |
36 | int64_t lo = find_low_bound(e); |
37 | int64_t hi = find_high_bound(e); |
38 | return std::make_pair(lo, hi); |
39 | } |
40 | |
41 | int64_t get_var_bound(const expr_t &e, bool is_low) const override { |
42 | ir_assert(has_var(e)) << "Variable not found: " << e; |
43 | auto &lo_hi = var_bounds_.at(e); |
44 | return is_low ? lo_hi.first : lo_hi.second; |
45 | } |
46 | |
47 | void set_var_bounds( |
48 | const expr_t &e, const std::pair<int64_t, int64_t> &lo_hi) { |
49 | ir_assert(is_good_bound(lo_hi.first)) |
50 | << "Can't compute low bound for " << e; |
51 | ir_assert(is_good_bound(lo_hi.second)) |
52 | << "Can't compute high bound for " << e; |
53 | var_bounds_.emplace(e, lo_hi); |
54 | } |
55 | |
56 | protected: |
57 | int64_t find_bound_impl(const expr_t &e, bool is_low) const override { |
58 | auto *cast = e.as_ptr<cast_t>(); |
59 | if (cast) { |
60 | if (e.type().is_u64() && cast->expr.type().is_ptr()) { |
61 | return is_low ? 0 : std::numeric_limits<uint32_t>::max(); |
62 | } else if (e.type().is_u32() && cast->expr.type().is_ptr()) { |
63 | return is_low ? 0 : std::numeric_limits<uint16_t>::max(); |
64 | } |
65 | } |
66 | return bound_finder_base_t::find_bound_impl(e, is_low); |
67 | } |
68 | |
69 | private: |
70 | object_map_t<expr_t, std::pair<int64_t, int64_t>> var_bounds_; |
71 | }; |
72 | |
73 | class overflow_fixer_t : public ir_mutator_t { |
74 | public: |
75 | overflow_fixer_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) { |
76 | for (auto &kv : ir_ctx.cset().relations()) { |
77 | int64_t lo = bound_finder_base_t::unlimited_bound(true); |
78 | int64_t hi = bound_finder_base_t::unlimited_bound(false); |
79 | for (auto &rel : kv.second) { |
80 | bool is_ge = (rel.op_kind() == op_kind_t::_ge); |
81 | bool is_le = (rel.op_kind() == op_kind_t::_le); |
82 | ir_assert(is_ge || is_le); |
83 | if (rel.op_kind() == op_kind_t::_ge) { |
84 | lo = std::max(to_cpp<int64_t>(rel.rhs()), lo); |
85 | } else if (rel.op_kind() == op_kind_t::_le) { |
86 | hi = std::min(to_cpp<int64_t>(rel.rhs()), hi); |
87 | } else { |
88 | ir_error_not_expected() |
89 | << "Only >= or <= is expected, found: " |
90 | << to_string(rel.op_kind()); |
91 | } |
92 | } |
93 | bound_finder_.set_var_bounds(kv.first, {lo, hi}); |
94 | } |
95 | } |
96 | |
97 | object_t _mutate(const alloc_t &obj) override { |
98 | return ir_mutator_t::_mutate(obj); |
99 | } |
100 | |
101 | object_t _mutate(const binary_op_t &obj) override { |
102 | return mutate_expr(obj); |
103 | } |
104 | |
105 | object_t _mutate(const for_t &obj) override { |
106 | auto lo = to_cpp<int64_t>(obj.init); |
107 | auto hi = to_cpp<int64_t>(obj.bound) - 1; |
108 | bound_finder_.set_var_bounds(obj.var, {lo, hi}); |
109 | return ir_mutator_t::_mutate(obj); |
110 | } |
111 | |
112 | object_t _mutate(const let_t &obj) override { |
113 | bool ok = true; |
114 | if (!obj.var.type().is_int()) ok = false; |
115 | if (ok && obj.value.is_empty()) ok = false; |
116 | if (ok && bound_finder_.has_var(obj.var)) ok = false; |
117 | |
118 | if (ok) { |
119 | if (contains_load(obj.value)) { |
120 | vars_with_load_.insert(obj.var); |
121 | ok = false; |
122 | } |
123 | } |
124 | |
125 | if (ok) { |
126 | int elems = obj.var.type().elems(); |
127 | vec_vars_[obj.var].reserve(elems); |
128 | for (int i = 0; i < elems; i++) { |
129 | auto var_i = make_vec_var(obj.var, elems, i); |
130 | expr_scalarizer_t scalarizer(elems, i, vec_vars_); |
131 | auto value_i = scalarizer.mutate(obj.value); |
132 | auto lo_hi = bound_finder_.find_bounds(value_i); |
133 | bound_finder_.set_var_bounds(var_i, lo_hi); |
134 | vec_vars_[obj.var].push_back(var_i); |
135 | } |
136 | } |
137 | expr_t var = obj.var; |
138 | expr_t value = mutate(obj.value); |
139 | stmt_t body = mutate(obj.body); |
140 | if (value.is_same(obj.value) && body.is_same(obj.body)) return obj; |
141 | if (!value.is_empty() && value.type() != obj.value.type()) { |
142 | auto old_var = var; |
143 | var = ir_ctx_.create_tmp_var( |
144 | value.type(), old_var.as<var_t>().name); |
145 | body = substitute_with_different_type(body, old_var, var); |
146 | } |
147 | return let_t::make(var, value, body); |
148 | } |
149 | |
150 | object_t _mutate(const unary_op_t &obj) override { |
151 | return mutate_expr(obj); |
152 | } |
153 | |
154 | private: |
155 | template <typename T> |
156 | object_t mutate_expr(const T &obj) { |
157 | expr_t new_obj = ir_mutator_t::_mutate(obj); |
158 | if (!new_obj.type().is_x32()) return std::move(new_obj); |
159 | if (contains_load(new_obj)) return std::move(new_obj); |
160 | |
161 | bool found_overflow = false; |
162 | int elems = new_obj.type().elems(); |
163 | for (int i = 0; i < elems; i++) { |
164 | expr_scalarizer_t scalarizer(elems, i, vec_vars_); |
165 | expr_t value = scalarizer.mutate(new_obj); |
166 | int64_t lo = bound_finder_.find_low_bound(value); |
167 | int64_t hi = bound_finder_.find_high_bound(value); |
168 | bool ok = bound_finder_base_t::is_good_bound(lo) |
169 | && bound_finder_base_t::is_good_bound(hi); |
170 | if (ok) { |
171 | int64_t type_lo = value.type().is_s32() |
172 | ? (int64_t)std::numeric_limits<int32_t>::min() |
173 | : (int64_t)std::numeric_limits<uint32_t>::min(); |
174 | int64_t type_hi = value.type().is_s32() |
175 | ? (int64_t)std::numeric_limits<int32_t>::max() |
176 | : (int64_t)std::numeric_limits<uint32_t>::max(); |
177 | |
178 | bool is_overflow = (lo < type_lo || hi > type_hi); |
179 | if (is_overflow) { |
180 | found_overflow = true; |
181 | ir_warning() << "Found overflow: " << value |
182 | << " low bound: " << lo |
183 | << " high bound: " << hi << std::endl; |
184 | break; |
185 | } |
186 | } |
187 | } |
188 | if (found_overflow) return fix_overflow(new_obj); |
189 | return std::move(new_obj); |
190 | } |
191 | |
192 | bool contains_load(const expr_t &e) const { |
193 | if (!find_objects<load_t>(e).empty()) return true; |
194 | for (auto &v : find_objects<var_t>(e)) { |
195 | if (vars_with_load_.count(v) != 0) return true; |
196 | } |
197 | return false; |
198 | } |
199 | |
200 | static expr_t make_vec_var(const expr_t &_var, int elems, int idx) { |
201 | if (elems == 1) return _var; |
202 | auto &var = _var.as<var_t>(); |
203 | auto vec_name = var.name + "_" + std::to_string(idx) + "_" ; |
204 | return var_t::make(var.type.scalar(), vec_name); |
205 | } |
206 | |
207 | static expr_t fix_overflow(const expr_t &e) { |
208 | auto *binary = e.as_ptr<binary_op_t>(); |
209 | if (binary) { |
210 | return binary_op_t::make(binary->op_kind, |
211 | cast(binary->a, type_t::u64(e.type().elems())), binary->b); |
212 | } |
213 | |
214 | ir_error_not_expected() << "Can't fix overflow: " << e; |
215 | return e; |
216 | } |
217 | |
218 | ir_context_t &ir_ctx_; |
219 | overflow_bound_finder_t bound_finder_; |
220 | object_map_t<expr_t, std::vector<expr_t>> vec_vars_; |
221 | object_set_t<expr_t> vars_with_load_; |
222 | }; |
223 | |
224 | stmt_t fix_int32_overflow(const stmt_t &s, ir_context_t &ir_ctx) { |
225 | trace_start(); |
226 | auto ret = overflow_fixer_t(ir_ctx).mutate(s); |
227 | trace_pass("fix_int32_overflow" , ret, ir_ctx); |
228 | return ret; |
229 | } |
230 | |
231 | } // namespace jit |
232 | } // namespace gpu |
233 | } // namespace impl |
234 | } // namespace dnnl |
235 | |