1 | #include <torch/csrc/jit/passes/specialize_autogradzero.h> |
2 | |
3 | #include <c10/util/Exception.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/jit_log.h> |
6 | #include <torch/csrc/jit/passes/clear_undefinedness.h> |
7 | #include <torch/csrc/jit/runtime/graph_executor.h> |
8 | #include <torch/csrc/jit/runtime/profiling_record.h> |
9 | |
10 | #include <ATen/core/symbol.h> |
11 | #include <c10/util/irange.h> |
12 | |
13 | namespace torch { |
14 | namespace jit { |
15 | |
16 | static const auto countsAttribute = Symbol::attr("none_counts" ); |
17 | |
18 | bool hasGradSumToSizeUses(Value* v) { |
19 | return std::any_of(v->uses().begin(), v->uses().end(), [](const Use& use) { |
20 | return use.user->kind() == aten::_grad_sum_to_size; |
21 | }); |
22 | } |
23 | |
24 | void insertProfileNodesForSpecializeAutogradZero( |
25 | Block* block, |
26 | ProfilingRecord* pr) { |
27 | for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { |
28 | auto n = *it; |
29 | for (const auto offset : c10::irange(n->inputs().size())) { |
30 | auto i = n->input(offset); |
31 | if (i->type()->cast<OptionalType>() && hasGradSumToSizeUses(i)) { |
32 | // here we are profile the definition instead of the use, |
33 | // because we are only optimizing in the case of a None value which is |
34 | // immutable |
35 | auto opt_pn = pr->createProfileIValueNode(i); |
36 | |
37 | c10::Dict<std::string, int64_t> noneCountsDict; |
38 | noneCountsDict.insert("num_none" , 0); |
39 | noneCountsDict.insert("num_present" , 0); |
40 | IValue init_val(noneCountsDict); |
41 | |
42 | opt_pn->ival_(countsAttribute, init_val); |
43 | |
44 | std::function<void(Stack&)> optional_profiler = [pr, |
45 | opt_pn](Stack& stack) { |
46 | std::lock_guard<std::mutex> lock(pr->mutex_); |
47 | |
48 | TORCH_INTERNAL_ASSERT(opt_pn->hasAttribute(countsAttribute)); |
49 | // frame_id is unused |
50 | int64_t frame_id = 0; |
51 | pop(stack, frame_id); |
52 | |
53 | const auto& counts_attr = opt_pn->ival(countsAttribute); |
54 | auto noneCounts = c10::impl::toTypedDict<std::string, int64_t>( |
55 | counts_attr.toGenericDict()); |
56 | IValue value; |
57 | pop(stack, value); |
58 | if (value.isNone()) { |
59 | noneCounts.insert_or_assign( |
60 | "num_none" , noneCounts.at("num_none" ) + 1); |
61 | } else { |
62 | noneCounts.insert_or_assign( |
63 | "num_present" , noneCounts.at("num_present" ) + 1); |
64 | } |
65 | push(stack, value); |
66 | }; |
67 | opt_pn->setCallback(optional_profiler); |
68 | opt_pn->insertAfter(i->node()); |
69 | i->replaceAllUsesAfterNodeWith(opt_pn, opt_pn->output()); |
70 | } |
71 | } |
72 | |
73 | for (auto ib : n->blocks()) { |
74 | insertProfileNodesForSpecializeAutogradZero(ib, pr); |
75 | } |
76 | } |
77 | } |
78 | |
79 | void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr) { |
80 | insertProfileNodesForSpecializeAutogradZero(pr->profiled_graph_->block(), pr); |
81 | } |
82 | |
83 | struct AutogradZeroSpecializer { |
84 | enum class State { Nonzero, Zero, Unknown }; |
85 | |
86 | AutogradZeroSpecializer(std::shared_ptr<Graph> graph) |
87 | : graph_(std::move(graph)) {} |
88 | |
89 | void run() { |
90 | if (!isBackwardGraph()) { |
91 | return; |
92 | } |
93 | if (getExecutorMode()) { |
94 | if (auto versioning_if = guardSpecializations()) { |
95 | specializeAutogradOps(versioning_if->blocks()[0]); |
96 | GRAPH_DUMP("After versioning graph" , graph_); |
97 | } |
98 | } else { |
99 | setStatesOnGraphInputs(); |
100 | specializeAutogradOps(graph_->block()); |
101 | } |
102 | GRAPH_DUMP("After specializeAutogradOps graph" , graph_); |
103 | } |
104 | |
105 | private: |
106 | bool isBackwardGraph() { |
107 | return std::any_of( |
108 | graph_->nodes().begin(), graph_->nodes().end(), [](Node* n) { |
109 | switch (n->kind()) { |
110 | case prim::AutogradAnyNonZero: |
111 | case prim::AutogradAdd: |
112 | case aten::_grad_sum_to_size: |
113 | return true; |
114 | default: |
115 | return false; |
116 | } |
117 | }); |
118 | } |
119 | |
120 | void replaceBlockInputsWithGraphInputs(Block* b) { |
121 | TORCH_INTERNAL_ASSERT(graph_->inputs().size() == b->inputs().size()); |
122 | size_t num_inputs = graph_->inputs().size(); |
123 | for (const auto i : c10::irange(num_inputs)) { |
124 | b->inputs().at(i)->replaceAllUsesWith(graph_->inputs().at(i)); |
125 | } |
126 | for (const auto i : c10::irange(num_inputs)) { |
127 | b->eraseInput(num_inputs - (1 + i)); |
128 | } |
129 | } |
130 | |
131 | void setStatesOnGraphInputs() { |
132 | for (Value* input : graph_->inputs()) { |
133 | const auto& tp = input->type(); |
134 | if (auto tt = tp->cast<TensorType>()) { |
135 | if (tt->undefined()) { |
136 | if (*tt->undefined()) { |
137 | state_[input] = State::Zero; |
138 | } else { |
139 | state_[input] = State::Nonzero; |
140 | } |
141 | } else { |
142 | state_[input] = State::Unknown; |
143 | } |
144 | } else if ( |
145 | tp->isSubtypeOf(*TensorType::get()) || |
146 | tp->isSubtypeOf(*ListType::ofTensors())) { |
147 | state_[input] = State::Nonzero; |
148 | } else { |
149 | state_[input] = State::Unknown; |
150 | } |
151 | } |
152 | } |
153 | |
154 | static void getUsesWithAttribute_( |
155 | Value* inp, |
156 | Symbol attr, |
157 | std::vector<Node*>& uses) { |
158 | for (auto use : inp->uses()) { |
159 | if (use.user->kind() != prim::profile_ivalue) { |
160 | continue; |
161 | } |
162 | |
163 | if (use.user->hasAttribute(attr)) { |
164 | uses.push_back(use.user); |
165 | } |
166 | |
167 | getUsesWithAttribute_(use.user->output(), attr, uses); |
168 | } |
169 | } |
170 | |
171 | // this is to deal with the fact that there could be other passes that |
172 | // would like to profile this exact same value. this helper walks |
173 | // chains of `prim::profile_ivalue` to locate the one inserted by/for |
174 | // `specializeAutogradZero` |
175 | static std::vector<Node*> getUsesWithAttribute(Value* inp, Symbol attr) { |
176 | std::vector<Node*> uses; |
177 | getUsesWithAttribute_(inp, attr, uses); |
178 | return uses; |
179 | } |
180 | |
181 | static Node* getUse(Value* inp, Symbol kind) { |
182 | for (auto use : inp->uses()) { |
183 | if (use.user->kind() == kind) { |
184 | return use.user; |
185 | } |
186 | } |
187 | |
188 | return nullptr; |
189 | } |
190 | |
191 | void removeProfiledOptionalUses(const std::vector<Node*>& uses) { |
192 | TORCH_INTERNAL_ASSERT(!uses.empty()); |
193 | auto inp = uses[0]->input(); |
194 | // this removes `prim::profile_ivalue` from the original and to-specialize |
195 | // blocks N.B. the false block isn't impacted as it has been already |
196 | // encapsulated in a fallback function |
197 | for (auto u : uses) { |
198 | u->output()->replaceAllUsesWith(inp); |
199 | } |
200 | } |
201 | |
202 | Node* guardSpecializations() { |
203 | auto versioning_if = graph_->create(prim::If, {}, graph_->outputs().size()); |
204 | auto value_map = [](Value* v) { return v; }; |
205 | auto true_block = versioning_if->addBlock(); |
206 | auto false_block = versioning_if->addBlock(); |
207 | |
208 | // we will optimize true_block |
209 | true_block->cloneFrom(graph_->block(), value_map); |
210 | replaceBlockInputsWithGraphInputs(true_block); |
211 | false_block->cloneFrom(graph_->block(), value_map); |
212 | replaceBlockInputsWithGraphInputs(false_block); |
213 | replaceBlockWithFallbackGraph(false_block, graph_->inputs()); |
214 | |
215 | WithInsertPoint wip{graph_->block()->param_node()->next()}; |
216 | Value* none_val = graph_->insertConstant(IValue()); |
217 | std::vector<Value*> checks; |
218 | std::vector<Value*> zero_values; |
219 | std::vector<Value*> nonzero_values; |
220 | |
221 | for (auto inp : graph_->inputs()) { |
222 | std::vector<Node*> iprofile_counts_nodes = |
223 | getUsesWithAttribute(inp, countsAttribute); |
224 | if (!iprofile_counts_nodes.empty()) { |
225 | // the original `prim::profile_value[num_present=0,...]` on `inp` is |
226 | // copied into `true_block` and `false_block`. |
227 | auto profile_ivalue_node = iprofile_counts_nodes[0]; |
228 | TORCH_INTERNAL_ASSERT( |
229 | profile_ivalue_node->hasAttribute(countsAttribute)); |
230 | const auto& counts_attr = |
231 | profile_ivalue_node->ival(countsAttribute).toGenericDict(); |
232 | auto num_present = counts_attr.at(IValue{"num_present" }).toInt(); |
233 | auto num_none = counts_attr.at(IValue{"num_none" }).toInt(); |
234 | if (num_present == 0 && num_none != 0) { |
235 | auto check = graph_->insert(aten::__is__, {inp, none_val})->node(); |
236 | checks.push_back(check->output()); |
237 | profiled_none_.insert(inp); |
238 | } |
239 | removeProfiledOptionalUses(iprofile_counts_nodes); |
240 | continue; |
241 | } |
242 | |
243 | if (inp->uses().empty() || !inp->type()->cast<TensorType>()) { |
244 | continue; |
245 | } |
246 | |
247 | // TODO: check multiple uses ? |
248 | auto pout = getUse(inp, prim::profile); |
249 | if (!pout) { |
250 | continue; |
251 | } |
252 | |
253 | auto pttp = pout->ty(attr::profiled_type)->expect<TensorType>(); |
254 | if (!pttp->undefined().has_value()) { |
255 | continue; |
256 | } |
257 | |
258 | state_[inp] = *pttp->undefined() ? State::Zero : State::Nonzero; |
259 | |
260 | if (*pttp->undefined()) { |
261 | zero_values.push_back(inp); |
262 | } else { |
263 | nonzero_values.push_back(inp); |
264 | } |
265 | } |
266 | GRAPH_DUMP("After for loop" , graph_); |
267 | // unable to specialize any of the inputs |
268 | if (nonzero_values.empty() && zero_values.empty()) { |
269 | GRAPH_DUMP("Unable to add any specialization guards" , graph_); |
270 | versioning_if->destroy(); |
271 | // the checks we inserted will be cleaned up |
272 | // by any subsequent DCE pass |
273 | return nullptr; |
274 | } |
275 | |
276 | Node* nonzero_check = graph_->insert(prim::AutogradAllNonZero, {})->node(); |
277 | for (Value* v : nonzero_values) { |
278 | nonzero_check->addInput(v); |
279 | } |
280 | checks.push_back(nonzero_check->output()); |
281 | |
282 | Node* zero_check = graph_->insert(prim::AutogradAllZero, {})->node(); |
283 | for (Value* v : zero_values) { |
284 | zero_check->addInput(v); |
285 | } |
286 | checks.push_back(zero_check->output()); |
287 | |
288 | Value* bool_list = |
289 | graph_->insertNode(graph_->createList(BoolType::get(), checks)) |
290 | ->output(); |
291 | Value* conjunction = graph_->insert(aten::all, {bool_list}); |
292 | |
293 | versioning_if->addInput(conjunction); |
294 | graph_->insertNode(versioning_if); |
295 | |
296 | auto ret = graph_->return_node(); |
297 | for (const auto i : c10::irange(ret->inputs().size())) { |
298 | auto ogo = ret->input(i); |
299 | auto ngo = versioning_if->output(i); |
300 | ngo->copyMetadata(ogo); |
301 | ret->replaceInput(i, ngo); |
302 | } |
303 | |
304 | // We've created: |
305 | // succesful_checks = Guards(...) |
306 | // if (succesful_checks) |
307 | // -> optimized graph |
308 | // else: |
309 | // -> fallback graph |
310 | // original graph |
311 | // |
312 | // Remove the dead original graph |
313 | for (auto it = graph_->block()->nodes().reverse().begin(); |
314 | *it != versioning_if;) { |
315 | Node* n = *it; |
316 | it++; |
317 | n->destroy(); |
318 | } |
319 | |
320 | GRAPH_DUMP("After guardSpecializations" , graph_); |
321 | return versioning_if; |
322 | } |
323 | |
324 | void specializeAutogradOps(Block* block) { |
325 | for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { |
326 | auto n = *it; |
327 | switch (n->kind()) { |
328 | case prim::AutogradAdd: { |
329 | auto a = n->input(0); |
330 | auto b = n->input(1); |
331 | // if one is Autograd zero, we can just drop the add |
332 | if (state_[a] == State::Zero) { |
333 | // Zero + b == b |
334 | n->output()->replaceAllUsesWith(b); |
335 | it.destroyCurrent(); |
336 | } else if (state_[b] == State::Zero) { |
337 | // a + Zero == a |
338 | n->output()->replaceAllUsesWith(a); |
339 | it.destroyCurrent(); |
340 | } else if ( |
341 | state_[a] == State::Nonzero && state_[b] == State::Nonzero) { |
342 | // when both are Nonzero, we can use a normal, optimizable add |
343 | // instruction |
344 | WithInsertPoint guard(n); |
345 | auto* cOne = graph_->insertConstant(1); |
346 | auto* add_node = graph_->insertNode(graph_->create(aten::add, 1)); |
347 | add_node->addInput(a); |
348 | add_node->addInput(b); |
349 | add_node->addInput(cOne); |
350 | auto* add_output = add_node->output(); |
351 | add_output->setType(n->output()->type()); |
352 | state_[add_output] = State::Nonzero; |
353 | n->output()->replaceAllUsesWith(add_output); |
354 | it.destroyCurrent(); |
355 | } else { |
356 | // otherwise we have conditionally-Nonzero things, and we need |
357 | // to actually run an AutogradAdd which will guard for Zeros |
358 | // so we leave the op as is |
359 | state_[n->output()] = State::Unknown; |
360 | } |
361 | } break; |
362 | case prim::AutogradZero: { |
363 | state_[n->output()] = State::Zero; |
364 | } break; |
365 | case prim::profile: { |
366 | // this a profile node on a tensor use |
367 | // if we decided to specialize this graph |
368 | // its input may have undefinedness info |
369 | // otherwise it should be Unknown |
370 | if (!n->inputs().empty()) { |
371 | state_[n->output()] = !state_.count(n->input()) |
372 | ? State::Unknown |
373 | : state_[n->output()] = state_[n->input()]; |
374 | } |
375 | break; |
376 | } |
377 | case prim::BailOut: { |
378 | if (auto ptt = n->output()->type()->expect<TensorType>()) { |
379 | state_[n->output()] = ptt->undefined() |
380 | ? *ptt->undefined() ? State::Zero : State::Nonzero |
381 | : State::Unknown; |
382 | } |
383 | } break; |
384 | // Lowered GradOf block |
385 | case prim::If: { |
386 | auto if_input = n->input(0)->node(); |
387 | if (if_input->kind() == prim::AutogradAnyNonZero) { |
388 | auto all_zeros = std::all_of( |
389 | if_input->inputs().begin(), |
390 | if_input->inputs().end(), |
391 | [&](Value* v) { return state_[v] == State::Zero; }); |
392 | |
393 | auto all_nonzeros = std::all_of( |
394 | if_input->inputs().begin(), |
395 | if_input->inputs().end(), |
396 | [&](Value* v) { return state_[v] == State::Nonzero; }); |
397 | // Property 1: if all the gradInputs to the GradOf are Zero |
398 | // then the gradOutputs are also zero and will be represented as |
399 | // AutogradZero nodes |
400 | if (all_zeros) { |
401 | auto zero = |
402 | graph_->createAutogradZero()->insertAfter(n)->output(); |
403 | state_[zero] = State::Zero; |
404 | for (auto o : n->outputs()) { |
405 | o->replaceAllUsesWith(zero); |
406 | } |
407 | it.destroyCurrent(); |
408 | break; |
409 | } |
410 | |
411 | specializeGradSumToSize(n->blocks().at(0)); |
412 | if (all_nonzeros) { |
413 | auto body = n->blocks().at(0); |
414 | // hoist the nodes in the GradOf body to be before the linear |
415 | // block |
416 | for (auto it = body->nodes().begin(); |
417 | it != body->nodes().end();) { |
418 | auto block_node = *it++; |
419 | block_node->moveBefore(n); |
420 | } |
421 | |
422 | for (size_t i = 0; i < n->outputs().size(); ++i) { |
423 | n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i)); |
424 | state_[body->outputs().at(i)] = State::Nonzero; |
425 | } |
426 | it.destroyCurrent(); |
427 | break; |
428 | } |
429 | } |
430 | |
431 | for (auto o : n->outputs()) { |
432 | state_[o] = State::Unknown; |
433 | } |
434 | break; |
435 | } |
436 | default: |
437 | for (auto o : n->outputs()) { |
438 | state_[o] = State::Unknown; |
439 | } |
440 | break; |
441 | } |
442 | } |
443 | } |
444 | |
445 | void specializeGradSumToSize(Block* b) { |
446 | for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) { |
447 | Node* n = *it; |
448 | if (n->kind() == aten::_grad_sum_to_size) { |
449 | bool profiled_none_flag = profiled_none_.count(n->input(1)); |
450 | const Node* node = n->input(1)->node(); |
451 | // propagate profiled none through other profile_ivalue nodes; |
452 | while (!profiled_none_flag && node->kind() == prim::profile_ivalue) { |
453 | profiled_none_flag = |
454 | profiled_none_flag || profiled_none_.count(node->input(0)); |
455 | node = node->input(0)->node(); |
456 | } |
457 | if (n->input(1)->mustBeNone() || profiled_none_flag) { |
458 | n->output()->replaceAllUsesWith(n->input(0)); |
459 | it.destroyCurrent(); |
460 | } |
461 | } |
462 | } |
463 | } |
464 | |
465 | std::shared_ptr<Graph> graph_; |
466 | std::unordered_set<Value*> profiled_none_; |
467 | std::unordered_map<Value*, State> state_; |
468 | }; |
469 | |
470 | // propagate autograd zero information through a gradient graph and |
471 | // remove grad_of blocks if present. |
472 | // Note: this is a very limited pass. It only propagates autograd zeros for |
473 | // operations generated by the symbolic autodiff code and cleans up |
474 | // AutogradAdds when possible. Outputs of other nodes are conservatively |
475 | // marked Unknown and not optimized. |
476 | void specializeAutogradZero(std::shared_ptr<Graph> g) { |
477 | AutogradZeroSpecializer azs(std::move(g)); |
478 | azs.run(); |
479 | } |
480 | |
481 | } // namespace jit |
482 | } // namespace torch |
483 | |