1/*******************************************************************************
2* Copyright 2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/pass/shuffle_splitter.hpp"
18
19#include "gpu/jit/utils/trace.hpp"
20
21namespace dnnl {
22namespace impl {
23namespace gpu {
24namespace jit {
25
26class shuffle_splitter_t : public ir_mutator_t {
27public:
28 static expr_t add(const expr_t &e, const expr_t &ee) {
29 if (e.is_empty()) {
30 return ee;
31 } else if (ee.is_empty()) {
32 return e;
33 } else {
34 return e + ee;
35 }
36 };
37
38 object_t _mutate(const binary_op_t &obj) override {
39 if (obj.op_kind != op_kind_t::_add) return ir_mutator_t::_mutate(obj);
40
41 // Aggregate bcast expressions together
42 auto new_obj = ir_mutator_t::_mutate(obj);
43 auto args = split_by_add(new_obj, obj.type.elems());
44 if (args.size() <= 1) return new_obj;
45
46 std::vector<expr_t> bcasts;
47 std::vector<expr_t> non_bcasts;
48 for (auto &a : args) {
49 if (a.type().elems() != obj.type.elems()) {
50 bcasts.push_back(a);
51 } else {
52 non_bcasts.push_back(a);
53 }
54 }
55
56 if (bcasts.size() <= 1) return new_obj;
57
58 int elems = obj.type.elems();
59 expr_t e = shuffle_t::make_broadcast(make_add(bcasts), elems);
60 if (!non_bcasts.empty()) e = add(e, make_add(non_bcasts));
61
62 ir_assert(!e.is_empty());
63 return std::move(e);
64 }
65
66 object_t _mutate(const shuffle_t &obj) override {
67 object_t new_obj = ir_mutator_t::_mutate(obj);
68 if (obj.is_broadcast() || !new_obj.is<shuffle_t>()) return new_obj;
69
70 auto &o = new_obj.as<shuffle_t>();
71
72 // Split shuffle to bcast(expr) + vector(exprs) + vector(constants). Use
73 // existing vector(constants) to improve the effect of common
74 // subexpression elimnation.
75
76 expr_t vec_bcast;
77 std::vector<expr_t> vec_const;
78 std::vector<expr_t> vec_off;
79
80 std::vector<object_eq_map_t<expr_t, int>> args;
81 bool can_split = false;
82 const expr_t zero = cast(0, o.type.scalar());
83
84 for (auto &v : o.vec) {
85 // Only supports integer arithmetic
86 if (!v.type().is_int()) return new_obj;
87 auto v_args = split_by_add(v, v.type().elems());
88 if (v_args.size() > 1) can_split = true;
89 expr_t e_const = zero;
90 args.emplace_back();
91 for (auto &a : v_args) {
92 if (is_const(a)) {
93 e_const += a;
94 } else {
95 args.back()[a] += 1;
96 }
97 }
98 vec_const.push_back(const_fold(e_const));
99 }
100
101 if (!can_split) return new_obj;
102
103 // Multiset Intersection
104 auto intersect = [](object_eq_map_t<expr_t, int> &a,
105 object_eq_map_t<expr_t, int> &b) {
106 object_eq_map_t<expr_t, int> c;
107 for (auto &kv : a) {
108 auto &key = kv.first;
109 int rep_a = kv.second;
110 int rep_b = b[key];
111 int rep_c = std::min(rep_a, rep_b);
112 if (rep_c > 0) c[key] = rep_c;
113 }
114 return c;
115 };
116 // Multiset Difference
117 auto difference = [](object_eq_map_t<expr_t, int> &a,
118 object_eq_map_t<expr_t, int> &b) {
119 object_eq_map_t<expr_t, int> c;
120 for (auto &kv : a) {
121 auto key = kv.first;
122 int rep_a = kv.second;
123 int rep_b = b[key];
124 int rep_c = rep_a - rep_b;
125 if (rep_c > 0) c[key] = rep_c;
126 }
127 return c;
128 };
129
130 auto is_empty_or_fill = [&](std::vector<expr_t> &vec) {
131 for (auto &c : vec) {
132 if (!c.is_empty() && !c.is_equal(zero)) { return false; }
133 if (c.is_empty()) c = zero;
134 }
135 return true;
136 };
137
138 auto is_bcast = [](const std::vector<expr_t> &vec) {
139 for (auto &c : vec) {
140 if (!c.is_equal(vec[0])) { return false; }
141 }
142 return true;
143 };
144
145 auto get_bcast_difference = [](expr_t expr_a, expr_t expr_b) {
146 if (!expr_a.is<shuffle_t>() || !expr_b.is<shuffle_t>())
147 return expr_t();
148
149 auto &a = expr_a.as<shuffle_t>();
150 auto &b = expr_b.as<shuffle_t>();
151 if (a.idx.size() != b.idx.size()) return expr_t();
152 if (a.vec.size() != b.vec.size()) return expr_t();
153
154 for (size_t i = 0; i < a.idx.size(); i++) {
155 if (a.idx[i] != b.idx[i]) return expr_t();
156 }
157
158 if (a.vec.size() <= 0) return expr_t();
159 expr_t offset = const_fold(a.vec[0] - b.vec[0]);
160 for (size_t i = 0; i < a.vec.size(); i++) {
161 expr_t new_offset = const_fold(a.vec[i] - b.vec[i]);
162 if (!offset.is_equal(new_offset)) return expr_t();
163 }
164 return offset;
165 };
166
167 auto base_args = args[0];
168 for (int i = 1; i < (int)args.size(); i++) {
169 base_args = intersect(base_args, args[i]);
170 }
171
172 vec_bcast = make_add(base_args);
173 for (auto &a : args)
174 vec_off.push_back(make_add(difference(a, base_args)));
175
176 bool is_bcast_empty = base_args.size() == 0;
177 bool is_consts_empty = is_empty_or_fill(vec_const);
178 bool is_consts_bcast = is_bcast(vec_const);
179 bool is_off_empty = is_empty_or_fill(vec_off);
180
181 expr_t const_shuffle;
182 if (!is_consts_empty) {
183 const_shuffle = shuffle_t::make(vec_const, o.idx);
184 if (!is_consts_bcast) {
185 expr_t offset;
186 for (auto &k : const_shuffles_) {
187 offset = get_bcast_difference(const_shuffle, k);
188 if (!offset.is_empty()) {
189 vec_bcast = add(vec_bcast, offset);
190 const_shuffle = k;
191 is_consts_bcast
192 = is_bcast(const_shuffle.as<shuffle_t>().vec);
193 break;
194 }
195 }
196
197 if (offset.is_empty()) {
198 const_shuffles_.emplace(const_shuffle);
199 }
200 }
201
202 if (is_consts_bcast) {
203 const_shuffle = shuffle_t::make_broadcast(
204 const_shuffle.as<shuffle_t>().vec[0], o.type.elems());
205 }
206 }
207
208 expr_t e;
209 if (!is_bcast_empty)
210 e = add(e, shuffle_t::make_broadcast(vec_bcast, o.type.elems()));
211 if (!is_off_empty) e = add(e, shuffle_t::make(vec_off, o.idx));
212 e = add(e, const_shuffle);
213
214 return std::move(e);
215 }
216
217private:
218 object_eq_set_t<expr_t> const_shuffles_;
219 static std::vector<expr_t> split_by_add(const expr_t &e, int elems) {
220 auto *shuffle = e.as_ptr<shuffle_t>();
221 if (shuffle && shuffle->is_broadcast() && shuffle->elems() == elems) {
222 return split_by_add(shuffle->vec[0], elems);
223 }
224 auto *op = e.as_ptr<binary_op_t>();
225 if (!op || op->op_kind != op_kind_t::_add) return {e};
226 auto a_args = split_by_add(op->a, elems);
227 auto b_args = split_by_add(op->b, elems);
228 std::vector<expr_t> args;
229 args.insert(args.end(), a_args.begin(), a_args.end());
230 args.insert(args.end(), b_args.begin(), b_args.end());
231 return args;
232 }
233
234 static expr_t make_add(const std::vector<expr_t> &args) {
235 if (args.empty()) return 0;
236 expr_t e = args[0];
237 for (int i = 1; i < (int)args.size(); i++)
238 e = e + args[i];
239 return e;
240 }
241 static expr_t make_add(const object_eq_map_t<expr_t, int> &args) {
242 if (args.empty()) return 0;
243 expr_t e;
244 for (auto &kv : args)
245 if (kv.second == 0)
246 continue;
247 else if (kv.second == 1)
248 e = add(e, kv.first);
249 else
250 e = add(e, kv.second * kv.first);
251 return e;
252 }
253};
254
255stmt_t split_shuffle(const stmt_t &s, ir_context_t &ir_ctx) {
256 trace_start();
257 auto ret = shuffle_splitter_t().mutate(s);
258 trace_pass("split_shuffle", ret, ir_ctx);
259 return ret;
260}
261
262} // namespace jit
263} // namespace gpu
264} // namespace impl
265} // namespace dnnl
266