1
2#pragma once
3
4#include <c10/macros/Export.h>
5
6#include <compute_at_map.h>
7#include <instrumentation.h>
8#include <ir_all_nodes.h>
9#include <kernel_ir.h>
10#include <lower_thread_predicate.h>
11
12namespace torch {
13namespace jit {
14namespace fuser {
15namespace cuda {
16
17//! Loop nest generator pass will get IR that looks something like:
18//! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
19
20// and will generate the loop nest structure for these exprs like:
21//!
22//! for( i : I0o{ceil(I0/4)} ) {
23//! for( j : I1o{ceil(I1/128)} ) {
24//! for( k : I0i{4} )
25//! for( l : I1i{128} )
26//! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
27//!
28//! It does not generate predicates, but it will generate allocations, and loop
29//! nests to initialize reduction buffers.
30class TORCH_CUDA_CU_API LoopNestGenerator {
31 public:
32 static std::vector<Expr*> loweredExprs(const std::vector<Expr*>& exprs);
33
34 private:
35 LoopNestGenerator(const std::vector<Expr*>& exprs);
36
37 // Open a new inner most for loop, track which TV it was constructed from
38 // according to the computeAt chain.
39 void openFor(IterDomain*);
40
41 // Close the inner most for loop
42 void closeFor();
43
44 // Appends an expression to the current scope
45 void pushFront(Expr* expr);
46
47 void handle(Expr* expr);
48
49 // Run the pass and accumulate output in lowered_exprs_
50 void generate(const std::vector<Expr*>& exprs);
51
52 private:
53 // Lowered exprs to return
54 std::vector<Expr*> lowered_exprs_;
55
56 // Keep all for loops conveniently to make unrolling easier, basically just a
57 // stack of the active for_loops
58 std::vector<kir::ForLoop*> for_loops_;
59
60 // Loop structure of each expression
61 std::unordered_map<TensorView*, std::vector<IterDomain*>> loop_structures_;
62};
63
64} // namespace cuda
65} // namespace fuser
66} // namespace jit
67} // namespace torch
68