1#include <torch/csrc/jit/passes/specialize_autogradzero.h>
2
3#include <c10/util/Exception.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/jit_log.h>
6#include <torch/csrc/jit/passes/clear_undefinedness.h>
7#include <torch/csrc/jit/runtime/graph_executor.h>
8#include <torch/csrc/jit/runtime/profiling_record.h>
9
10#include <ATen/core/symbol.h>
11#include <c10/util/irange.h>
12
13namespace torch {
14namespace jit {
15
16static const auto countsAttribute = Symbol::attr("none_counts");
17
18bool hasGradSumToSizeUses(Value* v) {
19 return std::any_of(v->uses().begin(), v->uses().end(), [](const Use& use) {
20 return use.user->kind() == aten::_grad_sum_to_size;
21 });
22}
23
24void insertProfileNodesForSpecializeAutogradZero(
25 Block* block,
26 ProfilingRecord* pr) {
27 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
28 auto n = *it;
29 for (const auto offset : c10::irange(n->inputs().size())) {
30 auto i = n->input(offset);
31 if (i->type()->cast<OptionalType>() && hasGradSumToSizeUses(i)) {
32 // here we are profile the definition instead of the use,
33 // because we are only optimizing in the case of a None value which is
34 // immutable
35 auto opt_pn = pr->createProfileIValueNode(i);
36
37 c10::Dict<std::string, int64_t> noneCountsDict;
38 noneCountsDict.insert("num_none", 0);
39 noneCountsDict.insert("num_present", 0);
40 IValue init_val(noneCountsDict);
41
42 opt_pn->ival_(countsAttribute, init_val);
43
44 std::function<void(Stack&)> optional_profiler = [pr,
45 opt_pn](Stack& stack) {
46 std::lock_guard<std::mutex> lock(pr->mutex_);
47
48 TORCH_INTERNAL_ASSERT(opt_pn->hasAttribute(countsAttribute));
49 // frame_id is unused
50 int64_t frame_id = 0;
51 pop(stack, frame_id);
52
53 const auto& counts_attr = opt_pn->ival(countsAttribute);
54 auto noneCounts = c10::impl::toTypedDict<std::string, int64_t>(
55 counts_attr.toGenericDict());
56 IValue value;
57 pop(stack, value);
58 if (value.isNone()) {
59 noneCounts.insert_or_assign(
60 "num_none", noneCounts.at("num_none") + 1);
61 } else {
62 noneCounts.insert_or_assign(
63 "num_present", noneCounts.at("num_present") + 1);
64 }
65 push(stack, value);
66 };
67 opt_pn->setCallback(optional_profiler);
68 opt_pn->insertAfter(i->node());
69 i->replaceAllUsesAfterNodeWith(opt_pn, opt_pn->output());
70 }
71 }
72
73 for (auto ib : n->blocks()) {
74 insertProfileNodesForSpecializeAutogradZero(ib, pr);
75 }
76 }
77}
78
79void InsertProfileNodesForSpecializeAutogradZero(ProfilingRecord* pr) {
80 insertProfileNodesForSpecializeAutogradZero(pr->profiled_graph_->block(), pr);
81}
82
83struct AutogradZeroSpecializer {
84 enum class State { Nonzero, Zero, Unknown };
85
86 AutogradZeroSpecializer(std::shared_ptr<Graph> graph)
87 : graph_(std::move(graph)) {}
88
89 void run() {
90 if (!isBackwardGraph()) {
91 return;
92 }
93 if (getExecutorMode()) {
94 if (auto versioning_if = guardSpecializations()) {
95 specializeAutogradOps(versioning_if->blocks()[0]);
96 GRAPH_DUMP("After versioning graph", graph_);
97 }
98 } else {
99 setStatesOnGraphInputs();
100 specializeAutogradOps(graph_->block());
101 }
102 GRAPH_DUMP("After specializeAutogradOps graph", graph_);
103 }
104
105 private:
106 bool isBackwardGraph() {
107 return std::any_of(
108 graph_->nodes().begin(), graph_->nodes().end(), [](Node* n) {
109 switch (n->kind()) {
110 case prim::AutogradAnyNonZero:
111 case prim::AutogradAdd:
112 case aten::_grad_sum_to_size:
113 return true;
114 default:
115 return false;
116 }
117 });
118 }
119
120 void replaceBlockInputsWithGraphInputs(Block* b) {
121 TORCH_INTERNAL_ASSERT(graph_->inputs().size() == b->inputs().size());
122 size_t num_inputs = graph_->inputs().size();
123 for (const auto i : c10::irange(num_inputs)) {
124 b->inputs().at(i)->replaceAllUsesWith(graph_->inputs().at(i));
125 }
126 for (const auto i : c10::irange(num_inputs)) {
127 b->eraseInput(num_inputs - (1 + i));
128 }
129 }
130
131 void setStatesOnGraphInputs() {
132 for (Value* input : graph_->inputs()) {
133 const auto& tp = input->type();
134 if (auto tt = tp->cast<TensorType>()) {
135 if (tt->undefined()) {
136 if (*tt->undefined()) {
137 state_[input] = State::Zero;
138 } else {
139 state_[input] = State::Nonzero;
140 }
141 } else {
142 state_[input] = State::Unknown;
143 }
144 } else if (
145 tp->isSubtypeOf(*TensorType::get()) ||
146 tp->isSubtypeOf(*ListType::ofTensors())) {
147 state_[input] = State::Nonzero;
148 } else {
149 state_[input] = State::Unknown;
150 }
151 }
152 }
153
154 static void getUsesWithAttribute_(
155 Value* inp,
156 Symbol attr,
157 std::vector<Node*>& uses) {
158 for (auto use : inp->uses()) {
159 if (use.user->kind() != prim::profile_ivalue) {
160 continue;
161 }
162
163 if (use.user->hasAttribute(attr)) {
164 uses.push_back(use.user);
165 }
166
167 getUsesWithAttribute_(use.user->output(), attr, uses);
168 }
169 }
170
171 // this is to deal with the fact that there could be other passes that
172 // would like to profile this exact same value. this helper walks
173 // chains of `prim::profile_ivalue` to locate the one inserted by/for
174 // `specializeAutogradZero`
175 static std::vector<Node*> getUsesWithAttribute(Value* inp, Symbol attr) {
176 std::vector<Node*> uses;
177 getUsesWithAttribute_(inp, attr, uses);
178 return uses;
179 }
180
181 static Node* getUse(Value* inp, Symbol kind) {
182 for (auto use : inp->uses()) {
183 if (use.user->kind() == kind) {
184 return use.user;
185 }
186 }
187
188 return nullptr;
189 }
190
191 void removeProfiledOptionalUses(const std::vector<Node*>& uses) {
192 TORCH_INTERNAL_ASSERT(!uses.empty());
193 auto inp = uses[0]->input();
194 // this removes `prim::profile_ivalue` from the original and to-specialize
195 // blocks N.B. the false block isn't impacted as it has been already
196 // encapsulated in a fallback function
197 for (auto u : uses) {
198 u->output()->replaceAllUsesWith(inp);
199 }
200 }
201
202 Node* guardSpecializations() {
203 auto versioning_if = graph_->create(prim::If, {}, graph_->outputs().size());
204 auto value_map = [](Value* v) { return v; };
205 auto true_block = versioning_if->addBlock();
206 auto false_block = versioning_if->addBlock();
207
208 // we will optimize true_block
209 true_block->cloneFrom(graph_->block(), value_map);
210 replaceBlockInputsWithGraphInputs(true_block);
211 false_block->cloneFrom(graph_->block(), value_map);
212 replaceBlockInputsWithGraphInputs(false_block);
213 replaceBlockWithFallbackGraph(false_block, graph_->inputs());
214
215 WithInsertPoint wip{graph_->block()->param_node()->next()};
216 Value* none_val = graph_->insertConstant(IValue());
217 std::vector<Value*> checks;
218 std::vector<Value*> zero_values;
219 std::vector<Value*> nonzero_values;
220
221 for (auto inp : graph_->inputs()) {
222 std::vector<Node*> iprofile_counts_nodes =
223 getUsesWithAttribute(inp, countsAttribute);
224 if (!iprofile_counts_nodes.empty()) {
225 // the original `prim::profile_value[num_present=0,...]` on `inp` is
226 // copied into `true_block` and `false_block`.
227 auto profile_ivalue_node = iprofile_counts_nodes[0];
228 TORCH_INTERNAL_ASSERT(
229 profile_ivalue_node->hasAttribute(countsAttribute));
230 const auto& counts_attr =
231 profile_ivalue_node->ival(countsAttribute).toGenericDict();
232 auto num_present = counts_attr.at(IValue{"num_present"}).toInt();
233 auto num_none = counts_attr.at(IValue{"num_none"}).toInt();
234 if (num_present == 0 && num_none != 0) {
235 auto check = graph_->insert(aten::__is__, {inp, none_val})->node();
236 checks.push_back(check->output());
237 profiled_none_.insert(inp);
238 }
239 removeProfiledOptionalUses(iprofile_counts_nodes);
240 continue;
241 }
242
243 if (inp->uses().empty() || !inp->type()->cast<TensorType>()) {
244 continue;
245 }
246
247 // TODO: check multiple uses ?
248 auto pout = getUse(inp, prim::profile);
249 if (!pout) {
250 continue;
251 }
252
253 auto pttp = pout->ty(attr::profiled_type)->expect<TensorType>();
254 if (!pttp->undefined().has_value()) {
255 continue;
256 }
257
258 state_[inp] = *pttp->undefined() ? State::Zero : State::Nonzero;
259
260 if (*pttp->undefined()) {
261 zero_values.push_back(inp);
262 } else {
263 nonzero_values.push_back(inp);
264 }
265 }
266 GRAPH_DUMP("After for loop", graph_);
267 // unable to specialize any of the inputs
268 if (nonzero_values.empty() && zero_values.empty()) {
269 GRAPH_DUMP("Unable to add any specialization guards", graph_);
270 versioning_if->destroy();
271 // the checks we inserted will be cleaned up
272 // by any subsequent DCE pass
273 return nullptr;
274 }
275
276 Node* nonzero_check = graph_->insert(prim::AutogradAllNonZero, {})->node();
277 for (Value* v : nonzero_values) {
278 nonzero_check->addInput(v);
279 }
280 checks.push_back(nonzero_check->output());
281
282 Node* zero_check = graph_->insert(prim::AutogradAllZero, {})->node();
283 for (Value* v : zero_values) {
284 zero_check->addInput(v);
285 }
286 checks.push_back(zero_check->output());
287
288 Value* bool_list =
289 graph_->insertNode(graph_->createList(BoolType::get(), checks))
290 ->output();
291 Value* conjunction = graph_->insert(aten::all, {bool_list});
292
293 versioning_if->addInput(conjunction);
294 graph_->insertNode(versioning_if);
295
296 auto ret = graph_->return_node();
297 for (const auto i : c10::irange(ret->inputs().size())) {
298 auto ogo = ret->input(i);
299 auto ngo = versioning_if->output(i);
300 ngo->copyMetadata(ogo);
301 ret->replaceInput(i, ngo);
302 }
303
304 // We've created:
305 // succesful_checks = Guards(...)
306 // if (succesful_checks)
307 // -> optimized graph
308 // else:
309 // -> fallback graph
310 // original graph
311 //
312 // Remove the dead original graph
313 for (auto it = graph_->block()->nodes().reverse().begin();
314 *it != versioning_if;) {
315 Node* n = *it;
316 it++;
317 n->destroy();
318 }
319
320 GRAPH_DUMP("After guardSpecializations", graph_);
321 return versioning_if;
322 }
323
324 void specializeAutogradOps(Block* block) {
325 for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) {
326 auto n = *it;
327 switch (n->kind()) {
328 case prim::AutogradAdd: {
329 auto a = n->input(0);
330 auto b = n->input(1);
331 // if one is Autograd zero, we can just drop the add
332 if (state_[a] == State::Zero) {
333 // Zero + b == b
334 n->output()->replaceAllUsesWith(b);
335 it.destroyCurrent();
336 } else if (state_[b] == State::Zero) {
337 // a + Zero == a
338 n->output()->replaceAllUsesWith(a);
339 it.destroyCurrent();
340 } else if (
341 state_[a] == State::Nonzero && state_[b] == State::Nonzero) {
342 // when both are Nonzero, we can use a normal, optimizable add
343 // instruction
344 WithInsertPoint guard(n);
345 auto* cOne = graph_->insertConstant(1);
346 auto* add_node = graph_->insertNode(graph_->create(aten::add, 1));
347 add_node->addInput(a);
348 add_node->addInput(b);
349 add_node->addInput(cOne);
350 auto* add_output = add_node->output();
351 add_output->setType(n->output()->type());
352 state_[add_output] = State::Nonzero;
353 n->output()->replaceAllUsesWith(add_output);
354 it.destroyCurrent();
355 } else {
356 // otherwise we have conditionally-Nonzero things, and we need
357 // to actually run an AutogradAdd which will guard for Zeros
358 // so we leave the op as is
359 state_[n->output()] = State::Unknown;
360 }
361 } break;
362 case prim::AutogradZero: {
363 state_[n->output()] = State::Zero;
364 } break;
365 case prim::profile: {
366 // this a profile node on a tensor use
367 // if we decided to specialize this graph
368 // its input may have undefinedness info
369 // otherwise it should be Unknown
370 if (!n->inputs().empty()) {
371 state_[n->output()] = !state_.count(n->input())
372 ? State::Unknown
373 : state_[n->output()] = state_[n->input()];
374 }
375 break;
376 }
377 case prim::BailOut: {
378 if (auto ptt = n->output()->type()->expect<TensorType>()) {
379 state_[n->output()] = ptt->undefined()
380 ? *ptt->undefined() ? State::Zero : State::Nonzero
381 : State::Unknown;
382 }
383 } break;
384 // Lowered GradOf block
385 case prim::If: {
386 auto if_input = n->input(0)->node();
387 if (if_input->kind() == prim::AutogradAnyNonZero) {
388 auto all_zeros = std::all_of(
389 if_input->inputs().begin(),
390 if_input->inputs().end(),
391 [&](Value* v) { return state_[v] == State::Zero; });
392
393 auto all_nonzeros = std::all_of(
394 if_input->inputs().begin(),
395 if_input->inputs().end(),
396 [&](Value* v) { return state_[v] == State::Nonzero; });
397 // Property 1: if all the gradInputs to the GradOf are Zero
398 // then the gradOutputs are also zero and will be represented as
399 // AutogradZero nodes
400 if (all_zeros) {
401 auto zero =
402 graph_->createAutogradZero()->insertAfter(n)->output();
403 state_[zero] = State::Zero;
404 for (auto o : n->outputs()) {
405 o->replaceAllUsesWith(zero);
406 }
407 it.destroyCurrent();
408 break;
409 }
410
411 specializeGradSumToSize(n->blocks().at(0));
412 if (all_nonzeros) {
413 auto body = n->blocks().at(0);
414 // hoist the nodes in the GradOf body to be before the linear
415 // block
416 for (auto it = body->nodes().begin();
417 it != body->nodes().end();) {
418 auto block_node = *it++;
419 block_node->moveBefore(n);
420 }
421
422 for (size_t i = 0; i < n->outputs().size(); ++i) {
423 n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
424 state_[body->outputs().at(i)] = State::Nonzero;
425 }
426 it.destroyCurrent();
427 break;
428 }
429 }
430
431 for (auto o : n->outputs()) {
432 state_[o] = State::Unknown;
433 }
434 break;
435 }
436 default:
437 for (auto o : n->outputs()) {
438 state_[o] = State::Unknown;
439 }
440 break;
441 }
442 }
443 }
444
445 void specializeGradSumToSize(Block* b) {
446 for (auto it = b->nodes().begin(); it != b->nodes().end(); ++it) {
447 Node* n = *it;
448 if (n->kind() == aten::_grad_sum_to_size) {
449 bool profiled_none_flag = profiled_none_.count(n->input(1));
450 const Node* node = n->input(1)->node();
451 // propagate profiled none through other profile_ivalue nodes;
452 while (!profiled_none_flag && node->kind() == prim::profile_ivalue) {
453 profiled_none_flag =
454 profiled_none_flag || profiled_none_.count(node->input(0));
455 node = node->input(0)->node();
456 }
457 if (n->input(1)->mustBeNone() || profiled_none_flag) {
458 n->output()->replaceAllUsesWith(n->input(0));
459 it.destroyCurrent();
460 }
461 }
462 }
463 }
464
465 std::shared_ptr<Graph> graph_;
466 std::unordered_set<Value*> profiled_none_;
467 std::unordered_map<Value*, State> state_;
468};
469
470// propagate autograd zero information through a gradient graph and
471// remove grad_of blocks if present.
472// Note: this is a very limited pass. It only propagates autograd zeros for
473// operations generated by the symbolic autodiff code and cleans up
474// AutogradAdds when possible. Outputs of other nodes are conservatively
475// marked Unknown and not optimized.
476void specializeAutogradZero(std::shared_ptr<Graph> g) {
477 AutogradZeroSpecializer azs(std::move(g));
478 azs.run();
479}
480
481} // namespace jit
482} // namespace torch
483