1#include <torch/csrc/jit/passes/create_autodiff_subgraphs.h>
2
3#include <c10/util/Exception.h>
4#include <torch/csrc/jit/ir/alias_analysis.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/jit_log.h>
7#include <torch/csrc/jit/passes/canonicalize.h>
8#include <torch/csrc/jit/passes/common_subexpression_elimination.h>
9#include <torch/csrc/jit/passes/remove_redundant_profiles.h>
10#include <torch/csrc/jit/passes/utils/subgraph_utils.h>
11#include <torch/csrc/jit/runtime/autodiff.h>
12
13namespace torch {
14namespace jit {
15
16namespace {
17
18struct WorkBlock : public std::pair<Node*, Node*> {
19 using pair::pair;
20
21 Node* begin() {
22 return this->first;
23 }
24 Node* end() {
25 return this->second;
26 }
27};
28
29class SubgraphSlicer {
30 public:
31 SubgraphSlicer(
32 Block* block,
33 std::shared_ptr<Graph> graph,
34 size_t minSubgraphSize,
35 AliasDb& aliasDb,
36 std::vector<Node*>& diff_nodes)
37 : block_(block),
38 graph_(std::move(graph)),
39 minSubgraphSize_(minSubgraphSize),
40 aliasDb_(aliasDb),
41 diff_nodes_(diff_nodes) {}
42
43 void run() {
44 // We maintain alias db correctness in-place while building up the autodiff
45 // subgraphs, however it is difficult to preserve correctness when
46 // un-inlining autodiff subgraphs. We first recursively construct all
47 // subgraphs and then recursively cleanup & unmerge the small subgraphs
48 buildupSubgraphs();
49 GRAPH_DUMP("before unfuseAliasedOutputs", graph_);
50 unfuseAliasedOutputs(block_);
51 cleanupSubgraphs();
52 // Run CSE globally onceto eliminate duplicates that may have occurred
53 // while inlining subgraphs.
54 EliminateCommonSubexpression(graph_);
55 }
56
57 void cleanupSubgraphs() {
58 auto curNode = *block_->nodes().rbegin();
59 while (curNode != *block_->nodes().rend()) {
60 // Save the previous node, since we might delete `curNode` in next block
61 auto prevNode = curNode->prev();
62 if (curNode->kind() == prim::DifferentiableGraph) {
63 // Inlining nodes may cause some subexpression to come back in the
64 // subgraphs (for example, copying constants in repeatedly will generate
65 // redundant prim::Constants). Run CSE to clean them up.
66 EliminateCommonSubexpression(curNode->g(attr::Subgraph));
67
68 if (!inlineIfTooSmall(curNode)) {
69 diff_nodes_.push_back(curNode);
70 }
71 }
72 curNode = prevNode;
73 }
74
75 for (Node* n : block_->nodes()) {
76 for (Block* b : n->blocks()) {
77 SubgraphSlicer(b, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
78 .cleanupSubgraphs();
79 }
80 }
81 }
82
83 void buildupSubgraphs() {
84 // We need to run the slicer multiple times in order to get all merge
85 // opportunities. This is because moveBeforeTopologicalValid may reorder
86 // nodes to be AFTER the current iteration point. In order to properly
87 // consider those nodes for merging, we need run the pass until no changes
88 // have been made.
89 //
90 // Example:
91 // c = f(a, b)
92 // d = f(c)
93 // e = f(d) <- iter is here, moving upward
94 // After c.moveBeforeTopologicallyValid(e), we have:
95 // c = f(a, b)
96 // e = f(d) <- iter still here
97 // d = f(c) <- this was node moved on the other side.
98
99 // see [workblocks]
100 auto workblocks = buildWorkBlocks();
101 for (auto& workblock : workblocks) {
102 bool any_changed = true;
103 while (any_changed) {
104 any_changed = false;
105 for (auto it = workblock.end()->reverseIterator();
106 it != workblock.begin()->reverseIterator();) {
107 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
108 bool changed;
109 std::tie(it, changed) = scanNode(*it);
110 any_changed |= changed;
111 }
112 }
113 }
114
115 // Construct Subgraphs Recursively
116 for (Node* n : block_->nodes()) {
117 for (auto subBlock : n->blocks()) {
118 SubgraphSlicer(
119 subBlock, graph_, minSubgraphSize_, aliasDb_, diff_nodes_)
120 .buildupSubgraphs();
121 }
122 }
123 }
124
125 private:
126 void unfuseAliasedOutputs(Block* b) {
127 bool any_changed = true;
128 while (any_changed) {
129 any_changed = false;
130 // we walk in the reverse order, so we can skip
131 // nodes that might get unfused after the current
132 // prim::DifferentiableGraph
133 for (auto n : b->nodes().reverse()) {
134 if (n->kind() == prim::DifferentiableGraph) {
135 // aliased outputs in DifferentiableGraphs must be unfused
136 // since autodiff doesn't know how to handle them correctly
137 // N.B. Note, |= since we don't want `unfuseAliasedOutputs`
138 // to short-circuit
139 any_changed |= SubgraphUtils::unmergeAliasedOutputs(n);
140 any_changed |= SubgraphUtils::unmergeOutputsAlisingInputs(n);
141 GRAPH_DEBUG(
142 "any_changed on ",
143 any_changed,
144 " ",
145 n->g(attr::Subgraph)->toString(false));
146 }
147 }
148 }
149
150 for (Node* n : b->nodes()) {
151 for (Block* ib : n->blocks()) {
152 unfuseAliasedOutputs(ib);
153 }
154 }
155 }
156
157 std::vector<WorkBlock> buildWorkBlocks() {
158 // [workblocks]
159 // the IR has many nodes which can never be reordered around, such as a
160 // prim::Bailout. if a node N is surrounded by two nodes which cannot be
161 // reordered, A and B, then a differentiable subgraph that is created from N
162 // can only contain nodes from (A, B) The nodes from A to B represent one
163 // work block for the subgraph slicer to work on. By creating these up
164 // front, we avoid retraversing the whole graph block any time scanNode
165 // returns, and we can also avoid attempting to create differentiable
166 // subgraphs in work blocks that do not contain a # of differentiable nodes
167 // >= minSubgraphSize_
168
169 Node* end_bound_node = block_->return_node();
170 Node* curr = end_bound_node->prev();
171
172 std::vector<WorkBlock> worklist;
173 size_t differentiable_nodes = 0;
174
175 while (curr != block_->param_node()) {
176 differentiable_nodes += shouldConsiderForMerge(curr);
177
178 // cannot reorder around side effectful nodes
179 if (curr->hasSideEffects()) {
180 // not enough differentiable nodes to create a differentiable subgraph
181 if (differentiable_nodes >= minSubgraphSize_) {
182 worklist.emplace_back(curr, end_bound_node);
183 }
184 differentiable_nodes = 0;
185 end_bound_node = curr;
186 }
187 curr = curr->prev();
188 }
189
190 if (differentiable_nodes >= minSubgraphSize_) {
191 worklist.emplace_back(curr, end_bound_node);
192 }
193
194 return worklist;
195 }
196
197 // Inline this node's group subgraph into the outer graph if it's smaller
198 // than the specified minimum size.
199 //
200 // Returns true if an inlining has occurred, false otherwise.
201 bool inlineIfTooSmall(Node* n) {
202 AT_ASSERT(n->kind() == prim::DifferentiableGraph);
203 auto subgraph = SubgraphUtils::getSubgraph(n);
204 size_t i = 0;
205 for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end();
206 ++it) {
207 i += !it->notExecutedOp();
208 if (i >= minSubgraphSize_) {
209 return false;
210 }
211 }
212
213 SubgraphUtils::unmergeSubgraph(n);
214 return true;
215 }
216
217 value_list sortReverseTopological(ArrayRef<Value*> inputs) {
218 value_list result;
219 for (auto i : inputs) {
220 if (i->node()->owningBlock() == block_) {
221 result.push_back(i);
222 }
223 }
224 // Sort in reverse topological order
225 std::sort(result.begin(), result.end(), [&](Value* a, Value* b) {
226 return a->node()->isAfter(b->node());
227 });
228 return result;
229 }
230
231 bool isViewOp(Node* n) {
232 switch (n->kind()) {
233 case aten::view:
234 case aten::view_as:
235 case aten::reshape:
236 case aten::reshape_as:
237 case aten::transpose:
238 case aten::expand:
239 case aten::expand_as:
240 return true;
241 }
242 return false;
243 }
244
245 bool shouldConsiderForMerge(Node* node) {
246 // if we're already in the process of merging
247 if (node->kind() == prim::DifferentiableGraph) {
248 return true;
249 }
250 if (node->kind() == prim::Constant) {
251 return false;
252 }
253
254 // view ops as outputs of differentiable subgraphs can cause incorrect
255 // differentiation for now, do not include them in the subgraph
256 if (isViewOp(node)) {
257 return false;
258 }
259
260 return isDifferentiable(node);
261 }
262
263 std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) {
264 if (shouldConsiderForMerge(consumer)) {
265 if (consumer->kind() != prim::DifferentiableGraph) {
266 consumer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing(
267 consumer, prim::DifferentiableGraph, aliasDb_);
268 }
269 auto inputs = sortReverseTopological(consumer->inputs());
270 for (auto input : inputs) {
271 if (auto group = tryMerge(consumer, input->node())) {
272 // we successfully merged, so the new group's `inputs` may have
273 // changed. So rescan the new group for more merging opportunities.
274 return std::make_pair(group.value()->reverseIterator(), true);
275 }
276 }
277 }
278
279 return std::make_pair(++consumer->reverseIterator(), false);
280 }
281
282 // Try to merge `producer` into `consumer`. If successful, this destroys
283 // `producer` and returns the `consumer` group.
284 c10::optional<Node*> tryMerge(Node* consumer, Node* producer) {
285 AT_ASSERT(consumer->kind() == prim::DifferentiableGraph);
286 bool canMerge = shouldConsiderForMerge(producer) &&
287 aliasDb_.moveBeforeTopologicallyValid(producer, consumer);
288
289 if (!canMerge) {
290 return c10::nullopt;
291 }
292
293 SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing(
294 producer, consumer, aliasDb_);
295 return consumer;
296 }
297
298 Block* block_;
299 std::shared_ptr<Graph> graph_;
300 size_t minSubgraphSize_;
301 AliasDb& aliasDb_;
302 std::vector<Node*>& diff_nodes_;
303};
304
305c10::optional<bool> getProfileNodeRequiresGrad(Node* n) {
306 TORCH_INTERNAL_ASSERT(n->kind() == prim::profile);
307 if (!n->hasAttribute(attr::profiled_type)) {
308 return c10::nullopt;
309 }
310 auto& type = n->ty(attr::profiled_type);
311 if (type->castRaw<TensorType>() == nullptr) {
312 return c10::nullopt;
313 }
314 return type->expectRef<TensorType>().requiresGrad();
315}
316
317struct ContextMapping {
318 std::vector<const Node*> ctx_stack_;
319 std::unordered_map<const Node*, const Node*> node_to_ctx_;
320
321 void processNode(Node* n) {
322 node_to_ctx_[n] = ctx_stack_.back();
323
324 if (n->kind() == prim::Enter) {
325 ctx_stack_.push_back(n);
326 } else if (n->kind() == prim::Exit) {
327 ctx_stack_.pop_back();
328 }
329 }
330
331 void processBlock(Block* block) {
332 for (Node* n : block->nodes()) {
333 processNode(n);
334 for (Block* b : n->blocks()) {
335 processBlock(b);
336 }
337 if (n->kind() == prim::DifferentiableGraph) {
338 const auto& subgraph = n->g(attr::Subgraph);
339 processBlock(subgraph->block());
340 }
341 }
342 }
343
344 ContextMapping(const std::shared_ptr<Graph>& graph) {
345 ctx_stack_.push_back(nullptr);
346 processBlock(graph->block());
347 }
348
349 const Node* get(const Node* n) const {
350 auto it = node_to_ctx_.find(n);
351 TORCH_INTERNAL_ASSERT(
352 it != node_to_ctx_.end(),
353 "Cannot find node in node-to-context mapping.");
354 return it->second;
355 }
356
357 bool has(const Node* n) const {
358 return node_to_ctx_.find(n) != node_to_ctx_.end();
359 }
360};
361
362c10::optional<bool> findRequiresGradForOutput(
363 Node* diff_graph,
364 Value* output,
365 const ContextMapping& ctx_mapping) {
366 for (auto& use : output->uses()) {
367 // [Only consider profiles in the same context]
368 // Ignore profiled uses if the use is within a different context.
369 // For example, a profile node within a no_grad() context will record the
370 // wrong requires_grad information.
371 if (ctx_mapping.has(use.user) &&
372 ctx_mapping.get(use.user) != ctx_mapping.get(diff_graph)) {
373 continue;
374 }
375
376 if (use.user->kind() == prim::profile) {
377 c10::optional<bool> req_grad_use;
378 if ((req_grad_use = getProfileNodeRequiresGrad(use.user)).has_value()) {
379 return req_grad_use.value();
380 }
381 }
382
383 // maybe the profile node got absorbed into a differentiable graph
384 if (use.user->kind() == prim::DifferentiableGraph) {
385 const auto& dg = use.user->g(attr::Subgraph);
386 // check all the uses of this graph input to look for profile nodes.
387 Value* dg_value = dg->inputs()[use.offset];
388 for (auto& dg_use : dg_value->uses()) {
389 // See [Only consider profiles in the same context]
390 if (ctx_mapping.has(dg_use.user) &&
391 ctx_mapping.get(dg_use.user) != ctx_mapping.get(diff_graph)) {
392 continue;
393 }
394
395 if (dg_use.user->kind() == prim::profile) {
396 c10::optional<bool> req_grad_use;
397 if ((req_grad_use = getProfileNodeRequiresGrad(dg_use.user))
398 .has_value()) {
399 return req_grad_use.value();
400 }
401 }
402 }
403 }
404 }
405
406 return c10::nullopt;
407}
408
409void AddRequiresGradToDifferentiableGraph(
410 Node* diff_graph,
411 const ContextMapping& ctx_mapping) {
412 TORCH_INTERNAL_ASSERT(diff_graph->kind() == prim::DifferentiableGraph);
413 const auto& subgraph = diff_graph->g(attr::Subgraph);
414 for (auto i : c10::irange(subgraph->outputs().size())) {
415 Value* output = subgraph->outputs()[i];
416 if (output->node()->kind() == prim::profile) {
417 // already have requires_grad info from this profile node
418 continue;
419 }
420 if (output->type()->castRaw<TensorType>() == nullptr) {
421 // non-tensors don't get profiled.
422 continue;
423 }
424 if (output->type()->expectRef<TensorType>().requiresGrad().has_value()) {
425 continue;
426 }
427
428 // this node doesn't have any requires_grad info.
429 // look at its uses to try to find a profile node.
430 auto requires_grad = findRequiresGradForOutput(
431 diff_graph, diff_graph->output(i), ctx_mapping);
432
433 output->setType(output->type()->expectRef<TensorType>().withRequiresGrad(
434 requires_grad));
435 }
436}
437
438void AddRequiresGradOnOutputNodes(
439 Block* block,
440 const ContextMapping& ctx_mapping) {
441 for (Node* n : block->nodes()) {
442 if (n->kind() == prim::DifferentiableGraph) {
443 AddRequiresGradToDifferentiableGraph(n, ctx_mapping);
444 }
445 for (Block* b : n->blocks()) {
446 AddRequiresGradOnOutputNodes(b, ctx_mapping);
447 }
448 }
449}
450
451// autodiff.cpp needs to know, for each output, whether or not it requires
452// grad. Sometimes a profile node will be present on the output, but sometimes
453// it won't be present. This might happen if there's a node with side effects
454// in between the definition of the output node and the profile node; in this
455// case the profile node and output node would be in different workblocks and
456// couldn't be merged into the same DifferentiableGraph. (see [workblocks])
457// Or it could happen if the output is profiled twice and the profile nodes get
458// removed by unfusedAliasedOutputs.
459void AddRequiresGradOnOutputNodes(const std::shared_ptr<Graph>& graph) {
460 ContextMapping ctx_mapping(graph);
461 AddRequiresGradOnOutputNodes(graph->block(), ctx_mapping);
462}
463} // anonymous namespace
464
465std::vector<Node*> CreateAutodiffSubgraphs(
466 const std::shared_ptr<Graph>& graph,
467 size_t threshold) {
468 std::vector<Node*> diff_nodes;
469 AliasDb db(graph);
470 GRAPH_DEBUG("Before creating autodiff subgraphs", *graph);
471 SubgraphSlicer(graph->block(), graph, threshold, db, diff_nodes).run();
472 GRAPH_DEBUG("After creating autodiff subgraphs", *graph);
473 AddRequiresGradOnOutputNodes(graph);
474 GRAPH_DEBUG("diff_nodes.size() ", diff_nodes.size());
475 return diff_nodes;
476}
477} // namespace jit
478} // namespace torch
479