1#include <torch/csrc/jit/passes/remove_mutation.h>
2#include <torch/csrc/jit/passes/restore_mutation.h>
3
4namespace torch {
5namespace jit {
6
7bool MutationRemover::removeListMutation() {
8 return RemoveListMutation(graph_->block());
9}
10
11bool MutationRemover::removeTensorMutation() {
12 return RemoveTensorMutation(graph_->block());
13}
14
15bool MutationRemover::hasSideEffectOrAlias(Value* v, AliasDb* aliasDb) {
16 // bail on nodes with side effects, blocks, or graph / graph inputs
17 Node* n = v->node();
18 bool unhandled_node = !n->blocks().empty() ||
19 n->hasAttribute(attr::Subgraph) || n->hasSideEffects() ||
20 (v->node()->kind() == prim::Param);
21
22 // if the output isn't contained or alias by the inputs to its node, it's
23 // unique. No need to check for alias if the node is a ListConstruct.
24 bool mayAliasInputs = (v->node()->kind() != prim::ListConstruct) &&
25 aliasDb->mayContainAlias(v->node()->inputs(), v);
26 return unhandled_node || mayAliasInputs || (v->node()->kind() == prim::Param);
27}
28
29Node* MutationRemover::createSpecialMappedOp(Node* n) {
30 WithInsertPoint guard(n);
31 auto inputs = n->inputs();
32 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
33 Node* new_node;
34 if (n->matches(
35 "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)")) {
36 auto dtype = graph_->insert(prim::dtype, {inputs.at(0)});
37 new_node = graph_
38 ->insert(
39 aten::full_like,
40 {inputs.at(0), inputs.at(1)},
41 {NamedValue("dtype", dtype)})
42 ->node();
43 new_node->copyMetadata(n);
44 new_node->output()->setType(n->output()->type());
45 } else if (n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)")) {
46 new_node = graph_->insert(aten::zeros_like, {n->inputs().at(0)})->node();
47 } else if (
48 n->matches(
49 "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)")) {
50 // TODO: we should have normal_like operator
51 // normal(float mean, float std, int[] size, *, Generator? generator=None,
52 // ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool?
53 // pin_memory=None) -> Tensor
54 auto size = graph_->insert(aten::size, {n->inputs().at(0)});
55 auto dtype = graph_->insert(prim::dtype, {n->inputs().at(0)});
56 auto layout = graph_->insert(prim::layout, {n->inputs().at(0)});
57 auto device = graph_->insert(prim::device, {n->inputs().at(0)});
58 auto pin_memory = graph_->insert(aten::is_pinned, {n->inputs().at(0)});
59 auto generator = graph_->insertConstant(IValue());
60 new_node = graph_->insertNode(graph_->create(
61 aten::normal,
62 {n->inputs().at(1),
63 n->inputs().at(2),
64 size,
65 generator,
66 dtype,
67 layout,
68 device,
69 pin_memory}));
70 } else {
71 TORCH_INTERNAL_ASSERT(false);
72 }
73 new_node->copyMetadata(n);
74 new_node->output()->setType(n->output()->type());
75 return new_node;
76}
77
78bool removableSetItem(Node* n) {
79 if (n->kind() != aten::_set_item ||
80 n->input(1)->node()->kind() != prim::Constant) {
81 return false;
82 }
83 if (n->inputs().at(0)->node()->kind() != prim::ListConstruct) {
84 return false;
85 }
86 auto li_node = n->inputs().at(0)->node();
87 int64_t index = *constant_as<int64_t>(n->input(1));
88 if (index < 0) {
89 index += li_node->inputs().size();
90 }
91 auto li_len = static_cast<int64_t>(li_node->inputs().size());
92 return index < li_len && index >= 0;
93}
94
95bool MutationRemover::listMutationFollowingListConstruct(Node* n) {
96 return (
97 (n->kind() == aten::append ||
98 (n->kind() == aten::insert &&
99 n->inputs().at(1)->node()->kind() == prim::Constant) ||
100 (removableSetItem(n))) &&
101 n->inputs().at(0)->node()->kind() == prim::ListConstruct);
102}
103
104bool MutationRemover::tryMakeCreationAndMutationAtomic(
105 Value* mutated_value,
106 Node* mutating_op) {
107 // We can only remove mutation to values that are unique aliases in the
108 // graph. if x = y[0] or y = self.y, then removing the mutation could
109 // change observable semantics
110 if (hasSideEffectOrAlias(mutated_value, getOrCreateAliasDb())) {
111 return false;
112 }
113
114 // In order to safely remove a mutation, the creation of a tensor and its
115 // subsequent mutation need to be one atomic operation
116 return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
117 mutated_value->node(), mutating_op);
118}
119
120bool MutationRemover::tryMakeUnaliasedIfOutputAndMutationAtomic(
121 Value* mutated_value,
122 Node* mutating_op) {
123 // if cond:
124 // x = op()
125 // else:
126 // x = op()
127 // x = add_(1)
128 // if x in both blocks have no other uses and are unaliased in the graph,
129 // and we make the if node and the mutation atomic,
130 // then removing mutation add_ does not change observable semantics
131
132 if (mutated_value->node()->kind() != prim::If) {
133 return false;
134 }
135
136 auto if_node = mutated_value->node();
137 auto offset = mutated_value->offset();
138 auto true_value = if_node->blocks().at(0)->outputs().at(offset);
139 auto false_value = if_node->blocks().at(1)->outputs().at(offset);
140
141 if (true_value->uses().size() > 1 || false_value->uses().size() > 1) {
142 return false;
143 }
144
145 if (hasSideEffectOrAlias(true_value, getOrCreateAliasDb()) ||
146 hasSideEffectOrAlias(false_value, getOrCreateAliasDb())) {
147 return false;
148 }
149
150 return getOrCreateAliasDb()->moveBeforeTopologicallyValid(
151 if_node, mutating_op);
152}
153
154bool MutationRemover::RemoveListMutation(Block* block) {
155 bool changed = false;
156 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
157 auto* node = *it;
158 it++;
159
160 for (Block* sub_block : node->blocks()) {
161 changed |= RemoveListMutation(sub_block);
162 }
163
164 if (!listMutationFollowingListConstruct(node)) {
165 continue;
166 }
167
168 Value* mutated_value = node->inputs().at(0);
169 if (!tryMakeCreationAndMutationAtomic(mutated_value, node)) {
170 continue;
171 }
172
173 changed = true;
174
175 // We rewrite something like:
176 // x = {v0}
177 // x.append(v1) (or x.insert(0, v1))
178 // to:
179 // x = {v0, v1} (or x = {v1, v0})
180 // We can remove x.append from the the alias db list of writes.
181 // All other aliasing properties remain valid.
182 Node* list_construct = mutated_value->node();
183 switch (node->kind()) {
184 case aten::append:
185 list_construct->addInput(node->inputs().at(1));
186 break;
187 case aten::insert: {
188 int pos = toIValue(node->inputs().at(1))->toInt();
189 int size = list_construct->inputs().size();
190 // insert to neg position equals insert to std::max(pos+size, 0)
191 if (pos < 0) {
192 pos = std::max(pos + size, 0);
193 }
194 // insert beyond current list length is the same as append
195 pos = std::min(pos, size);
196 list_construct->insertInput(pos, node->inputs().at(2));
197 break;
198 }
199 case aten::_set_item: {
200 int pos = toIValue(node->inputs().at(1))->toInt();
201 int size = list_construct->inputs().size();
202 if (pos < 0) {
203 pos = std::max(pos + size, 0);
204 }
205 list_construct->replaceInput(pos, node->input(2));
206 break;
207 }
208 default:
209 TORCH_INTERNAL_ASSERT(false);
210 }
211
212 // process use-chain and aliasing of node output
213 bool has_output = (!node->outputs().empty());
214 if (has_output) {
215 node->output()->replaceAllUsesWith(mutated_value);
216 getOrCreateAliasDb()->writeIndex_->erase(node);
217 }
218
219 node->destroy();
220
221 // TODO: don't strictly need to reset write cache, evaluate on models
222 getOrCreateAliasDb()->buildWrittenToLocationsIndex();
223 }
224
225 return changed;
226}
227
228bool MutationRemover::RemoveTensorMutation(Block* block) {
229 bool changed = false;
230 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
231 auto* node = *it;
232 it++;
233
234 for (Block* sub_block : node->blocks()) {
235 changed |= RemoveTensorMutation(sub_block);
236 }
237
238 if (mutation_filter_) {
239 const auto& mutation_filter = *mutation_filter_;
240 if (!mutation_filter(node)) {
241 continue;
242 }
243 }
244
245 // TODO: out op variants
246 if (!inplaceOpVariant(node)) {
247 continue;
248 }
249
250 Value* mutated_value = node->inputs().at(0);
251 if (!tryMakeCreationAndMutationAtomic(mutated_value, node) &&
252 !tryMakeUnaliasedIfOutputAndMutationAtomic(mutated_value, node)) {
253 continue;
254 }
255
256 // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
257 Node* new_node;
258 if (isSpecialMappedOp(node)) {
259 new_node = createSpecialMappedOp(node);
260 } else {
261 auto schema_name = node->schema().name();
262 auto new_schema = schema_name.substr(0, schema_name.size() - 1);
263 new_node = graph_->create(Symbol::fromQualString(new_schema), 1);
264 new_node->copyMetadata(node);
265 new_node->insertBefore(node);
266 for (Value* input : node->inputs()) {
267 new_node->addInput(input);
268 }
269 new_node->output()->setType(node->output()->type());
270
271 // weird case where there is an inplace op and an equivalent functional op
272 // of the same symbol, but they have different schemas
273 if (!new_node->maybeOperator()) {
274 new_node->destroy();
275 continue;
276 }
277 }
278
279 changed = true;
280 mutated_value->replaceAllUsesAfterNodeWith(node, new_node->output());
281 node->output()->replaceAllUsesWith(new_node->output());
282
283 // We rewrite something like:
284 // x = torch.zeros()
285 // x.add_(1)
286 // x.add_(2)
287 // to:
288 // x = torch.zeros()
289 // x0 = x.add(1)
290 // x0.add_(2)
291 // For the remainder of the function, x0 will have the
292 // same aliasing relationships as the original x.
293 // To avoid rebuilding the entire alias db, we can replace
294 // the memory DAG element of x with x0.
295 getOrCreateAliasDb()->replaceWithNewValue(
296 mutated_value, new_node->output());
297
298 // it is an invariant that all mutable types have an element in the memory
299 // DAG so we must regive x an alias db element. We have already verified
300 // that the mutated value is a fresh alias with a single use.
301 getOrCreateAliasDb()->createValue(mutated_value);
302
303 // We must erase the destroyed node from the AliasDb lists of writes
304 getOrCreateAliasDb()->writeIndex_->erase(node);
305 node->destroy();
306
307 // now that we have removed a mutating op, the write cache is stale
308 // TODO: don't strictly need to reset write cache, evaluate on models
309 getOrCreateAliasDb()->buildWrittenToLocationsIndex();
310 }
311
312 return changed;
313}
314
315bool MutationRemover::inplaceOpVariant(Node* n) {
316 if (!n->kind().is_aten()) {
317 return false;
318 }
319
320 if (isSpecialMappedOp(n)) {
321 return true;
322 }
323
324 auto name = n->schema().name();
325 bool inplace_op = name.at(name.size() - 1) == '_';
326 if (!inplace_op) {
327 return false;
328 }
329
330 // needs to have alias analysis by schema
331 auto op = n->maybeOperator();
332 if (!op) {
333 return false;
334 }
335 if (op->aliasAnalysisKind() != AliasAnalysisKind::FROM_SCHEMA) {
336 return false;
337 }
338
339 // all inplace ops at time of writing have a single input that is mutated
340 // and returned. check that this is true, anything else could have strange
341 // semantics,
342 if (n->outputs().size() != 1 || n->inputs().empty()) {
343 return false;
344 }
345 auto inputs = n->inputs();
346 if (!getOrCreateAliasDb()->writesToAlias(n, {inputs.at(0)}) ||
347 getOrCreateAliasDb()->writesToAlias(
348 n, {inputs.slice(1).begin(), inputs.slice(1).end()})) {
349 return false;
350 }
351
352 auto new_schema = name.substr(0, name.size() - 1);
353 return !getAllOperatorsFor(Symbol::fromQualString(new_schema)).empty();
354}
355
356bool RemoveListMutation(const std::shared_ptr<Graph>& graph) {
357 MutationRemover mr(graph);
358 return mr.removeListMutation();
359}
360
361bool RemoveTensorMutation(
362 const std::shared_ptr<Graph>& graph,
363 c10::optional<std::function<bool(Node*)>> mutation_filter) {
364 MutationRemover mr(graph, std::move(mutation_filter));
365 return mr.removeTensorMutation();
366}
367
368static const std::unordered_set<Symbol> activation_ops = []() {
369 std::unordered_set<Symbol> target_ops;
370 for (const auto& iter : activation_type_promotion_mapping) {
371 std::string name = std::string(iter.first.toQualString()) + "_";
372 target_ops.insert(Symbol::fromQualString(name));
373 }
374 return target_ops;
375}();
376
377bool InplaceToFunctionalActivation(const std::shared_ptr<Graph>& graph) {
378 return RemoveTensorMutation(graph, [](Node* node) {
379 return activation_ops.count(node->kind()) != 0;
380 });
381}
382
383} // namespace jit
384} // namespace torch
385