1 | #include <lower_predicate.h> |
2 | |
3 | #include <arith.h> |
4 | #include <index_compute.h> |
5 | #include <instrumentation.h> |
6 | #include <ir_iostream.h> |
7 | #include <ir_utils.h> |
8 | #include <kernel_ir.h> |
9 | #include <kernel_ir_dispatch.h> |
10 | #include <lower2device.h> |
11 | #include <lower_utils.h> |
12 | #include <predicate_compute.h> |
13 | #include <transform_iter.h> |
14 | #include <transform_replay.h> |
15 | |
16 | namespace torch { |
17 | namespace jit { |
18 | namespace fuser { |
19 | namespace cuda { |
20 | |
21 | namespace { |
22 | |
23 | class ConditionalFromPredicateModifier : public kir::ExprMutator { |
24 | public: |
25 | ConditionalFromPredicateModifier() = delete; |
26 | |
27 | static std::vector<Expr*> fillPredicates(const std::vector<Expr*>& exprs) { |
28 | ConditionalFromPredicateModifier cfpm(exprs); |
29 | return cfpm.exprs_; |
30 | } |
31 | |
32 | private: |
33 | ConditionalFromPredicateModifier(const std::vector<Expr*>& exprs) { |
34 | FUSER_PERF_SCOPE( |
35 | "ConditionalFromPredicateModifier::ConditionalFromPredicateModifier" ); |
36 | traverseAndInsert(exprs); |
37 | } |
38 | |
39 | using kir::ExprMutator::handle; |
40 | |
41 | void handle(Expr* expr) final { |
42 | if (expr != nullptr && expr->predicate() != nullptr) { |
43 | // Replace expr predicate with bool conditional |
44 | auto conditional = generateConditional(expr->predicate()); |
45 | if (expr->predicate()->predicate_type() == PredicateType::Vectorize) { |
46 | if (expr->isA<kir::IfThenElse>()) { |
47 | // TODO: This logic doesn't seem to fit well here, for unswitch the |
48 | // logic is in the unroll loop to set the thread predicate to the |
49 | // expr. I didn't have a quick way to do that so placing this here for |
50 | // now. |
51 | auto ite = expr->as<kir::IfThenElse>(); |
52 | |
53 | TORCH_INTERNAL_ASSERT( |
54 | ite->thenBody().size() == 1, |
55 | "Expecting predicated body to only have one vectorized expression." ); |
56 | auto vec_expr = ite->thenBody()[0]; |
57 | TORCH_INTERNAL_ASSERT( |
58 | vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(), |
59 | "Vectorize predicate exprs only supported on set operations." ); |
60 | TORCH_INTERNAL_ASSERT( |
61 | ir_utils::isTvOp(vec_expr), |
62 | "Vectorize predicate exprs only supported on tensor view operations." ); |
63 | if (!vec_expr->inputs()[0]->isConstScalar()) { |
64 | conditional = SimplifyingIrBuilder::andExpr( |
65 | conditional, |
66 | GpuLower::current()->threadPredMap().getPredicate( |
67 | ir_utils::getTvOutput(vec_expr))) |
68 | ->as<Bool>(); |
69 | } |
70 | } else { |
71 | TORCH_INTERNAL_ASSERT(lower_utils::supportInlinePredicate(expr)); |
72 | auto thread_pred = GpuLower::current()->threadPredMap().getPredicate( |
73 | ir_utils::getTvOutput(expr)); |
74 | TORCH_INTERNAL_ASSERT( |
75 | thread_pred->isConst() && thread_pred->value().value()); |
76 | conditional = SimplifyingIrBuilder::andExpr( |
77 | conditional, |
78 | GpuLower::current()->threadPredMap().getPredicate( |
79 | ir_utils::getTvOutput(expr))) |
80 | ->as<Bool>(); |
81 | } |
82 | } |
83 | TORCH_INTERNAL_ASSERT(conditional != nullptr); |
84 | expr->predicate()->setValue(conditional); |
85 | TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr); |
86 | setWritePredicate(expr); |
87 | } |
88 | |
89 | // Note: [Predicate Inversion for CpAsync] |
90 | // Today for vectorized support the pattern is: |
91 | // Initialize buffer -> predicated load |
92 | // For memcpy async: |
93 | // If we initialized and then loaded (without sync) it would be undefined |
94 | // behavior. |
95 | // Initialize only the "virtual out of boundary" accesses. |
96 | // Memory allocated, but outside the virtual tensor space. |
97 | // Virtual tensor space today is effectively what would be allocated in |
98 | // global memory. Then only copy the "within bound" accesses. |
99 | // This is a WAR today based on how our system is set up. |
100 | // We would want to have a separate concept of SMEM space from Virtual or |
101 | // GMEM space, so that we know we're only working with the allocated |
102 | // SMEM. |
103 | // If we hit outside the allocated SMEM bad things happen. |
104 | // Today asserting in predicate removal making sure that the virtual and |
105 | // SMEM boundaries line up based on the IterDomains. |
106 | // |
107 | // TODO: in a follow up we need to extend the predicate |
108 | // infrastructure to generate predicate for both gmem |
109 | // and smem, and the predicate removal will need to |
110 | // be extended as well for the perf critical regions. |
111 | if (isPredicatedInitForCpAsync(expr)) { |
112 | invertPredicateForGmemToSharedMemInitialize(expr); |
113 | } |
114 | |
115 | kir::ExprMutator::handle(expr); |
116 | } |
117 | |
118 | // Invert the predicate of given expr. |
119 | void invertPredicateForGmemToSharedMemInitialize(Expr* expr) { |
120 | auto pred = expr->predicate()->value(); |
121 | auto invert = SimplifyingIrBuilder::notExpr(pred); |
122 | expr->predicate()->setValue(invert->as<Bool>()); |
123 | } |
124 | |
125 | // Detect if this expr is an initialization for vectorized |
126 | // cp asyc with predicates. |
127 | bool isPredicatedInitForCpAsync(Expr* expr) { |
128 | // Match the pattern: |
129 | // If(pred) |
130 | // TV = 0; |
131 | // where TV is the output of cp async. |
132 | auto maybe_init = ir_utils::getMaybePredicatedSingleton(expr); |
133 | return maybe_init.has_value() && |
134 | ir_utils::isCpAsyncInit(maybe_init.value()); |
135 | } |
136 | |
137 | void setWritePredicate(Expr* expr) { |
138 | if (expr->writePredicate() != nullptr) { |
139 | auto write_cond = generateConditional(expr->writePredicate()); |
140 | if (write_cond) { |
141 | expr->writePredicate()->setValue(write_cond); |
142 | } else { |
143 | // If generateConditional returns null, it means no specific |
144 | // predicate needs to be used. |
145 | registerReplace(expr, expr->withWritePredicate(nullptr)); |
146 | } |
147 | } |
148 | } |
149 | |
150 | void handle(kir::IfThenElse* ite) final { |
151 | TORCH_INTERNAL_ASSERT(ite->predicate() != nullptr); |
152 | |
153 | // If ite already has Bool conditional, handle internal expressions |
154 | // Otherwise, generate conditional and update predicate |
155 | if (!ite->predicate()->hasValue()) { |
156 | auto conditional = generateConditional(ite->predicate()); |
157 | TORCH_INTERNAL_ASSERT(conditional != nullptr); |
158 | TORCH_INTERNAL_ASSERT(conditional->isA<Bool>()); |
159 | |
160 | // Update bool conditional in-place |
161 | ite->predicate()->setValue(conditional); |
162 | TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr); |
163 | } |
164 | kir::ExprMutator::handle(ite); |
165 | } |
166 | |
167 | // Generate conditional according to PredicateType |
168 | Bool* generateConditional(kir::Predicate* pred) { |
169 | switch (pred->predicate_type()) { |
170 | case PredicateType::Inline: |
171 | case PredicateType::ReductionWrite: |
172 | case PredicateType::Misaligned: |
173 | case PredicateType::Shift: |
174 | case PredicateType::Padding: { |
175 | return PredicateCompute::getInlinePredicate( |
176 | pred->expr(), |
177 | for_loops_, |
178 | pred->thread_pred(), |
179 | pred->predicate_type()); |
180 | } |
181 | case PredicateType::Vectorize: { |
182 | std::vector<kir::ForLoop*> outer_loops; |
183 | kir::ForLoop* vectorized_loop = nullptr; |
184 | for (auto loop : for_loops_) { |
185 | if (loop->iter_domain()->getParallelType() == |
186 | ParallelType::Vectorize) { |
187 | vectorized_loop = loop; |
188 | break; |
189 | } else { |
190 | outer_loops.emplace_back(loop); |
191 | } |
192 | } |
193 | TORCH_INTERNAL_ASSERT( |
194 | vectorized_loop != nullptr, "Should be unreachable." ); |
195 | return UnswitchPredicate::get(outer_loops, vectorized_loop); |
196 | } |
197 | case PredicateType::Unswitch: { |
198 | return UnswitchPredicate::get(for_loops_, pred->unrolled_loop()); |
199 | } |
200 | case PredicateType::Manual: { |
201 | return pred->value(); |
202 | } |
203 | default: |
204 | break; |
205 | } |
206 | return nullptr; |
207 | } |
208 | }; |
209 | |
210 | } // namespace |
211 | |
212 | std::vector<Expr*> generateConditionalFromPredicate( |
213 | const std::vector<Expr*>& exprs) { |
214 | return ConditionalFromPredicateModifier::fillPredicates(exprs); |
215 | } |
216 | |
217 | } // namespace cuda |
218 | } // namespace fuser |
219 | } // namespace jit |
220 | } // namespace torch |
221 | |