1#include <lower_loops.h>
2
3#include <arith.h>
4#include <ir_iostream.h>
5#include <ir_utils.h>
6#include <iter_visitor.h>
7#include <kernel_expr_evaluator.h>
8#include <lower2device.h>
9#include <lower_utils.h>
10#include <transform_replay.h>
11
12#include <algorithm>
13#include <deque>
14#include <numeric>
15
16namespace torch {
17namespace jit {
18namespace fuser {
19namespace cuda {
20
21std::vector<Expr*> LoopNestGenerator::loweredExprs(
22 const std::vector<Expr*>& exprs) {
23 FUSER_PERF_SCOPE("GpuLower::Lower::LoopNestGenerator::loweredExprs");
24 TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr);
25 LoopNestGenerator generator(exprs);
26 return generator.lowered_exprs_;
27}
28
29LoopNestGenerator::LoopNestGenerator(const std::vector<Expr*>& exprs) {
30 generate(exprs);
31}
32
33namespace {
34
35kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) {
36 auto extent_with_halo = GpuLower::current()->haloInfo()->getExtent(id);
37 kir::ForLoop* new_scope = nullptr;
38 if (extent_with_halo) {
39 // When an axis is extended with halo, unrolling and vectorization
40 // are assumed to not be used for now.
41 TORCH_INTERNAL_ASSERT(
42 id->getParallelType() != ParallelType::Unroll &&
43 !isParallelTypeVectorize(id->getParallelType()));
44 // Use the extent that's extended by halo
45 new_scope = IrBuilder::create<kir::ForLoop>(
46 id,
47 GpuLower::current()->caMap()->getIndexVariable(id),
48 nullptr,
49 extent_with_halo,
50 nullptr,
51 false,
52 nullptr,
53 false,
54 DoubleBufferLoopStage::NotApplicable);
55 } else {
56 new_scope = IrBuilder::create<kir::ForLoop>(id);
57 }
58 if (scope != nullptr) {
59 scope->body().insert(0, new_scope);
60 }
61 return new_scope;
62}
63
64} // namespace
65
66void LoopNestGenerator::openFor(IterDomain* id) {
67 if (for_loops_.size() > 0) {
68 const auto new_scope = openForHelper(for_loops_.back(), id);
69 // for_loop_allocations_.insert({new_scope, 0});
70 for_loops_.push_back(new_scope);
71 } else {
72 for_loops_.push_back(openForHelper(nullptr, id));
73 lowered_exprs_.insert(lowered_exprs_.begin(), for_loops_.back());
74 }
75}
76
77void LoopNestGenerator::closeFor() {
78 TORCH_INTERNAL_ASSERT(!for_loops_.empty());
79 for_loops_.pop_back();
80}
81
82void LoopNestGenerator::pushFront(Expr* expr) {
83 if (for_loops_.size() == 0) {
84 lowered_exprs_.insert(lowered_exprs_.begin(), expr);
85 } else {
86 for_loops_.back()->body().insert(0, expr);
87 }
88}
89
90void LoopNestGenerator::handle(Expr* expr) {
91 // Check if it's a tensor view expression we need to place in the loop nest
92 // structure
93 if (!ir_utils::isTvOp(expr)) {
94 // Close all the loops, scalar operations cannot be inside for loops based
95 // on expr sorting.
96 while (!for_loops_.empty()) {
97 closeFor();
98 }
99 pushFront(expr);
100
101 for (auto out : expr->outputs()) {
102 TORCH_INTERNAL_ASSERT(
103 out->getValType().value() == ValType::Scalar,
104 "Unrecognized output type found in expr ",
105 expr,
106 " cannot lower ",
107 out->getValType().value());
108
109 pushFront(IrBuilder::create<kir::Allocate>(
110 out, MemoryType::Local, GpuLower::current()->kernel()->oneVal()));
111 }
112 return;
113 }
114
115 TensorView* out_tv = expr->output(0)->as<TensorView>();
116
117 // Grab the loop structure
118 TORCH_INTERNAL_ASSERT(
119 loop_structures_.find(out_tv) != loop_structures_.end(),
120 "Could not find loop structure of ",
121 out_tv);
122
123 // Figure out what the entire loop structure should look like.
124 std::vector<IterDomain*> loop_structure = loop_structures_.at(out_tv);
125
126 // Ordering of loop_structure is global, so simply close loops we don't need,
127 // and open the ones we do.
128
129 while (!for_loops_.empty() &&
130 std::find(
131 loop_structure.begin(),
132 loop_structure.end(),
133 for_loops_.back()->iter_domain()) == loop_structure.end()) {
134 closeFor();
135 }
136
137 for (auto loop : loop_structure) {
138 auto find_it = std::find_if(
139 for_loops_.begin(), for_loops_.end(), [loop](kir::ForLoop* fl) {
140 return fl->iter_domain() == loop;
141 });
142 if (find_it == for_loops_.end()) {
143 openFor(loop);
144 }
145 }
146
147 pushFront(expr);
148}
149
150// Generate the loop nest structure and place it in lowered_exprs_
151void LoopNestGenerator::generate(const std::vector<Expr*>& exprs) {
152 TORCH_INTERNAL_ASSERT(lowered_exprs_.empty());
153
154 // Figure out loop structure of each expression. This can be a bit convoluted,
155 // for an example why see FusionAdvancedLowering6
156
157 // Grab iteration domain dependencies, similar to the logic in
158 // lower_expr_sort, EXCEPT dependencies are in opposite order,
159 // inner loops are dependant on outer loops.
160
161 const auto& ca_map = GpuLower::current()->caMap();
162
163 std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>>
164 concrete_id_dependencies;
165 for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) {
166 std::unordered_set<IterDomain*> dependencies;
167
168 for (auto tv_id : tv->domain()->domain()) {
169 auto concrete_id =
170 ca_map->getConcreteMappedID(tv_id, IdMappingMode::LOOP);
171
172 if (concrete_id_dependencies.find(concrete_id) ==
173 concrete_id_dependencies.end()) {
174 concrete_id_dependencies[concrete_id] = dependencies;
175 } else {
176 concrete_id_dependencies[concrete_id].insert(
177 dependencies.begin(), dependencies.end());
178 }
179
180 // Loops after tv_id are dependent on tv_id
181 dependencies.emplace(concrete_id);
182 }
183 }
184
185 // Fill out dependencies as IDs will have local dependency information, but
186 // it's still not guaranteed to be global.
187
188 // If loop structure is something like:
189 // T0 [I0]
190 // T1 [I0, I1]
191 // T2 [I1, I2]
192 //
193 // I0 will be marked as a dependency of I1
194 // I1 will be marked as a dependency of I2
195 //
196 // However, I0 will not be marked as a dep of I2, so we need to fill out the
197 // dependency analysis. This is done by iterating through IterDomains filling
198 // out all the dependencies of dependencies recursively.
199
200 std::deque<IterDomain*> to_visit;
201 std::unordered_set<IterDomain*> visited;
202
203 std::transform(
204 concrete_id_dependencies.begin(),
205 concrete_id_dependencies.end(),
206 std::back_inserter(to_visit),
207 [](const auto& concrete_dep_entry) { return concrete_dep_entry.first; });
208
209 while (!to_visit.empty()) {
210 auto id = to_visit.front();
211 to_visit.pop_front();
212
213 auto& dependencies = concrete_id_dependencies.at(id);
214 bool ready = std::all_of(
215 dependencies.begin(), dependencies.end(), [&visited](IterDomain* id) {
216 return visited.count(id);
217 });
218
219 if (!ready) {
220 to_visit.push_back(id);
221 continue;
222 }
223
224 for (auto dependency : dependencies) {
225 auto dep_of_dep = concrete_id_dependencies.at(dependency);
226 dependencies.insert(dep_of_dep.begin(), dep_of_dep.end());
227 }
228 visited.emplace(id);
229 }
230
231 // Generate loop structure for each tensor view
232 for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) {
233 // Zero dim tensor support
234 if (tv->nDims() == 0) {
235 loop_structures_[tv] = std::vector<IterDomain*>();
236 continue;
237 }
238
239 auto last_id_concrete = ca_map->getConcreteMappedID(
240 tv->axis((int)(tv->nDims() - 1)), IdMappingMode::LOOP);
241 auto all_loops_it = concrete_id_dependencies.find(last_id_concrete);
242 TORCH_INTERNAL_ASSERT(
243 all_loops_it != concrete_id_dependencies.end(),
244 "Should have processed all id's in all tvs.");
245 std::vector<IterDomain*> loop_structure(
246 all_loops_it->second.begin(), all_loops_it->second.end());
247 // Dependencies of last domain doesn't include last domain, include it
248 // manually
249 loop_structure.emplace_back(last_id_concrete);
250 // reverse sort (rbegin & rend) since we want the reverse of the order
251 // given by IterDomainDependencySorter
252 std::sort(
253 loop_structure.rbegin(),
254 loop_structure.rend(),
255 ir_utils::IterDomainDependencySorter(
256 concrete_id_dependencies, GpuLower::current()->caMap()));
257 loop_structures_[tv] = loop_structure;
258 }
259
260 // Process the carefully ordered expressions
261 for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) {
262 handle(*it);
263 }
264}
265
266} // namespace cuda
267} // namespace fuser
268} // namespace jit
269} // namespace torch
270