1 | #include <torch/csrc/jit/passes/lower_grad_of.h> |
2 | |
3 | #include <torch/csrc/jit/jit_log.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | void LowerGradOf(Graph& g) { |
9 | for (auto it = g.nodes().begin(); it != g.nodes().end(); ++it) { |
10 | if (it->kind() == prim::GradOf) { |
11 | // if any_defined(inputs): |
12 | // outputs = <original_computation> |
13 | // else: |
14 | // outputs = autograd zero tensors |
15 | WithInsertPoint guard(*it); |
16 | auto cond = g.insertNode(g.create(prim::AutogradAnyNonZero, it->inputs())) |
17 | ->output() |
18 | ->setType(IntType::get()); |
19 | auto if_stat = |
20 | g.insertNode(g.create(prim::If, {cond}, it->outputs().size())); |
21 | if_stat->addBlock()->cloneFrom( |
22 | it->blocks().at(0), [](Value* v) { return v; }); |
23 | auto else_block = if_stat->addBlock(); |
24 | auto undef = g.createAutogradZero() |
25 | ->insertBefore(else_block->return_node()) |
26 | ->output(); |
27 | for (size_t i = 0; i < it->outputs().size(); ++i) { |
28 | // the else block returns a tensor for each of the outputs of the GradOf |
29 | // i.e. assuming that all the outputs are tensors. This might not be |
30 | // true, e.g. backward for cat() returns a list of gradient tensors. |
31 | // This is fixed in DifferentiableGraphBackward, where the list sizes |
32 | // are stored during the forward pass, and then undefined tensors are |
33 | // turned into lists of undefined tensors where necessary. |
34 | else_block->registerOutput(undef); |
35 | if_stat->outputs().at(i)->copyMetadata(it->outputs().at(i)); |
36 | } |
37 | GRAPH_UPDATE("Replacing " , getHeader(*it), " with " , getHeader(if_stat)); |
38 | it->replaceAllUsesWith(if_stat); |
39 | it.destroyCurrent(); |
40 | } |
41 | } |
42 | } |
43 | |
44 | } // namespace jit |
45 | } // namespace torch |
46 | |