1 | #pragma once |
2 | #include <c10/macros/Export.h> |
3 | |
4 | #include <ir_all_nodes.h> |
5 | |
6 | #include <vector> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | namespace fuser { |
11 | namespace cuda { |
12 | |
13 | //! Transform for-loop structure to handle misaligned addresses |
14 | //! |
15 | //! Sections of misaligned addresses are handled sequentially |
16 | //! while aligned addresses use vectorized memory accesses. |
17 | //! |
18 | //! --------------------------------------------------------------------------- |
19 | //! Before Misaligned Vectorization: |
20 | //! |
21 | //! Inputs: T0 |
22 | //! Outputs: T3 |
23 | //! |
24 | //! for(...) { |
25 | //! T1[vector_size]; |
26 | //! for( i : vector_size ) { |
27 | //! T1[i] = T0[...] |
28 | //! } |
29 | //! |
30 | //! T2[vector_size]; |
31 | //! for( i : vector_size ) { |
32 | //! T2[i] = unaryOp(T1[i]) |
33 | //! } |
34 | //! |
35 | //! for( i : vector_size ) { |
36 | //! T3[...] = T2[i] |
37 | //! } |
38 | //! } |
39 | //! |
40 | //! --------------------------------------------------------------------------- |
41 | //! After Misaligned Vectorization: |
42 | //! |
43 | //! Inputs: T0 |
44 | //! Outputs: T3 |
45 | //! |
46 | //! for(...) { |
47 | //! T1[vector_size]; |
48 | //! T2[vector_size]; |
49 | //! |
50 | //! if (inline_predicate_except_last_root_domain) { |
51 | //! index_except_last_root_domain = ... |
52 | //! address = (int64_t) &T1[index_except_last_root_domain] |
53 | //! |
54 | //! offset_size = (address % vector_size_bytes) / data_type_size_bytes |
55 | //! shift_init = vector_size - offset_size |
56 | //! shift = (shift_init == vector_size) ? 0 : shift_init |
57 | //! |
58 | //! // size of the last root domain |
59 | //! extent = ... |
60 | //! remainder = (extent - shift) % vector_size |
61 | //! |
62 | //! last_root_domain_index = ... |
63 | //! |
64 | //! // Vectorize Section |
65 | //! if ( (last_root_domain_index + shift) < (extent - remainder) ) { |
66 | //! T1[0] = vectorize_load( T0[index + shift] ); |
67 | //! |
68 | //! for( i : vector_size ) { |
69 | //! T2[i] = unaryOp(T1[i]) |
70 | //! } |
71 | //! |
72 | //! T3[index + shift] = vectorize_store( T2[0] ); |
73 | //! } |
74 | //! |
75 | //! // Initial Section |
76 | //! if ( last_root_domain_index == 0 ) { |
77 | //! for( i : shift ) { |
78 | //! T1[i] = T0[...] |
79 | //! } |
80 | //! |
81 | //! for( i : shift ) { |
82 | //! T2[i] = unaryOp(T1[i]) |
83 | //! } |
84 | //! |
85 | //! for( i : shift ) { |
86 | //! T3[...] = T2[i] |
87 | //! } |
88 | //! } |
89 | //! |
90 | //! // Remainder Section |
91 | //! if ( (last_root_domain_index + shift) >= (extent - remainder) && |
92 | //! (last_root_domain_index + shift) < extent) { |
93 | //! |
94 | //! for( i : remainder ) { |
95 | //! T1[i] = T0[index + shift] |
96 | //! } |
97 | //! |
98 | //! for( i : remainder ) { |
99 | //! T2[i] = unaryOp(T1[i]) |
100 | //! } |
101 | //! |
102 | //! for( i : remainder ) { |
103 | //! T3[index + shift] = T2[i] |
104 | //! } |
105 | //! } |
106 | //! } |
107 | //! } |
108 | //! |
109 | std::vector<Expr*> processMisalignedVectorization( |
110 | const std::vector<Expr*>& exprs); |
111 | |
112 | bool containsAnyDirectChildMisalignedVectorize(const kir::ForLoop* fl); |
113 | |
114 | } // namespace cuda |
115 | } // namespace fuser |
116 | } // namespace jit |
117 | } // namespace torch |
118 | |