1#include <torch/csrc/jit/passes/inline_autodiff_subgraphs.h>
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/passes/dead_code_elimination.h>
5#include <torch/csrc/jit/passes/update_differentiable_graph_requires_grad.h>
6#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
7
8namespace torch {
9namespace jit {
10
11// aten and prim nodes (except FusionGroup) are guaranteed to work
12// with Autograd, other nodes (e.g. user-defined nodes) are not necessarily
13// Autograd-aware
14bool canRunWithAutograd(Node* node) {
15 auto kind = node->kind();
16 for (Block* block : node->blocks()) {
17 if (!std::all_of(
18 block->nodes().begin(), block->nodes().end(), canRunWithAutograd)) {
19 return false;
20 }
21 }
22 return kind != prim::FusionGroup && kind != prim::CudaFusionGroup &&
23 kind != prim::TypeCheck && kind != prim::TensorExprGroup &&
24 kind != prim::CudaFusionGuard && kind != prim::oneDNNFusionGroup &&
25 kind != prim::oneDNNFusionGuard && (kind.is_aten() || kind.is_prim());
26}
27
28namespace {
29
30void InlineAutodiffSubgraphs(Block* block, size_t threshold);
31
32size_t blockSize(Block* block) {
33 size_t num = 0;
34 for (Node* n : block->nodes()) {
35 for (Block* b : n->blocks()) {
36 num += blockSize(b);
37 }
38 num++;
39 }
40 return num;
41}
42
43graph_node_list::iterator scanNode(Node* node, size_t threshold) {
44 auto next_node = ++node->iterator();
45
46 for (Block* block : node->blocks()) {
47 InlineAutodiffSubgraphs(block, threshold);
48 }
49
50 if (node->kind() != prim::DifferentiableGraph) {
51 return next_node;
52 }
53
54 auto subgraph = node->g(attr::Subgraph);
55 size_t subgraph_size = blockSize(subgraph->block());
56 if (subgraph_size >= threshold) {
57 return next_node;
58 }
59
60 if (!std::all_of(
61 subgraph->nodes().begin(),
62 subgraph->nodes().end(),
63 canRunWithAutograd)) {
64 return next_node;
65 }
66
67 // now that we inline the graph, we are no longer detaching input tensors,
68 // so the profiles will have outdated requires_grad=False.
69 // conservatively update them to maybe requiring grad, bc we might create
70 // autodiff graphs when the tensors maybe require grad
71 UpdateDifferentiableGraphRequiresGrad(subgraph, c10::nullopt);
72 SubgraphUtils::unmergeSubgraph(node);
73 return next_node;
74}
75
76void InlineAutodiffSubgraphs(Block* block, size_t threshold) {
77 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
78 it = scanNode(*it, threshold);
79 }
80}
81
82} // anonymous namespace
83
84void InlineAutodiffSubgraphs(std::shared_ptr<Graph>& graph, size_t threshold) {
85 InlineAutodiffSubgraphs(graph->block(), threshold);
86 EliminateDeadCode(graph);
87}
88
89} // namespace jit
90} // namespace torch
91