1 | #include <lower_loops.h> |
2 | |
3 | #include <arith.h> |
4 | #include <ir_iostream.h> |
5 | #include <ir_utils.h> |
6 | #include <iter_visitor.h> |
7 | #include <kernel_expr_evaluator.h> |
8 | #include <lower2device.h> |
9 | #include <lower_utils.h> |
10 | #include <transform_replay.h> |
11 | |
12 | #include <algorithm> |
13 | #include <deque> |
14 | #include <numeric> |
15 | |
16 | namespace torch { |
17 | namespace jit { |
18 | namespace fuser { |
19 | namespace cuda { |
20 | |
21 | std::vector<Expr*> LoopNestGenerator::loweredExprs( |
22 | const std::vector<Expr*>& exprs) { |
23 | FUSER_PERF_SCOPE("GpuLower::Lower::LoopNestGenerator::loweredExprs" ); |
24 | TORCH_INTERNAL_ASSERT(FusionGuard::getCurFusion() != nullptr); |
25 | LoopNestGenerator generator(exprs); |
26 | return generator.lowered_exprs_; |
27 | } |
28 | |
29 | LoopNestGenerator::LoopNestGenerator(const std::vector<Expr*>& exprs) { |
30 | generate(exprs); |
31 | } |
32 | |
33 | namespace { |
34 | |
35 | kir::ForLoop* openForHelper(kir::ForLoop* scope, IterDomain* id) { |
36 | auto extent_with_halo = GpuLower::current()->haloInfo()->getExtent(id); |
37 | kir::ForLoop* new_scope = nullptr; |
38 | if (extent_with_halo) { |
39 | // When an axis is extended with halo, unrolling and vectorization |
40 | // are assumed to not be used for now. |
41 | TORCH_INTERNAL_ASSERT( |
42 | id->getParallelType() != ParallelType::Unroll && |
43 | !isParallelTypeVectorize(id->getParallelType())); |
44 | // Use the extent that's extended by halo |
45 | new_scope = IrBuilder::create<kir::ForLoop>( |
46 | id, |
47 | GpuLower::current()->caMap()->getIndexVariable(id), |
48 | nullptr, |
49 | extent_with_halo, |
50 | nullptr, |
51 | false, |
52 | nullptr, |
53 | false, |
54 | DoubleBufferLoopStage::NotApplicable); |
55 | } else { |
56 | new_scope = IrBuilder::create<kir::ForLoop>(id); |
57 | } |
58 | if (scope != nullptr) { |
59 | scope->body().insert(0, new_scope); |
60 | } |
61 | return new_scope; |
62 | } |
63 | |
64 | } // namespace |
65 | |
66 | void LoopNestGenerator::openFor(IterDomain* id) { |
67 | if (for_loops_.size() > 0) { |
68 | const auto new_scope = openForHelper(for_loops_.back(), id); |
69 | // for_loop_allocations_.insert({new_scope, 0}); |
70 | for_loops_.push_back(new_scope); |
71 | } else { |
72 | for_loops_.push_back(openForHelper(nullptr, id)); |
73 | lowered_exprs_.insert(lowered_exprs_.begin(), for_loops_.back()); |
74 | } |
75 | } |
76 | |
77 | void LoopNestGenerator::closeFor() { |
78 | TORCH_INTERNAL_ASSERT(!for_loops_.empty()); |
79 | for_loops_.pop_back(); |
80 | } |
81 | |
82 | void LoopNestGenerator::pushFront(Expr* expr) { |
83 | if (for_loops_.size() == 0) { |
84 | lowered_exprs_.insert(lowered_exprs_.begin(), expr); |
85 | } else { |
86 | for_loops_.back()->body().insert(0, expr); |
87 | } |
88 | } |
89 | |
90 | void LoopNestGenerator::handle(Expr* expr) { |
91 | // Check if it's a tensor view expression we need to place in the loop nest |
92 | // structure |
93 | if (!ir_utils::isTvOp(expr)) { |
94 | // Close all the loops, scalar operations cannot be inside for loops based |
95 | // on expr sorting. |
96 | while (!for_loops_.empty()) { |
97 | closeFor(); |
98 | } |
99 | pushFront(expr); |
100 | |
101 | for (auto out : expr->outputs()) { |
102 | TORCH_INTERNAL_ASSERT( |
103 | out->getValType().value() == ValType::Scalar, |
104 | "Unrecognized output type found in expr " , |
105 | expr, |
106 | " cannot lower " , |
107 | out->getValType().value()); |
108 | |
109 | pushFront(IrBuilder::create<kir::Allocate>( |
110 | out, MemoryType::Local, GpuLower::current()->kernel()->oneVal())); |
111 | } |
112 | return; |
113 | } |
114 | |
115 | TensorView* out_tv = expr->output(0)->as<TensorView>(); |
116 | |
117 | // Grab the loop structure |
118 | TORCH_INTERNAL_ASSERT( |
119 | loop_structures_.find(out_tv) != loop_structures_.end(), |
120 | "Could not find loop structure of " , |
121 | out_tv); |
122 | |
123 | // Figure out what the entire loop structure should look like. |
124 | std::vector<IterDomain*> loop_structure = loop_structures_.at(out_tv); |
125 | |
126 | // Ordering of loop_structure is global, so simply close loops we don't need, |
127 | // and open the ones we do. |
128 | |
129 | while (!for_loops_.empty() && |
130 | std::find( |
131 | loop_structure.begin(), |
132 | loop_structure.end(), |
133 | for_loops_.back()->iter_domain()) == loop_structure.end()) { |
134 | closeFor(); |
135 | } |
136 | |
137 | for (auto loop : loop_structure) { |
138 | auto find_it = std::find_if( |
139 | for_loops_.begin(), for_loops_.end(), [loop](kir::ForLoop* fl) { |
140 | return fl->iter_domain() == loop; |
141 | }); |
142 | if (find_it == for_loops_.end()) { |
143 | openFor(loop); |
144 | } |
145 | } |
146 | |
147 | pushFront(expr); |
148 | } |
149 | |
150 | // Generate the loop nest structure and place it in lowered_exprs_ |
151 | void LoopNestGenerator::generate(const std::vector<Expr*>& exprs) { |
152 | TORCH_INTERNAL_ASSERT(lowered_exprs_.empty()); |
153 | |
154 | // Figure out loop structure of each expression. This can be a bit convoluted, |
155 | // for an example why see FusionAdvancedLowering6 |
156 | |
157 | // Grab iteration domain dependencies, similar to the logic in |
158 | // lower_expr_sort, EXCEPT dependencies are in opposite order, |
159 | // inner loops are dependant on outer loops. |
160 | |
161 | const auto& ca_map = GpuLower::current()->caMap(); |
162 | |
163 | std::unordered_map<IterDomain*, std::unordered_set<IterDomain*>> |
164 | concrete_id_dependencies; |
165 | for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) { |
166 | std::unordered_set<IterDomain*> dependencies; |
167 | |
168 | for (auto tv_id : tv->domain()->domain()) { |
169 | auto concrete_id = |
170 | ca_map->getConcreteMappedID(tv_id, IdMappingMode::LOOP); |
171 | |
172 | if (concrete_id_dependencies.find(concrete_id) == |
173 | concrete_id_dependencies.end()) { |
174 | concrete_id_dependencies[concrete_id] = dependencies; |
175 | } else { |
176 | concrete_id_dependencies[concrete_id].insert( |
177 | dependencies.begin(), dependencies.end()); |
178 | } |
179 | |
180 | // Loops after tv_id are dependent on tv_id |
181 | dependencies.emplace(concrete_id); |
182 | } |
183 | } |
184 | |
185 | // Fill out dependencies as IDs will have local dependency information, but |
186 | // it's still not guaranteed to be global. |
187 | |
188 | // If loop structure is something like: |
189 | // T0 [I0] |
190 | // T1 [I0, I1] |
191 | // T2 [I1, I2] |
192 | // |
193 | // I0 will be marked as a dependency of I1 |
194 | // I1 will be marked as a dependency of I2 |
195 | // |
196 | // However, I0 will not be marked as a dep of I2, so we need to fill out the |
197 | // dependency analysis. This is done by iterating through IterDomains filling |
198 | // out all the dependencies of dependencies recursively. |
199 | |
200 | std::deque<IterDomain*> to_visit; |
201 | std::unordered_set<IterDomain*> visited; |
202 | |
203 | std::transform( |
204 | concrete_id_dependencies.begin(), |
205 | concrete_id_dependencies.end(), |
206 | std::back_inserter(to_visit), |
207 | [](const auto& concrete_dep_entry) { return concrete_dep_entry.first; }); |
208 | |
209 | while (!to_visit.empty()) { |
210 | auto id = to_visit.front(); |
211 | to_visit.pop_front(); |
212 | |
213 | auto& dependencies = concrete_id_dependencies.at(id); |
214 | bool ready = std::all_of( |
215 | dependencies.begin(), dependencies.end(), [&visited](IterDomain* id) { |
216 | return visited.count(id); |
217 | }); |
218 | |
219 | if (!ready) { |
220 | to_visit.push_back(id); |
221 | continue; |
222 | } |
223 | |
224 | for (auto dependency : dependencies) { |
225 | auto dep_of_dep = concrete_id_dependencies.at(dependency); |
226 | dependencies.insert(dep_of_dep.begin(), dep_of_dep.end()); |
227 | } |
228 | visited.emplace(id); |
229 | } |
230 | |
231 | // Generate loop structure for each tensor view |
232 | for (auto tv : ir_utils::allTvs(FusionGuard::getCurFusion())) { |
233 | // Zero dim tensor support |
234 | if (tv->nDims() == 0) { |
235 | loop_structures_[tv] = std::vector<IterDomain*>(); |
236 | continue; |
237 | } |
238 | |
239 | auto last_id_concrete = ca_map->getConcreteMappedID( |
240 | tv->axis((int)(tv->nDims() - 1)), IdMappingMode::LOOP); |
241 | auto all_loops_it = concrete_id_dependencies.find(last_id_concrete); |
242 | TORCH_INTERNAL_ASSERT( |
243 | all_loops_it != concrete_id_dependencies.end(), |
244 | "Should have processed all id's in all tvs." ); |
245 | std::vector<IterDomain*> loop_structure( |
246 | all_loops_it->second.begin(), all_loops_it->second.end()); |
247 | // Dependencies of last domain doesn't include last domain, include it |
248 | // manually |
249 | loop_structure.emplace_back(last_id_concrete); |
250 | // reverse sort (rbegin & rend) since we want the reverse of the order |
251 | // given by IterDomainDependencySorter |
252 | std::sort( |
253 | loop_structure.rbegin(), |
254 | loop_structure.rend(), |
255 | ir_utils::IterDomainDependencySorter( |
256 | concrete_id_dependencies, GpuLower::current()->caMap())); |
257 | loop_structures_[tv] = loop_structure; |
258 | } |
259 | |
260 | // Process the carefully ordered expressions |
261 | for (auto it = exprs.rbegin(); it != exprs.rend(); ++it) { |
262 | handle(*it); |
263 | } |
264 | } |
265 | |
266 | } // namespace cuda |
267 | } // namespace fuser |
268 | } // namespace jit |
269 | } // namespace torch |
270 | |