1#include "triton/codegen/analysis/align.h"
2#include "triton/ir/utils.h"
3#include "triton/ir/module.h"
4#include "triton/ir/function.h"
5#include "triton/ir/basic_block.h"
6#include "triton/ir/instructions.h"
7#include "triton/ir/type.h"
8#include <iostream>
9
10namespace triton {
11namespace codegen{
12namespace analysis{
13
14
15// Function for extended Euclidean Algorithm
16int gcd_impl(int a, int b, int *x, int *y)
17{
18 // Base Case
19 if (a == 0)
20 {
21 *x = 0;
22 *y = 1;
23 return b;
24 }
25
26 int x1, y1; // To store results of recursive call
27 int gcd = gcd_impl(b%a, a, &x1, &y1);
28
29 // Update x and y using results of
30 // recursive call
31 *x = y1 - (b/a) * x1;
32 *y = x1;
33
34 return gcd;
35}
36
37int gcd(int a, int b) {
38 int x, y;
39 return gcd_impl(a, b, &x, &y);
40}
41
42
43inline int lcm(int a, int b) {
44 return (a * b) / gcd(a, b);
45}
46
47template<class T>
48inline T add_to_cache(ir::value *i, T value, std::map<ir::value*, T> &map) {
49 return map[i] = value;
50}
51
52/*
53 * is constant
54 */
55
56std::vector<unsigned> align::get_shapes(ir::value *v) {
57 ir::type *ty = v->get_type();
58 if(ty->is_block_ty())
59 return ty->get_block_shapes();
60 else
61 return {1};
62}
63
64std::vector<align::cst_info> align::populate_is_constant_phi(ir::phi_node* x) {
65 auto shapes = get_shapes(x);
66 std::vector<cst_info> result(shapes.size(), cst_info{1, 0});
67 for(unsigned n = 0; n < x->get_num_incoming(); n++){
68 ir::value* inc = x->get_incoming_value(n);
69 auto it = is_constant_.find(inc);
70 if(it != is_constant_.end())
71 result = it->second;
72 }
73 return add_to_cache(x, result, is_constant_);
74 // recurse
75 for(unsigned n = 0; n < x->get_num_incoming(); n++){
76 ir::value* inc = x->get_incoming_value(n);
77 auto cst = populate_is_constant(inc);
78 for(size_t d = 0; d < cst.size(); d++)
79 result[d].num_cst = std::min(result[d].num_cst, cst[d].num_cst);
80 }
81 return add_to_cache(x, result, is_constant_);
82}
83
84std::vector<align::cst_info> align::populate_is_constant_splat(ir::splat_inst* x) {
85 auto shapes = get_shapes(x);
86 ir::value* op = x->get_operand(0);
87 std::vector<cst_info> result;
88 auto op_cst = populate_is_constant(op);
89 for(auto d: shapes)
90 result.push_back(cst_info{d, op_cst[0].value});
91 return add_to_cache(x, result, is_constant_);
92}
93
94std::vector<align::cst_info> align::populate_is_constant_reshape(ir::reshape_inst* x) {
95 auto x_shapes = get_shapes(x);
96 std::vector<cst_info> result;
97 ir::value *op = x->get_operand(0);
98 auto op_shapes = op->get_type()->get_block_shapes();
99 auto op_cst = populate_is_constant(op);
100 unsigned current = 0;
101 bool is_skewed = false;
102 for(size_t d = 0; d < x_shapes.size(); d ++){
103 cst_info ax ;
104 if(x_shapes[d] == 1)
105 ax = {1, op_cst[current].value};
106 else if(!is_skewed
107 && x_shapes[d] == op_shapes[current])
108 ax = {x_shapes[d], op_cst[current++].value};
109 else {
110 is_skewed = true;
111 ax = {x_shapes[d], 0};
112 }
113 result.push_back(ax);
114 }
115 return add_to_cache(x, result, is_constant_);
116}
117
118std::vector<align::cst_info> align::populate_is_constant_dequantize(ir::dequantize_inst* x) {
119 auto x_shapes = get_shapes(x);
120 std::vector<cst_info> result;
121 ir::value *op = x->get_operand(0);
122 auto op_shapes = op->get_type()->get_block_shapes();
123 auto op_cst = populate_is_constant(op);
124 for(size_t d = 0; d < x_shapes.size(); d++) {
125 result.push_back(op_cst[d]);
126 }
127 return add_to_cache(x, result, is_constant_);
128}
129
130std::vector<align::cst_info> align::populate_is_constant_broadcast(ir::broadcast_inst* x) {
131 auto x_shapes = get_shapes(x);
132 std::vector<cst_info> result;
133 ir::value *op = x->get_operand(0);
134 auto op_shapes = op->get_type()->get_block_shapes();
135 auto op_cst = populate_is_constant(op);
136 for(size_t d = 0; d < x_shapes.size(); d++)
137 if(op_shapes[d] == 1)
138 result.push_back(cst_info{x_shapes[d], op_cst[d].value});
139 else
140 result.push_back(op_cst[d]);
141 return add_to_cache(x, result, is_constant_);
142}
143
144std::vector<align::cst_info> align::populate_is_constant_cmp(ir::cmp_inst* x) {
145 auto x_shapes = get_shapes(x);
146 std::vector<cst_info> result;
147 ir::value* lhs_op = x->get_operand(0);
148 ir::value* rhs_op = x->get_operand(1);
149 auto lhs = populate_is_constant(lhs_op);
150 auto rhs = populate_is_constant(rhs_op);
151 auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
152 auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
153 auto lhs_multiple_of = populate_starting_multiple(lhs_op);
154 auto rhs_multiple_of = populate_starting_multiple(rhs_op);
155 for(size_t d = 0; d < x_shapes.size(); d++) {
156 cst_info ax = {1, 0};
157 // Examples:
158 // 16 17 18 ... 32 < 24 24 24 ... 24 => equal in groups of 8
159 // 16 17 18 ... 32 < 20 20 20 ... 20 => equal in groups of 4
160 // 16 17 18 ... 32 < 16 16 16 ... 16 => equal in groups of 16
161 //
162 // if LHS is a range of N continuous (or equal) elements that starts at M,
163 // and RHS is a set of N constants that start at K
164 // then the result in constant in groups of gcd(M, K)
165 if(rhs[d].num_cst % lhs_max_contiguous[d] == 0 ||
166 rhs[d].num_cst % lhs[d].num_cst == 0)
167 ax.num_cst = gcd(lhs_multiple_of[d], rhs_multiple_of[d]);
168 result.push_back(ax);
169 }
170 return add_to_cache(x, result, is_constant_);
171}
172
173
174std::vector<align::cst_info> align::populate_is_constant_binop(ir::binary_operator* x) {
175 auto x_shapes = get_shapes(x);
176 std::vector<cst_info> result;
177 ir::value* lhs_op = x->get_operand(0);
178 ir::value* rhs_op = x->get_operand(1);
179 auto lhs = populate_is_constant(lhs_op);
180 auto rhs = populate_is_constant(rhs_op);
181 auto lhs_max_contiguous = populate_max_contiguous(lhs_op);
182 auto rhs_max_contiguous = populate_max_contiguous(rhs_op);
183 auto lhs_multiple_of = populate_starting_multiple(lhs_op);
184 auto rhs_multiple_of = populate_starting_multiple(rhs_op);
185 for(size_t d = 0; d < x_shapes.size(); d++) {
186 cst_info ax;
187 if(lhs[d].num_cst==0 && rhs[d].value && x->is_int_div()){
188 unsigned num_constants = gcd(lhs_max_contiguous[d], rhs[d].value);
189 ax = {num_constants, 0};
190 }
191 else
192 ax = {std::min(lhs[d].num_cst, rhs[d].num_cst), 0};
193 result.push_back(ax);
194 }
195 return add_to_cache(x, result, is_constant_);
196}
197
198std::vector<align::cst_info> align::populate_is_constant_gep(ir::getelementptr_inst* x) {
199 auto x_shapes = get_shapes(x);
200 ir::value* lhs_op = x->get_operand(0);
201 ir::value* rhs_op = x->get_operand(1);
202 auto lhs = populate_is_constant(lhs_op);
203 auto rhs = populate_is_constant(rhs_op);
204 std::vector<cst_info> result;
205 for(size_t d = 0; d < x_shapes.size(); d++)
206 result.push_back({std::min(lhs[d].num_cst, rhs[d].num_cst), 0});
207 return add_to_cache(x, result, is_constant_);
208}
209
210std::vector<align::cst_info> align::populate_is_constant_default(ir::value *v) {
211 auto shapes = get_shapes(v);
212 std::vector<cst_info> result(shapes.size(), {1, 0});
213 return add_to_cache(v, result, is_constant_);
214}
215
216std::vector<align::cst_info> align::populate_is_constant(ir::value *v) {
217 if(is_constant_.find(v) != is_constant_.end())
218 return is_constant_.at(v);
219 if(auto *x = dynamic_cast<ir::constant_int*>(v))
220 return add_to_cache(v, {cst_info{true, std::min<unsigned>(x->get_value(), 128)}}, is_constant_);
221 if(auto *x = dynamic_cast<ir::phi_node*>(v))
222 return populate_is_constant_phi(x);
223 if(auto *x = dynamic_cast<ir::splat_inst*>(v))
224 return populate_is_constant_splat(x);
225 if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
226 return populate_is_constant_reshape(x);
227 if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
228 return populate_is_constant_dequantize(x);
229 if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
230 return populate_is_constant_broadcast(x);
231 if(auto *x = dynamic_cast<ir::binary_operator*>(v))
232 return populate_is_constant_binop(x);
233 if(auto *x = dynamic_cast<ir::cmp_inst*>(v))
234 return populate_is_constant_cmp(x);
235 if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
236 return populate_is_constant_gep(x);
237 return populate_is_constant_default(v);
238}
239
240
241/*
242 * max contiguous
243 */
244
245std::vector<unsigned> align::populate_max_contiguous_phi(ir::phi_node* x) {
246 auto shapes = get_shapes(x);
247 std::vector<unsigned> result(shapes.size(), 1);
248 for(unsigned n = 0; n < x->get_num_incoming(); n++){
249 ir::value* inc = x->get_incoming_value(n);
250 auto it = max_contiguous_.find(inc);
251 if(it != max_contiguous_.end())
252 result = it->second;
253 }
254 add_to_cache(x, result, max_contiguous_);
255 // recurse
256 for(unsigned n = 0; n < x->get_num_incoming(); n++){
257 ir::value* inc = x->get_incoming_value(n);
258 auto contiguous = populate_max_contiguous(inc);
259 for(size_t d = 0; d < result.size(); d++)
260 result[d] = std::min(result[d], contiguous[d]);
261 }
262 return add_to_cache(x, result, max_contiguous_);
263
264}
265
266std::vector<unsigned> align::populate_max_contiguous_splat(ir::splat_inst* x) {
267 auto x_shapes = get_shapes(x);
268 std::vector<unsigned> result;
269 for(size_t d = 0; d < x_shapes.size(); d++)
270 result.push_back({1});
271 return add_to_cache(x, result, max_contiguous_);
272}
273
274std::vector<unsigned> align::populate_max_contiguous_reshape(ir::reshape_inst* x) {
275 auto shapes = get_shapes(x);
276 std::vector<unsigned> result;
277 ir::value *op = x->get_operand(0);
278 auto op_shapes = op->get_type()->get_block_shapes();
279 auto op_mc = populate_max_contiguous(op);
280 unsigned current = 0;
281 bool is_skewed = false;
282 for(size_t d = 0; d < shapes.size(); d ++){
283 if(shapes[d] == 1)
284 result.push_back(1);
285 else if(!is_skewed
286 && shapes[d] == op_shapes[current])
287 result.push_back(op_mc[current++]);
288 else {
289 is_skewed = true;
290 result.push_back(1);
291 }
292 }
293 return add_to_cache(x, result, max_contiguous_);
294}
295
296std::vector<unsigned> align::populate_max_contiguous_dequantize(ir::dequantize_inst* x) {
297 auto shapes = get_shapes(x);
298 std::vector<unsigned> result;
299 ir::value *op = x->get_operand(0);
300 auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
301 auto op_last_dim = (op->get_type()->get_block_shapes()).back();
302 auto op_mc = populate_max_contiguous(op);
303 for(size_t d = 0; d < shapes.size(); d++) {
304 unsigned factor = 1;
305 if (d == shapes.size() - 1) {
306 factor = ret_last_dim / op_last_dim;
307 }
308 result.push_back(factor * op_mc[d]);
309 }
310 return add_to_cache(x, result, max_contiguous_);
311}
312
313std::vector<unsigned> align::populate_max_contiguous_broadcast(ir::broadcast_inst* x) {
314 auto shapes = get_shapes(x);
315 std::vector<unsigned> result;
316 ir::value *op = x->get_operand(0);
317 auto op_shapes = op->get_type()->get_block_shapes();
318 auto op_mc = populate_max_contiguous(op);
319 for(size_t d = 0; d < shapes.size(); d++)
320 if(op_shapes[d] == 1)
321 result.push_back(1);
322 else
323 result.push_back(op_mc[d]);
324 return add_to_cache(x, result, max_contiguous_);
325}
326
327std::vector<unsigned> align::populate_max_contiguous_binop(ir::binary_operator* x) {
328 auto shapes = get_shapes(x);
329 ir::value* lhs = x->get_operand(0);
330 ir::value* rhs = x->get_operand(1);
331 auto lhs_max_contiguous = populate_max_contiguous(lhs);
332 auto rhs_max_contiguous = populate_max_contiguous(rhs);
333 auto lhs_cst_info = populate_is_constant(lhs);
334 auto rhs_cst_info = populate_is_constant(rhs);
335 auto lhs_starting_multiple = populate_starting_multiple(lhs);
336 auto rhs_starting_multiple = populate_starting_multiple(rhs);
337 std::vector<unsigned> result;
338 for(size_t d = 0; d < shapes.size(); d++){
339 unsigned value = 1;
340 if(x->is_int_rem() && rhs_starting_multiple[d] > 0){
341 value = std::min(lhs_max_contiguous[d], rhs_starting_multiple[d]);
342 }
343 if(x->is_int_mult()){
344 unsigned lvalue = 1, rvalue = 1;
345 if(rhs_cst_info[d].value == 1)
346 lvalue = lhs_max_contiguous[d];
347 if(lhs_cst_info[d].value == 1)
348 rvalue = rhs_max_contiguous[d];
349 value = std::max(lvalue, rvalue);
350 }
351 if(x->is_int_add_sub()){
352 unsigned lvalue = 1, rvalue = 1;
353 lvalue = gcd(rhs_max_contiguous[d], lhs_cst_info[d].num_cst);
354 rvalue = gcd(lhs_max_contiguous[d], rhs_cst_info[d].num_cst);
355 value = std::max(lvalue, rvalue);
356 }
357 result.push_back(value);
358 }
359 return add_to_cache(x, result, max_contiguous_);
360}
361
362std::vector<unsigned> align::populate_max_contiguous_gep(ir::getelementptr_inst* x) {
363 auto shapes = get_shapes(x);
364 ir::value* lhs = x->get_operand(0);
365 ir::value* rhs = x->get_operand(1);
366 auto lhs_max_contiguous = populate_max_contiguous(lhs);
367 auto rhs_max_contiguous = populate_max_contiguous(rhs);
368 auto lhs_cst_info = populate_is_constant(lhs);
369 auto rhs_cst_info = populate_is_constant(rhs);
370 std::vector<unsigned> result(shapes.size(), 1);
371 for(size_t d = 0; d < shapes.size(); d++){
372 unsigned lvalue = 1, rvalue = 1;
373 if(lhs_cst_info[d].num_cst)
374 lvalue = rhs_max_contiguous[d];
375 if(rhs_cst_info[d].num_cst)
376 rvalue = lhs_max_contiguous[d];
377 result[d] = std::max(lvalue, rvalue);
378 }
379 return add_to_cache(x, result, max_contiguous_);
380}
381
382std::vector<unsigned> align::populate_max_contiguous_default(ir::value* v) {
383 if(!v->get_type()->is_block_ty())
384 return add_to_cache(v, {1}, max_contiguous_);
385 auto shapes = v->get_type()->get_block_shapes();
386 if(dynamic_cast<ir::make_range*>(v))
387 return add_to_cache(v, {shapes[0]}, max_contiguous_);
388 return add_to_cache(v, std::vector<unsigned>(shapes.size(), 1), max_contiguous_);
389}
390
391std::vector<unsigned> align::populate_max_contiguous_cast(ir::cast_inst* v){
392 auto result = populate_max_contiguous(v->get_operand(0));
393 return add_to_cache(v, result, max_contiguous_);
394}
395
396std::vector<unsigned> align::populate_max_contiguous(ir::value *v){
397 if(max_contiguous_.find(v) != max_contiguous_.end())
398 return max_contiguous_.at(v);
399 if(auto *x = dynamic_cast<ir::instruction*>(v)){
400 std::vector<unsigned> max_contiguous = x->get_metadata(ir::metadata::max_contiguous);
401 if(!max_contiguous.empty())
402 return add_to_cache(x, max_contiguous, max_contiguous_);
403 }
404 if(auto *x = dynamic_cast<ir::cast_inst*>(v))
405 return populate_max_contiguous_cast(x);
406 if(auto *x = dynamic_cast<ir::splat_inst*>(v))
407 return populate_max_contiguous_splat(x);
408 if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
409 return populate_max_contiguous_reshape(x);
410 if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
411 return populate_max_contiguous_dequantize(x);
412 if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
413 return populate_max_contiguous_broadcast(x);
414 if(auto *x = dynamic_cast<ir::binary_operator*>(v))
415 return populate_max_contiguous_binop(x);
416 if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
417 return populate_max_contiguous_gep(x);
418 if(auto *x = dynamic_cast<ir::phi_node*>(v))
419 return populate_max_contiguous_phi(x);
420 return populate_max_contiguous_default(v);
421}
422
423
424/*
425 * starting multiple
426 */
427
428std::vector<unsigned> align::populate_starting_multiple_splat(ir::splat_inst* x){
429 auto shapes = get_shapes(x);
430 auto op = populate_starting_multiple(x->get_operand(0));
431 std::vector<unsigned> result(shapes.size(), op[0]);
432 return add_to_cache(x, result, starting_multiple_);
433}
434
435std::vector<unsigned> align::populate_starting_multiple_reshape(ir::reshape_inst* x){
436 auto op = populate_starting_multiple(x->get_operand(0));
437 auto op_shapes = get_shapes(x->get_operand(0));
438 auto shapes = get_shapes(x);
439 std::vector<unsigned> result(shapes.size(), 1);
440 unsigned current = 0;
441 bool is_skewed = false;
442 for(size_t d = 0; d < shapes.size(); d ++){
443 if(shapes[d] == 1)
444 result[d] = 1;
445 else if(!is_skewed
446 && shapes[d] == op_shapes[current])
447 result[d] = op[current++];
448 else {
449 is_skewed = true;
450 result[d] = 1;
451 }
452 }
453 return add_to_cache(x, result, starting_multiple_);
454}
455
456std::vector<unsigned> align::populate_starting_multiple_dequantize(ir::dequantize_inst* x){
457 auto shapes = get_shapes(x);
458 std::vector<unsigned> result;
459 ir::value *op = x->get_operand(0);
460 auto ret_last_dim = (x->get_type()->get_block_shapes()).back();
461 auto op_last_dim = (op->get_type()->get_block_shapes()).back();
462 auto op_multiple = populate_starting_multiple(op);
463 for(size_t d = 0; d < shapes.size(); d++) {
464 unsigned factor = 1;
465 if (d == shapes.size() - 1) {
466 factor = ret_last_dim / op_last_dim;
467 }
468 result.push_back(factor * op_multiple[d]);
469 }
470 return add_to_cache(x, result, starting_multiple_);
471}
472
473std::vector<unsigned> align::populate_starting_multiple_broadcast(ir::broadcast_inst* x){
474 auto result = populate_starting_multiple(x->get_operand(0));
475 return add_to_cache(x, result, starting_multiple_);
476}
477
478std::vector<unsigned> align::populate_starting_multiple_binop(ir::binary_operator* x){
479 auto lhs = populate_starting_multiple(x->get_operand(0));
480 auto rhs = populate_starting_multiple(x->get_operand(1));
481 std::vector<unsigned> result(lhs.size(), 1);
482 for(size_t d = 0; d < lhs.size(); d++){
483 if(x->is_int_mult())
484 result[d] = lhs[d] * rhs[d];
485 if(x->is_int_add_sub())
486 result[d] = gcd(lhs[d], rhs[d]);
487 if(x->is_int_div())
488 result[d] = (lhs[d] == (1 << 31)) ? 1 << 31 : 1;
489 if(x->is_int_rem() && rhs[d] > 1){
490 result[d] = gcd(lhs[d], rhs[d]);
491 }
492 if(x->is_shl())
493 result[d] = lhs[d] << rhs[d];
494 if(x->is_shr())
495 result[d] = std::max<unsigned>(lhs[d] >> rhs[d], 1);
496 }
497 return add_to_cache(x, result, starting_multiple_);
498}
499
500std::vector<unsigned> align::populate_starting_multiple_gep(ir::getelementptr_inst* x){
501 auto lhs = populate_starting_multiple(x->get_operand(0));
502 auto rhs = populate_starting_multiple(x->get_operand(1));
503 std::vector<unsigned> result(lhs.size(), 1);
504 for(size_t d = 0; d < lhs.size(); d++){
505 result[d] = gcd(lhs[d], rhs[d]);
506// std::cout << "starting multiple: " << x->get_name() << " " << d << " " << result[d] << std::endl;
507 }
508 return add_to_cache(x, result, starting_multiple_);
509}
510
511std::vector<unsigned> align::populate_starting_multiple_phi(ir::phi_node* x){
512 auto shape = get_shapes(x);
513 std::vector<unsigned> result(shape.size(), 1);
514 for(unsigned n = 0; n < x->get_num_incoming(); n++){
515 ir::value* inc = x->get_incoming_value(n);
516 if(starting_multiple_.find(inc) != starting_multiple_.end())
517 result = starting_multiple_.at(inc);
518 }
519 add_to_cache(x, result, starting_multiple_);
520 // recurse
521 for(unsigned n = 0; n < x->get_num_incoming(); n++){
522 ir::value* inc = x->get_incoming_value(n);
523 auto sm = populate_starting_multiple(inc);
524 for(size_t d = 0; d < result.size(); d++)
525 result[d] = gcd(result[d], sm[d]);
526 }
527 return add_to_cache(x, result, starting_multiple_);
528}
529
530
531std::vector<unsigned> align::populate_starting_multiple_cast(ir::cast_inst* x){
532 auto result = populate_starting_multiple(x->get_operand(0));
533 return add_to_cache(x, result, starting_multiple_);
534}
535
536std::vector<unsigned> align::populate_starting_multiple_default(ir::value* v) {
537 ir::type* ty = v->get_type();
538 if(ty->is_block_ty()) {
539 return add_to_cache(v, ty->get_block_shapes(), starting_multiple_);
540 }
541 if(auto *x = dynamic_cast<ir::argument*>(v)){
542 std::set<ir::attribute> attributes = x->get_parent()->get_attributes(x);
543 for(auto attr: attributes){
544 if(attr.get_kind() == ir::multiple_of){
545 return add_to_cache(x, {attr.get_value()}, starting_multiple_);
546 }
547 if(attr.get_kind() == ir::aligned){
548 ir::type* ty = x->get_type()->get_pointer_element_ty();
549 int nbits = ty->get_primitive_size_in_bits();
550 int nbytes = std::max<int>(nbits / 8, 1);
551 return add_to_cache(x, {attr.get_value() / nbytes}, starting_multiple_);
552 }
553 }
554 }
555 return add_to_cache(v, {1}, starting_multiple_);
556}
557
558unsigned get_max_multiple(int val){
559 if(val == 0) return 1 << 31;
560 if(val % 128 == 0) return 128;
561 if(val % 64 == 0) return 64;
562 if(val % 32 == 0) return 32;
563 if(val % 16 == 0) return 16;
564 if(val % 8 == 0) return 8;
565 if(val % 4 == 0) return 4;
566 if(val % 2 == 0) return 2;
567 return 1;
568}
569
570std::vector<unsigned> align::populate_starting_multiple(ir::value *v){
571 if(starting_multiple_.find(v) != starting_multiple_.end())
572 return starting_multiple_.at(v);
573 if(auto *x = dynamic_cast<ir::instruction*>(v)){
574 std::vector<unsigned> multiple_of = x->get_metadata(ir::metadata::multiple_of);
575 if(!multiple_of.empty())
576 return add_to_cache(x, multiple_of, starting_multiple_);
577 }
578 if(auto *x = dynamic_cast<ir::cast_inst*>(v))
579 return populate_starting_multiple_cast(x);
580 if(auto *x = dynamic_cast<ir::binary_operator*>(v))
581 return populate_starting_multiple_binop(x);
582 if(auto *x = dynamic_cast<ir::constant_int*>(v))
583 return add_to_cache(x, {get_max_multiple(x->get_value())}, starting_multiple_);
584 if(auto *x = dynamic_cast<ir::make_range*>(v))
585 return add_to_cache(x, {get_max_multiple(x->get_first()->get_value())}, starting_multiple_);
586 if(auto *x = dynamic_cast<ir::getelementptr_inst*>(v))
587 return populate_starting_multiple_gep(x);
588 if(auto *x = dynamic_cast<ir::splat_inst*>(v))
589 return populate_starting_multiple_splat(x);
590 if(auto *x = dynamic_cast<ir::reshape_inst*>(v))
591 return populate_starting_multiple_reshape(x);
592 if(auto *x = dynamic_cast<ir::dequantize_inst*>(v))
593 return populate_starting_multiple_dequantize(x);
594 if(auto *x = dynamic_cast<ir::broadcast_inst*>(v))
595 return populate_starting_multiple_broadcast(x);
596 if(auto *x = dynamic_cast<ir::phi_node*>(v))
597 return populate_starting_multiple_phi(x);
598 return populate_starting_multiple_default(v);
599}
600
601
602unsigned align::get(ir::value *v, unsigned ax) const {
603 unsigned starting_multiple = starting_multiple_.at(v)[ax];
604 unsigned max_contiguous = max_contiguous_.at(v)[ax];
605 return std::min(starting_multiple, max_contiguous);
606}
607
608std::vector<unsigned> align::contiguous(ir::value* v) const {
609 return max_contiguous_.at(v);
610}
611
612std::vector<align::cst_info> align::get_cst_info(ir::value* v) const {
613 return is_constant_.at(v);
614}
615
616
617void align::populate(ir::value *v) {
618 populate_is_constant(v);
619 populate_starting_multiple(v);
620 populate_max_contiguous(v);
621}
622
623void align::run(ir::module &mod) {
624 ir::for_each_value(mod, [this](ir::value* v) { populate(v); } );
625// ir::for_each_value(mod, [this](ir::value* v) {
626// if(dynamic_cast<ir::cast_inst*>(v) || dynamic_cast<ir::getelementptr_inst*>(v))
627// std::cout << "ALIGN: " << v->get_name() << " " << max_contiguous_.at(v)[0] << " " << max_contiguous_.at(v)[1] << std::endl;
628// });
629}
630
631
632}
633}
634}
635