1#pragma once
2#include <c10/macros/Export.h>
3
4#include <kernel_ir.h>
5#include <kernel_ir_dispatch.h>
6#include <lower_thread_predicate.h>
7#include <lower_utils.h>
8#include <root_domain_map.h>
9
10#include <bitset>
11#include <unordered_map>
12
13namespace torch {
14namespace jit {
15namespace fuser {
16namespace cuda {
17
18//! Unroll pass
19//!
20//! A bit deceptively: UnrollPass adds all predicates, so it needs to be run
21//! even if we don't unroll any loops.
22//!
23//! Unrolling pass will get IR that looks something like:
24//! for( i : I0o{ceil(I0/4)} ) {
25//! for( j : I1o{ceil(I1/128)} ) {
26//! for( k : I0i{4} )
27//! for( l : I1i{128} )
28//! T0[I0o{ceil(I0/4)}, I1o{ceil(I1/128)}, I0iU{4}, I1i{128}] = ...
29//!
30//! And it will return the following:
31//! for( i : I0o{ceil(I0/4)} ) {
32//! for( j : I1o{ceil(I1/128)} ) {
33//!
34//! if( i * 4 + 3 < I && j * 128 + 127 < J ){
35//! for( k : I0i{4} )
36//! for( l : I1i{128} )
37//! T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
38//! } else {
39//! for( k : I0i{4} )
40//! for( l : I1i{128} )
41//! if( i * 4 + k < I && j * 128 + l < J)
42//! T0[ ( i * 4 + k ) * J + j * 128 + l ] = ...
43//! }
44//!
45//! }
46//! }
47//!
48//! As can be seen it generates two sets of loops for I0i{4} and I1i{128}. The
49//! first set is protected by a predicate that makes sure there's a full
50//! internal tile we can iterate over. This way we remove the predicate nested
51//! in the inner most loop. There's of course a second set of loops, which has a
52//! predicate still in the inner most loop, making sure that we cover edges and
53//! corners.
54//!
55class TORCH_CUDA_CU_API UnrollPass : kir::ExprMutator {
56 public:
57 // Take the incoming exprs and run loop unrolling, returning the new IR
58 static std::vector<Expr*> runPass(
59 Fusion* fusion,
60 const std::vector<Expr*>& exprs);
61
62 static bool canOmitElseClause(kir::ForLoop* fl);
63
64 private:
65 void registerReplace(Expr* reference, Expr* new_expr, kir::Scope* scope);
66
67 // Generate the for Expr replacement map
68 UnrollPass(const std::vector<Expr*>& exprs);
69
70 const std::unordered_map<Expr*, Expr*>& replacementMap() const {
71 return expr_replacement_map_;
72 }
73
74 using OptOutDispatch::handle;
75
76 void handle(kir::ForLoop* fl) final;
77
78 void handle(Expr* expr) final;
79
80 private:
81 // We will track which loops in the incoming IR will be replaced and by what
82 std::unordered_map<Expr*, Expr*> expr_replacement_map_;
83
84 // keep track if we're within an unrolled loop
85 bool look_for_unroll_ = true;
86
87 // Indicates if the currently visited expression is inside a
88 // unswitched path
89 bool unswitched_loop_ = false;
90
91 // As we generate inline predicates check if we actually generated a
92 // non-trivial one.
93 bool non_trivial_pred_found_ = false;
94};
95
96} // namespace cuda
97} // namespace fuser
98} // namespace jit
99} // namespace torch
100