1#include <torch/csrc/jit/passes/remove_inplace_ops.h>
2
3namespace torch {
4namespace jit {
5namespace {
6static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = {
7 {aten::add_, aten::add},
8 {aten::sub_, aten::sub},
9 {aten::div_, aten::div},
10 {aten::mul_, aten::mul},
11 {aten::masked_fill_, aten::masked_fill},
12 {aten::zero_, aten::zeros_like},
13 {aten::fill_, aten::full_like}};
14
15// This is a horrible no good awful hack to "fill in" the TensorOptions
16// arguments of zeros_like and full_like so that the defaults are filled
17// in. Ugh. Would be better to just run the frontend to get the correct
18// arity here.
19static const std::unordered_map<NodeKind, int> expectedInputCount = {
20 {aten::zero_, 6},
21 {aten::fill_, 7}};
22
23bool isInplaceOp(const Node* node) {
24 return inPlaceToOutOfPlace.count(node->kind()) != 0;
25}
26
27// Remove all in-place ops and replace them with out-of-place equivalents.
28// e.g.
29// %foo = aten::add_(%foo, %n)
30// becomes
31// %foo.2 = aten::add(%foo, %n)
32//
33// NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by
34// another value. This is only to avoid breaking ONNX export; when alias
35// analysis is done we can emit a warning if someone tries to export.
36void RemoveInplaceOps(Block* block) {
37 auto graph = block->owningGraph();
38 auto it = block->nodes().begin();
39 while (it != block->nodes().end()) {
40 auto node = *it;
41 ++it;
42 for (auto block : node->blocks()) {
43 RemoveInplaceOps(block);
44 }
45
46 if (isInplaceOp(node)) {
47 // create a replacement out of place op
48 auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind()));
49 newNode->insertBefore(node);
50 newNode->copyMetadata(node);
51 // copy inputs
52 for (auto input : node->inputs()) {
53 newNode->addInput(input);
54 }
55
56 int additionalInputCount = 0;
57 if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) {
58 additionalInputCount = expectedInputCount.at(node->kind()) -
59 static_cast<int>(newNode->inputs().size());
60 }
61
62 for (int i = 0; i < additionalInputCount; ++i) {
63 auto noneNode = graph->createNone();
64 noneNode->insertBefore(newNode);
65 newNode->addInput(noneNode->output());
66 }
67
68 // Create a new output node and replace all uses of self with it
69 newNode->output()->copyMetadata(node->output());
70 node->replaceAllUsesWith(newNode);
71 node->inputs()[0]->replaceAllUsesAfterNodeWith(
72 newNode, newNode->output());
73 node->destroy();
74 }
75 }
76}
77} // namespace
78
79// Handles special case of binary inplace ops, where the first input node
80// has a lower type precedence than the second input node. When the
81// inplace node is converted to a regular op, this information is lost and
82// the resulting type is based on type precedence, just like regular ops.
83// To avoid this loss of information, we add a cast node before the input
84// node with the higher data type precedence, so that both the input types
85// are the same.
86// An example scenario would be:
87// Before:
88// graph(%0 : Float),
89// %1 : Half):
90// # Should result in a Half, but after translation to out-of-place,
91// # would become a Float b/c Half+Float -> Float.
92// %4 : Float = onnx::Cast[to=1](%1)
93// %5 : Float = onnx::Add(%4, %0)
94// ...
95// After:
96// graph(%0 : Float),
97// %1 : Half):
98// %4 : Half = onnx::Cast[to=10](%0)
99// %5 : Half = onnx::Add(%1, %4)
100// ...
101
102void ImplicitCastForBinaryInplaceOps(Block* b) {
103 for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) {
104 for (auto* child_block : it->blocks()) {
105 ImplicitCastForBinaryInplaceOps(child_block);
106 }
107
108 // Check type if inplace operation is a binary node
109 if ((it->kind() == aten::add_) || (it->kind() == aten::sub_) ||
110 (it->kind() == aten::mul_) || (it->kind() == aten::div_)) {
111 auto originalInputs = it->inputs();
112 if (originalInputs.at(0) == originalInputs.at(1)) {
113 continue;
114 }
115
116 auto shape_node = originalInputs.at(0)->node();
117 if ((shape_node->kind() == prim::NumToTensor) &&
118 (shape_node->inputs().at(0)->node()->kind() == aten::size)) {
119 std::cerr
120 << "In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#"
121 << "avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode"
122 << std::endl;
123 }
124
125 TensorTypePtr firstInp_tensor =
126 originalInputs.at(0)->type()->cast<TensorType>();
127 TensorTypePtr secondInp_tensor =
128 originalInputs.at(1)->type()->cast<TensorType>();
129 if (!(firstInp_tensor) || !(secondInp_tensor) ||
130 !(firstInp_tensor->scalarType().has_value())) {
131 continue;
132 }
133 auto newInputNode = it->owningGraph()->create(aten::type_as, 1);
134 newInputNode->insertBefore(*it);
135 newInputNode->addInput(originalInputs.at(1));
136 newInputNode->addInput(originalInputs.at(0));
137 it->replaceInput(1, newInputNode->outputs().at(0));
138 }
139 }
140}
141
142void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) {
143 ImplicitCastForBinaryInplaceOps(graph->block());
144 RemoveInplaceOps(graph->block());
145}
146} // namespace jit
147} // namespace torch
148