1#include <torch/csrc/jit/passes/guard_elimination.h>
2
3#include <torch/csrc/jit/ir/alias_analysis.h>
4#include <torch/csrc/jit/jit_log.h>
5#include <torch/csrc/jit/passes/constant_propagation.h>
6#include <torch/csrc/jit/passes/peephole.h>
7#include <torch/csrc/jit/runtime/graph_executor.h>
8#include <memory>
9#include <unordered_set>
10
11namespace torch {
12namespace jit {
13
14struct GuardElimination {
15 GuardElimination(std::shared_ptr<Graph> graph)
16 : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {}
17
18 void run() {
19 const size_t MAX_ATTEMPTS = 5;
20 size_t attempts = MAX_ATTEMPTS;
21 while (attempts-- && moveGuardsToDefs(graph_->block())) {
22 }
23 GRAPH_DUMP("After moveGuardsToDefs", graph_);
24 coalesceGuards(graph_->block());
25 GRAPH_DUMP("After coalesceGuards", graph_);
26 removeDominatedGuards(graph_->block());
27 GRAPH_DUMP("After removeDominatedGuards", graph_);
28 eliminateRedundantGuards(graph_->block());
29 GRAPH_DUMP("After eliminateRedundantGuards", graph_);
30 }
31
32 static bool isLoweredGradOf(Node* n) {
33 if (n->kind() != prim::If) {
34 return false;
35 }
36
37 return n->input(0)->node()->kind() == prim::AutogradAnyNonZero;
38 }
39
40 bool moveGuardsToDefs(Block* b) {
41 bool changed = false;
42 for (auto it = b->nodes().begin(); it != b->nodes().end();) {
43 auto n = *it;
44 if (n->kind() == prim::Guard) {
45 // grab the next node before we move this one all the way back
46 it++;
47 auto guardee = n->inputs().at(0)->node();
48 // alias analysis will try to hoist a node out of a loop
49 // if asked. if guardee is in a loop, it should only
50 // be moved to the beginning of the basic block
51 // given the current implementation of AliasAnalysis
52 if (guardee->owningBlock() != n->owningBlock()) {
53 guardee = *n->owningBlock()->nodes().begin();
54 }
55 bool moved = aliasDb_->moveAfterTopologicallyValid(n, guardee);
56 changed |= moved;
57 if (moved) {
58 GRAPH_UPDATE(
59 "Moved ",
60 n->output()->debugName(),
61 " to ",
62 n->inputs().at(0)->debugName());
63 }
64 } else {
65 it++;
66 for (Block* ib : n->blocks()) {
67 moveGuardsToDefs(ib);
68 }
69 }
70 }
71
72 if (b->owningNode() &&
73 isLoweredGradOf(
74 b->owningNode()) /*b->owningNode()->kind() == prim::If*/) {
75 for (auto it = b->nodes().begin(); it != b->nodes().end();) {
76 auto block_node = *it++;
77 if (block_node->kind() != prim::Guard) {
78 break;
79 }
80 block_node->moveBefore(b->owningNode());
81 changed = true;
82 }
83 }
84
85 return changed;
86 }
87
88 void coalesceGuards(Block* b) {
89 // uses on *all* parameters are moved to the same anchor node
90 // and they may come in different order after the anchor node
91 // e.g. (anchor, guard_x, guard_y, guard_x, guard_y)
92 // this pass recognizes contigious streches of guards and
93 // keeps track of the guards it's seen for each def. the next time
94 // the guard on the same def, it simply removes it.
95 std::unordered_map<Value*, Node*> inputs_to_guards;
96 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
97 auto n = *it;
98 if (n->kind() == prim::Guard) {
99 if (inputs_to_guards.count(n->input())) {
100 auto prev = inputs_to_guards[n->input()];
101 n->output()->replaceAllUsesWith(prev->output());
102 GRAPH_UPDATE(
103 "Replacing ",
104 n->output()->debugName(),
105 " with ",
106 prev->output()->debugName());
107 it.destroyCurrent();
108 } else {
109 inputs_to_guards.insert({n->input(), n});
110 }
111 } else if (n->kind() != prim::Constant) {
112 inputs_to_guards.clear();
113 for (Block* ib : n->blocks()) {
114 coalesceGuards(ib);
115 }
116 }
117 }
118 }
119
120 void removeDominatedGuards(Block* b) {
121 // If a Node guards a value which isn't mutated, then that node
122 // can replace all other guards of the value which it dominates
123 for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
124 auto n = *it;
125 if (n->kind() == prim::Guard) {
126 Value* input = n->input();
127 if (aliasDb_->hasWriters(input)) {
128 continue;
129 }
130 Value* guard_output = n->output();
131
132 // find all uses of the input that the guard node dominates
133 std::vector<Use> uses = input->uses();
134 while (!uses.empty()) {
135 auto use = uses.at(uses.size() - 1);
136 uses.pop_back();
137
138 // not all uses are guarded
139 if (use.user->kind() != prim::Guard) {
140 continue;
141 }
142
143 if (!use.user->isDominatedBy(n)) {
144 continue;
145 }
146
147 // the dominated guard type may be different from the dominator
148 // if it is only executed for a subtype, or if it is executed
149 // in a different global context for grad enabled
150 // check that the types are equal before continuing
151
152 auto dominator_type = guard_output->type();
153 auto dominated_type = use.user->output()->type();
154
155 if (*dominator_type == *dominated_type) {
156 use.user->replaceInput(use.offset, guard_output);
157 }
158 }
159
160 // remove redundant dominated guards
161 std::vector<Use> users = n->output()->uses();
162 for (auto use : users) {
163 auto user = use.user;
164 if (user->kind() == prim::Guard) {
165 GRAPH_UPDATE(
166 "Removing dominated guard ", user, " and replacing with ", n);
167 user->output()->replaceAllUsesWith(guard_output);
168 user->destroy();
169 }
170 }
171 } else {
172 for (Block* ib : n->blocks()) {
173 removeDominatedGuards(ib);
174 }
175 }
176 }
177 }
178
179 // we need to make sure there are no ops in between guardee's
180 // output and its guard except for other guards as they can
181 // invalidate shape information.
182 bool guardsOutput(Node* guard) {
183 auto output = guard->input()->node();
184 auto it = guard;
185 while (it != output) {
186 if (it->kind() != prim::Guard && it->kind() != prim::Constant) {
187 GRAPH_DEBUG(
188 "found an unexpected node ",
189 *it,
190 " while trying to eliminate ",
191 *guard);
192 return false;
193 }
194 it = it->prev();
195 }
196
197 return true;
198 }
199
200 void eliminateRedundantGuards(Block* b) {
201 // a very simple pass to eliminate redundant guards for ops
202 // whose outputs are fully determined by their inputs
203 // i.e. if inputs to such ops are guarded we are allowed
204 // to remove a guard on ops' outputs
205 for (auto it = b->nodes().rbegin(); it != b->nodes().rend();) {
206 auto n = *it;
207 if (n->kind() == prim::Guard && guardsOutput(n) &&
208 removableGuard(n->inputs().at(0)->node())) {
209 auto pttp = n->output()->type();
210 n->output()->replaceAllUsesWith(n->inputs().at(0));
211 n->inputs().at(0)->setType(pttp);
212 GRAPH_UPDATE(
213 "Eliminating the redundant guard ", n->output()->debugName());
214 it.destroyCurrent();
215 } else {
216 it++;
217 for (Block* ib : n->blocks()) {
218 eliminateRedundantGuards(ib);
219 }
220 }
221 }
222 }
223
224 // `checkInputs` check the invariants specified in `removableGuard`
225 // on inputs to `n`. The invariants must hold, or an input must
226 // be a `prim::Constant` or be included as an exception in `except`
227 bool checkInputs(
228 Node* n,
229 const std::unordered_set<size_t>& except,
230 bool allow_numbers) {
231 bool all_inputs_guarded = true;
232 size_t i = 0;
233 for (auto input : n->inputs()) {
234 if ((input->node()->kind() == prim::Guard &&
235 !input->type()->expectRef<TensorType>().isSummarized()) ||
236 input->node()->kind() == prim::Constant ||
237 (allow_numbers && input->type()->isSubtypeOf(*NumberType::get())) ||
238 except.count(i) != 0) {
239 AT_ASSERT(
240 input->node()->kind() != prim::Guard ||
241 input->type()->expect<TensorType>());
242 } else {
243 GRAPH_DEBUG(
244 "input ",
245 input->debugName(),
246 " isn't guarded, type ",
247 *input->type());
248 all_inputs_guarded = false;
249 break;
250 }
251 i++;
252 }
253 return all_inputs_guarded;
254 }
255
256 private:
257 // `removableGuard` relies on the properties checked by `isSummarized()`
258 // and passes shouldn't insert nodes between a guard and its uses that
259 // may alter those properties.
260 // `removableGuard` expects type information to come directly from
261 // Profiler. Passes shouldn't try to alter type information provided by
262 // profiling
263 // While we can derive very simple rules stating when it's valid to remove
264 // `prim::Guard` on operation's output if all of its inputs are guarded for
265 // some
266 // categories of operations
267 // there's no comprehensive set of rules that covers all the operations
268 // available in PyTorch
269 // If your operation falls into one of the categories described below, you
270 // should add it
271 // to switch statement below that contains the other operations in the said
272 // category.
273 // Otherwise, you will need to derive the rules for your case on your own.
274 // Generally, any operation that is stateful in any way or uses its underlying
275 // data
276 // to compute any properties `isSummarized()` isn't amenable to guard
277 // elimination.
278 // Categories:
279 // * Functional-like(e.g. add, sub, le) operations with broadcast semenatics
280 // Guards can be removed if all inputs are guarded and `isSummarized()`
281 // returns
282 // false or inputs are `prim::Constant`
283 bool removableGuard(Node* n) {
284 const static auto no_exceptions = std::unordered_set<size_t>{};
285 switch (n->kind()) {
286 case aten::add:
287 case aten::add_:
288 case aten::sub:
289 case aten::mul:
290 case aten::div:
291 case aten::t:
292 case aten::sigmoid:
293 case aten::sin:
294 case aten::cos:
295 case aten::tan:
296 case aten::sinh:
297 case aten::cosh:
298 case aten::tanh:
299 case aten::asin:
300 case aten::acos:
301 case aten::atan:
302 case aten::atan2:
303 case aten::floor:
304 case aten::fmod:
305 case aten::ceil:
306 case aten::trunc:
307 case aten::sqrt:
308 case aten::rsqrt:
309 case aten::remainder:
310 case aten::mm:
311 case aten::min:
312 case aten::max:
313 case aten::type_as:
314 case aten::ge:
315 case aten::gt:
316 case aten::lt:
317 case aten::le:
318 case aten::eq:
319 case aten::ne:
320 case aten::neg:
321 case prim::ConstantChunk:
322 case aten::size:
323 case aten::abs:
324 case aten::sign:
325 case aten::pow:
326 case aten::relu:
327 case aten::threshold:
328 case prim::AutogradAdd:
329 case prim::AutogradZero:
330 case aten::rand_like:
331 case aten::erf:
332 case aten::erfc:
333 case aten::exp:
334 case aten::expm1:
335 case aten::log:
336 case aten::log2:
337 case aten::log10:
338 case aten::frac:
339 case aten::lerp:
340 case aten::lgamma:
341 case aten::reciprocal:
342 case aten::addcmul:
343 case aten::where:
344 case aten::_cast_Float:
345 case aten::_cast_Long:
346 case aten::__and__:
347 case aten::__or__:
348 case aten::__xor__:
349 case aten::__lshift__:
350 case aten::__rshift__:
351 case aten::bitwise_not:
352 case aten::bitwise_and:
353 case aten::bitwise_or:
354 case aten::bitwise_xor:
355 return checkInputs(n, no_exceptions, true);
356 case aten::softmax:
357 return checkInputs(n, std::unordered_set<size_t>{1}, true);
358 case aten::multinomial:
359 return checkInputs(n, std::unordered_set<size_t>{2, 3}, false);
360 case aten::flatten:
361 case aten::argmax:
362 case aten::squeeze:
363 case aten::avg_pool2d:
364 return checkInputs(n, no_exceptions, false);
365 case aten::conv1d:
366 case aten::conv2d:
367 case aten::conv3d:
368 return checkInputs(n, std::unordered_set<size_t>{2, 6}, false);
369 case aten::slice:
370 return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
371 // check that the dimension argument is constant
372 n->input(1)->node()->kind() == prim::Constant &&
373 // the start offset is constant
374 n->input(2)->node()->kind() == prim::Constant &&
375 // the end offset is constant
376 n->input(3)->node()->kind() == prim::Constant &&
377 // the stride is constant
378 n->input(4)->node()->kind() == prim::Constant;
379 case aten::max_pool1d:
380 case aten::max_pool2d:
381 case aten::max_pool3d:
382 return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
383 // check that the kernel size is constant
384 n->input(1)->node()->kind() == prim::Constant &&
385 // check that the stride is constant
386 n->input(2)->node()->kind() == prim::Constant &&
387 // check that the padding is constant
388 n->input(3)->node()->kind() == prim::Constant &&
389 // check that the dilation is constant
390 n->input(4)->node()->kind() == prim::Constant &&
391 // check that the ceil_mode is constant
392 n->input(5)->node()->kind() == prim::Constant;
393 case aten::unsqueeze:
394 // check that the dimension argument is constant
395 return !n->input(0)->type()->expectRef<TensorType>().isSummarized() &&
396 n->input(1)->node()->kind() == prim::Constant;
397 case aten::cat:
398 // check that the dimension argument is constant
399 return n->input(1)->node()->kind() == prim::Constant &&
400 n->input(0)->node()->kind() == prim::ListConstruct &&
401 // no extra nodes in between aten::cat and prim::ListConstruct
402 n->prev() == n->input(0)->node() &&
403 // check the inputs to prim::ListConstruct (not aten::cat)
404 checkInputs(n->input(0)->node(), no_exceptions, false);
405 case aten::clamp:
406 // the second and third args do not affect shapes
407 return checkInputs(n, std::unordered_set<size_t>{1, 2}, false);
408 // after some optimizations we might end up with two Guards back-to-back
409 // which case we can remove the one whose input is also prim::Guard
410 case aten::_grad_sum_to_size:
411 // skip checking size argument
412 if (checkInputs(n, std::unordered_set<size_t>{1}, false)) {
413 auto asize = n->input(1)->node();
414 if (asize->kind() == prim::Constant) {
415 return true;
416 } else if (asize->matches("aten::size(Tensor self) -> int[]")) {
417 // aten::size is effectively a constant
418 if (asize->input()
419 ->type()
420 ->expectRef<TensorType>()
421 .sizes()
422 .concrete_sizes()) {
423 return true;
424 }
425 }
426 }
427 return false;
428
429 // this is checked by one of the tests in test_jit_fuser.py
430 case prim::ListUnpack: {
431 // check if the input is a constant chunk
432 // used for LSTM fusions
433 auto chunk = n->input(0)->node();
434 if (chunk->kind() != aten::chunk) {
435 return false;
436 }
437 return checkInputs(chunk, no_exceptions, false);
438 }
439 // this is checked by one of the tests in test_jit_fuser.py
440 case aten::broadcast_tensors: {
441 auto list_construct = n->input(0)->node();
442 if (list_construct->kind() != prim::ListConstruct) {
443 return false;
444 }
445 return checkInputs(list_construct, no_exceptions, false);
446 }
447 case prim::Guard:
448 case prim::GradOf:
449 return true;
450 default:
451 GRAPH_DEBUG("cannot remove ", n->kind().toQualString());
452 return false;
453 }
454 }
455
456 std::shared_ptr<Graph> graph_;
457 std::unique_ptr<AliasDb> aliasDb_;
458 static std::unordered_set<Symbol> simple_ops_;
459};
460
461void EliminateRedundantGuards(std::shared_ptr<Graph> graph) {
462 GuardElimination ge(std::move(graph));
463 ge.run();
464}
465
466} // namespace jit
467} // namespace torch
468