1 | /******************************************************************************* |
2 | * Copyright 2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/pass/shuffle_splitter.hpp" |
18 | |
19 | #include "gpu/jit/utils/trace.hpp" |
20 | |
21 | namespace dnnl { |
22 | namespace impl { |
23 | namespace gpu { |
24 | namespace jit { |
25 | |
26 | class shuffle_splitter_t : public ir_mutator_t { |
27 | public: |
28 | static expr_t add(const expr_t &e, const expr_t &ee) { |
29 | if (e.is_empty()) { |
30 | return ee; |
31 | } else if (ee.is_empty()) { |
32 | return e; |
33 | } else { |
34 | return e + ee; |
35 | } |
36 | }; |
37 | |
38 | object_t _mutate(const binary_op_t &obj) override { |
39 | if (obj.op_kind != op_kind_t::_add) return ir_mutator_t::_mutate(obj); |
40 | |
41 | // Aggregate bcast expressions together |
42 | auto new_obj = ir_mutator_t::_mutate(obj); |
43 | auto args = split_by_add(new_obj, obj.type.elems()); |
44 | if (args.size() <= 1) return new_obj; |
45 | |
46 | std::vector<expr_t> bcasts; |
47 | std::vector<expr_t> non_bcasts; |
48 | for (auto &a : args) { |
49 | if (a.type().elems() != obj.type.elems()) { |
50 | bcasts.push_back(a); |
51 | } else { |
52 | non_bcasts.push_back(a); |
53 | } |
54 | } |
55 | |
56 | if (bcasts.size() <= 1) return new_obj; |
57 | |
58 | int elems = obj.type.elems(); |
59 | expr_t e = shuffle_t::make_broadcast(make_add(bcasts), elems); |
60 | if (!non_bcasts.empty()) e = add(e, make_add(non_bcasts)); |
61 | |
62 | ir_assert(!e.is_empty()); |
63 | return std::move(e); |
64 | } |
65 | |
66 | object_t _mutate(const shuffle_t &obj) override { |
67 | object_t new_obj = ir_mutator_t::_mutate(obj); |
68 | if (obj.is_broadcast() || !new_obj.is<shuffle_t>()) return new_obj; |
69 | |
70 | auto &o = new_obj.as<shuffle_t>(); |
71 | |
72 | // Split shuffle to bcast(expr) + vector(exprs) + vector(constants). Use |
73 | // existing vector(constants) to improve the effect of common |
74 | // subexpression elimnation. |
75 | |
76 | expr_t vec_bcast; |
77 | std::vector<expr_t> vec_const; |
78 | std::vector<expr_t> vec_off; |
79 | |
80 | std::vector<object_eq_map_t<expr_t, int>> args; |
81 | bool can_split = false; |
82 | const expr_t zero = cast(0, o.type.scalar()); |
83 | |
84 | for (auto &v : o.vec) { |
85 | // Only supports integer arithmetic |
86 | if (!v.type().is_int()) return new_obj; |
87 | auto v_args = split_by_add(v, v.type().elems()); |
88 | if (v_args.size() > 1) can_split = true; |
89 | expr_t e_const = zero; |
90 | args.emplace_back(); |
91 | for (auto &a : v_args) { |
92 | if (is_const(a)) { |
93 | e_const += a; |
94 | } else { |
95 | args.back()[a] += 1; |
96 | } |
97 | } |
98 | vec_const.push_back(const_fold(e_const)); |
99 | } |
100 | |
101 | if (!can_split) return new_obj; |
102 | |
103 | // Multiset Intersection |
104 | auto intersect = [](object_eq_map_t<expr_t, int> &a, |
105 | object_eq_map_t<expr_t, int> &b) { |
106 | object_eq_map_t<expr_t, int> c; |
107 | for (auto &kv : a) { |
108 | auto &key = kv.first; |
109 | int rep_a = kv.second; |
110 | int rep_b = b[key]; |
111 | int rep_c = std::min(rep_a, rep_b); |
112 | if (rep_c > 0) c[key] = rep_c; |
113 | } |
114 | return c; |
115 | }; |
116 | // Multiset Difference |
117 | auto difference = [](object_eq_map_t<expr_t, int> &a, |
118 | object_eq_map_t<expr_t, int> &b) { |
119 | object_eq_map_t<expr_t, int> c; |
120 | for (auto &kv : a) { |
121 | auto key = kv.first; |
122 | int rep_a = kv.second; |
123 | int rep_b = b[key]; |
124 | int rep_c = rep_a - rep_b; |
125 | if (rep_c > 0) c[key] = rep_c; |
126 | } |
127 | return c; |
128 | }; |
129 | |
130 | auto is_empty_or_fill = [&](std::vector<expr_t> &vec) { |
131 | for (auto &c : vec) { |
132 | if (!c.is_empty() && !c.is_equal(zero)) { return false; } |
133 | if (c.is_empty()) c = zero; |
134 | } |
135 | return true; |
136 | }; |
137 | |
138 | auto is_bcast = [](const std::vector<expr_t> &vec) { |
139 | for (auto &c : vec) { |
140 | if (!c.is_equal(vec[0])) { return false; } |
141 | } |
142 | return true; |
143 | }; |
144 | |
145 | auto get_bcast_difference = [](expr_t expr_a, expr_t expr_b) { |
146 | if (!expr_a.is<shuffle_t>() || !expr_b.is<shuffle_t>()) |
147 | return expr_t(); |
148 | |
149 | auto &a = expr_a.as<shuffle_t>(); |
150 | auto &b = expr_b.as<shuffle_t>(); |
151 | if (a.idx.size() != b.idx.size()) return expr_t(); |
152 | if (a.vec.size() != b.vec.size()) return expr_t(); |
153 | |
154 | for (size_t i = 0; i < a.idx.size(); i++) { |
155 | if (a.idx[i] != b.idx[i]) return expr_t(); |
156 | } |
157 | |
158 | if (a.vec.size() <= 0) return expr_t(); |
159 | expr_t offset = const_fold(a.vec[0] - b.vec[0]); |
160 | for (size_t i = 0; i < a.vec.size(); i++) { |
161 | expr_t new_offset = const_fold(a.vec[i] - b.vec[i]); |
162 | if (!offset.is_equal(new_offset)) return expr_t(); |
163 | } |
164 | return offset; |
165 | }; |
166 | |
167 | auto base_args = args[0]; |
168 | for (int i = 1; i < (int)args.size(); i++) { |
169 | base_args = intersect(base_args, args[i]); |
170 | } |
171 | |
172 | vec_bcast = make_add(base_args); |
173 | for (auto &a : args) |
174 | vec_off.push_back(make_add(difference(a, base_args))); |
175 | |
176 | bool is_bcast_empty = base_args.size() == 0; |
177 | bool is_consts_empty = is_empty_or_fill(vec_const); |
178 | bool is_consts_bcast = is_bcast(vec_const); |
179 | bool is_off_empty = is_empty_or_fill(vec_off); |
180 | |
181 | expr_t const_shuffle; |
182 | if (!is_consts_empty) { |
183 | const_shuffle = shuffle_t::make(vec_const, o.idx); |
184 | if (!is_consts_bcast) { |
185 | expr_t offset; |
186 | for (auto &k : const_shuffles_) { |
187 | offset = get_bcast_difference(const_shuffle, k); |
188 | if (!offset.is_empty()) { |
189 | vec_bcast = add(vec_bcast, offset); |
190 | const_shuffle = k; |
191 | is_consts_bcast |
192 | = is_bcast(const_shuffle.as<shuffle_t>().vec); |
193 | break; |
194 | } |
195 | } |
196 | |
197 | if (offset.is_empty()) { |
198 | const_shuffles_.emplace(const_shuffle); |
199 | } |
200 | } |
201 | |
202 | if (is_consts_bcast) { |
203 | const_shuffle = shuffle_t::make_broadcast( |
204 | const_shuffle.as<shuffle_t>().vec[0], o.type.elems()); |
205 | } |
206 | } |
207 | |
208 | expr_t e; |
209 | if (!is_bcast_empty) |
210 | e = add(e, shuffle_t::make_broadcast(vec_bcast, o.type.elems())); |
211 | if (!is_off_empty) e = add(e, shuffle_t::make(vec_off, o.idx)); |
212 | e = add(e, const_shuffle); |
213 | |
214 | return std::move(e); |
215 | } |
216 | |
217 | private: |
218 | object_eq_set_t<expr_t> const_shuffles_; |
219 | static std::vector<expr_t> split_by_add(const expr_t &e, int elems) { |
220 | auto *shuffle = e.as_ptr<shuffle_t>(); |
221 | if (shuffle && shuffle->is_broadcast() && shuffle->elems() == elems) { |
222 | return split_by_add(shuffle->vec[0], elems); |
223 | } |
224 | auto *op = e.as_ptr<binary_op_t>(); |
225 | if (!op || op->op_kind != op_kind_t::_add) return {e}; |
226 | auto a_args = split_by_add(op->a, elems); |
227 | auto b_args = split_by_add(op->b, elems); |
228 | std::vector<expr_t> args; |
229 | args.insert(args.end(), a_args.begin(), a_args.end()); |
230 | args.insert(args.end(), b_args.begin(), b_args.end()); |
231 | return args; |
232 | } |
233 | |
234 | static expr_t make_add(const std::vector<expr_t> &args) { |
235 | if (args.empty()) return 0; |
236 | expr_t e = args[0]; |
237 | for (int i = 1; i < (int)args.size(); i++) |
238 | e = e + args[i]; |
239 | return e; |
240 | } |
241 | static expr_t make_add(const object_eq_map_t<expr_t, int> &args) { |
242 | if (args.empty()) return 0; |
243 | expr_t e; |
244 | for (auto &kv : args) |
245 | if (kv.second == 0) |
246 | continue; |
247 | else if (kv.second == 1) |
248 | e = add(e, kv.first); |
249 | else |
250 | e = add(e, kv.second * kv.first); |
251 | return e; |
252 | } |
253 | }; |
254 | |
255 | stmt_t split_shuffle(const stmt_t &s, ir_context_t &ir_ctx) { |
256 | trace_start(); |
257 | auto ret = shuffle_splitter_t().mutate(s); |
258 | trace_pass("split_shuffle" , ret, ir_ctx); |
259 | return ret; |
260 | } |
261 | |
262 | } // namespace jit |
263 | } // namespace gpu |
264 | } // namespace impl |
265 | } // namespace dnnl |
266 | |