1 | #include "triton/codegen/analysis/align.h" |
2 | #include "triton/ir/utils.h" |
3 | #include "triton/ir/module.h" |
4 | #include "triton/ir/function.h" |
5 | #include "triton/ir/basic_block.h" |
6 | #include "triton/ir/instructions.h" |
7 | #include "triton/ir/type.h" |
8 | #include <iostream> |
9 | |
10 | namespace triton { |
11 | namespace codegen{ |
12 | namespace analysis{ |
13 | |
14 | |
15 | // Function for extended Euclidean Algorithm |
16 | int gcd_impl(int a, int b, int *x, int *y) |
17 | { |
18 | // Base Case |
19 | if (a == 0) |
20 | { |
21 | *x = 0; |
22 | *y = 1; |
23 | return b; |
24 | } |
25 | |
26 | int x1, y1; // To store results of recursive call |
27 | int gcd = gcd_impl(b%a, a, &x1, &y1); |
28 | |
29 | // Update x and y using results of |
30 | // recursive call |
31 | *x = y1 - (b/a) * x1; |
32 | *y = x1; |
33 | |
34 | return gcd; |
35 | } |
36 | |
37 | int gcd(int a, int b) { |
38 | int x, y; |
39 | return gcd_impl(a, b, &x, &y); |
40 | } |
41 | |
42 | |
43 | inline int lcm(int a, int b) { |
44 | return (a * b) / gcd(a, b); |
45 | } |
46 | |
47 | template<class T> |
48 | inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) { |
49 | return map[i] = value; |
50 | } |
51 | |
52 | /* |
53 | * is constant |
54 | */ |
55 | |
56 | std::vector<unsigned> align::get_shapes(ir::value *v) { |
57 | ir::type *ty = v->get_type(); |
58 | if(ty->is_block_ty()) |
59 | return ty->get_block_shapes(); |
60 | else |
61 | return {1}; |
62 | } |
63 | |
64 | std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) { |
65 | auto shapes = get_shapes(x); |
66 | std::vector<cst_info> result(shapes.size(), cst_info{1, 0}); |
67 | for(unsigned n = 0; n < x->get_num_incoming(); n++){ |
68 | ir::value* inc = x->get_incoming_value(n); |
69 | auto it = is_constant_.find(inc); |
70 | if(it != is_constant_.end()) |
71 | result = it->second; |
72 | } |
73 | return add_to_cache(x, result, is_constant_); |
74 | // recurse |
75 | for(unsigned n = 0; n < x->get_num_incoming(); n++){ |
76 | ir::value* inc = x->get_incoming_value(n); |
77 | auto cst = populate_is_constant(inc); |
78 | for(size_t d = 0; d < cst.size(); d++) |
79 | result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst); |
80 | } |
81 | return add_to_cache(x, result, is_constant_); |
82 | } |
83 | |
84 | std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) { |
85 | auto shapes = get_shapes(x); |
86 | ir::value* op = x->get_operand(0); |
87 | std::vector<cst_info> result; |
88 | auto op_cst = populate_is_constant(op); |
89 | for(auto d: shapes) |
90 | result.push_back(cst_info{d, op_cst[0].value}); |
91 | return add_to_cache(x, result, is_constant_); |
92 | } |
93 | |
94 | std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) { |
95 | auto x_shapes = get_shapes(x); |
96 | std::vector<cst_info> result; |
97 | ir::value *op = x->get_operand(0); |
98 | auto op_shapes = op->get_type()->get_block_shapes(); |
99 | auto op_cst = populate_is_constant(op); |
100 | unsigned current = 0; |
101 | bool is_skewed = false; |
102 | for(size_t d = 0; d < x_shapes.size(); d ++){ |
103 | cst_info ax ; |
104 | if(x_shapes[d] == 1) |
105 | ax = {1, op_cst[current].value}; |
106 | else if(!is_skewed |
107 | && x_shapes[d] == op_shapes[current]) |
108 | ax = {x_shapes[d], op_cst[current++].value}; |
109 | else { |
110 | is_skewed = true; |
111 | ax = {x_shapes[d], 0}; |
112 | } |
113 | result.push_back(ax); |
114 | } |
115 | return add_to_cache(x, result, is_constant_); |
116 | } |
117 | |
118 | std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) { |
119 | auto x_shapes = get_shapes(x); |
120 | std::vector<cst_info> result; |
121 | ir::value *op = x->get_operand(0); |
122 | auto op_shapes = op->get_type()->get_block_shapes(); |
123 | auto op_cst = populate_is_constant(op); |
124 | for(size_t d = 0; d < x_shapes.size(); d++) { |
125 | result.push_back(op_cst[d]); |
126 | } |
127 | return add_to_cache(x, result, is_constant_); |
128 | } |
129 | |
130 | std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) { |
131 | auto x_shapes = get_shapes(x); |
132 | std::vector<cst_info> result; |
133 | ir::value *op = x->get_operand(0); |
134 | auto op_shapes = op->get_type()->get_block_shapes(); |
135 | auto op_cst = populate_is_constant(op); |
136 | for(size_t d = 0; d < x_shapes.size(); d++) |
137 | if(op_shapes[d] == 1) |
138 | result.push_back(cst_info{x_shapes[d], op_cst[d].value}); |
139 | else |
140 | result.push_back(op_cst[d]); |
141 | return add_to_cache(x, result, is_constant_); |
142 | } |
143 | |
144 | std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) { |
145 | auto x_shapes = get_shapes(x); |
146 | std::vector<cst_info> result; |
147 | ir::value* lhs_op = x->get_operand(0); |
148 | ir::value* rhs_op = x->get_operand(1); |
149 | auto lhs = populate_is_constant(lhs_op); |
150 | auto rhs = populate_is_constant(rhs_op); |
151 | auto lhs_max_contiguous = populate_max_contiguous(lhs_op); |
152 | auto rhs_max_contiguous = populate_max_contiguous(rhs_op); |
153 | auto lhs_multiple_of = populate_starting_multiple(lhs_op); |
154 | auto rhs_multiple_of = populate_starting_multiple(rhs_op); |
155 | for(size_t d = 0; d < x_shapes.size(); d++) { |
156 | cst_info ax = {1, 0}; |
157 | // Examples: |
158 | // 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8 |
159 | // 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4 |
160 | // 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16 |
161 | // |
162 | // if LHS is a range of N continuous (or equal) elements that starts at M, |
163 | // and RHS is a set of N constants that start at K |
164 | // then the result in constant in groups of gcd(M, K) |
165 | if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 || |
166 | rhs[d].num_cst % lhs[d].num_cst == 0) |
167 | ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]); |
168 | result.push_back(ax); |
169 | } |
170 | return add_to_cache(x, result, is_constant_); |
171 | } |
172 | |
173 | |
174 | std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) { |
175 | auto x_shapes = get_shapes(x); |
176 | std::vector<cst_info> result; |
177 | ir::value* lhs_op = x->get_operand(0); |
178 | ir::value* rhs_op = x->get_operand(1); |
179 | auto lhs = populate_is_constant(lhs_op); |
180 | auto rhs = populate_is_constant(rhs_op); |
181 | auto lhs_max_contiguous = populate_max_contiguous(lhs_op); |
182 | auto rhs_max_contiguous = populate_max_contiguous(rhs_op); |
183 | auto lhs_multiple_of = populate_starting_multiple(lhs_op); |
184 | auto rhs_multiple_of = populate_starting_multiple(rhs_op); |
185 | for(size_t d = 0; d < x_shapes.size(); d++) { |
186 | cst_info ax; |
187 | if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){ |
188 | unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value); |
189 | ax = {num_constants, 0}; |
190 | } |
191 | else |
192 | ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0}; |
193 | result.push_back(ax); |
194 | } |
195 | return add_to_cache(x, result, is_constant_); |
196 | } |
197 | |
198 | std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) { |
199 | auto x_shapes = get_shapes(x); |
200 | ir::value* lhs_op = x->get_operand(0); |
201 | ir::value* rhs_op = x->get_operand(1); |
202 | auto lhs = populate_is_constant(lhs_op); |
203 | auto rhs = populate_is_constant(rhs_op); |
204 | std::vector<cst_info> result; |
205 | for(size_t d = 0; d < x_shapes.size(); d++) |
206 | result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0}); |
207 | return add_to_cache(x, result, is_constant_); |
208 | } |
209 | |
210 | std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) { |
211 | auto shapes = get_shapes(v); |
212 | std::vector<cst_info> result(shapes.size(), {1, 0}); |
213 | return add_to_cache(v, result, is_constant_); |
214 | } |
215 | |
216 | std::vector<align::cst_info> align::populate_is_constant(ir::value *v) { |
217 | if(is_constant_.find(v) != is_constant_.end()) |
218 | return is_constant_.at(v); |
219 | if(auto *x = dynamic_cast<ir::constant_int*>(v)) |
220 | return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_); |
221 | if(auto *x = dynamic_cast<ir::phi_node*>(v)) |
222 | return populate_is_constant_phi(x); |
223 | if(auto *x = dynamic_cast<ir::splat_inst*>(v)) |
224 | return populate_is_constant_splat(x); |
225 | if(auto *x = dynamic_cast<ir::reshape_inst*>(v)) |
226 | return populate_is_constant_reshape(x); |
227 | if(auto *x = dynamic_cast<ir::dequantize_inst*>(v)) |
228 | return populate_is_constant_dequantize(x); |
229 | if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)) |
230 | return populate_is_constant_broadcast(x); |
231 | if(auto *x = dynamic_cast<ir::binary_operator*>(v)) |
232 | return populate_is_constant_binop(x); |
233 | if(auto *x = dynamic_cast<ir::cmp_inst*>(v)) |
234 | return populate_is_constant_cmp(x); |
235 | if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)) |
236 | return populate_is_constant_gep(x); |
237 | return populate_is_constant_default(v); |
238 | } |
239 | |
240 | |
241 | /* |
242 | * max contiguous |
243 | */ |
244 | |
245 | std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) { |
246 | auto shapes = get_shapes(x); |
247 | std::vector<unsigned> result(shapes.size(), 1); |
248 | for(unsigned n = 0; n < x->get_num_incoming(); n++){ |
249 | ir::value* inc = x->get_incoming_value(n); |
250 | auto it = max_contiguous_.find(inc); |
251 | if(it != max_contiguous_.end()) |
252 | result = it->second; |
253 | } |
254 | add_to_cache(x, result, max_contiguous_); |
255 | // recurse |
256 | for(unsigned n = 0; n < x->get_num_incoming(); n++){ |
257 | ir::value* inc = x->get_incoming_value(n); |
258 | auto contiguous = populate_max_contiguous(inc); |
259 | for(size_t d = 0; d < result.size(); d++) |
260 | result[d] = std::min(result[d], contiguous[d]); |
261 | } |
262 | return add_to_cache(x, result, max_contiguous_); |
263 | |
264 | } |
265 | |
266 | std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) { |
267 | auto x_shapes = get_shapes(x); |
268 | std::vector<unsigned> result; |
269 | for(size_t d = 0; d < x_shapes.size(); d++) |
270 | result.push_back({1}); |
271 | return add_to_cache(x, result, max_contiguous_); |
272 | } |
273 | |
274 | std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) { |
275 | auto shapes = get_shapes(x); |
276 | std::vector<unsigned> result; |
277 | ir::value *op = x->get_operand(0); |
278 | auto op_shapes = op->get_type()->get_block_shapes(); |
279 | auto op_mc = populate_max_contiguous(op); |
280 | unsigned current = 0; |
281 | bool is_skewed = false; |
282 | for(size_t d = 0; d < shapes.size(); d ++){ |
283 | if(shapes[d] == 1) |
284 | result.push_back(1); |
285 | else if(!is_skewed |
286 | && shapes[d] == op_shapes[current]) |
287 | result.push_back(op_mc[current++]); |
288 | else { |
289 | is_skewed = true; |
290 | result.push_back(1); |
291 | } |
292 | } |
293 | return add_to_cache(x, result, max_contiguous_); |
294 | } |
295 | |
296 | std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) { |
297 | auto shapes = get_shapes(x); |
298 | std::vector<unsigned> result; |
299 | ir::value *op = x->get_operand(0); |
300 | auto ret_last_dim = (x->get_type()->get_block_shapes()).back(); |
301 | auto op_last_dim = (op->get_type()->get_block_shapes()).back(); |
302 | auto op_mc = populate_max_contiguous(op); |
303 | for(size_t d = 0; d < shapes.size(); d++) { |
304 | unsigned factor = 1; |
305 | if (d == shapes.size() - 1) { |
306 | factor = ret_last_dim / op_last_dim; |
307 | } |
308 | result.push_back(factor * op_mc[d]); |
309 | } |
310 | return add_to_cache(x, result, max_contiguous_); |
311 | } |
312 | |
313 | std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) { |
314 | auto shapes = get_shapes(x); |
315 | std::vector<unsigned> result; |
316 | ir::value *op = x->get_operand(0); |
317 | auto op_shapes = op->get_type()->get_block_shapes(); |
318 | auto op_mc = populate_max_contiguous(op); |
319 | for(size_t d = 0; d < shapes.size(); d++) |
320 | if(op_shapes[d] == 1) |
321 | result.push_back(1); |
322 | else |
323 | result.push_back(op_mc[d]); |
324 | return add_to_cache(x, result, max_contiguous_); |
325 | } |
326 | |
327 | std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) { |
328 | auto shapes = get_shapes(x); |
329 | ir::value* lhs = x->get_operand(0); |
330 | ir::value* rhs = x->get_operand(1); |
331 | auto lhs_max_contiguous = populate_max_contiguous(lhs); |
332 | auto rhs_max_contiguous = populate_max_contiguous(rhs); |
333 | auto lhs_cst_info = populate_is_constant(lhs); |
334 | auto rhs_cst_info = populate_is_constant(rhs); |
335 | auto lhs_starting_multiple = populate_starting_multiple(lhs); |
336 | auto rhs_starting_multiple = populate_starting_multiple(rhs); |
337 | std::vector<unsigned> result; |
338 | for(size_t d = 0; d < shapes.size(); d++){ |
339 | unsigned value = 1; |
340 | if(x->is_int_rem() && rhs_starting_multiple[d] > 0){ |
341 | value = std::min(lhs_max_contiguous[d], rhs_starting_multiple[d]); |
342 | } |
343 | if(x->is_int_mult()){ |
344 | unsigned lvalue = 1, rvalue = 1; |
345 | if(rhs_cst_info[d].value == 1) |
346 | lvalue = lhs_max_contiguous[d]; |
347 | if(lhs_cst_info[d].value == 1) |
348 | rvalue = rhs_max_contiguous[d]; |
349 | value = std::max(lvalue, rvalue); |
350 | } |
351 | if(x->is_int_add_sub()){ |
352 | unsigned lvalue = 1, rvalue = 1; |
353 | lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst); |
354 | rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst); |
355 | value = std::max(lvalue, rvalue); |
356 | } |
357 | result.push_back(value); |
358 | } |
359 | return add_to_cache(x, result, max_contiguous_); |
360 | } |
361 | |
362 | std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) { |
363 | auto shapes = get_shapes(x); |
364 | ir::value* lhs = x->get_operand(0); |
365 | ir::value* rhs = x->get_operand(1); |
366 | auto lhs_max_contiguous = populate_max_contiguous(lhs); |
367 | auto rhs_max_contiguous = populate_max_contiguous(rhs); |
368 | auto lhs_cst_info = populate_is_constant(lhs); |
369 | auto rhs_cst_info = populate_is_constant(rhs); |
370 | std::vector<unsigned> result(shapes.size(), 1); |
371 | for(size_t d = 0; d < shapes.size(); d++){ |
372 | unsigned lvalue = 1, rvalue = 1; |
373 | if(lhs_cst_info[d].num_cst) |
374 | lvalue = rhs_max_contiguous[d]; |
375 | if(rhs_cst_info[d].num_cst) |
376 | rvalue = lhs_max_contiguous[d]; |
377 | result[d] = std::max(lvalue, rvalue); |
378 | } |
379 | return add_to_cache(x, result, max_contiguous_); |
380 | } |
381 | |
382 | std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) { |
383 | if(!v->get_type()->is_block_ty()) |
384 | return add_to_cache(v, {1}, max_contiguous_); |
385 | auto shapes = v->get_type()->get_block_shapes(); |
386 | if(dynamic_cast<ir::make_range*>(v)) |
387 | return add_to_cache(v, {shapes[0]}, max_contiguous_); |
388 | return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_); |
389 | } |
390 | |
391 | std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){ |
392 | auto result = populate_max_contiguous(v->get_operand(0)); |
393 | return add_to_cache(v, result, max_contiguous_); |
394 | } |
395 | |
396 | std::vector<unsigned> align::populate_max_contiguous(ir::value *v){ |
397 | if(max_contiguous_.find(v) != max_contiguous_.end()) |
398 | return max_contiguous_.at(v); |
399 | if(auto *x = dynamic_cast<ir::instruction*>(v)){ |
400 | std::vector<unsigned> max_contiguous = x->get_metadata(ir::metadata::max_contiguous); |
401 | if(!max_contiguous.empty()) |
402 | return add_to_cache(x, max_contiguous, max_contiguous_); |
403 | } |
404 | if(auto *x = dynamic_cast<ir::cast_inst*>(v)) |
405 | return populate_max_contiguous_cast(x); |
406 | if(auto *x = dynamic_cast<ir::splat_inst*>(v)) |
407 | return populate_max_contiguous_splat(x); |
408 | if(auto *x = dynamic_cast<ir::reshape_inst*>(v)) |
409 | return populate_max_contiguous_reshape(x); |
410 | if(auto *x = dynamic_cast<ir::dequantize_inst*>(v)) |
411 | return populate_max_contiguous_dequantize(x); |
412 | if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)) |
413 | return populate_max_contiguous_broadcast(x); |
414 | if(auto *x = dynamic_cast<ir::binary_operator*>(v)) |
415 | return populate_max_contiguous_binop(x); |
416 | if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)) |
417 | return populate_max_contiguous_gep(x); |
418 | if(auto *x = dynamic_cast<ir::phi_node*>(v)) |
419 | return populate_max_contiguous_phi(x); |
420 | return populate_max_contiguous_default(v); |
421 | } |
422 | |
423 | |
424 | /* |
425 | * starting multiple |
426 | */ |
427 | |
428 | std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){ |
429 | auto shapes = get_shapes(x); |
430 | auto op = populate_starting_multiple(x->get_operand(0)); |
431 | std::vector<unsigned> result(shapes.size(), op[0]); |
432 | return add_to_cache(x, result, starting_multiple_); |
433 | } |
434 | |
435 | std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){ |
436 | auto op = populate_starting_multiple(x->get_operand(0)); |
437 | auto op_shapes = get_shapes(x->get_operand(0)); |
438 | auto shapes = get_shapes(x); |
439 | std::vector<unsigned> result(shapes.size(), 1); |
440 | unsigned current = 0; |
441 | bool is_skewed = false; |
442 | for(size_t d = 0; d < shapes.size(); d ++){ |
443 | if(shapes[d] == 1) |
444 | result[d] = 1; |
445 | else if(!is_skewed |
446 | && shapes[d] == op_shapes[current]) |
447 | result[d] = op[current++]; |
448 | else { |
449 | is_skewed = true; |
450 | result[d] = 1; |
451 | } |
452 | } |
453 | return add_to_cache(x, result, starting_multiple_); |
454 | } |
455 | |
456 | std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){ |
457 | auto shapes = get_shapes(x); |
458 | std::vector<unsigned> result; |
459 | ir::value *op = x->get_operand(0); |
460 | auto ret_last_dim = (x->get_type()->get_block_shapes()).back(); |
461 | auto op_last_dim = (op->get_type()->get_block_shapes()).back(); |
462 | auto op_multiple = populate_starting_multiple(op); |
463 | for(size_t d = 0; d < shapes.size(); d++) { |
464 | unsigned factor = 1; |
465 | if (d == shapes.size() - 1) { |
466 | factor = ret_last_dim / op_last_dim; |
467 | } |
468 | result.push_back(factor * op_multiple[d]); |
469 | } |
470 | return add_to_cache(x, result, starting_multiple_); |
471 | } |
472 | |
473 | std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){ |
474 | auto result = populate_starting_multiple(x->get_operand(0)); |
475 | return add_to_cache(x, result, starting_multiple_); |
476 | } |
477 | |
478 | std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){ |
479 | auto lhs = populate_starting_multiple(x->get_operand(0)); |
480 | auto rhs = populate_starting_multiple(x->get_operand(1)); |
481 | std::vector<unsigned> result(lhs.size(), 1); |
482 | for(size_t d = 0; d < lhs.size(); d++){ |
483 | if(x->is_int_mult()) |
484 | result[d] = lhs[d] * rhs[d]; |
485 | if(x->is_int_add_sub()) |
486 | result[d] = gcd(lhs[d], rhs[d]); |
487 | if(x->is_int_div()) |
488 | result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1; |
489 | if(x->is_int_rem() && rhs[d] > 1){ |
490 | result[d] = gcd(lhs[d], rhs[d]); |
491 | } |
492 | if(x->is_shl()) |
493 | result[d] = lhs[d] << rhs[d]; |
494 | if(x->is_shr()) |
495 | result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1); |
496 | } |
497 | return add_to_cache(x, result, starting_multiple_); |
498 | } |
499 | |
500 | std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){ |
501 | auto lhs = populate_starting_multiple(x->get_operand(0)); |
502 | auto rhs = populate_starting_multiple(x->get_operand(1)); |
503 | std::vector<unsigned> result(lhs.size(), 1); |
504 | for(size_t d = 0; d < lhs.size(); d++){ |
505 | result[d] = gcd(lhs[d], rhs[d]); |
506 | // std::cout << "starting multiple: " << x->get_name() << " " << d << " " << result[d] << std::endl; |
507 | } |
508 | return add_to_cache(x, result, starting_multiple_); |
509 | } |
510 | |
511 | std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){ |
512 | auto shape = get_shapes(x); |
513 | std::vector<unsigned> result(shape.size(), 1); |
514 | for(unsigned n = 0; n < x->get_num_incoming(); n++){ |
515 | ir::value* inc = x->get_incoming_value(n); |
516 | if(starting_multiple_.find(inc) != starting_multiple_.end()) |
517 | result = starting_multiple_.at(inc); |
518 | } |
519 | add_to_cache(x, result, starting_multiple_); |
520 | // recurse |
521 | for(unsigned n = 0; n < x->get_num_incoming(); n++){ |
522 | ir::value* inc = x->get_incoming_value(n); |
523 | auto sm = populate_starting_multiple(inc); |
524 | for(size_t d = 0; d < result.size(); d++) |
525 | result[d] = gcd(result[d], sm[d]); |
526 | } |
527 | return add_to_cache(x, result, starting_multiple_); |
528 | } |
529 | |
530 | |
531 | std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){ |
532 | auto result = populate_starting_multiple(x->get_operand(0)); |
533 | return add_to_cache(x, result, starting_multiple_); |
534 | } |
535 | |
536 | std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) { |
537 | ir::type* ty = v->get_type(); |
538 | if(ty->is_block_ty()) { |
539 | return add_to_cache(v, ty->get_block_shapes(), starting_multiple_); |
540 | } |
541 | if(auto *x = dynamic_cast<ir::argument*>(v)){ |
542 | std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x); |
543 | for(auto attr: attributes){ |
544 | if(attr.get_kind() == ir::multiple_of){ |
545 | return add_to_cache(x, {attr.get_value()}, starting_multiple_); |
546 | } |
547 | if(attr.get_kind() == ir::aligned){ |
548 | ir::type* ty = x->get_type()->get_pointer_element_ty(); |
549 | int nbits = ty->get_primitive_size_in_bits(); |
550 | int nbytes = std::max<int>(nbits / 8, 1); |
551 | return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_); |
552 | } |
553 | } |
554 | } |
555 | return add_to_cache(v, {1}, starting_multiple_); |
556 | } |
557 | |
558 | unsigned get_max_multiple(int val){ |
559 | if(val == 0) return 1 << 31; |
560 | if(val % 128 == 0) return 128; |
561 | if(val % 64 == 0) return 64; |
562 | if(val % 32 == 0) return 32; |
563 | if(val % 16 == 0) return 16; |
564 | if(val % 8 == 0) return 8; |
565 | if(val % 4 == 0) return 4; |
566 | if(val % 2 == 0) return 2; |
567 | return 1; |
568 | } |
569 | |
570 | std::vector<unsigned> align::populate_starting_multiple(ir::value *v){ |
571 | if(starting_multiple_.find(v) != starting_multiple_.end()) |
572 | return starting_multiple_.at(v); |
573 | if(auto *x = dynamic_cast<ir::instruction*>(v)){ |
574 | std::vector<unsigned> multiple_of = x->get_metadata(ir::metadata::multiple_of); |
575 | if(!multiple_of.empty()) |
576 | return add_to_cache(x, multiple_of, starting_multiple_); |
577 | } |
578 | if(auto *x = dynamic_cast<ir::cast_inst*>(v)) |
579 | return populate_starting_multiple_cast(x); |
580 | if(auto *x = dynamic_cast<ir::binary_operator*>(v)) |
581 | return populate_starting_multiple_binop(x); |
582 | if(auto *x = dynamic_cast<ir::constant_int*>(v)) |
583 | return add_to_cache(x, {get_max_multiple(x->get_value())}, starting_multiple_); |
584 | if(auto *x = dynamic_cast<ir::make_range*>(v)) |
585 | return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_); |
586 | if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v)) |
587 | return populate_starting_multiple_gep(x); |
588 | if(auto *x = dynamic_cast<ir::splat_inst*>(v)) |
589 | return populate_starting_multiple_splat(x); |
590 | if(auto *x = dynamic_cast<ir::reshape_inst*>(v)) |
591 | return populate_starting_multiple_reshape(x); |
592 | if(auto *x = dynamic_cast<ir::dequantize_inst*>(v)) |
593 | return populate_starting_multiple_dequantize(x); |
594 | if(auto *x = dynamic_cast<ir::broadcast_inst*>(v)) |
595 | return populate_starting_multiple_broadcast(x); |
596 | if(auto *x = dynamic_cast<ir::phi_node*>(v)) |
597 | return populate_starting_multiple_phi(x); |
598 | return populate_starting_multiple_default(v); |
599 | } |
600 | |
601 | |
602 | unsigned align::get(ir::value *v, unsigned ax) const { |
603 | unsigned starting_multiple = starting_multiple_.at(v)[ax]; |
604 | unsigned max_contiguous = max_contiguous_.at(v)[ax]; |
605 | return std::min(starting_multiple, max_contiguous); |
606 | } |
607 | |
608 | std::vector<unsigned> align::contiguous(ir::value* v) const { |
609 | return max_contiguous_.at(v); |
610 | } |
611 | |
612 | std::vector<align::cst_info> align::get_cst_info(ir::value* v) const { |
613 | return is_constant_.at(v); |
614 | } |
615 | |
616 | |
617 | void align::populate(ir::value *v) { |
618 | populate_is_constant(v); |
619 | populate_starting_multiple(v); |
620 | populate_max_contiguous(v); |
621 | } |
622 | |
623 | void align::run(ir::module &mod) { |
624 | ir::for_each_value(mod, [this](ir::value* v) { populate(v); } ); |
625 | // ir::for_each_value(mod, [this](ir::value* v) { |
626 | // if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v)) |
627 | // std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl; |
628 | // }); |
629 | } |
630 | |
631 | |
632 | } |
633 | } |
634 | } |
635 | |