1 | |
2 | #pragma once |
3 | |
4 | #include <c10/macros/Export.h> |
5 | |
6 | #include <compute_at_map.h> |
7 | #include <instrumentation.h> |
8 | #include <ir_all_nodes.h> |
9 | #include <kernel_ir.h> |
10 | #include <lower_thread_predicate.h> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | namespace fuser { |
15 | namespace cuda { |
16 | |
17 | //! Loop nest generator pass will get IR that looks something like: |
18 | //! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ... |
19 | |
20 | // and will generate the loop nest structure for these exprs like: |
21 | //! |
22 | //! for( i : I0o{ceil(I0/4)} ) { |
23 | //! for( j : I1o{ceil(I1/128)} ) { |
24 | //! for( k : I0i{4} ) |
25 | //! for( l : I1i{128} ) |
26 | //! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ... |
27 | //! |
28 | //! It does not generate predicates, but it will generate allocations, and loop |
29 | //! nests to initialize reduction buffers. |
30 | class TORCH_CUDA_CU_API LoopNestGenerator { |
31 | public: |
32 | static std::vector<Expr*> loweredExprs(const std::vector<Expr*>& exprs); |
33 | |
34 | private: |
35 | LoopNestGenerator(const std::vector<Expr*>& exprs); |
36 | |
37 | // Open a new inner most for loop, track which TV it was constructed from |
38 | // according to the computeAt chain. |
39 | void openFor(IterDomain*); |
40 | |
41 | // Close the inner most for loop |
42 | void closeFor(); |
43 | |
44 | // Appends an expression to the current scope |
45 | void pushFront(Expr* expr); |
46 | |
47 | void handle(Expr* expr); |
48 | |
49 | // Run the pass and accumulate output in lowered_exprs_ |
50 | void generate(const std::vector<Expr*>& exprs); |
51 | |
52 | private: |
53 | // Lowered exprs to return |
54 | std::vector<Expr*> lowered_exprs_; |
55 | |
56 | // Keep all for loops conveniently to make unrolling easier, basically just a |
57 | // stack of the active for_loops |
58 | std::vector<kir::ForLoop*> for_loops_; |
59 | |
60 | // Loop structure of each expression |
61 | std::unordered_map<TensorView*, std::vector<IterDomain*>> loop_structures_; |
62 | }; |
63 | |
64 | } // namespace cuda |
65 | } // namespace fuser |
66 | } // namespace jit |
67 | } // namespace torch |
68 | |