1 | #include <torch/csrc/jit/passes/remove_inplace_ops.h> |
2 | |
3 | namespace torch { |
4 | namespace jit { |
5 | namespace { |
6 | static const std::unordered_map<NodeKind, NodeKind> inPlaceToOutOfPlace = { |
7 | {aten::add_, aten::add}, |
8 | {aten::sub_, aten::sub}, |
9 | {aten::div_, aten::div}, |
10 | {aten::mul_, aten::mul}, |
11 | {aten::masked_fill_, aten::masked_fill}, |
12 | {aten::zero_, aten::zeros_like}, |
13 | {aten::fill_, aten::full_like}}; |
14 | |
15 | // This is a horrible no good awful hack to "fill in" the TensorOptions |
16 | // arguments of zeros_like and full_like so that the defaults are filled |
17 | // in. Ugh. Would be better to just run the frontend to get the correct |
18 | // arity here. |
19 | static const std::unordered_map<NodeKind, int> expectedInputCount = { |
20 | {aten::zero_, 6}, |
21 | {aten::fill_, 7}}; |
22 | |
23 | bool isInplaceOp(const Node* node) { |
24 | return inPlaceToOutOfPlace.count(node->kind()) != 0; |
25 | } |
26 | |
27 | // Remove all in-place ops and replace them with out-of-place equivalents. |
28 | // e.g. |
29 | // %foo = aten::add_(%foo, %n) |
30 | // becomes |
31 | // %foo.2 = aten::add(%foo, %n) |
32 | // |
33 | // NOTE: this is NOT SAFE, since it assumes that the LHS is not aliased by |
34 | // another value. This is only to avoid breaking ONNX export; when alias |
35 | // analysis is done we can emit a warning if someone tries to export. |
36 | void RemoveInplaceOps(Block* block) { |
37 | auto graph = block->owningGraph(); |
38 | auto it = block->nodes().begin(); |
39 | while (it != block->nodes().end()) { |
40 | auto node = *it; |
41 | ++it; |
42 | for (auto block : node->blocks()) { |
43 | RemoveInplaceOps(block); |
44 | } |
45 | |
46 | if (isInplaceOp(node)) { |
47 | // create a replacement out of place op |
48 | auto newNode = graph->create(inPlaceToOutOfPlace.at(node->kind())); |
49 | newNode->insertBefore(node); |
50 | newNode->copyMetadata(node); |
51 | // copy inputs |
52 | for (auto input : node->inputs()) { |
53 | newNode->addInput(input); |
54 | } |
55 | |
56 | int additionalInputCount = 0; |
57 | if (expectedInputCount.find(node->kind()) != expectedInputCount.end()) { |
58 | additionalInputCount = expectedInputCount.at(node->kind()) - |
59 | static_cast<int>(newNode->inputs().size()); |
60 | } |
61 | |
62 | for (int i = 0; i < additionalInputCount; ++i) { |
63 | auto noneNode = graph->createNone(); |
64 | noneNode->insertBefore(newNode); |
65 | newNode->addInput(noneNode->output()); |
66 | } |
67 | |
68 | // Create a new output node and replace all uses of self with it |
69 | newNode->output()->copyMetadata(node->output()); |
70 | node->replaceAllUsesWith(newNode); |
71 | node->inputs()[0]->replaceAllUsesAfterNodeWith( |
72 | newNode, newNode->output()); |
73 | node->destroy(); |
74 | } |
75 | } |
76 | } |
77 | } // namespace |
78 | |
79 | // Handles special case of binary inplace ops, where the first input node |
80 | // has a lower type precedence than the second input node. When the |
81 | // inplace node is converted to a regular op, this information is lost and |
82 | // the resulting type is based on type precedence, just like regular ops. |
83 | // To avoid this loss of information, we add a cast node before the input |
84 | // node with the higher data type precedence, so that both the input types |
85 | // are the same. |
86 | // An example scenario would be: |
87 | // Before: |
88 | // graph(%0 : Float), |
89 | // %1 : Half): |
90 | // # Should result in a Half, but after translation to out-of-place, |
91 | // # would become a Float b/c Half+Float -> Float. |
92 | // %4 : Float = onnx::Cast[to=1](%1) |
93 | // %5 : Float = onnx::Add(%4, %0) |
94 | // ... |
95 | // After: |
96 | // graph(%0 : Float), |
97 | // %1 : Half): |
98 | // %4 : Half = onnx::Cast[to=10](%0) |
99 | // %5 : Half = onnx::Add(%1, %4) |
100 | // ... |
101 | |
102 | void ImplicitCastForBinaryInplaceOps(Block* b) { |
103 | for (auto it = b->nodes().begin(), end = b->nodes().end(); it != end; ++it) { |
104 | for (auto* child_block : it->blocks()) { |
105 | ImplicitCastForBinaryInplaceOps(child_block); |
106 | } |
107 | |
108 | // Check type if inplace operation is a binary node |
109 | if ((it->kind() == aten::add_) || (it->kind() == aten::sub_) || |
110 | (it->kind() == aten::mul_) || (it->kind() == aten::div_)) { |
111 | auto originalInputs = it->inputs(); |
112 | if (originalInputs.at(0) == originalInputs.at(1)) { |
113 | continue; |
114 | } |
115 | |
116 | auto shape_node = originalInputs.at(0)->node(); |
117 | if ((shape_node->kind() == prim::NumToTensor) && |
118 | (shape_node->inputs().at(0)->node()->kind() == aten::size)) { |
119 | std::cerr |
120 | << "In-place op on output of tensor.shape. See https://pytorch.org/docs/master/onnx.html#" |
121 | << "avoid-inplace-operations-when-using-tensor-shape-in-tracing-mode" |
122 | << std::endl; |
123 | } |
124 | |
125 | TensorTypePtr firstInp_tensor = |
126 | originalInputs.at(0)->type()->cast<TensorType>(); |
127 | TensorTypePtr secondInp_tensor = |
128 | originalInputs.at(1)->type()->cast<TensorType>(); |
129 | if (!(firstInp_tensor) || !(secondInp_tensor) || |
130 | !(firstInp_tensor->scalarType().has_value())) { |
131 | continue; |
132 | } |
133 | auto newInputNode = it->owningGraph()->create(aten::type_as, 1); |
134 | newInputNode->insertBefore(*it); |
135 | newInputNode->addInput(originalInputs.at(1)); |
136 | newInputNode->addInput(originalInputs.at(0)); |
137 | it->replaceInput(1, newInputNode->outputs().at(0)); |
138 | } |
139 | } |
140 | } |
141 | |
142 | void RemoveInplaceOps(const std::shared_ptr<Graph>& graph) { |
143 | ImplicitCastForBinaryInplaceOps(graph->block()); |
144 | RemoveInplaceOps(graph->block()); |
145 | } |
146 | } // namespace jit |
147 | } // namespace torch |
148 | |