1 | #pragma once |
2 | #include <c10/macros/Export.h> |
3 | |
4 | #include <kernel_ir.h> |
5 | #include <kernel_ir_dispatch.h> |
6 | #include <lower_thread_predicate.h> |
7 | #include <lower_utils.h> |
8 | #include <root_domain_map.h> |
9 | |
10 | #include <bitset> |
11 | #include <unordered_map> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | namespace fuser { |
16 | namespace cuda { |
17 | |
18 | //! Unroll pass |
19 | //! |
20 | //! A bit deceptively: UnrollPass adds all predicates, so it needs to be run |
21 | //! even if we don't unroll any loops. |
22 | //! |
23 | //! Unrolling pass will get IR that looks something like: |
24 | //! for( i : I0o{ceil(I0/4)} ) { |
25 | //! for( j : I1o{ceil(I1/128)} ) { |
26 | //! for( k : I0i{4} ) |
27 | //! for( l : I1i{128} ) |
28 | //! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ... |
29 | //! |
30 | //! And it will return the following: |
31 | //! for( i : I0o{ceil(I0/4)} ) { |
32 | //! for( j : I1o{ceil(I1/128)} ) { |
33 | //! |
34 | //! if( i * 4 + 3 < I && j * 128 + 127 < J ){ |
35 | //! for( k : I0i{4} ) |
36 | //! for( l : I1i{128} ) |
37 | //! T0[ ( i * 4 + k ) * J + j * 128 + l ] = ... |
38 | //! } else { |
39 | //! for( k : I0i{4} ) |
40 | //! for( l : I1i{128} ) |
41 | //! if( i * 4 + k < I && j * 128 + l < J) |
42 | //! T0[ ( i * 4 + k ) * J + j * 128 + l ] = ... |
43 | //! } |
44 | //! |
45 | //! } |
46 | //! } |
47 | //! |
48 | //! As can be seen it generates two sets of loops for I0i{4} and I1i{128}. The |
49 | //! first set is protected by a predicate that makes sure there's a full |
50 | //! internal tile we can iterate over. This way we remove the predicate nested |
51 | //! in the inner most loop. There's of course a second set of loops, which has a |
52 | //! predicate still in the inner most loop, making sure that we cover edges and |
53 | //! corners. |
54 | //! |
55 | class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator { |
56 | public: |
57 | // Take the incoming exprs and run loop unrolling, returning the new IR |
58 | static std::vector<Expr*> runPass( |
59 | Fusion* fusion, |
60 | const std::vector<Expr*>& exprs); |
61 | |
62 | static bool canOmitElseClause(kir::ForLoop* fl); |
63 | |
64 | private: |
65 | void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope); |
66 | |
67 | // Generate the for Expr replacement map |
68 | UnrollPass(const std::vector<Expr*>& exprs); |
69 | |
70 | const std::unordered_map<Expr*, Expr*>& replacementMap() const { |
71 | return expr_replacement_map_; |
72 | } |
73 | |
74 | using OptOutDispatch::handle; |
75 | |
76 | void handle(kir::ForLoop* fl) final; |
77 | |
78 | void handle(Expr* expr) final; |
79 | |
80 | private: |
81 | // We will track which loops in the incoming IR will be replaced and by what |
82 | std::unordered_map<Expr*, Expr*> expr_replacement_map_; |
83 | |
84 | // keep track if we're within an unrolled loop |
85 | bool look_for_unroll_ = true; |
86 | |
87 | // Indicates if the currently visited expression is inside a |
88 | // unswitched path |
89 | bool unswitched_loop_ = false; |
90 | |
91 | // As we generate inline predicates check if we actually generated a |
92 | // non-trivial one. |
93 | bool non_trivial_pred_found_ = false; |
94 | }; |
95 | |
96 | } // namespace cuda |
97 | } // namespace fuser |
98 | } // namespace jit |
99 | } // namespace torch |
100 | |