1 | #include <torch/csrc/jit/passes/loop_unrolling.h> |
2 | |
3 | #include <ATen/core/symbol.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <c10/util/irange.h> |
6 | |
7 | #include <torch/csrc/jit/ir/constants.h> |
8 | #include <torch/csrc/jit/ir/ir_views.h> |
9 | #include <torch/csrc/jit/jit_log.h> |
10 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | namespace { |
16 | |
17 | static constexpr int64_t kUnrollFactor = 8; |
18 | static constexpr int64_t kMaxBodySize = 32; |
19 | static constexpr int64_t kMaxBodyRepeats = 64; |
20 | |
21 | bool isTrueConstant(Value* val) { |
22 | c10::optional<bool> maybe_value = constant_as<bool>(val); |
23 | return maybe_value && *maybe_value; |
24 | } |
25 | |
26 | bool isForLoop(Node* node) { |
27 | if (node->kind() != prim::Loop) |
28 | return false; |
29 | Value* start_cond = node->inputs().at(1); |
30 | Value* continue_cond = node->blocks().at(0)->outputs().at(0); |
31 | return isTrueConstant(start_cond) && isTrueConstant(continue_cond); |
32 | } |
33 | |
34 | // Counts the size of this block, stopping and returning once reaches limit |
35 | // instructions. |
36 | int64_t limitedBlockSize(Block* body, int64_t limit) { |
37 | auto it = body->nodes().begin(); |
38 | auto end = body->nodes().end(); |
39 | for (int64_t i = 0; i < limit; ++it) { |
40 | for (Block* subblock : it->blocks()) { |
41 | i += limitedBlockSize(subblock, limit - i); |
42 | } |
43 | if (!it->notExecutedOp()) { |
44 | ++i; |
45 | } |
46 | if (it == end) { |
47 | return i; |
48 | } |
49 | } |
50 | return limit; |
51 | } |
52 | |
53 | bool isSmallBlock(Block* body) { |
54 | return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize; |
55 | } |
56 | |
57 | // XXX: This function can only be called with a loop that is guaranteed to |
58 | // execute EXACTLY ONCE. |
59 | void inlineBody(Node* loop) { |
60 | auto graph = loop->owningGraph(); |
61 | auto body = loop->blocks().at(0); |
62 | WithInsertPoint insert_point_guard{loop}; |
63 | |
64 | std::unordered_map<Value*, Value*> value_map; |
65 | auto get_value = [&](Value* v) { |
66 | auto it = value_map.find(v); |
67 | if (it != value_map.end()) |
68 | return it->second; |
69 | return v; |
70 | }; |
71 | |
72 | // Loop node has extra (max_iters, initial_cond) inputs, |
73 | // body has an extra (loop_counter) input. |
74 | for (size_t i = 2; i < loop->inputs().size(); ++i) { |
75 | value_map[body->inputs()[i - 1]] = loop->inputs()[i]; |
76 | } |
77 | |
78 | for (Node* orig : body->nodes()) { |
79 | Node* clone = graph->insertNode(graph->createClone(orig, get_value)); |
80 | for (size_t i = 0; i < orig->outputs().size(); ++i) { |
81 | value_map[orig->outputs()[i]] = clone->outputs()[i]; |
82 | } |
83 | } |
84 | for (size_t i = 0; i < loop->outputs().size(); ++i) { |
85 | loop->outputs().at(i)->replaceAllUsesWith( |
86 | get_value(body->outputs().at(i + 1))); |
87 | } |
88 | // XXX: it is extremely important to destroy the loop in here. DCE might not |
89 | // be able to conclude that it's safe, because the loop might contain side |
90 | // effects. |
91 | loop->destroy(); |
92 | } |
93 | |
94 | // inserts a copy of body, passing inputs to the inputs of the block |
95 | // it returns the a list of the Values for the output of the block |
96 | std::vector<Value*> insertBlockCopy( |
97 | Graph& graph, |
98 | Block* body, |
99 | at::ArrayRef<Value*> inputs) { |
100 | TORCH_INTERNAL_ASSERT(inputs.size() == body->inputs().size()); |
101 | std::unordered_map<Value*, Value*> value_map; |
102 | auto get_value = [&](Value* v) { |
103 | auto it = value_map.find(v); |
104 | if (it != value_map.end()) |
105 | return it->second; |
106 | return v; |
107 | }; |
108 | auto inputs_it = inputs.begin(); |
109 | for (Value* input : body->inputs()) { |
110 | value_map[input] = *inputs_it++; |
111 | } |
112 | for (Node* node : body->nodes()) { |
113 | Node* new_node = graph.insertNode(graph.createClone(node, get_value)); |
114 | auto outputs_it = new_node->outputs().begin(); |
115 | for (Value* output : node->outputs()) { |
116 | value_map[output] = *outputs_it++; |
117 | } |
118 | } |
119 | return fmap(body->outputs(), get_value); |
120 | } |
121 | |
122 | void repeatBody(Block* body, size_t times, Block* dest) { |
123 | auto graph = body->owningGraph(); |
124 | WithInsertPoint insert_point_guard(dest); |
125 | for (Value* input : body->inputs()) { |
126 | dest->addInput()->copyMetadata(input); |
127 | } |
128 | |
129 | std::vector<Value*> io = dest->inputs().vec(); |
130 | TORCH_INTERNAL_ASSERT( |
131 | !body->inputs().at(0)->hasUses(), "loop counter should be unused" ); |
132 | for (const auto i : c10::irange(times)) { |
133 | (void)i; // Suppress unused variable warning |
134 | io[0] = body->inputs().at(0); |
135 | io = insertBlockCopy(*graph, body, io); |
136 | } |
137 | for (Value* output : io) { |
138 | dest->registerOutput(output); |
139 | } |
140 | |
141 | // It's likely that we have some dead nodes now - for example the "true" |
142 | // constant that prevents the loop from breaking. We shouldn't wait too long |
143 | // before removing them because they might artificially increase the loop size |
144 | // and prevent outer loop unrolling. |
145 | EliminateDeadCode(dest, false); |
146 | } |
147 | |
148 | // Replaces the builtin loop counter with a "mutable" variable outside of the |
149 | // loop. |
150 | void replaceLoopCounter(Node* loop) { |
151 | Graph* graph = loop->owningGraph(); |
152 | Block* body = loop->blocks().at(0); |
153 | WithInsertPoint guard(loop); |
154 | Value* init_counter = graph->insertConstant(0); |
155 | |
156 | loop->insertInput(2, init_counter); |
157 | loop->insertOutput(0)->setType(IntType::get()); |
158 | |
159 | Value* internal_counter = body->insertInput(1)->setType(init_counter->type()); |
160 | body->inputs()[0]->replaceAllUsesWith(internal_counter); |
161 | |
162 | WithInsertPoint insertPointGuard{body->return_node()}; |
163 | Value* result = graph->insert(aten::add, {internal_counter, 1}); |
164 | body->insertOutput(1, result); |
165 | } |
166 | |
167 | void unroll(Node* loop) { |
168 | Graph* graph = loop->owningGraph(); |
169 | Block* body = loop->blocks().at(0); |
170 | |
171 | // We will be using a "mutable" counter outside of the loop instead of the |
172 | // default one, because this will allow us to share it between the unrolled |
173 | // loop and its epilogue. This is necessary only if the loop counter is |
174 | // actually used in the body. |
175 | if (!body->inputs()[0]->uses().empty()) |
176 | replaceLoopCounter(loop); |
177 | |
178 | // Some optimization for constant-length loops. If we know they won't run too |
179 | // many times, then we can unroll them entirely. |
180 | Value* trip_count = loop->inputs().at(0); |
181 | c10::optional<int64_t> const_len = constant_as<int64_t>(trip_count); |
182 | if (const_len && *const_len < kMaxBodyRepeats) { |
183 | Block* dest = loop->addBlock(); |
184 | repeatBody(body, *const_len, dest); |
185 | loop->eraseBlock(0); |
186 | inlineBody(loop); |
187 | return; |
188 | } |
189 | |
190 | WithInsertPoint insert_point_guard{loop}; |
191 | |
192 | // Clone the loop before we unroll it. The clone will become the epilogue. |
193 | Node* loop_epilogue = |
194 | graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop); |
195 | for (size_t i = 0; i < loop->outputs().size(); ++i) { |
196 | loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]); |
197 | loop_epilogue->replaceInput(i + 2, loop->outputs()[i]); |
198 | } |
199 | |
200 | Block* dest = loop->addBlock(); |
201 | repeatBody(body, kUnrollFactor, dest); |
202 | loop->eraseBlock(0); |
203 | |
204 | // Change the iteration counts of both loops |
205 | Value* iter_count = loop->inputs().at(0); |
206 | Value* unrolled_iter_count = graph->insert( |
207 | aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor}); |
208 | loop->replaceInput(0, unrolled_iter_count); |
209 | loop_epilogue->replaceInput( |
210 | 0, |
211 | graph->insert( |
212 | aten::sub, |
213 | {iter_count, |
214 | graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})})); |
215 | } |
216 | |
217 | bool UnrollLoops(Block* block, bool constant_only) { |
218 | bool changed = false; |
219 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
220 | // XXX: unroll might destroy the current node, so we need to pre-increment |
221 | // the iterator |
222 | Node* node = *it; |
223 | ++it; |
224 | for (Block* subblock : node->blocks()) { |
225 | changed |= UnrollLoops(subblock, constant_only); |
226 | } |
227 | if (!isForLoop(node)) { |
228 | continue; |
229 | } |
230 | if (constant_only) { |
231 | if (node->inputs().at(0)->node()->kind() != prim::Constant) { |
232 | continue; |
233 | } |
234 | } else if (!isSmallBlock(node->blocks().at(0))) { |
235 | continue; |
236 | } |
237 | |
238 | unroll(node); |
239 | changed = true; |
240 | } |
241 | return changed; |
242 | } |
243 | |
244 | } // anonymous namespace |
245 | |
246 | static void addCondAsOutput(Node* loop) { |
247 | LoopView loop_view(loop); |
248 | loop->addInput(loop_view.inputCond()); |
249 | auto block_cond_input = loop_view.bodyBlock()->addInput(); |
250 | block_cond_input->copyMetadata(loop_view.inputCond()); |
251 | auto cond_output_index = |
252 | loop_view.bodyBlock()->registerOutput(loop_view.nextCond()); |
253 | loop_view.bodyBlock()->outputs()[cond_output_index]->copyMetadata( |
254 | loop_view.nextCond()); |
255 | auto cond_output = loop->addOutput(); |
256 | cond_output->copyMetadata(loop_view.nextCond()); |
257 | } |
258 | |
259 | bool LoopsPeeler::run(const std::shared_ptr<Graph>& graph) { |
260 | GRAPH_DUMP("Before LoopsPeeler" , graph); |
261 | collectLoops(graph->block()); |
262 | peelLoops(); |
263 | GRAPH_DUMP("After LoopsPeeler" , graph); |
264 | return true; |
265 | } |
266 | |
267 | void LoopsPeeler::collectLoop(Node* n) { |
268 | if (callback_(n)) { |
269 | if (in_loop_) { |
270 | GRAPH_DEBUG("Loop " , getHeader(in_loop_), " will be unrolled" ); |
271 | loops_to_peel_.push_back(in_loop_); |
272 | in_loop_ = nullptr; |
273 | } |
274 | } |
275 | } |
276 | |
277 | void LoopsPeeler::collectLoops(Block* block) { |
278 | // we do a pre-order traversal to reduce the number |
279 | // of peeled loops. |
280 | for (auto n : block->nodes()) { |
281 | collectLoop(n); |
282 | } |
283 | collectLoop(block->return_node()); |
284 | |
285 | // process child blocks |
286 | for (auto n : block->nodes()) { |
287 | auto old_in_loop_ = in_loop_; |
288 | if (n->kind() == prim::Loop) { |
289 | in_loop_ = n; |
290 | } |
291 | for (auto b : n->blocks()) { |
292 | collectLoops(b); |
293 | } |
294 | in_loop_ = old_in_loop_; |
295 | } |
296 | } |
297 | |
298 | void LoopsPeeler::peelLoops() { |
299 | for (auto loop : loops_to_peel_) { |
300 | PeelLoop(loop, num_iterations_); |
301 | } |
302 | } |
303 | |
304 | bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph) { |
305 | auto peel_predicate = [](Node* n) { |
306 | for (auto i : n->inputs()) { |
307 | if (i->type()->isSubtypeOf(*TensorType::get())) { |
308 | return true; |
309 | } |
310 | } |
311 | |
312 | return false; |
313 | }; |
314 | |
315 | LoopsPeeler lp(peel_predicate); |
316 | return lp.run(graph); |
317 | } |
318 | |
319 | Node* PeelLoop(Node* n, size_t times) { |
320 | GRAPH_DEBUG("Peeling the loop " , getHeader(n), " " , times, " times" ); |
321 | |
322 | auto graph = n->owningGraph(); |
323 | auto orig_loop = LoopView(n); |
324 | |
325 | WithInsertPoint wip(n); |
326 | auto times_const = graph->insertConstant(static_cast<int64_t>(times)); |
327 | // N.B. even though a caller may request to peel `times` iterations |
328 | // `maxTripCount` of the original loop might be less than that |
329 | // so we should take the minimum of the two |
330 | auto min_trip_count = |
331 | graph->insert(prim::min, {orig_loop.maxTripCount(), times_const}); |
332 | |
333 | // make the peeled clone |
334 | auto peeled_copy = graph->createClone(n, [](Value* v) { return v; }); |
335 | addCondAsOutput(peeled_copy); |
336 | |
337 | LoopView new_lv(peeled_copy); |
338 | graph->insertNode(peeled_copy); |
339 | // only run until the peeled count |
340 | new_lv.replaceMaxTripCount(min_trip_count); |
341 | |
342 | // substract `maxTripCount` of the original loop by the number iterations |
343 | // the peeled loop runs |
344 | auto new_max_trip_count = |
345 | graph->insert(aten::sub, {orig_loop.maxTripCount(), min_trip_count}); |
346 | orig_loop.replaceMaxTripCount(new_max_trip_count); |
347 | // update the termination condition |
348 | auto cond_index = peeled_copy->outputs().size() - 1; |
349 | orig_loop.replaceInputCondition(peeled_copy->output(cond_index)); |
350 | |
351 | static const size_t LOOP_DEPS_WITH_COND_OFFSET = 2; |
352 | for (size_t i = 0; i < peeled_copy->outputs().size() - |
353 | 1 /* leave off the termination condition */; |
354 | i++) { |
355 | n->replaceInput(LOOP_DEPS_WITH_COND_OFFSET + i, peeled_copy->output(i)); |
356 | } |
357 | |
358 | // the induction variable also needs to be adjusted by the number of |
359 | // iterations the peeled loop runs |
360 | { |
361 | WithInsertPoint peeled_wip(*orig_loop.bodyBlock()->nodes().begin()); |
362 | // we can't create the expression: `new_counter` = `old_counter` + 1 yet |
363 | // because when we |
364 | // run `old_counter->replaceAllUsesWith(new_counter)`, we will get |
365 | // `new_counter = new_counter + 1` |
366 | auto adjusted_iter_counter = |
367 | graph->insert(aten::add, {min_trip_count, min_trip_count}); |
368 | orig_loop.currentTripCount()->replaceAllUsesWith(adjusted_iter_counter); |
369 | adjusted_iter_counter->node()->replaceInput( |
370 | 0, orig_loop.currentTripCount()); |
371 | } |
372 | |
373 | return peeled_copy; |
374 | } |
375 | |
376 | bool UnrollLoops(std::shared_ptr<Graph>& graph) { |
377 | bool changed = UnrollLoops(graph->block(), false); |
378 | if (changed) { |
379 | EliminateDeadCode(graph); |
380 | } |
381 | return changed; |
382 | } |
383 | |
384 | bool UnrollConstantLoops(std::shared_ptr<Graph>& graph) { |
385 | bool changed = UnrollLoops(graph->block(), true); |
386 | if (changed) { |
387 | EliminateDeadCode(graph); |
388 | } |
389 | return changed; |
390 | } |
391 | |
392 | } // namespace jit |
393 | } // namespace torch |
394 | |