1 | #include <torch/csrc/jit/passes/guard_elimination.h> |
2 | |
3 | #include <torch/csrc/jit/ir/alias_analysis.h> |
4 | #include <torch/csrc/jit/jit_log.h> |
5 | #include <torch/csrc/jit/passes/constant_propagation.h> |
6 | #include <torch/csrc/jit/passes/peephole.h> |
7 | #include <torch/csrc/jit/runtime/graph_executor.h> |
8 | #include <memory> |
9 | #include <unordered_set> |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | struct GuardElimination { |
15 | GuardElimination(std::shared_ptr<Graph> graph) |
16 | : graph_(std::move(graph)), aliasDb_(std::make_unique<AliasDb>(graph_)) {} |
17 | |
18 | void run() { |
19 | const size_t MAX_ATTEMPTS = 5; |
20 | size_t attempts = MAX_ATTEMPTS; |
21 | while (attempts-- && moveGuardsToDefs(graph_->block())) { |
22 | } |
23 | GRAPH_DUMP("After moveGuardsToDefs" , graph_); |
24 | coalesceGuards(graph_->block()); |
25 | GRAPH_DUMP("After coalesceGuards" , graph_); |
26 | removeDominatedGuards(graph_->block()); |
27 | GRAPH_DUMP("After removeDominatedGuards" , graph_); |
28 | eliminateRedundantGuards(graph_->block()); |
29 | GRAPH_DUMP("After eliminateRedundantGuards" , graph_); |
30 | } |
31 | |
32 | static bool isLoweredGradOf(Node* n) { |
33 | if (n->kind() != prim::If) { |
34 | return false; |
35 | } |
36 | |
37 | return n->input(0)->node()->kind() == prim::AutogradAnyNonZero; |
38 | } |
39 | |
40 | bool moveGuardsToDefs(Block* b) { |
41 | bool changed = false; |
42 | for (auto it = b->nodes().begin(); it != b->nodes().end();) { |
43 | auto n = *it; |
44 | if (n->kind() == prim::Guard) { |
45 | // grab the next node before we move this one all the way back |
46 | it++; |
47 | auto guardee = n->inputs().at(0)->node(); |
48 | // alias analysis will try to hoist a node out of a loop |
49 | // if asked. if guardee is in a loop, it should only |
50 | // be moved to the beginning of the basic block |
51 | // given the current implementation of AliasAnalysis |
52 | if (guardee->owningBlock() != n->owningBlock()) { |
53 | guardee = *n->owningBlock()->nodes().begin(); |
54 | } |
55 | bool moved = aliasDb_->moveAfterTopologicallyValid(n, guardee); |
56 | changed |= moved; |
57 | if (moved) { |
58 | GRAPH_UPDATE( |
59 | "Moved " , |
60 | n->output()->debugName(), |
61 | " to " , |
62 | n->inputs().at(0)->debugName()); |
63 | } |
64 | } else { |
65 | it++; |
66 | for (Block* ib : n->blocks()) { |
67 | moveGuardsToDefs(ib); |
68 | } |
69 | } |
70 | } |
71 | |
72 | if (b->owningNode() && |
73 | isLoweredGradOf( |
74 | b->owningNode()) /*b->owningNode()->kind() == prim::If*/) { |
75 | for (auto it = b->nodes().begin(); it != b->nodes().end();) { |
76 | auto block_node = *it++; |
77 | if (block_node->kind() != prim::Guard) { |
78 | break; |
79 | } |
80 | block_node->moveBefore(b->owningNode()); |
81 | changed = true; |
82 | } |
83 | } |
84 | |
85 | return changed; |
86 | } |
87 | |
88 | void coalesceGuards(Block* b) { |
89 | // uses on *all* parameters are moved to the same anchor node |
90 | // and they may come in different order after the anchor node |
91 | // e.g. (anchor, guard_x, guard_y, guard_x, guard_y) |
92 | // this pass recognizes contigious streches of guards and |
93 | // keeps track of the guards it's seen for each def. the next time |
94 | // the guard on the same def, it simply removes it. |
95 | std::unordered_map<Value*, Node*> inputs_to_guards; |
96 | for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { |
97 | auto n = *it; |
98 | if (n->kind() == prim::Guard) { |
99 | if (inputs_to_guards.count(n->input())) { |
100 | auto prev = inputs_to_guards[n->input()]; |
101 | n->output()->replaceAllUsesWith(prev->output()); |
102 | GRAPH_UPDATE( |
103 | "Replacing " , |
104 | n->output()->debugName(), |
105 | " with " , |
106 | prev->output()->debugName()); |
107 | it.destroyCurrent(); |
108 | } else { |
109 | inputs_to_guards.insert({n->input(), n}); |
110 | } |
111 | } else if (n->kind() != prim::Constant) { |
112 | inputs_to_guards.clear(); |
113 | for (Block* ib : n->blocks()) { |
114 | coalesceGuards(ib); |
115 | } |
116 | } |
117 | } |
118 | } |
119 | |
120 | void removeDominatedGuards(Block* b) { |
121 | // If a Node guards a value which isn't mutated, then that node |
122 | // can replace all other guards of the value which it dominates |
123 | for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) { |
124 | auto n = *it; |
125 | if (n->kind() == prim::Guard) { |
126 | Value* input = n->input(); |
127 | if (aliasDb_->hasWriters(input)) { |
128 | continue; |
129 | } |
130 | Value* guard_output = n->output(); |
131 | |
132 | // find all uses of the input that the guard node dominates |
133 | std::vector<Use> uses = input->uses(); |
134 | while (!uses.empty()) { |
135 | auto use = uses.at(uses.size() - 1); |
136 | uses.pop_back(); |
137 | |
138 | // not all uses are guarded |
139 | if (use.user->kind() != prim::Guard) { |
140 | continue; |
141 | } |
142 | |
143 | if (!use.user->isDominatedBy(n)) { |
144 | continue; |
145 | } |
146 | |
147 | // the dominated guard type may be different from the dominator |
148 | // if it is only executed for a subtype, or if it is executed |
149 | // in a different global context for grad enabled |
150 | // check that the types are equal before continuing |
151 | |
152 | auto dominator_type = guard_output->type(); |
153 | auto dominated_type = use.user->output()->type(); |
154 | |
155 | if (*dominator_type == *dominated_type) { |
156 | use.user->replaceInput(use.offset, guard_output); |
157 | } |
158 | } |
159 | |
160 | // remove redundant dominated guards |
161 | std::vector<Use> users = n->output()->uses(); |
162 | for (auto use : users) { |
163 | auto user = use.user; |
164 | if (user->kind() == prim::Guard) { |
165 | GRAPH_UPDATE( |
166 | "Removing dominated guard " , user, " and replacing with " , n); |
167 | user->output()->replaceAllUsesWith(guard_output); |
168 | user->destroy(); |
169 | } |
170 | } |
171 | } else { |
172 | for (Block* ib : n->blocks()) { |
173 | removeDominatedGuards(ib); |
174 | } |
175 | } |
176 | } |
177 | } |
178 | |
179 | // we need to make sure there are no ops in between guardee's |
180 | // output and its guard except for other guards as they can |
181 | // invalidate shape information. |
182 | bool guardsOutput(Node* guard) { |
183 | auto output = guard->input()->node(); |
184 | auto it = guard; |
185 | while (it != output) { |
186 | if (it->kind() != prim::Guard && it->kind() != prim::Constant) { |
187 | GRAPH_DEBUG( |
188 | "found an unexpected node " , |
189 | *it, |
190 | " while trying to eliminate " , |
191 | *guard); |
192 | return false; |
193 | } |
194 | it = it->prev(); |
195 | } |
196 | |
197 | return true; |
198 | } |
199 | |
200 | void eliminateRedundantGuards(Block* b) { |
201 | // a very simple pass to eliminate redundant guards for ops |
202 | // whose outputs are fully determined by their inputs |
203 | // i.e. if inputs to such ops are guarded we are allowed |
204 | // to remove a guard on ops' outputs |
205 | for (auto it = b->nodes().rbegin(); it != b->nodes().rend();) { |
206 | auto n = *it; |
207 | if (n->kind() == prim::Guard && guardsOutput(n) && |
208 | removableGuard(n->inputs().at(0)->node())) { |
209 | auto pttp = n->output()->type(); |
210 | n->output()->replaceAllUsesWith(n->inputs().at(0)); |
211 | n->inputs().at(0)->setType(pttp); |
212 | GRAPH_UPDATE( |
213 | "Eliminating the redundant guard " , n->output()->debugName()); |
214 | it.destroyCurrent(); |
215 | } else { |
216 | it++; |
217 | for (Block* ib : n->blocks()) { |
218 | eliminateRedundantGuards(ib); |
219 | } |
220 | } |
221 | } |
222 | } |
223 | |
224 | // `checkInputs` check the invariants specified in `removableGuard` |
225 | // on inputs to `n`. The invariants must hold, or an input must |
226 | // be a `prim::Constant` or be included as an exception in `except` |
227 | bool checkInputs( |
228 | Node* n, |
229 | const std::unordered_set<size_t>& except, |
230 | bool allow_numbers) { |
231 | bool all_inputs_guarded = true; |
232 | size_t i = 0; |
233 | for (auto input : n->inputs()) { |
234 | if ((input->node()->kind() == prim::Guard && |
235 | !input->type()->expectRef<TensorType>().isSummarized()) || |
236 | input->node()->kind() == prim::Constant || |
237 | (allow_numbers && input->type()->isSubtypeOf(*NumberType::get())) || |
238 | except.count(i) != 0) { |
239 | AT_ASSERT( |
240 | input->node()->kind() != prim::Guard || |
241 | input->type()->expect<TensorType>()); |
242 | } else { |
243 | GRAPH_DEBUG( |
244 | "input " , |
245 | input->debugName(), |
246 | " isn't guarded, type " , |
247 | *input->type()); |
248 | all_inputs_guarded = false; |
249 | break; |
250 | } |
251 | i++; |
252 | } |
253 | return all_inputs_guarded; |
254 | } |
255 | |
256 | private: |
257 | // `removableGuard` relies on the properties checked by `isSummarized()` |
258 | // and passes shouldn't insert nodes between a guard and its uses that |
259 | // may alter those properties. |
260 | // `removableGuard` expects type information to come directly from |
261 | // Profiler. Passes shouldn't try to alter type information provided by |
262 | // profiling |
263 | // While we can derive very simple rules stating when it's valid to remove |
264 | // `prim::Guard` on operation's output if all of its inputs are guarded for |
265 | // some |
266 | // categories of operations |
267 | // there's no comprehensive set of rules that covers all the operations |
268 | // available in PyTorch |
269 | // If your operation falls into one of the categories described below, you |
270 | // should add it |
271 | // to switch statement below that contains the other operations in the said |
272 | // category. |
273 | // Otherwise, you will need to derive the rules for your case on your own. |
274 | // Generally, any operation that is stateful in any way or uses its underlying |
275 | // data |
276 | // to compute any properties `isSummarized()` isn't amenable to guard |
277 | // elimination. |
278 | // Categories: |
279 | // * Functional-like(e.g. add, sub, le) operations with broadcast semenatics |
280 | // Guards can be removed if all inputs are guarded and `isSummarized()` |
281 | // returns |
282 | // false or inputs are `prim::Constant` |
283 | bool removableGuard(Node* n) { |
284 | const static auto no_exceptions = std::unordered_set<size_t>{}; |
285 | switch (n->kind()) { |
286 | case aten::add: |
287 | case aten::add_: |
288 | case aten::sub: |
289 | case aten::mul: |
290 | case aten::div: |
291 | case aten::t: |
292 | case aten::sigmoid: |
293 | case aten::sin: |
294 | case aten::cos: |
295 | case aten::tan: |
296 | case aten::sinh: |
297 | case aten::cosh: |
298 | case aten::tanh: |
299 | case aten::asin: |
300 | case aten::acos: |
301 | case aten::atan: |
302 | case aten::atan2: |
303 | case aten::floor: |
304 | case aten::fmod: |
305 | case aten::ceil: |
306 | case aten::trunc: |
307 | case aten::sqrt: |
308 | case aten::rsqrt: |
309 | case aten::remainder: |
310 | case aten::mm: |
311 | case aten::min: |
312 | case aten::max: |
313 | case aten::type_as: |
314 | case aten::ge: |
315 | case aten::gt: |
316 | case aten::lt: |
317 | case aten::le: |
318 | case aten::eq: |
319 | case aten::ne: |
320 | case aten::neg: |
321 | case prim::ConstantChunk: |
322 | case aten::size: |
323 | case aten::abs: |
324 | case aten::sign: |
325 | case aten::pow: |
326 | case aten::relu: |
327 | case aten::threshold: |
328 | case prim::AutogradAdd: |
329 | case prim::AutogradZero: |
330 | case aten::rand_like: |
331 | case aten::erf: |
332 | case aten::erfc: |
333 | case aten::exp: |
334 | case aten::expm1: |
335 | case aten::log: |
336 | case aten::log2: |
337 | case aten::log10: |
338 | case aten::frac: |
339 | case aten::lerp: |
340 | case aten::lgamma: |
341 | case aten::reciprocal: |
342 | case aten::addcmul: |
343 | case aten::where: |
344 | case aten::_cast_Float: |
345 | case aten::_cast_Long: |
346 | case aten::__and__: |
347 | case aten::__or__: |
348 | case aten::__xor__: |
349 | case aten::__lshift__: |
350 | case aten::__rshift__: |
351 | case aten::bitwise_not: |
352 | case aten::bitwise_and: |
353 | case aten::bitwise_or: |
354 | case aten::bitwise_xor: |
355 | return checkInputs(n, no_exceptions, true); |
356 | case aten::softmax: |
357 | return checkInputs(n, std::unordered_set<size_t>{1}, true); |
358 | case aten::multinomial: |
359 | return checkInputs(n, std::unordered_set<size_t>{2, 3}, false); |
360 | case aten::flatten: |
361 | case aten::argmax: |
362 | case aten::squeeze: |
363 | case aten::avg_pool2d: |
364 | return checkInputs(n, no_exceptions, false); |
365 | case aten::conv1d: |
366 | case aten::conv2d: |
367 | case aten::conv3d: |
368 | return checkInputs(n, std::unordered_set<size_t>{2, 6}, false); |
369 | case aten::slice: |
370 | return !n->input(0)->type()->expectRef<TensorType>().isSummarized() && |
371 | // check that the dimension argument is constant |
372 | n->input(1)->node()->kind() == prim::Constant && |
373 | // the start offset is constant |
374 | n->input(2)->node()->kind() == prim::Constant && |
375 | // the end offset is constant |
376 | n->input(3)->node()->kind() == prim::Constant && |
377 | // the stride is constant |
378 | n->input(4)->node()->kind() == prim::Constant; |
379 | case aten::max_pool1d: |
380 | case aten::max_pool2d: |
381 | case aten::max_pool3d: |
382 | return !n->input(0)->type()->expectRef<TensorType>().isSummarized() && |
383 | // check that the kernel size is constant |
384 | n->input(1)->node()->kind() == prim::Constant && |
385 | // check that the stride is constant |
386 | n->input(2)->node()->kind() == prim::Constant && |
387 | // check that the padding is constant |
388 | n->input(3)->node()->kind() == prim::Constant && |
389 | // check that the dilation is constant |
390 | n->input(4)->node()->kind() == prim::Constant && |
391 | // check that the ceil_mode is constant |
392 | n->input(5)->node()->kind() == prim::Constant; |
393 | case aten::unsqueeze: |
394 | // check that the dimension argument is constant |
395 | return !n->input(0)->type()->expectRef<TensorType>().isSummarized() && |
396 | n->input(1)->node()->kind() == prim::Constant; |
397 | case aten::cat: |
398 | // check that the dimension argument is constant |
399 | return n->input(1)->node()->kind() == prim::Constant && |
400 | n->input(0)->node()->kind() == prim::ListConstruct && |
401 | // no extra nodes in between aten::cat and prim::ListConstruct |
402 | n->prev() == n->input(0)->node() && |
403 | // check the inputs to prim::ListConstruct (not aten::cat) |
404 | checkInputs(n->input(0)->node(), no_exceptions, false); |
405 | case aten::clamp: |
406 | // the second and third args do not affect shapes |
407 | return checkInputs(n, std::unordered_set<size_t>{1, 2}, false); |
408 | // after some optimizations we might end up with two Guards back-to-back |
409 | // which case we can remove the one whose input is also prim::Guard |
410 | case aten::_grad_sum_to_size: |
411 | // skip checking size argument |
412 | if (checkInputs(n, std::unordered_set<size_t>{1}, false)) { |
413 | auto asize = n->input(1)->node(); |
414 | if (asize->kind() == prim::Constant) { |
415 | return true; |
416 | } else if (asize->matches("aten::size(Tensor self) -> int[]" )) { |
417 | // aten::size is effectively a constant |
418 | if (asize->input() |
419 | ->type() |
420 | ->expectRef<TensorType>() |
421 | .sizes() |
422 | .concrete_sizes()) { |
423 | return true; |
424 | } |
425 | } |
426 | } |
427 | return false; |
428 | |
429 | // this is checked by one of the tests in test_jit_fuser.py |
430 | case prim::ListUnpack: { |
431 | // check if the input is a constant chunk |
432 | // used for LSTM fusions |
433 | auto chunk = n->input(0)->node(); |
434 | if (chunk->kind() != aten::chunk) { |
435 | return false; |
436 | } |
437 | return checkInputs(chunk, no_exceptions, false); |
438 | } |
439 | // this is checked by one of the tests in test_jit_fuser.py |
440 | case aten::broadcast_tensors: { |
441 | auto list_construct = n->input(0)->node(); |
442 | if (list_construct->kind() != prim::ListConstruct) { |
443 | return false; |
444 | } |
445 | return checkInputs(list_construct, no_exceptions, false); |
446 | } |
447 | case prim::Guard: |
448 | case prim::GradOf: |
449 | return true; |
450 | default: |
451 | GRAPH_DEBUG("cannot remove " , n->kind().toQualString()); |
452 | return false; |
453 | } |
454 | } |
455 | |
456 | std::shared_ptr<Graph> graph_; |
457 | std::unique_ptr<AliasDb> aliasDb_; |
458 | static std::unordered_set<Symbol> simple_ops_; |
459 | }; |
460 | |
461 | void EliminateRedundantGuards(std::shared_ptr<Graph> graph) { |
462 | GuardElimination ge(std::move(graph)); |
463 | ge.run(); |
464 | } |
465 | |
466 | } // namespace jit |
467 | } // namespace torch |
468 | |