1#include <lower_predicate.h>
2
3#include <arith.h>
4#include <index_compute.h>
5#include <instrumentation.h>
6#include <ir_iostream.h>
7#include <ir_utils.h>
8#include <kernel_ir.h>
9#include <kernel_ir_dispatch.h>
10#include <lower2device.h>
11#include <lower_utils.h>
12#include <predicate_compute.h>
13#include <transform_iter.h>
14#include <transform_replay.h>
15
16namespace torch {
17namespace jit {
18namespace fuser {
19namespace cuda {
20
21namespace {
22
23class ConditionalFromPredicateModifier : public kir::ExprMutator {
24 public:
25 ConditionalFromPredicateModifier() = delete;
26
27 static std::vector<Expr*> fillPredicates(const std::vector<Expr*>& exprs) {
28 ConditionalFromPredicateModifier cfpm(exprs);
29 return cfpm.exprs_;
30 }
31
32 private:
33 ConditionalFromPredicateModifier(const std::vector<Expr*>& exprs) {
34 FUSER_PERF_SCOPE(
35 "ConditionalFromPredicateModifier::ConditionalFromPredicateModifier");
36 traverseAndInsert(exprs);
37 }
38
39 using kir::ExprMutator::handle;
40
41 void handle(Expr* expr) final {
42 if (expr != nullptr && expr->predicate() != nullptr) {
43 // Replace expr predicate with bool conditional
44 auto conditional = generateConditional(expr->predicate());
45 if (expr->predicate()->predicate_type() == PredicateType::Vectorize) {
46 if (expr->isA<kir::IfThenElse>()) {
47 // TODO: This logic doesn't seem to fit well here, for unswitch the
48 // logic is in the unroll loop to set the thread predicate to the
49 // expr. I didn't have a quick way to do that so placing this here for
50 // now.
51 auto ite = expr->as<kir::IfThenElse>();
52
53 TORCH_INTERNAL_ASSERT(
54 ite->thenBody().size() == 1,
55 "Expecting predicated body to only have one vectorized expression.");
56 auto vec_expr = ite->thenBody()[0];
57 TORCH_INTERNAL_ASSERT(
58 vec_expr->isA<UnaryOp>() || vec_expr->isA<LoadStoreOp>(),
59 "Vectorize predicate exprs only supported on set operations.");
60 TORCH_INTERNAL_ASSERT(
61 ir_utils::isTvOp(vec_expr),
62 "Vectorize predicate exprs only supported on tensor view operations.");
63 if (!vec_expr->inputs()[0]->isConstScalar()) {
64 conditional = SimplifyingIrBuilder::andExpr(
65 conditional,
66 GpuLower::current()->threadPredMap().getPredicate(
67 ir_utils::getTvOutput(vec_expr)))
68 ->as<Bool>();
69 }
70 } else {
71 TORCH_INTERNAL_ASSERT(lower_utils::supportInlinePredicate(expr));
72 auto thread_pred = GpuLower::current()->threadPredMap().getPredicate(
73 ir_utils::getTvOutput(expr));
74 TORCH_INTERNAL_ASSERT(
75 thread_pred->isConst() && thread_pred->value().value());
76 conditional = SimplifyingIrBuilder::andExpr(
77 conditional,
78 GpuLower::current()->threadPredMap().getPredicate(
79 ir_utils::getTvOutput(expr)))
80 ->as<Bool>();
81 }
82 }
83 TORCH_INTERNAL_ASSERT(conditional != nullptr);
84 expr->predicate()->setValue(conditional);
85 TORCH_INTERNAL_ASSERT(expr->predicate()->value() != nullptr);
86 setWritePredicate(expr);
87 }
88
89 // Note: [Predicate Inversion for CpAsync]
90 // Today for vectorized support the pattern is:
91 // Initialize buffer -> predicated load
92 // For memcpy async:
93 // If we initialized and then loaded (without sync) it would be undefined
94 // behavior.
95 // Initialize only the "virtual out of boundary" accesses.
96 // Memory allocated, but outside the virtual tensor space.
97 // Virtual tensor space today is effectively what would be allocated in
98 // global memory. Then only copy the "within bound" accesses.
99 // This is a WAR today based on how our system is set up.
100 // We would want to have a separate concept of SMEM space from Virtual or
101 // GMEM space, so that we know we're only working with the allocated
102 // SMEM.
103 // If we hit outside the allocated SMEM bad things happen.
104 // Today asserting in predicate removal making sure that the virtual and
105 // SMEM boundaries line up based on the IterDomains.
106 //
107 // TODO: in a follow up we need to extend the predicate
108 // infrastructure to generate predicate for both gmem
109 // and smem, and the predicate removal will need to
110 // be extended as well for the perf critical regions.
111 if (isPredicatedInitForCpAsync(expr)) {
112 invertPredicateForGmemToSharedMemInitialize(expr);
113 }
114
115 kir::ExprMutator::handle(expr);
116 }
117
118 // Invert the predicate of given expr.
119 void invertPredicateForGmemToSharedMemInitialize(Expr* expr) {
120 auto pred = expr->predicate()->value();
121 auto invert = SimplifyingIrBuilder::notExpr(pred);
122 expr->predicate()->setValue(invert->as<Bool>());
123 }
124
125 // Detect if this expr is an initialization for vectorized
126 // cp asyc with predicates.
127 bool isPredicatedInitForCpAsync(Expr* expr) {
128 // Match the pattern:
129 // If(pred)
130 // TV = 0;
131 // where TV is the output of cp async.
132 auto maybe_init = ir_utils::getMaybePredicatedSingleton(expr);
133 return maybe_init.has_value() &&
134 ir_utils::isCpAsyncInit(maybe_init.value());
135 }
136
137 void setWritePredicate(Expr* expr) {
138 if (expr->writePredicate() != nullptr) {
139 auto write_cond = generateConditional(expr->writePredicate());
140 if (write_cond) {
141 expr->writePredicate()->setValue(write_cond);
142 } else {
143 // If generateConditional returns null, it means no specific
144 // predicate needs to be used.
145 registerReplace(expr, expr->withWritePredicate(nullptr));
146 }
147 }
148 }
149
150 void handle(kir::IfThenElse* ite) final {
151 TORCH_INTERNAL_ASSERT(ite->predicate() != nullptr);
152
153 // If ite already has Bool conditional, handle internal expressions
154 // Otherwise, generate conditional and update predicate
155 if (!ite->predicate()->hasValue()) {
156 auto conditional = generateConditional(ite->predicate());
157 TORCH_INTERNAL_ASSERT(conditional != nullptr);
158 TORCH_INTERNAL_ASSERT(conditional->isA<Bool>());
159
160 // Update bool conditional in-place
161 ite->predicate()->setValue(conditional);
162 TORCH_INTERNAL_ASSERT(ite->predicate()->value() != nullptr);
163 }
164 kir::ExprMutator::handle(ite);
165 }
166
167 // Generate conditional according to PredicateType
168 Bool* generateConditional(kir::Predicate* pred) {
169 switch (pred->predicate_type()) {
170 case PredicateType::Inline:
171 case PredicateType::ReductionWrite:
172 case PredicateType::Misaligned:
173 case PredicateType::Shift:
174 case PredicateType::Padding: {
175 return PredicateCompute::getInlinePredicate(
176 pred->expr(),
177 for_loops_,
178 pred->thread_pred(),
179 pred->predicate_type());
180 }
181 case PredicateType::Vectorize: {
182 std::vector<kir::ForLoop*> outer_loops;
183 kir::ForLoop* vectorized_loop = nullptr;
184 for (auto loop : for_loops_) {
185 if (loop->iter_domain()->getParallelType() ==
186 ParallelType::Vectorize) {
187 vectorized_loop = loop;
188 break;
189 } else {
190 outer_loops.emplace_back(loop);
191 }
192 }
193 TORCH_INTERNAL_ASSERT(
194 vectorized_loop != nullptr, "Should be unreachable.");
195 return UnswitchPredicate::get(outer_loops, vectorized_loop);
196 }
197 case PredicateType::Unswitch: {
198 return UnswitchPredicate::get(for_loops_, pred->unrolled_loop());
199 }
200 case PredicateType::Manual: {
201 return pred->value();
202 }
203 default:
204 break;
205 }
206 return nullptr;
207 }
208};
209
210} // namespace
211
212std::vector<Expr*> generateConditionalFromPredicate(
213 const std::vector<Expr*>& exprs) {
214 return ConditionalFromPredicateModifier::fillPredicates(exprs);
215}
216
217} // namespace cuda
218} // namespace fuser
219} // namespace jit
220} // namespace torch
221