1#include <torch/csrc/jit/passes/constant_propagation.h>
2
3#include <ATen/core/functional.h>
4#include <ATen/core/ivalue.h>
5#include <c10/util/Exception.h>
6#include <c10/util/irange.h>
7#include <torch/csrc/autograd/variable.h>
8#include <torch/csrc/jit/ir/alias_analysis.h>
9#include <torch/csrc/jit/ir/constants.h>
10#include <torch/csrc/jit/ir/ir.h>
11#include <torch/csrc/jit/ir/node_hashing.h>
12#include <torch/csrc/jit/jit_log.h>
13#include <torch/csrc/jit/passes/dead_code_elimination.h>
14#include <torch/csrc/jit/runtime/operator.h>
15#include <torch/csrc/jit/runtime/vararg_functions.h>
16#include <torch/csrc/utils/memory.h>
17
18#include <utility>
19
20namespace torch {
21namespace jit {
22
23c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant(
24 const Node* n,
25 bool ignore_custom_classes,
26 AliasDb* db) {
27 Stack stack;
28 for (auto input : n->inputs()) {
29 if (auto ival = toIValue(input)) {
30 stack.push_back(*ival);
31 } else {
32 return c10::nullopt;
33 }
34 }
35
36 switch (n->kind()) {
37 case prim::ListUnpack: {
38 if (stack.back().toList().size() != n->outputs().size()) {
39 return c10::nullopt;
40 }
41 listUnpack(stack, n->outputs().size());
42 } break;
43 case prim::TupleConstruct: {
44 auto tt = n->output()->type()->expect<TupleType>();
45 if (tt->name()) {
46 namedTupleConstruct(stack, std::move(tt), n->inputs().size());
47 } else {
48 tupleConstruct(stack, n->inputs().size());
49 }
50 } break;
51 case prim::ListConstruct: {
52 listConstruct(
53 stack,
54 n->output()->type()->expectRef<ListType>(),
55 n->inputs().size());
56 } break;
57 case prim::DictConstruct: {
58 dictConstruct(
59 stack,
60 n->output()->type()->expectRef<DictType>(),
61 n->inputs().size());
62 } break;
63 case prim::CreateObject: {
64 createObject(
65 stack,
66 n->output()->type()->expect<ClassType>(),
67 /*use_weak_ref*/ true);
68 } break;
69 case prim::GetAttr: {
70 auto attr = pop(stack).toObject()->getAttr(n->s(attr::name));
71 push(stack, attr);
72 } break;
73 case prim::isinstance: {
74 isinstance(stack, n->tys(attr::types));
75 } break;
76 default: {
77 const auto maybe_schema = n->maybeSchema();
78 if (maybe_schema && maybe_schema->is_vararg()) {
79 // vararg schemas require the number of inputs at the top of the stack
80 // but this is broken in other places in constant prop, so disable it
81 // for now
82 return c10::nullopt;
83 }
84
85 try {
86 auto op = n->getOperation();
87 op(stack);
88 } catch (...) {
89 return c10::nullopt;
90 }
91 } break;
92 }
93
94 for (IValue& v : stack) {
95 if (v.isTensor()) {
96 const at::Tensor& t = v.toTensor();
97 if (t.defined() && t.requires_grad()) {
98 // requires grad tensors cannot be constants
99 return c10::nullopt;
100 }
101 }
102 // Weak form of const propagation
103 if (ignore_custom_classes) {
104 if (v.isCustomClass()) {
105 return c10::nullopt;
106 }
107 }
108 // see [Constant Object Weak CompilationUnit Reference]
109 if (v.isCustomClass()) {
110 if (v.toObject()->is_weak_compilation_ref()) {
111 continue;
112 }
113 if (!db) {
114 continue;
115 }
116 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
117 Node* n_non_const = const_cast<Node*>(n);
118 if (db->mayContainAlias(
119 n_non_const->inputs(), {n_non_const->outputs()})) {
120 continue;
121 }
122 auto obj = v.toObject();
123 obj->unsafe_make_weak_compilation_ref();
124 }
125 if (v.isObject()) {
126 if (!v.toObject()->is_weak_compilation_ref()) {
127 return c10::nullopt;
128 }
129 }
130 }
131 return stack;
132}
133
134namespace {
135
136std::unordered_set<Symbol> skip_list = {
137 prim::If,
138 prim::Loop,
139 prim::Closure,
140 prim::Constant,
141 prim::AutogradZero,
142 prim::Uninitialized,
143 prim::Guard,
144 prim::profile,
145 prim::profile_ivalue,
146 prim::unchecked_unwrap_optional, // TODO remove
147 prim::awaitable,
148 aten::dequantize,
149 // TODO (zach): we should consider skipping tensor factories in the cases
150 // where the constant tensor would be large but cheap to create.
151};
152
153struct ConstantPropagator {
154 // Runs constant propagation with an aliasing db and checks if inputs or
155 // outputs might be mutated in the graph
156 static ConstantPropagator WithAliasDb(
157 std::shared_ptr<Graph> graph,
158 bool ignore_custom_classes) {
159 return ConstantPropagator(std::move(graph), true, ignore_custom_classes);
160 }
161
162 // Runs constant propagation only on ops that clearly do not have aliased
163 // inputs or outputs without computing aliasing information
164 static ConstantPropagator NoAliasDb(std::shared_ptr<Graph> graph) {
165 return ConstantPropagator(std::move(graph), false, false);
166 }
167
168 bool run() {
169 ConstantPropagation(graph_->block());
170 return made_change_;
171 }
172
173 private:
174 ConstantPropagator(
175 std::shared_ptr<Graph> graph,
176 bool aliasing_types,
177 bool ignore_custom_classes)
178 : graph_(std::move(graph)),
179 aliasing_types_(aliasing_types),
180 ignore_custom_classes_(ignore_custom_classes) {}
181
182 void propagateNode(Node* n) {
183 std::vector<IValue> outputs;
184 if (auto outputs_opt =
185 runNodeIfInputsAreConstant(n, ignore_custom_classes_)) {
186 outputs = std::move(outputs_opt.value());
187 } else {
188 // The op failed to run, so we cannot continue constant-prop for it.
189 return;
190 }
191 auto graph = n->owningGraph();
192 WithInsertPoint guard(n);
193 for (const auto i : c10::irange(outputs.size())) {
194 auto new_output = tryInsertConstant(*graph, outputs[i]);
195 if (new_output) {
196 made_change_ = true;
197 GRAPH_UPDATE(
198 "Folding %",
199 n->outputs()[i]->debugName(),
200 " with ",
201 getHeader((*new_output)->node()));
202 if (outputs[i].isNone()) {
203 (*new_output)->setType(n->outputs()[i]->type());
204 }
205 n->outputs()[i]->replaceAllUsesWith(*new_output);
206 }
207 // If we cannot insert the IValue as a constant, give up replacing the
208 // node and let DCE remove it
209 }
210 }
211
212 void removeLoopNode(Node* n) {
213 auto loop_input_offset = 2; // offset of loop carried deps in input list
214 for (size_t i = 0; i < n->outputs().size(); ++i) {
215 n->outputs().at(i)->replaceAllUsesWith(
216 n->inputs().at(i + loop_input_offset));
217 }
218 made_change_ = true;
219 n->destroy();
220 }
221
222 bool loopWillNotRun(Node* node) {
223 Value* trip_count = node->inputs().at(0);
224 int64_t iter_len = constant_as<int64_t>(trip_count).value_or(1);
225
226 Value* start_cond = node->inputs().at(1);
227 bool cond_val = constant_as<bool>(start_cond).value_or(true);
228
229 bool loop_might_run = cond_val && iter_len > 0;
230 if (!loop_might_run) {
231 GRAPH_UPDATE(
232 "Removing unexecuted loop: ",
233 *node,
234 "\ntripcount: ",
235 trip_count,
236 " and start_cond: ",
237 getHeader(start_cond->node()));
238 }
239 return !loop_might_run;
240 }
241
242 void inlineIfBody(Block* body) {
243 Node* n = body->owningNode();
244 for (auto it = body->nodes().begin(); it != body->nodes().end();) {
245 Node* body_node = *it;
246 // advance iterator because after body_node is moved its next pointer will
247 // be to n
248 it++;
249 body_node->moveBefore(n);
250 }
251 for (size_t i = 0; i < n->outputs().size(); ++i) {
252 n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i));
253 }
254 // NB: destroy the node here, because it might contain side effects, like
255 // print
256 n->destroy();
257 }
258
259 void inlineIf(Node* n) {
260 auto input_bool = constant_as<bool>(n->input());
261 AT_ASSERT(input_bool);
262 GRAPH_UPDATE(
263 "Folding if ",
264 getHeader(n->input()->node()),
265 " where condition = ",
266 *input_bool);
267 size_t block_index = *input_bool ? 0 : 1;
268 ConstantPropagation(n->blocks().at(block_index));
269 inlineIfBody(n->blocks().at(block_index));
270 made_change_ = true;
271 }
272
273 void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) {
274 n->outputs().at(i)->replaceAllUsesWith(replacement);
275 n->eraseOutput(i);
276 n->blocks().at(0)->eraseOutput(i);
277 n->blocks().at(1)->eraseOutput(i);
278 }
279
280 // remove extra outputs from the node
281 void removeExtraIfOutputs(Node* n) {
282 TORCH_CHECK(n->kind() == prim::If, "Only supported for If nodes");
283 auto true_block = n->blocks()[0];
284 auto false_block = n->blocks()[1];
285 auto graph = n->owningGraph();
286 auto initial_outputs = true_block->outputs().size();
287 WithInsertPoint guard(n);
288 for (size_t i = 0; i < true_block->outputs().size();) {
289 auto t_out = true_block->outputs().at(i);
290 auto f_out = false_block->outputs().at(i);
291
292 // neither block changes the output value
293 if (true_block->outputs()[i] == false_block->outputs()[i]) {
294 replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]);
295 continue;
296 }
297
298 // true block output is constant and constant matches false block output
299 auto maybe_const = toIValue(t_out);
300 auto eq = EqualNode();
301 if (maybe_const && eq(t_out->node(), f_out->node())) {
302 auto new_const = graph->insertConstant(*maybe_const);
303 replaceAndRemoveIfOutput(n, i, new_const);
304 continue;
305 }
306
307 i++; // increment bc we didn't remove current index
308 }
309 made_change_ |= initial_outputs != true_block->outputs().size();
310 }
311
312 // remove extra outputs from the node
313 void removeExtraLoopOutputs(Node* node) {
314 auto initial_outputs = node->outputs().size();
315 auto loop_body = node->blocks().at(0);
316 auto loop_input_offset = 2; // offset of loop carried deps in input list
317 auto loop_body_offset =
318 1; // offset to the loop carried dependencies in block inputs/outputs
319 for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) {
320 size_t i = i_1 - 1;
321 // if the value is no longer changed remove output
322 if (loop_body->inputs().at(loop_body_offset + i) ==
323 loop_body->outputs().at(loop_body_offset + i)) {
324 auto node_input = node->inputs().at(loop_input_offset + i);
325 node->outputs().at(i)->replaceAllUsesWith(node_input);
326 loop_body->inputs()
327 .at(loop_body_offset + i)
328 ->replaceAllUsesWith(node_input);
329 node->eraseOutput(i);
330 node->removeInput(loop_input_offset + i);
331 loop_body->eraseInput(loop_body_offset + i);
332 loop_body->eraseOutput(loop_body_offset + i);
333 }
334 }
335 made_change_ |= initial_outputs != node->outputs().size();
336 }
337
338 bool noMutableValues(at::ArrayRef<Value*> values) {
339 return std::none_of(values.begin(), values.end(), [](Value* v) {
340 return AliasDb::isMutableType(v);
341 });
342 }
343
344 AliasDb* getOrCreateAliasDb() {
345 if (!aliasDb_) {
346 aliasDb_ = std::make_unique<AliasDb>(graph_);
347 }
348 return aliasDb_.get();
349 }
350
351 bool supportedNode(Node* n) {
352 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
353 bool no_mutation;
354 if (aliasing_types_) {
355 no_mutation = !getOrCreateAliasDb()->hasWriters(n);
356 } else {
357 no_mutation =
358 noMutableValues(n->inputs()) && noMutableValues(n->outputs());
359 }
360 return no_mutation && !n->kind().is_onnx() &&
361 skip_list.count(n->kind()) == 0 && !n->isNondeterministic() &&
362 !n->hasSideEffects() && n->blocks().empty();
363 }
364
365 void ConstantPropagation(at::ArrayRef<Block*> blocks) {
366 for (Block* block : blocks) {
367 ConstantPropagation(block);
368 }
369 }
370
371 void ConstantPropagation(Node* n) {
372 bool constant_inputs =
373 std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) {
374 return v->node()->kind() == prim::Constant;
375 });
376 if (n->kind() == prim::If) {
377 // inline node if we can, otherwise check for simplified outputs
378 if (constant_inputs) {
379 inlineIf(n);
380 } else {
381 ConstantPropagation(n->blocks());
382 removeExtraIfOutputs(n);
383 }
384 } else if (n->kind() == prim::Loop) {
385 if (loopWillNotRun(n)) {
386 removeLoopNode(n);
387 } else {
388 ConstantPropagation(n->blocks());
389 removeExtraLoopOutputs(n);
390 }
391 } else if (constant_inputs && supportedNode(n)) {
392 propagateNode(n);
393 } else {
394 ConstantPropagation(n->blocks());
395 }
396 }
397
398 void ConstantPropagation(Block* block) {
399 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
400 Node* n = *it;
401 it++; // advance iterator bc the current node may be destroyed
402 ConstantPropagation(n);
403 }
404 }
405
406 std::shared_ptr<Graph> graph_;
407 // lazily initialized if using aliasing_types, otherwise not initialized
408 std::unique_ptr<AliasDb> aliasDb_ = nullptr;
409 bool aliasing_types_;
410 bool made_change_ = false;
411 bool ignore_custom_classes_;
412};
413} // anonymous namespace
414
415bool ConstantPropagation(
416 std::shared_ptr<Graph>& graph,
417 bool ignore_custom_classes) {
418 ConstantPropagator cp =
419 ConstantPropagator::WithAliasDb(graph, ignore_custom_classes);
420 bool made_change = cp.run();
421 if (made_change) {
422 EliminateDeadCode(graph);
423 }
424 GRAPH_DUMP("After ConstantPropagation: ", graph);
425 return made_change;
426}
427
428bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph) {
429 ConstantPropagator cp = ConstantPropagator::NoAliasDb(graph);
430 bool made_change = cp.run();
431 if (made_change) {
432 EliminateDeadCode(graph);
433 }
434 GRAPH_DUMP("After ConstantPropagationImmutableTypes: ", graph);
435 return made_change;
436}
437
438} // namespace jit
439} // namespace torch
440