1 | #include <torch/csrc/jit/passes/constant_propagation.h> |
2 | |
3 | #include <ATen/core/functional.h> |
4 | #include <ATen/core/ivalue.h> |
5 | #include <c10/util/Exception.h> |
6 | #include <c10/util/irange.h> |
7 | #include <torch/csrc/autograd/variable.h> |
8 | #include <torch/csrc/jit/ir/alias_analysis.h> |
9 | #include <torch/csrc/jit/ir/constants.h> |
10 | #include <torch/csrc/jit/ir/ir.h> |
11 | #include <torch/csrc/jit/ir/node_hashing.h> |
12 | #include <torch/csrc/jit/jit_log.h> |
13 | #include <torch/csrc/jit/passes/dead_code_elimination.h> |
14 | #include <torch/csrc/jit/runtime/operator.h> |
15 | #include <torch/csrc/jit/runtime/vararg_functions.h> |
16 | #include <torch/csrc/utils/memory.h> |
17 | |
18 | #include <utility> |
19 | |
20 | namespace torch { |
21 | namespace jit { |
22 | |
23 | c10::optional<std::vector<IValue>> runNodeIfInputsAreConstant( |
24 | const Node* n, |
25 | bool ignore_custom_classes, |
26 | AliasDb* db) { |
27 | Stack stack; |
28 | for (auto input : n->inputs()) { |
29 | if (auto ival = toIValue(input)) { |
30 | stack.push_back(*ival); |
31 | } else { |
32 | return c10::nullopt; |
33 | } |
34 | } |
35 | |
36 | switch (n->kind()) { |
37 | case prim::ListUnpack: { |
38 | if (stack.back().toList().size() != n->outputs().size()) { |
39 | return c10::nullopt; |
40 | } |
41 | listUnpack(stack, n->outputs().size()); |
42 | } break; |
43 | case prim::TupleConstruct: { |
44 | auto tt = n->output()->type()->expect<TupleType>(); |
45 | if (tt->name()) { |
46 | namedTupleConstruct(stack, std::move(tt), n->inputs().size()); |
47 | } else { |
48 | tupleConstruct(stack, n->inputs().size()); |
49 | } |
50 | } break; |
51 | case prim::ListConstruct: { |
52 | listConstruct( |
53 | stack, |
54 | n->output()->type()->expectRef<ListType>(), |
55 | n->inputs().size()); |
56 | } break; |
57 | case prim::DictConstruct: { |
58 | dictConstruct( |
59 | stack, |
60 | n->output()->type()->expectRef<DictType>(), |
61 | n->inputs().size()); |
62 | } break; |
63 | case prim::CreateObject: { |
64 | createObject( |
65 | stack, |
66 | n->output()->type()->expect<ClassType>(), |
67 | /*use_weak_ref*/ true); |
68 | } break; |
69 | case prim::GetAttr: { |
70 | auto attr = pop(stack).toObject()->getAttr(n->s(attr::name)); |
71 | push(stack, attr); |
72 | } break; |
73 | case prim::isinstance: { |
74 | isinstance(stack, n->tys(attr::types)); |
75 | } break; |
76 | default: { |
77 | const auto maybe_schema = n->maybeSchema(); |
78 | if (maybe_schema && maybe_schema->is_vararg()) { |
79 | // vararg schemas require the number of inputs at the top of the stack |
80 | // but this is broken in other places in constant prop, so disable it |
81 | // for now |
82 | return c10::nullopt; |
83 | } |
84 | |
85 | try { |
86 | auto op = n->getOperation(); |
87 | op(stack); |
88 | } catch (...) { |
89 | return c10::nullopt; |
90 | } |
91 | } break; |
92 | } |
93 | |
94 | for (IValue& v : stack) { |
95 | if (v.isTensor()) { |
96 | const at::Tensor& t = v.toTensor(); |
97 | if (t.defined() && t.requires_grad()) { |
98 | // requires grad tensors cannot be constants |
99 | return c10::nullopt; |
100 | } |
101 | } |
102 | // Weak form of const propagation |
103 | if (ignore_custom_classes) { |
104 | if (v.isCustomClass()) { |
105 | return c10::nullopt; |
106 | } |
107 | } |
108 | // see [Constant Object Weak CompilationUnit Reference] |
109 | if (v.isCustomClass()) { |
110 | if (v.toObject()->is_weak_compilation_ref()) { |
111 | continue; |
112 | } |
113 | if (!db) { |
114 | continue; |
115 | } |
116 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) |
117 | Node* n_non_const = const_cast<Node*>(n); |
118 | if (db->mayContainAlias( |
119 | n_non_const->inputs(), {n_non_const->outputs()})) { |
120 | continue; |
121 | } |
122 | auto obj = v.toObject(); |
123 | obj->unsafe_make_weak_compilation_ref(); |
124 | } |
125 | if (v.isObject()) { |
126 | if (!v.toObject()->is_weak_compilation_ref()) { |
127 | return c10::nullopt; |
128 | } |
129 | } |
130 | } |
131 | return stack; |
132 | } |
133 | |
134 | namespace { |
135 | |
136 | std::unordered_set<Symbol> skip_list = { |
137 | prim::If, |
138 | prim::Loop, |
139 | prim::Closure, |
140 | prim::Constant, |
141 | prim::AutogradZero, |
142 | prim::Uninitialized, |
143 | prim::Guard, |
144 | prim::profile, |
145 | prim::profile_ivalue, |
146 | prim::unchecked_unwrap_optional, // TODO remove |
147 | prim::awaitable, |
148 | aten::dequantize, |
149 | // TODO (zach): we should consider skipping tensor factories in the cases |
150 | // where the constant tensor would be large but cheap to create. |
151 | }; |
152 | |
153 | struct ConstantPropagator { |
154 | // Runs constant propagation with an aliasing db and checks if inputs or |
155 | // outputs might be mutated in the graph |
156 | static ConstantPropagator WithAliasDb( |
157 | std::shared_ptr<Graph> graph, |
158 | bool ignore_custom_classes) { |
159 | return ConstantPropagator(std::move(graph), true, ignore_custom_classes); |
160 | } |
161 | |
162 | // Runs constant propagation only on ops that clearly do not have aliased |
163 | // inputs or outputs without computing aliasing information |
164 | static ConstantPropagator NoAliasDb(std::shared_ptr<Graph> graph) { |
165 | return ConstantPropagator(std::move(graph), false, false); |
166 | } |
167 | |
168 | bool run() { |
169 | ConstantPropagation(graph_->block()); |
170 | return made_change_; |
171 | } |
172 | |
173 | private: |
174 | ConstantPropagator( |
175 | std::shared_ptr<Graph> graph, |
176 | bool aliasing_types, |
177 | bool ignore_custom_classes) |
178 | : graph_(std::move(graph)), |
179 | aliasing_types_(aliasing_types), |
180 | ignore_custom_classes_(ignore_custom_classes) {} |
181 | |
182 | void propagateNode(Node* n) { |
183 | std::vector<IValue> outputs; |
184 | if (auto outputs_opt = |
185 | runNodeIfInputsAreConstant(n, ignore_custom_classes_)) { |
186 | outputs = std::move(outputs_opt.value()); |
187 | } else { |
188 | // The op failed to run, so we cannot continue constant-prop for it. |
189 | return; |
190 | } |
191 | auto graph = n->owningGraph(); |
192 | WithInsertPoint guard(n); |
193 | for (const auto i : c10::irange(outputs.size())) { |
194 | auto new_output = tryInsertConstant(*graph, outputs[i]); |
195 | if (new_output) { |
196 | made_change_ = true; |
197 | GRAPH_UPDATE( |
198 | "Folding %" , |
199 | n->outputs()[i]->debugName(), |
200 | " with " , |
201 | getHeader((*new_output)->node())); |
202 | if (outputs[i].isNone()) { |
203 | (*new_output)->setType(n->outputs()[i]->type()); |
204 | } |
205 | n->outputs()[i]->replaceAllUsesWith(*new_output); |
206 | } |
207 | // If we cannot insert the IValue as a constant, give up replacing the |
208 | // node and let DCE remove it |
209 | } |
210 | } |
211 | |
212 | void removeLoopNode(Node* n) { |
213 | auto loop_input_offset = 2; // offset of loop carried deps in input list |
214 | for (size_t i = 0; i < n->outputs().size(); ++i) { |
215 | n->outputs().at(i)->replaceAllUsesWith( |
216 | n->inputs().at(i + loop_input_offset)); |
217 | } |
218 | made_change_ = true; |
219 | n->destroy(); |
220 | } |
221 | |
222 | bool loopWillNotRun(Node* node) { |
223 | Value* trip_count = node->inputs().at(0); |
224 | int64_t iter_len = constant_as<int64_t>(trip_count).value_or(1); |
225 | |
226 | Value* start_cond = node->inputs().at(1); |
227 | bool cond_val = constant_as<bool>(start_cond).value_or(true); |
228 | |
229 | bool loop_might_run = cond_val && iter_len > 0; |
230 | if (!loop_might_run) { |
231 | GRAPH_UPDATE( |
232 | "Removing unexecuted loop: " , |
233 | *node, |
234 | "\ntripcount: " , |
235 | trip_count, |
236 | " and start_cond: " , |
237 | getHeader(start_cond->node())); |
238 | } |
239 | return !loop_might_run; |
240 | } |
241 | |
242 | void inlineIfBody(Block* body) { |
243 | Node* n = body->owningNode(); |
244 | for (auto it = body->nodes().begin(); it != body->nodes().end();) { |
245 | Node* body_node = *it; |
246 | // advance iterator because after body_node is moved its next pointer will |
247 | // be to n |
248 | it++; |
249 | body_node->moveBefore(n); |
250 | } |
251 | for (size_t i = 0; i < n->outputs().size(); ++i) { |
252 | n->outputs().at(i)->replaceAllUsesWith(body->outputs().at(i)); |
253 | } |
254 | // NB: destroy the node here, because it might contain side effects, like |
255 | // print |
256 | n->destroy(); |
257 | } |
258 | |
259 | void inlineIf(Node* n) { |
260 | auto input_bool = constant_as<bool>(n->input()); |
261 | AT_ASSERT(input_bool); |
262 | GRAPH_UPDATE( |
263 | "Folding if " , |
264 | getHeader(n->input()->node()), |
265 | " where condition = " , |
266 | *input_bool); |
267 | size_t block_index = *input_bool ? 0 : 1; |
268 | ConstantPropagation(n->blocks().at(block_index)); |
269 | inlineIfBody(n->blocks().at(block_index)); |
270 | made_change_ = true; |
271 | } |
272 | |
273 | void replaceAndRemoveIfOutput(Node* n, size_t i, Value* replacement) { |
274 | n->outputs().at(i)->replaceAllUsesWith(replacement); |
275 | n->eraseOutput(i); |
276 | n->blocks().at(0)->eraseOutput(i); |
277 | n->blocks().at(1)->eraseOutput(i); |
278 | } |
279 | |
280 | // remove extra outputs from the node |
281 | void (Node* n) { |
282 | TORCH_CHECK(n->kind() == prim::If, "Only supported for If nodes" ); |
283 | auto true_block = n->blocks()[0]; |
284 | auto false_block = n->blocks()[1]; |
285 | auto graph = n->owningGraph(); |
286 | auto initial_outputs = true_block->outputs().size(); |
287 | WithInsertPoint guard(n); |
288 | for (size_t i = 0; i < true_block->outputs().size();) { |
289 | auto t_out = true_block->outputs().at(i); |
290 | auto f_out = false_block->outputs().at(i); |
291 | |
292 | // neither block changes the output value |
293 | if (true_block->outputs()[i] == false_block->outputs()[i]) { |
294 | replaceAndRemoveIfOutput(n, i, true_block->outputs()[i]); |
295 | continue; |
296 | } |
297 | |
298 | // true block output is constant and constant matches false block output |
299 | auto maybe_const = toIValue(t_out); |
300 | auto eq = EqualNode(); |
301 | if (maybe_const && eq(t_out->node(), f_out->node())) { |
302 | auto new_const = graph->insertConstant(*maybe_const); |
303 | replaceAndRemoveIfOutput(n, i, new_const); |
304 | continue; |
305 | } |
306 | |
307 | i++; // increment bc we didn't remove current index |
308 | } |
309 | made_change_ |= initial_outputs != true_block->outputs().size(); |
310 | } |
311 | |
312 | // remove extra outputs from the node |
313 | void (Node* node) { |
314 | auto initial_outputs = node->outputs().size(); |
315 | auto loop_body = node->blocks().at(0); |
316 | auto loop_input_offset = 2; // offset of loop carried deps in input list |
317 | auto loop_body_offset = |
318 | 1; // offset to the loop carried dependencies in block inputs/outputs |
319 | for (size_t i_1 = node->outputs().size(); i_1 > 0; --i_1) { |
320 | size_t i = i_1 - 1; |
321 | // if the value is no longer changed remove output |
322 | if (loop_body->inputs().at(loop_body_offset + i) == |
323 | loop_body->outputs().at(loop_body_offset + i)) { |
324 | auto node_input = node->inputs().at(loop_input_offset + i); |
325 | node->outputs().at(i)->replaceAllUsesWith(node_input); |
326 | loop_body->inputs() |
327 | .at(loop_body_offset + i) |
328 | ->replaceAllUsesWith(node_input); |
329 | node->eraseOutput(i); |
330 | node->removeInput(loop_input_offset + i); |
331 | loop_body->eraseInput(loop_body_offset + i); |
332 | loop_body->eraseOutput(loop_body_offset + i); |
333 | } |
334 | } |
335 | made_change_ |= initial_outputs != node->outputs().size(); |
336 | } |
337 | |
338 | bool noMutableValues(at::ArrayRef<Value*> values) { |
339 | return std::none_of(values.begin(), values.end(), [](Value* v) { |
340 | return AliasDb::isMutableType(v); |
341 | }); |
342 | } |
343 | |
344 | AliasDb* getOrCreateAliasDb() { |
345 | if (!aliasDb_) { |
346 | aliasDb_ = std::make_unique<AliasDb>(graph_); |
347 | } |
348 | return aliasDb_.get(); |
349 | } |
350 | |
351 | bool supportedNode(Node* n) { |
352 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
353 | bool no_mutation; |
354 | if (aliasing_types_) { |
355 | no_mutation = !getOrCreateAliasDb()->hasWriters(n); |
356 | } else { |
357 | no_mutation = |
358 | noMutableValues(n->inputs()) && noMutableValues(n->outputs()); |
359 | } |
360 | return no_mutation && !n->kind().is_onnx() && |
361 | skip_list.count(n->kind()) == 0 && !n->isNondeterministic() && |
362 | !n->hasSideEffects() && n->blocks().empty(); |
363 | } |
364 | |
365 | void ConstantPropagation(at::ArrayRef<Block*> blocks) { |
366 | for (Block* block : blocks) { |
367 | ConstantPropagation(block); |
368 | } |
369 | } |
370 | |
371 | void ConstantPropagation(Node* n) { |
372 | bool constant_inputs = |
373 | std::all_of(n->inputs().begin(), n->inputs().end(), [&](Value* v) { |
374 | return v->node()->kind() == prim::Constant; |
375 | }); |
376 | if (n->kind() == prim::If) { |
377 | // inline node if we can, otherwise check for simplified outputs |
378 | if (constant_inputs) { |
379 | inlineIf(n); |
380 | } else { |
381 | ConstantPropagation(n->blocks()); |
382 | removeExtraIfOutputs(n); |
383 | } |
384 | } else if (n->kind() == prim::Loop) { |
385 | if (loopWillNotRun(n)) { |
386 | removeLoopNode(n); |
387 | } else { |
388 | ConstantPropagation(n->blocks()); |
389 | removeExtraLoopOutputs(n); |
390 | } |
391 | } else if (constant_inputs && supportedNode(n)) { |
392 | propagateNode(n); |
393 | } else { |
394 | ConstantPropagation(n->blocks()); |
395 | } |
396 | } |
397 | |
398 | void ConstantPropagation(Block* block) { |
399 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
400 | Node* n = *it; |
401 | it++; // advance iterator bc the current node may be destroyed |
402 | ConstantPropagation(n); |
403 | } |
404 | } |
405 | |
406 | std::shared_ptr<Graph> graph_; |
407 | // lazily initialized if using aliasing_types, otherwise not initialized |
408 | std::unique_ptr<AliasDb> aliasDb_ = nullptr; |
409 | bool aliasing_types_; |
410 | bool made_change_ = false; |
411 | bool ignore_custom_classes_; |
412 | }; |
413 | } // anonymous namespace |
414 | |
415 | bool ConstantPropagation( |
416 | std::shared_ptr<Graph>& graph, |
417 | bool ignore_custom_classes) { |
418 | ConstantPropagator cp = |
419 | ConstantPropagator::WithAliasDb(graph, ignore_custom_classes); |
420 | bool made_change = cp.run(); |
421 | if (made_change) { |
422 | EliminateDeadCode(graph); |
423 | } |
424 | GRAPH_DUMP("After ConstantPropagation: " , graph); |
425 | return made_change; |
426 | } |
427 | |
428 | bool ConstantPropagationImmutableTypes(std::shared_ptr<Graph>& graph) { |
429 | ConstantPropagator cp = ConstantPropagator::NoAliasDb(graph); |
430 | bool made_change = cp.run(); |
431 | if (made_change) { |
432 | EliminateDeadCode(graph); |
433 | } |
434 | GRAPH_DUMP("After ConstantPropagationImmutableTypes: " , graph); |
435 | return made_change; |
436 | } |
437 | |
438 | } // namespace jit |
439 | } // namespace torch |
440 | |