1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/overflow.hpp"
18
19#include "gpu/jit/pass/expr_scalarizer.hpp"
20#include "gpu/jit/utils/trace.hpp"
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27class overflow_bound_finder_t : public bound_finder_base_t {
28public:
29 bool has_var(const expr_t &e) const {
30 ir_assert(is_var(e)) << "Expected variable, found: " << e;
31 auto it = var_bounds_.find(e);
32 return it != var_bounds_.end();
33 }
34
35 std::pair<int64_t, int64_t> find_bounds(const expr_t &e) const {
36 int64_t lo = find_low_bound(e);
37 int64_t hi = find_high_bound(e);
38 return std::make_pair(lo, hi);
39 }
40
41 int64_t get_var_bound(const expr_t &e, bool is_low) const override {
42 ir_assert(has_var(e)) << "Variable not found: " << e;
43 auto &lo_hi = var_bounds_.at(e);
44 return is_low ? lo_hi.first : lo_hi.second;
45 }
46
47 void set_var_bounds(
48 const expr_t &e, const std::pair<int64_t, int64_t> &lo_hi) {
49 ir_assert(is_good_bound(lo_hi.first))
50 << "Can't compute low bound for " << e;
51 ir_assert(is_good_bound(lo_hi.second))
52 << "Can't compute high bound for " << e;
53 var_bounds_.emplace(e, lo_hi);
54 }
55
56protected:
57 int64_t find_bound_impl(const expr_t &e, bool is_low) const override {
58 auto *cast = e.as_ptr<cast_t>();
59 if (cast) {
60 if (e.type().is_u64() && cast->expr.type().is_ptr()) {
61 return is_low ? 0 : std::numeric_limits<uint32_t>::max();
62 } else if (e.type().is_u32() && cast->expr.type().is_ptr()) {
63 return is_low ? 0 : std::numeric_limits<uint16_t>::max();
64 }
65 }
66 return bound_finder_base_t::find_bound_impl(e, is_low);
67 }
68
69private:
70 object_map_t<expr_t, std::pair<int64_t, int64_t>> var_bounds_;
71};
72
73class overflow_fixer_t : public ir_mutator_t {
74public:
75 overflow_fixer_t(ir_context_t &ir_ctx) : ir_ctx_(ir_ctx) {
76 for (auto &kv : ir_ctx.cset().relations()) {
77 int64_t lo = bound_finder_base_t::unlimited_bound(true);
78 int64_t hi = bound_finder_base_t::unlimited_bound(false);
79 for (auto &rel : kv.second) {
80 bool is_ge = (rel.op_kind() == op_kind_t::_ge);
81 bool is_le = (rel.op_kind() == op_kind_t::_le);
82 ir_assert(is_ge || is_le);
83 if (rel.op_kind() == op_kind_t::_ge) {
84 lo = std::max(to_cpp<int64_t>(rel.rhs()), lo);
85 } else if (rel.op_kind() == op_kind_t::_le) {
86 hi = std::min(to_cpp<int64_t>(rel.rhs()), hi);
87 } else {
88 ir_error_not_expected()
89 << "Only >= or <= is expected, found: "
90 << to_string(rel.op_kind());
91 }
92 }
93 bound_finder_.set_var_bounds(kv.first, {lo, hi});
94 }
95 }
96
97 object_t _mutate(const alloc_t &obj) override {
98 return ir_mutator_t::_mutate(obj);
99 }
100
101 object_t _mutate(const binary_op_t &obj) override {
102 return mutate_expr(obj);
103 }
104
105 object_t _mutate(const for_t &obj) override {
106 auto lo = to_cpp<int64_t>(obj.init);
107 auto hi = to_cpp<int64_t>(obj.bound) - 1;
108 bound_finder_.set_var_bounds(obj.var, {lo, hi});
109 return ir_mutator_t::_mutate(obj);
110 }
111
112 object_t _mutate(const let_t &obj) override {
113 bool ok = true;
114 if (!obj.var.type().is_int()) ok = false;
115 if (ok && obj.value.is_empty()) ok = false;
116 if (ok && bound_finder_.has_var(obj.var)) ok = false;
117
118 if (ok) {
119 if (contains_load(obj.value)) {
120 vars_with_load_.insert(obj.var);
121 ok = false;
122 }
123 }
124
125 if (ok) {
126 int elems = obj.var.type().elems();
127 vec_vars_[obj.var].reserve(elems);
128 for (int i = 0; i < elems; i++) {
129 auto var_i = make_vec_var(obj.var, elems, i);
130 expr_scalarizer_t scalarizer(elems, i, vec_vars_);
131 auto value_i = scalarizer.mutate(obj.value);
132 auto lo_hi = bound_finder_.find_bounds(value_i);
133 bound_finder_.set_var_bounds(var_i, lo_hi);
134 vec_vars_[obj.var].push_back(var_i);
135 }
136 }
137 expr_t var = obj.var;
138 expr_t value = mutate(obj.value);
139 stmt_t body = mutate(obj.body);
140 if (value.is_same(obj.value) && body.is_same(obj.body)) return obj;
141 if (!value.is_empty() && value.type() != obj.value.type()) {
142 auto old_var = var;
143 var = ir_ctx_.create_tmp_var(
144 value.type(), old_var.as<var_t>().name);
145 body = substitute_with_different_type(body, old_var, var);
146 }
147 return let_t::make(var, value, body);
148 }
149
150 object_t _mutate(const unary_op_t &obj) override {
151 return mutate_expr(obj);
152 }
153
154private:
155 template <typename T>
156 object_t mutate_expr(const T &obj) {
157 expr_t new_obj = ir_mutator_t::_mutate(obj);
158 if (!new_obj.type().is_x32()) return std::move(new_obj);
159 if (contains_load(new_obj)) return std::move(new_obj);
160
161 bool found_overflow = false;
162 int elems = new_obj.type().elems();
163 for (int i = 0; i < elems; i++) {
164 expr_scalarizer_t scalarizer(elems, i, vec_vars_);
165 expr_t value = scalarizer.mutate(new_obj);
166 int64_t lo = bound_finder_.find_low_bound(value);
167 int64_t hi = bound_finder_.find_high_bound(value);
168 bool ok = bound_finder_base_t::is_good_bound(lo)
169 && bound_finder_base_t::is_good_bound(hi);
170 if (ok) {
171 int64_t type_lo = value.type().is_s32()
172 ? (int64_t)std::numeric_limits<int32_t>::min()
173 : (int64_t)std::numeric_limits<uint32_t>::min();
174 int64_t type_hi = value.type().is_s32()
175 ? (int64_t)std::numeric_limits<int32_t>::max()
176 : (int64_t)std::numeric_limits<uint32_t>::max();
177
178 bool is_overflow = (lo < type_lo || hi > type_hi);
179 if (is_overflow) {
180 found_overflow = true;
181 ir_warning() << "Found overflow: " << value
182 << " low bound: " << lo
183 << " high bound: " << hi << std::endl;
184 break;
185 }
186 }
187 }
188 if (found_overflow) return fix_overflow(new_obj);
189 return std::move(new_obj);
190 }
191
192 bool contains_load(const expr_t &e) const {
193 if (!find_objects<load_t>(e).empty()) return true;
194 for (auto &v : find_objects<var_t>(e)) {
195 if (vars_with_load_.count(v) != 0) return true;
196 }
197 return false;
198 }
199
200 static expr_t make_vec_var(const expr_t &_var, int elems, int idx) {
201 if (elems == 1) return _var;
202 auto &var = _var.as<var_t>();
203 auto vec_name = var.name + "_" + std::to_string(idx) + "_";
204 return var_t::make(var.type.scalar(), vec_name);
205 }
206
207 static expr_t fix_overflow(const expr_t &e) {
208 auto *binary = e.as_ptr<binary_op_t>();
209 if (binary) {
210 return binary_op_t::make(binary->op_kind,
211 cast(binary->a, type_t::u64(e.type().elems())), binary->b);
212 }
213
214 ir_error_not_expected() << "Can't fix overflow: " << e;
215 return e;
216 }
217
218 ir_context_t &ir_ctx_;
219 overflow_bound_finder_t bound_finder_;
220 object_map_t<expr_t, std::vector<expr_t>> vec_vars_;
221 object_set_t<expr_t> vars_with_load_;
222};
223
224stmt_t fix_int32_overflow(const stmt_t &s, ir_context_t &ir_ctx) {
225 trace_start();
226 auto ret = overflow_fixer_t(ir_ctx).mutate(s);
227 trace_pass("fix_int32_overflow", ret, ir_ctx);
228 return ret;
229}
230
231} // namespace jit
232} // namespace gpu
233} // namespace impl
234} // namespace dnnl
235