1 | #include <torch/csrc/jit/passes/create_autodiff_subgraphs.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <torch/csrc/jit/ir/alias_analysis.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/jit_log.h> |
7 | #include <torch/csrc/jit/passes/canonicalize.h> |
8 | #include <torch/csrc/jit/passes/common_subexpression_elimination.h> |
9 | #include <torch/csrc/jit/passes/remove_redundant_profiles.h> |
10 | #include <torch/csrc/jit/passes/utils/subgraph_utils.h> |
11 | #include <torch/csrc/jit/runtime/autodiff.h> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | |
16 | namespace { |
17 | |
18 | struct WorkBlock : public std::pair<Node*, Node*> { |
19 | using pair::pair; |
20 | |
21 | Node* begin() { |
22 | return this->first; |
23 | } |
24 | Node* end() { |
25 | return this->second; |
26 | } |
27 | }; |
28 | |
29 | class SubgraphSlicer { |
30 | public: |
31 | SubgraphSlicer( |
32 | Block* block, |
33 | std::shared_ptr<Graph> graph, |
34 | size_t minSubgraphSize, |
35 | AliasDb& aliasDb, |
36 | std::vector<Node*>& diff_nodes) |
37 | : block_(block), |
38 | graph_(std::move(graph)), |
39 | minSubgraphSize_(minSubgraphSize), |
40 | aliasDb_(aliasDb), |
41 | diff_nodes_(diff_nodes) {} |
42 | |
43 | void run() { |
44 | // We maintain alias db correctness in-place while building up the autodiff |
45 | // subgraphs, however it is difficult to preserve correctness when |
46 | // un-inlining autodiff subgraphs. We first recursively construct all |
47 | // subgraphs and then recursively cleanup & unmerge the small subgraphs |
48 | buildupSubgraphs(); |
49 | GRAPH_DUMP("before unfuseAliasedOutputs" , graph_); |
50 | unfuseAliasedOutputs(block_); |
51 | cleanupSubgraphs(); |
52 | // Run CSE globally onceto eliminate duplicates that may have occurred |
53 | // while inlining subgraphs. |
54 | EliminateCommonSubexpression(graph_); |
55 | } |
56 | |
57 | void cleanupSubgraphs() { |
58 | auto curNode = *block_->nodes().rbegin(); |
59 | while (curNode != *block_->nodes().rend()) { |
60 | // Save the previous node, since we might delete `curNode` in next block |
61 | auto prevNode = curNode->prev(); |
62 | if (curNode->kind() == prim::DifferentiableGraph) { |
63 | // Inlining nodes may cause some subexpression to come back in the |
64 | // subgraphs (for example, copying constants in repeatedly will generate |
65 | // redundant prim::Constants). Run CSE to clean them up. |
66 | EliminateCommonSubexpression(curNode->g(attr::Subgraph)); |
67 | |
68 | if (!inlineIfTooSmall(curNode)) { |
69 | diff_nodes_.push_back(curNode); |
70 | } |
71 | } |
72 | curNode = prevNode; |
73 | } |
74 | |
75 | for (Node* n : block_->nodes()) { |
76 | for (Block* b : n->blocks()) { |
77 | SubgraphSlicer(b, graph_, minSubgraphSize_, aliasDb_, diff_nodes_) |
78 | .cleanupSubgraphs(); |
79 | } |
80 | } |
81 | } |
82 | |
83 | void buildupSubgraphs() { |
84 | // We need to run the slicer multiple times in order to get all merge |
85 | // opportunities. This is because moveBeforeTopologicalValid may reorder |
86 | // nodes to be AFTER the current iteration point. In order to properly |
87 | // consider those nodes for merging, we need run the pass until no changes |
88 | // have been made. |
89 | // |
90 | // Example: |
91 | // c = f(a, b) |
92 | // d = f(c) |
93 | // e = f(d) <- iter is here, moving upward |
94 | // After c.moveBeforeTopologicallyValid(e), we have: |
95 | // c = f(a, b) |
96 | // e = f(d) <- iter still here |
97 | // d = f(c) <- this was node moved on the other side. |
98 | |
99 | // see [workblocks] |
100 | auto workblocks = buildWorkBlocks(); |
101 | for (auto& workblock : workblocks) { |
102 | bool any_changed = true; |
103 | while (any_changed) { |
104 | any_changed = false; |
105 | for (auto it = workblock.end()->reverseIterator(); |
106 | it != workblock.begin()->reverseIterator();) { |
107 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
108 | bool changed; |
109 | std::tie(it, changed) = scanNode(*it); |
110 | any_changed |= changed; |
111 | } |
112 | } |
113 | } |
114 | |
115 | // Construct Subgraphs Recursively |
116 | for (Node* n : block_->nodes()) { |
117 | for (auto subBlock : n->blocks()) { |
118 | SubgraphSlicer( |
119 | subBlock, graph_, minSubgraphSize_, aliasDb_, diff_nodes_) |
120 | .buildupSubgraphs(); |
121 | } |
122 | } |
123 | } |
124 | |
125 | private: |
126 | void unfuseAliasedOutputs(Block* b) { |
127 | bool any_changed = true; |
128 | while (any_changed) { |
129 | any_changed = false; |
130 | // we walk in the reverse order, so we can skip |
131 | // nodes that might get unfused after the current |
132 | // prim::DifferentiableGraph |
133 | for (auto n : b->nodes().reverse()) { |
134 | if (n->kind() == prim::DifferentiableGraph) { |
135 | // aliased outputs in DifferentiableGraphs must be unfused |
136 | // since autodiff doesn't know how to handle them correctly |
137 | // N.B. Note, |= since we don't want `unfuseAliasedOutputs` |
138 | // to short-circuit |
139 | any_changed |= SubgraphUtils::unmergeAliasedOutputs(n); |
140 | any_changed |= SubgraphUtils::unmergeOutputsAlisingInputs(n); |
141 | GRAPH_DEBUG( |
142 | "any_changed on " , |
143 | any_changed, |
144 | " " , |
145 | n->g(attr::Subgraph)->toString(false)); |
146 | } |
147 | } |
148 | } |
149 | |
150 | for (Node* n : b->nodes()) { |
151 | for (Block* ib : n->blocks()) { |
152 | unfuseAliasedOutputs(ib); |
153 | } |
154 | } |
155 | } |
156 | |
157 | std::vector<WorkBlock> buildWorkBlocks() { |
158 | // [workblocks] |
159 | // the IR has many nodes which can never be reordered around, such as a |
160 | // prim::Bailout. if a node N is surrounded by two nodes which cannot be |
161 | // reordered, A and B, then a differentiable subgraph that is created from N |
162 | // can only contain nodes from (A, B) The nodes from A to B represent one |
163 | // work block for the subgraph slicer to work on. By creating these up |
164 | // front, we avoid retraversing the whole graph block any time scanNode |
165 | // returns, and we can also avoid attempting to create differentiable |
166 | // subgraphs in work blocks that do not contain a # of differentiable nodes |
167 | // >= minSubgraphSize_ |
168 | |
169 | Node* end_bound_node = block_->return_node(); |
170 | Node* curr = end_bound_node->prev(); |
171 | |
172 | std::vector<WorkBlock> worklist; |
173 | size_t differentiable_nodes = 0; |
174 | |
175 | while (curr != block_->param_node()) { |
176 | differentiable_nodes += shouldConsiderForMerge(curr); |
177 | |
178 | // cannot reorder around side effectful nodes |
179 | if (curr->hasSideEffects()) { |
180 | // not enough differentiable nodes to create a differentiable subgraph |
181 | if (differentiable_nodes >= minSubgraphSize_) { |
182 | worklist.emplace_back(curr, end_bound_node); |
183 | } |
184 | differentiable_nodes = 0; |
185 | end_bound_node = curr; |
186 | } |
187 | curr = curr->prev(); |
188 | } |
189 | |
190 | if (differentiable_nodes >= minSubgraphSize_) { |
191 | worklist.emplace_back(curr, end_bound_node); |
192 | } |
193 | |
194 | return worklist; |
195 | } |
196 | |
197 | // Inline this node's group subgraph into the outer graph if it's smaller |
198 | // than the specified minimum size. |
199 | // |
200 | // Returns true if an inlining has occurred, false otherwise. |
201 | bool inlineIfTooSmall(Node* n) { |
202 | AT_ASSERT(n->kind() == prim::DifferentiableGraph); |
203 | auto subgraph = SubgraphUtils::getSubgraph(n); |
204 | size_t i = 0; |
205 | for (auto it = subgraph->nodes().begin(); it != subgraph->nodes().end(); |
206 | ++it) { |
207 | i += !it->notExecutedOp(); |
208 | if (i >= minSubgraphSize_) { |
209 | return false; |
210 | } |
211 | } |
212 | |
213 | SubgraphUtils::unmergeSubgraph(n); |
214 | return true; |
215 | } |
216 | |
217 | value_list sortReverseTopological(ArrayRef<Value*> inputs) { |
218 | value_list result; |
219 | for (auto i : inputs) { |
220 | if (i->node()->owningBlock() == block_) { |
221 | result.push_back(i); |
222 | } |
223 | } |
224 | // Sort in reverse topological order |
225 | std::sort(result.begin(), result.end(), [&](Value* a, Value* b) { |
226 | return a->node()->isAfter(b->node()); |
227 | }); |
228 | return result; |
229 | } |
230 | |
231 | bool isViewOp(Node* n) { |
232 | switch (n->kind()) { |
233 | case aten::view: |
234 | case aten::view_as: |
235 | case aten::reshape: |
236 | case aten::reshape_as: |
237 | case aten::transpose: |
238 | case aten::expand: |
239 | case aten::expand_as: |
240 | return true; |
241 | } |
242 | return false; |
243 | } |
244 | |
245 | bool shouldConsiderForMerge(Node* node) { |
246 | // if we're already in the process of merging |
247 | if (node->kind() == prim::DifferentiableGraph) { |
248 | return true; |
249 | } |
250 | if (node->kind() == prim::Constant) { |
251 | return false; |
252 | } |
253 | |
254 | // view ops as outputs of differentiable subgraphs can cause incorrect |
255 | // differentiation for now, do not include them in the subgraph |
256 | if (isViewOp(node)) { |
257 | return false; |
258 | } |
259 | |
260 | return isDifferentiable(node); |
261 | } |
262 | |
263 | std::pair<graph_node_list::iterator, bool> scanNode(Node* consumer) { |
264 | if (shouldConsiderForMerge(consumer)) { |
265 | if (consumer->kind() != prim::DifferentiableGraph) { |
266 | consumer = SubgraphUtils::createSingletonSubgraphAndUpdateAliasing( |
267 | consumer, prim::DifferentiableGraph, aliasDb_); |
268 | } |
269 | auto inputs = sortReverseTopological(consumer->inputs()); |
270 | for (auto input : inputs) { |
271 | if (auto group = tryMerge(consumer, input->node())) { |
272 | // we successfully merged, so the new group's `inputs` may have |
273 | // changed. So rescan the new group for more merging opportunities. |
274 | return std::make_pair(group.value()->reverseIterator(), true); |
275 | } |
276 | } |
277 | } |
278 | |
279 | return std::make_pair(++consumer->reverseIterator(), false); |
280 | } |
281 | |
282 | // Try to merge `producer` into `consumer`. If successful, this destroys |
283 | // `producer` and returns the `consumer` group. |
284 | c10::optional<Node*> tryMerge(Node* consumer, Node* producer) { |
285 | AT_ASSERT(consumer->kind() == prim::DifferentiableGraph); |
286 | bool canMerge = shouldConsiderForMerge(producer) && |
287 | aliasDb_.moveBeforeTopologicallyValid(producer, consumer); |
288 | |
289 | if (!canMerge) { |
290 | return c10::nullopt; |
291 | } |
292 | |
293 | SubgraphUtils::mergeNodeIntoSubgraphAndUpdateAliasing( |
294 | producer, consumer, aliasDb_); |
295 | return consumer; |
296 | } |
297 | |
298 | Block* block_; |
299 | std::shared_ptr<Graph> graph_; |
300 | size_t minSubgraphSize_; |
301 | AliasDb& aliasDb_; |
302 | std::vector<Node*>& diff_nodes_; |
303 | }; |
304 | |
305 | c10::optional<bool> getProfileNodeRequiresGrad(Node* n) { |
306 | TORCH_INTERNAL_ASSERT(n->kind() == prim::profile); |
307 | if (!n->hasAttribute(attr::profiled_type)) { |
308 | return c10::nullopt; |
309 | } |
310 | auto& type = n->ty(attr::profiled_type); |
311 | if (type->castRaw<TensorType>() == nullptr) { |
312 | return c10::nullopt; |
313 | } |
314 | return type->expectRef<TensorType>().requiresGrad(); |
315 | } |
316 | |
317 | struct ContextMapping { |
318 | std::vector<const Node*> ctx_stack_; |
319 | std::unordered_map<const Node*, const Node*> node_to_ctx_; |
320 | |
321 | void processNode(Node* n) { |
322 | node_to_ctx_[n] = ctx_stack_.back(); |
323 | |
324 | if (n->kind() == prim::Enter) { |
325 | ctx_stack_.push_back(n); |
326 | } else if (n->kind() == prim::Exit) { |
327 | ctx_stack_.pop_back(); |
328 | } |
329 | } |
330 | |
331 | void processBlock(Block* block) { |
332 | for (Node* n : block->nodes()) { |
333 | processNode(n); |
334 | for (Block* b : n->blocks()) { |
335 | processBlock(b); |
336 | } |
337 | if (n->kind() == prim::DifferentiableGraph) { |
338 | const auto& subgraph = n->g(attr::Subgraph); |
339 | processBlock(subgraph->block()); |
340 | } |
341 | } |
342 | } |
343 | |
344 | ContextMapping(const std::shared_ptr<Graph>& graph) { |
345 | ctx_stack_.push_back(nullptr); |
346 | processBlock(graph->block()); |
347 | } |
348 | |
349 | const Node* get(const Node* n) const { |
350 | auto it = node_to_ctx_.find(n); |
351 | TORCH_INTERNAL_ASSERT( |
352 | it != node_to_ctx_.end(), |
353 | "Cannot find node in node-to-context mapping." ); |
354 | return it->second; |
355 | } |
356 | |
357 | bool has(const Node* n) const { |
358 | return node_to_ctx_.find(n) != node_to_ctx_.end(); |
359 | } |
360 | }; |
361 | |
362 | c10::optional<bool> findRequiresGradForOutput( |
363 | Node* diff_graph, |
364 | Value* output, |
365 | const ContextMapping& ctx_mapping) { |
366 | for (auto& use : output->uses()) { |
367 | // [Only consider profiles in the same context] |
368 | // Ignore profiled uses if the use is within a different context. |
369 | // For example, a profile node within a no_grad() context will record the |
370 | // wrong requires_grad information. |
371 | if (ctx_mapping.has(use.user) && |
372 | ctx_mapping.get(use.user) != ctx_mapping.get(diff_graph)) { |
373 | continue; |
374 | } |
375 | |
376 | if (use.user->kind() == prim::profile) { |
377 | c10::optional<bool> req_grad_use; |
378 | if ((req_grad_use = getProfileNodeRequiresGrad(use.user)).has_value()) { |
379 | return req_grad_use.value(); |
380 | } |
381 | } |
382 | |
383 | // maybe the profile node got absorbed into a differentiable graph |
384 | if (use.user->kind() == prim::DifferentiableGraph) { |
385 | const auto& dg = use.user->g(attr::Subgraph); |
386 | // check all the uses of this graph input to look for profile nodes. |
387 | Value* dg_value = dg->inputs()[use.offset]; |
388 | for (auto& dg_use : dg_value->uses()) { |
389 | // See [Only consider profiles in the same context] |
390 | if (ctx_mapping.has(dg_use.user) && |
391 | ctx_mapping.get(dg_use.user) != ctx_mapping.get(diff_graph)) { |
392 | continue; |
393 | } |
394 | |
395 | if (dg_use.user->kind() == prim::profile) { |
396 | c10::optional<bool> req_grad_use; |
397 | if ((req_grad_use = getProfileNodeRequiresGrad(dg_use.user)) |
398 | .has_value()) { |
399 | return req_grad_use.value(); |
400 | } |
401 | } |
402 | } |
403 | } |
404 | } |
405 | |
406 | return c10::nullopt; |
407 | } |
408 | |
409 | void AddRequiresGradToDifferentiableGraph( |
410 | Node* diff_graph, |
411 | const ContextMapping& ctx_mapping) { |
412 | TORCH_INTERNAL_ASSERT(diff_graph->kind() == prim::DifferentiableGraph); |
413 | const auto& subgraph = diff_graph->g(attr::Subgraph); |
414 | for (auto i : c10::irange(subgraph->outputs().size())) { |
415 | Value* output = subgraph->outputs()[i]; |
416 | if (output->node()->kind() == prim::profile) { |
417 | // already have requires_grad info from this profile node |
418 | continue; |
419 | } |
420 | if (output->type()->castRaw<TensorType>() == nullptr) { |
421 | // non-tensors don't get profiled. |
422 | continue; |
423 | } |
424 | if (output->type()->expectRef<TensorType>().requiresGrad().has_value()) { |
425 | continue; |
426 | } |
427 | |
428 | // this node doesn't have any requires_grad info. |
429 | // look at its uses to try to find a profile node. |
430 | auto requires_grad = findRequiresGradForOutput( |
431 | diff_graph, diff_graph->output(i), ctx_mapping); |
432 | |
433 | output->setType(output->type()->expectRef<TensorType>().withRequiresGrad( |
434 | requires_grad)); |
435 | } |
436 | } |
437 | |
438 | void AddRequiresGradOnOutputNodes( |
439 | Block* block, |
440 | const ContextMapping& ctx_mapping) { |
441 | for (Node* n : block->nodes()) { |
442 | if (n->kind() == prim::DifferentiableGraph) { |
443 | AddRequiresGradToDifferentiableGraph(n, ctx_mapping); |
444 | } |
445 | for (Block* b : n->blocks()) { |
446 | AddRequiresGradOnOutputNodes(b, ctx_mapping); |
447 | } |
448 | } |
449 | } |
450 | |
451 | // autodiff.cpp needs to know, for each output, whether or not it requires |
452 | // grad. Sometimes a profile node will be present on the output, but sometimes |
453 | // it won't be present. This might happen if there's a node with side effects |
454 | // in between the definition of the output node and the profile node; in this |
455 | // case the profile node and output node would be in different workblocks and |
456 | // couldn't be merged into the same DifferentiableGraph. (see [workblocks]) |
457 | // Or it could happen if the output is profiled twice and the profile nodes get |
458 | // removed by unfusedAliasedOutputs. |
459 | void AddRequiresGradOnOutputNodes(const std::shared_ptr<Graph>& graph) { |
460 | ContextMapping ctx_mapping(graph); |
461 | AddRequiresGradOnOutputNodes(graph->block(), ctx_mapping); |
462 | } |
463 | } // anonymous namespace |
464 | |
465 | std::vector<Node*> CreateAutodiffSubgraphs( |
466 | const std::shared_ptr<Graph>& graph, |
467 | size_t threshold) { |
468 | std::vector<Node*> diff_nodes; |
469 | AliasDb db(graph); |
470 | GRAPH_DEBUG("Before creating autodiff subgraphs" , *graph); |
471 | SubgraphSlicer(graph->block(), graph, threshold, db, diff_nodes).run(); |
472 | GRAPH_DEBUG("After creating autodiff subgraphs" , *graph); |
473 | AddRequiresGradOnOutputNodes(graph); |
474 | GRAPH_DEBUG("diff_nodes.size() " , diff_nodes.size()); |
475 | return diff_nodes; |
476 | } |
477 | } // namespace jit |
478 | } // namespace torch |
479 | |