1 | #include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h> |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
5 | #include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h> |
6 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | // aten and prim nodes (except FusionGroup) are guaranteed to work |
12 | // with Autograd, other nodes (e.g. user-defined nodes) are not necessarily |
13 | // Autograd-aware |
14 | bool canRunWithAutograd(Node* node) { |
15 | auto kind = node->kind(); |
16 | for (Block* block : node->blocks()) { |
17 | if (!std::all_of( |
18 | block->nodes().begin(), block->nodes().end(), canRunWithAutograd)) { |
19 | return false; |
20 | } |
21 | } |
22 | return kind != prim::FusionGroup && kind != prim::CudaFusionGroup && |
23 | kind != prim::TypeCheck && kind != prim::TensorExprGroup && |
24 | kind != prim::CudaFusionGuard && kind != prim::oneDNNFusionGroup && |
25 | kind != prim::oneDNNFusionGuard && (kind.is_aten() || kind.is_prim()); |
26 | } |
27 | |
28 | namespace { |
29 | |
30 | void InlineAutodiffSubgraphs(Block* block, size_t threshold); |
31 | |
32 | size_t blockSize(Block* block) { |
33 | size_t num = 0; |
34 | for (Node* n : block->nodes()) { |
35 | for (Block* b : n->blocks()) { |
36 | num += blockSize(b); |
37 | } |
38 | num++; |
39 | } |
40 | return num; |
41 | } |
42 | |
43 | graph_node_list::iterator scanNode(Node* node, size_t threshold) { |
44 | auto next_node = ++node->iterator(); |
45 | |
46 | for (Block* block : node->blocks()) { |
47 | InlineAutodiffSubgraphs(block, threshold); |
48 | } |
49 | |
50 | if (node->kind() != prim::DifferentiableGraph) { |
51 | return next_node; |
52 | } |
53 | |
54 | auto subgraph = node->g(attr::Subgraph); |
55 | size_t subgraph_size = blockSize(subgraph->block()); |
56 | if (subgraph_size >= threshold) { |
57 | return next_node; |
58 | } |
59 | |
60 | if (!std::all_of( |
61 | subgraph->nodes().begin(), |
62 | subgraph->nodes().end(), |
63 | canRunWithAutograd)) { |
64 | return next_node; |
65 | } |
66 | |
67 | // now that we inline the graph, we are no longer detaching input tensors, |
68 | // so the profiles will have outdated requires_grad=False. |
69 | // conservatively update them to maybe requiring grad, bc we might create |
70 | // autodiff graphs when the tensors maybe require grad |
71 | UpdateDifferentiableGraphRequiresGrad(subgraph, c10::nullopt); |
72 | SubgraphUtils::unmergeSubgraph(node); |
73 | return next_node; |
74 | } |
75 | |
76 | void InlineAutodiffSubgraphs(Block* block, size_t threshold) { |
77 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
78 | it = scanNode(*it, threshold); |
79 | } |
80 | } |
81 | |
82 | } // anonymous namespace |
83 | |
84 | void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold) { |
85 | InlineAutodiffSubgraphs(graph->block(), threshold); |
86 | EliminateDeadCode(graph); |
87 | } |
88 | |
89 | } // namespace jit |
90 | } // namespace torch |
91 | |