1#include <torch/csrc/jit/passes/loop_unrolling.h>
2
3#include <ATen/core/symbol.h>
4#include <c10/util/Exception.h>
5#include <c10/util/irange.h>
6
7#include <torch/csrc/jit/ir/constants.h>
8#include <torch/csrc/jit/ir/ir_views.h>
9#include <torch/csrc/jit/jit_log.h>
10#include <torch/csrc/jit/passes/dead_code_elimination.h>
11
12namespace torch {
13namespace jit {
14
15namespace {
16
17static constexpr int64_t kUnrollFactor = 8;
18static constexpr int64_t kMaxBodySize = 32;
19static constexpr int64_t kMaxBodyRepeats = 64;
20
21bool isTrueConstant(Value* val) {
22 c10::optional<bool> maybe_value = constant_as<bool>(val);
23 return maybe_value && *maybe_value;
24}
25
26bool isForLoop(Node* node) {
27 if (node->kind() != prim::Loop)
28 return false;
29 Value* start_cond = node->inputs().at(1);
30 Value* continue_cond = node->blocks().at(0)->outputs().at(0);
31 return isTrueConstant(start_cond) && isTrueConstant(continue_cond);
32}
33
34// Counts the size of this block, stopping and returning once reaches limit
35// instructions.
36int64_t limitedBlockSize(Block* body, int64_t limit) {
37 auto it = body->nodes().begin();
38 auto end = body->nodes().end();
39 for (int64_t i = 0; i < limit; ++it) {
40 for (Block* subblock : it->blocks()) {
41 i += limitedBlockSize(subblock, limit - i);
42 }
43 if (!it->notExecutedOp()) {
44 ++i;
45 }
46 if (it == end) {
47 return i;
48 }
49 }
50 return limit;
51}
52
53bool isSmallBlock(Block* body) {
54 return limitedBlockSize(body, kMaxBodySize + 1) <= kMaxBodySize;
55}
56
57// XXX: This function can only be called with a loop that is guaranteed to
58// execute EXACTLY ONCE.
59void inlineBody(Node* loop) {
60 auto graph = loop->owningGraph();
61 auto body = loop->blocks().at(0);
62 WithInsertPoint insert_point_guard{loop};
63
64 std::unordered_map<Value*, Value*> value_map;
65 auto get_value = [&](Value* v) {
66 auto it = value_map.find(v);
67 if (it != value_map.end())
68 return it->second;
69 return v;
70 };
71
72 // Loop node has extra (max_iters, initial_cond) inputs,
73 // body has an extra (loop_counter) input.
74 for (size_t i = 2; i < loop->inputs().size(); ++i) {
75 value_map[body->inputs()[i - 1]] = loop->inputs()[i];
76 }
77
78 for (Node* orig : body->nodes()) {
79 Node* clone = graph->insertNode(graph->createClone(orig, get_value));
80 for (size_t i = 0; i < orig->outputs().size(); ++i) {
81 value_map[orig->outputs()[i]] = clone->outputs()[i];
82 }
83 }
84 for (size_t i = 0; i < loop->outputs().size(); ++i) {
85 loop->outputs().at(i)->replaceAllUsesWith(
86 get_value(body->outputs().at(i + 1)));
87 }
88 // XXX: it is extremely important to destroy the loop in here. DCE might not
89 // be able to conclude that it's safe, because the loop might contain side
90 // effects.
91 loop->destroy();
92}
93
94// inserts a copy of body, passing inputs to the inputs of the block
95// it returns the a list of the Values for the output of the block
96std::vector<Value*> insertBlockCopy(
97 Graph& graph,
98 Block* body,
99 at::ArrayRef<Value*> inputs) {
100 TORCH_INTERNAL_ASSERT(inputs.size() == body->inputs().size());
101 std::unordered_map<Value*, Value*> value_map;
102 auto get_value = [&](Value* v) {
103 auto it = value_map.find(v);
104 if (it != value_map.end())
105 return it->second;
106 return v;
107 };
108 auto inputs_it = inputs.begin();
109 for (Value* input : body->inputs()) {
110 value_map[input] = *inputs_it++;
111 }
112 for (Node* node : body->nodes()) {
113 Node* new_node = graph.insertNode(graph.createClone(node, get_value));
114 auto outputs_it = new_node->outputs().begin();
115 for (Value* output : node->outputs()) {
116 value_map[output] = *outputs_it++;
117 }
118 }
119 return fmap(body->outputs(), get_value);
120}
121
122void repeatBody(Block* body, size_t times, Block* dest) {
123 auto graph = body->owningGraph();
124 WithInsertPoint insert_point_guard(dest);
125 for (Value* input : body->inputs()) {
126 dest->addInput()->copyMetadata(input);
127 }
128
129 std::vector<Value*> io = dest->inputs().vec();
130 TORCH_INTERNAL_ASSERT(
131 !body->inputs().at(0)->hasUses(), "loop counter should be unused");
132 for (const auto i : c10::irange(times)) {
133 (void)i; // Suppress unused variable warning
134 io[0] = body->inputs().at(0);
135 io = insertBlockCopy(*graph, body, io);
136 }
137 for (Value* output : io) {
138 dest->registerOutput(output);
139 }
140
141 // It's likely that we have some dead nodes now - for example the "true"
142 // constant that prevents the loop from breaking. We shouldn't wait too long
143 // before removing them because they might artificially increase the loop size
144 // and prevent outer loop unrolling.
145 EliminateDeadCode(dest, false);
146}
147
148// Replaces the builtin loop counter with a "mutable" variable outside of the
149// loop.
150void replaceLoopCounter(Node* loop) {
151 Graph* graph = loop->owningGraph();
152 Block* body = loop->blocks().at(0);
153 WithInsertPoint guard(loop);
154 Value* init_counter = graph->insertConstant(0);
155
156 loop->insertInput(2, init_counter);
157 loop->insertOutput(0)->setType(IntType::get());
158
159 Value* internal_counter = body->insertInput(1)->setType(init_counter->type());
160 body->inputs()[0]->replaceAllUsesWith(internal_counter);
161
162 WithInsertPoint insertPointGuard{body->return_node()};
163 Value* result = graph->insert(aten::add, {internal_counter, 1});
164 body->insertOutput(1, result);
165}
166
167void unroll(Node* loop) {
168 Graph* graph = loop->owningGraph();
169 Block* body = loop->blocks().at(0);
170
171 // We will be using a "mutable" counter outside of the loop instead of the
172 // default one, because this will allow us to share it between the unrolled
173 // loop and its epilogue. This is necessary only if the loop counter is
174 // actually used in the body.
175 if (!body->inputs()[0]->uses().empty())
176 replaceLoopCounter(loop);
177
178 // Some optimization for constant-length loops. If we know they won't run too
179 // many times, then we can unroll them entirely.
180 Value* trip_count = loop->inputs().at(0);
181 c10::optional<int64_t> const_len = constant_as<int64_t>(trip_count);
182 if (const_len && *const_len < kMaxBodyRepeats) {
183 Block* dest = loop->addBlock();
184 repeatBody(body, *const_len, dest);
185 loop->eraseBlock(0);
186 inlineBody(loop);
187 return;
188 }
189
190 WithInsertPoint insert_point_guard{loop};
191
192 // Clone the loop before we unroll it. The clone will become the epilogue.
193 Node* loop_epilogue =
194 graph->createClone(loop, [](Value* v) { return v; })->insertAfter(loop);
195 for (size_t i = 0; i < loop->outputs().size(); ++i) {
196 loop->outputs()[i]->replaceAllUsesWith(loop_epilogue->outputs()[i]);
197 loop_epilogue->replaceInput(i + 2, loop->outputs()[i]);
198 }
199
200 Block* dest = loop->addBlock();
201 repeatBody(body, kUnrollFactor, dest);
202 loop->eraseBlock(0);
203
204 // Change the iteration counts of both loops
205 Value* iter_count = loop->inputs().at(0);
206 Value* unrolled_iter_count = graph->insert(
207 aten::__round_to_zero_floordiv, {iter_count, kUnrollFactor});
208 loop->replaceInput(0, unrolled_iter_count);
209 loop_epilogue->replaceInput(
210 0,
211 graph->insert(
212 aten::sub,
213 {iter_count,
214 graph->insert(aten::mul, {unrolled_iter_count, kUnrollFactor})}));
215}
216
217bool UnrollLoops(Block* block, bool constant_only) {
218 bool changed = false;
219 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
220 // XXX: unroll might destroy the current node, so we need to pre-increment
221 // the iterator
222 Node* node = *it;
223 ++it;
224 for (Block* subblock : node->blocks()) {
225 changed |= UnrollLoops(subblock, constant_only);
226 }
227 if (!isForLoop(node)) {
228 continue;
229 }
230 if (constant_only) {
231 if (node->inputs().at(0)->node()->kind() != prim::Constant) {
232 continue;
233 }
234 } else if (!isSmallBlock(node->blocks().at(0))) {
235 continue;
236 }
237
238 unroll(node);
239 changed = true;
240 }
241 return changed;
242}
243
244} // anonymous namespace
245
246static void addCondAsOutput(Node* loop) {
247 LoopView loop_view(loop);
248 loop->addInput(loop_view.inputCond());
249 auto block_cond_input = loop_view.bodyBlock()->addInput();
250 block_cond_input->copyMetadata(loop_view.inputCond());
251 auto cond_output_index =
252 loop_view.bodyBlock()->registerOutput(loop_view.nextCond());
253 loop_view.bodyBlock()->outputs()[cond_output_index]->copyMetadata(
254 loop_view.nextCond());
255 auto cond_output = loop->addOutput();
256 cond_output->copyMetadata(loop_view.nextCond());
257}
258
259bool LoopsPeeler::run(const std::shared_ptr<Graph>& graph) {
260 GRAPH_DUMP("Before LoopsPeeler", graph);
261 collectLoops(graph->block());
262 peelLoops();
263 GRAPH_DUMP("After LoopsPeeler", graph);
264 return true;
265}
266
267void LoopsPeeler::collectLoop(Node* n) {
268 if (callback_(n)) {
269 if (in_loop_) {
270 GRAPH_DEBUG("Loop ", getHeader(in_loop_), " will be unrolled");
271 loops_to_peel_.push_back(in_loop_);
272 in_loop_ = nullptr;
273 }
274 }
275}
276
277void LoopsPeeler::collectLoops(Block* block) {
278 // we do a pre-order traversal to reduce the number
279 // of peeled loops.
280 for (auto n : block->nodes()) {
281 collectLoop(n);
282 }
283 collectLoop(block->return_node());
284
285 // process child blocks
286 for (auto n : block->nodes()) {
287 auto old_in_loop_ = in_loop_;
288 if (n->kind() == prim::Loop) {
289 in_loop_ = n;
290 }
291 for (auto b : n->blocks()) {
292 collectLoops(b);
293 }
294 in_loop_ = old_in_loop_;
295 }
296}
297
298void LoopsPeeler::peelLoops() {
299 for (auto loop : loops_to_peel_) {
300 PeelLoop(loop, num_iterations_);
301 }
302}
303
304bool PeelProfilingLoops(const std::shared_ptr<Graph>& graph) {
305 auto peel_predicate = [](Node* n) {
306 for (auto i : n->inputs()) {
307 if (i->type()->isSubtypeOf(*TensorType::get())) {
308 return true;
309 }
310 }
311
312 return false;
313 };
314
315 LoopsPeeler lp(peel_predicate);
316 return lp.run(graph);
317}
318
319Node* PeelLoop(Node* n, size_t times) {
320 GRAPH_DEBUG("Peeling the loop ", getHeader(n), " ", times, " times");
321
322 auto graph = n->owningGraph();
323 auto orig_loop = LoopView(n);
324
325 WithInsertPoint wip(n);
326 auto times_const = graph->insertConstant(static_cast<int64_t>(times));
327 // N.B. even though a caller may request to peel `times` iterations
328 // `maxTripCount` of the original loop might be less than that
329 // so we should take the minimum of the two
330 auto min_trip_count =
331 graph->insert(prim::min, {orig_loop.maxTripCount(), times_const});
332
333 // make the peeled clone
334 auto peeled_copy = graph->createClone(n, [](Value* v) { return v; });
335 addCondAsOutput(peeled_copy);
336
337 LoopView new_lv(peeled_copy);
338 graph->insertNode(peeled_copy);
339 // only run until the peeled count
340 new_lv.replaceMaxTripCount(min_trip_count);
341
342 // substract `maxTripCount` of the original loop by the number iterations
343 // the peeled loop runs
344 auto new_max_trip_count =
345 graph->insert(aten::sub, {orig_loop.maxTripCount(), min_trip_count});
346 orig_loop.replaceMaxTripCount(new_max_trip_count);
347 // update the termination condition
348 auto cond_index = peeled_copy->outputs().size() - 1;
349 orig_loop.replaceInputCondition(peeled_copy->output(cond_index));
350
351 static const size_t LOOP_DEPS_WITH_COND_OFFSET = 2;
352 for (size_t i = 0; i < peeled_copy->outputs().size() -
353 1 /* leave off the termination condition */;
354 i++) {
355 n->replaceInput(LOOP_DEPS_WITH_COND_OFFSET + i, peeled_copy->output(i));
356 }
357
358 // the induction variable also needs to be adjusted by the number of
359 // iterations the peeled loop runs
360 {
361 WithInsertPoint peeled_wip(*orig_loop.bodyBlock()->nodes().begin());
362 // we can't create the expression: `new_counter` = `old_counter` + 1 yet
363 // because when we
364 // run `old_counter->replaceAllUsesWith(new_counter)`, we will get
365 // `new_counter = new_counter + 1`
366 auto adjusted_iter_counter =
367 graph->insert(aten::add, {min_trip_count, min_trip_count});
368 orig_loop.currentTripCount()->replaceAllUsesWith(adjusted_iter_counter);
369 adjusted_iter_counter->node()->replaceInput(
370 0, orig_loop.currentTripCount());
371 }
372
373 return peeled_copy;
374}
375
376bool UnrollLoops(std::shared_ptr<Graph>& graph) {
377 bool changed = UnrollLoops(graph->block(), false);
378 if (changed) {
379 EliminateDeadCode(graph);
380 }
381 return changed;
382}
383
384bool UnrollConstantLoops(std::shared_ptr<Graph>& graph) {
385 bool changed = UnrollLoops(graph->block(), true);
386 if (changed) {
387 EliminateDeadCode(graph);
388 }
389 return changed;
390}
391
392} // namespace jit
393} // namespace torch
394