1 | #include <torch/csrc/jit/passes/remove_mutation.h> |
2 | #include <torch/csrc/jit/passes/restore_mutation.h> |
3 | |
4 | namespace torch { |
5 | namespace jit { |
6 | |
7 | bool MutationRemover::removeListMutation() { |
8 | return RemoveListMutation(graph_->block()); |
9 | } |
10 | |
11 | bool MutationRemover::removeTensorMutation() { |
12 | return RemoveTensorMutation(graph_->block()); |
13 | } |
14 | |
15 | bool MutationRemover::hasSideEffectOrAlias(Value* v, AliasDb* aliasDb) { |
16 | // bail on nodes with side effects, blocks, or graph / graph inputs |
17 | Node* n = v->node(); |
18 | bool unhandled_node = !n->blocks().empty() || |
19 | n->hasAttribute(attr::Subgraph) || n->hasSideEffects() || |
20 | (v->node()->kind() == prim::Param); |
21 | |
22 | // if the output isn't contained or alias by the inputs to its node, it's |
23 | // unique. No need to check for alias if the node is a ListConstruct. |
24 | bool mayAliasInputs = (v->node()->kind() != prim::ListConstruct) && |
25 | aliasDb->mayContainAlias(v->node()->inputs(), v); |
26 | return unhandled_node || mayAliasInputs || (v->node()->kind() == prim::Param); |
27 | } |
28 | |
29 | Node* MutationRemover::createSpecialMappedOp(Node* n) { |
30 | WithInsertPoint guard(n); |
31 | auto inputs = n->inputs(); |
32 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
33 | Node* new_node; |
34 | if (n->matches( |
35 | "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)" )) { |
36 | auto dtype = graph_->insert(prim::dtype, {inputs.at(0)}); |
37 | new_node = graph_ |
38 | ->insert( |
39 | aten::full_like, |
40 | {inputs.at(0), inputs.at(1)}, |
41 | {NamedValue("dtype" , dtype)}) |
42 | ->node(); |
43 | new_node->copyMetadata(n); |
44 | new_node->output()->setType(n->output()->type()); |
45 | } else if (n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)" )) { |
46 | new_node = graph_->insert(aten::zeros_like, {n->inputs().at(0)})->node(); |
47 | } else if ( |
48 | n->matches( |
49 | "aten::normal_(Tensor(a!) self, float mean=0, float std=1, *, Generator? generator=None) -> Tensor(a!)" )) { |
50 | // TODO: we should have normal_like operator |
51 | // normal(float mean, float std, int[] size, *, Generator? generator=None, |
52 | // ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? |
53 | // pin_memory=None) -> Tensor |
54 | auto size = graph_->insert(aten::size, {n->inputs().at(0)}); |
55 | auto dtype = graph_->insert(prim::dtype, {n->inputs().at(0)}); |
56 | auto layout = graph_->insert(prim::layout, {n->inputs().at(0)}); |
57 | auto device = graph_->insert(prim::device, {n->inputs().at(0)}); |
58 | auto pin_memory = graph_->insert(aten::is_pinned, {n->inputs().at(0)}); |
59 | auto generator = graph_->insertConstant(IValue()); |
60 | new_node = graph_->insertNode(graph_->create( |
61 | aten::normal, |
62 | {n->inputs().at(1), |
63 | n->inputs().at(2), |
64 | size, |
65 | generator, |
66 | dtype, |
67 | layout, |
68 | device, |
69 | pin_memory})); |
70 | } else { |
71 | TORCH_INTERNAL_ASSERT(false); |
72 | } |
73 | new_node->copyMetadata(n); |
74 | new_node->output()->setType(n->output()->type()); |
75 | return new_node; |
76 | } |
77 | |
78 | bool removableSetItem(Node* n) { |
79 | if (n->kind() != aten::_set_item || |
80 | n->input(1)->node()->kind() != prim::Constant) { |
81 | return false; |
82 | } |
83 | if (n->inputs().at(0)->node()->kind() != prim::ListConstruct) { |
84 | return false; |
85 | } |
86 | auto li_node = n->inputs().at(0)->node(); |
87 | int64_t index = *constant_as<int64_t>(n->input(1)); |
88 | if (index < 0) { |
89 | index += li_node->inputs().size(); |
90 | } |
91 | auto li_len = static_cast<int64_t>(li_node->inputs().size()); |
92 | return index < li_len && index >= 0; |
93 | } |
94 | |
95 | bool MutationRemover::listMutationFollowingListConstruct(Node* n) { |
96 | return ( |
97 | (n->kind() == aten::append || |
98 | (n->kind() == aten::insert && |
99 | n->inputs().at(1)->node()->kind() == prim::Constant) || |
100 | (removableSetItem(n))) && |
101 | n->inputs().at(0)->node()->kind() == prim::ListConstruct); |
102 | } |
103 | |
104 | bool MutationRemover::tryMakeCreationAndMutationAtomic( |
105 | Value* mutated_value, |
106 | Node* mutating_op) { |
107 | // We can only remove mutation to values that are unique aliases in the |
108 | // graph. if x = y[0] or y = self.y, then removing the mutation could |
109 | // change observable semantics |
110 | if (hasSideEffectOrAlias(mutated_value, getOrCreateAliasDb())) { |
111 | return false; |
112 | } |
113 | |
114 | // In order to safely remove a mutation, the creation of a tensor and its |
115 | // subsequent mutation need to be one atomic operation |
116 | return getOrCreateAliasDb()->moveBeforeTopologicallyValid( |
117 | mutated_value->node(), mutating_op); |
118 | } |
119 | |
120 | bool MutationRemover::tryMakeUnaliasedIfOutputAndMutationAtomic( |
121 | Value* mutated_value, |
122 | Node* mutating_op) { |
123 | // if cond: |
124 | // x = op() |
125 | // else: |
126 | // x = op() |
127 | // x = add_(1) |
128 | // if x in both blocks have no other uses and are unaliased in the graph, |
129 | // and we make the if node and the mutation atomic, |
130 | // then removing mutation add_ does not change observable semantics |
131 | |
132 | if (mutated_value->node()->kind() != prim::If) { |
133 | return false; |
134 | } |
135 | |
136 | auto if_node = mutated_value->node(); |
137 | auto offset = mutated_value->offset(); |
138 | auto true_value = if_node->blocks().at(0)->outputs().at(offset); |
139 | auto false_value = if_node->blocks().at(1)->outputs().at(offset); |
140 | |
141 | if (true_value->uses().size() > 1 || false_value->uses().size() > 1) { |
142 | return false; |
143 | } |
144 | |
145 | if (hasSideEffectOrAlias(true_value, getOrCreateAliasDb()) || |
146 | hasSideEffectOrAlias(false_value, getOrCreateAliasDb())) { |
147 | return false; |
148 | } |
149 | |
150 | return getOrCreateAliasDb()->moveBeforeTopologicallyValid( |
151 | if_node, mutating_op); |
152 | } |
153 | |
154 | bool MutationRemover::RemoveListMutation(Block* block) { |
155 | bool changed = false; |
156 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
157 | auto* node = *it; |
158 | it++; |
159 | |
160 | for (Block* sub_block : node->blocks()) { |
161 | changed |= RemoveListMutation(sub_block); |
162 | } |
163 | |
164 | if (!listMutationFollowingListConstruct(node)) { |
165 | continue; |
166 | } |
167 | |
168 | Value* mutated_value = node->inputs().at(0); |
169 | if (!tryMakeCreationAndMutationAtomic(mutated_value, node)) { |
170 | continue; |
171 | } |
172 | |
173 | changed = true; |
174 | |
175 | // We rewrite something like: |
176 | // x = {v0} |
177 | // x.append(v1) (or x.insert(0, v1)) |
178 | // to: |
179 | // x = {v0, v1} (or x = {v1, v0}) |
180 | // We can remove x.append from the the alias db list of writes. |
181 | // All other aliasing properties remain valid. |
182 | Node* list_construct = mutated_value->node(); |
183 | switch (node->kind()) { |
184 | case aten::append: |
185 | list_construct->addInput(node->inputs().at(1)); |
186 | break; |
187 | case aten::insert: { |
188 | int pos = toIValue(node->inputs().at(1))->toInt(); |
189 | int size = list_construct->inputs().size(); |
190 | // insert to neg position equals insert to std::max(pos+size, 0) |
191 | if (pos < 0) { |
192 | pos = std::max(pos + size, 0); |
193 | } |
194 | // insert beyond current list length is the same as append |
195 | pos = std::min(pos, size); |
196 | list_construct->insertInput(pos, node->inputs().at(2)); |
197 | break; |
198 | } |
199 | case aten::_set_item: { |
200 | int pos = toIValue(node->inputs().at(1))->toInt(); |
201 | int size = list_construct->inputs().size(); |
202 | if (pos < 0) { |
203 | pos = std::max(pos + size, 0); |
204 | } |
205 | list_construct->replaceInput(pos, node->input(2)); |
206 | break; |
207 | } |
208 | default: |
209 | TORCH_INTERNAL_ASSERT(false); |
210 | } |
211 | |
212 | // process use-chain and aliasing of node output |
213 | bool has_output = (!node->outputs().empty()); |
214 | if (has_output) { |
215 | node->output()->replaceAllUsesWith(mutated_value); |
216 | getOrCreateAliasDb()->writeIndex_->erase(node); |
217 | } |
218 | |
219 | node->destroy(); |
220 | |
221 | // TODO: don't strictly need to reset write cache, evaluate on models |
222 | getOrCreateAliasDb()->buildWrittenToLocationsIndex(); |
223 | } |
224 | |
225 | return changed; |
226 | } |
227 | |
228 | bool MutationRemover::RemoveTensorMutation(Block* block) { |
229 | bool changed = false; |
230 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
231 | auto* node = *it; |
232 | it++; |
233 | |
234 | for (Block* sub_block : node->blocks()) { |
235 | changed |= RemoveTensorMutation(sub_block); |
236 | } |
237 | |
238 | if (mutation_filter_) { |
239 | const auto& mutation_filter = *mutation_filter_; |
240 | if (!mutation_filter(node)) { |
241 | continue; |
242 | } |
243 | } |
244 | |
245 | // TODO: out op variants |
246 | if (!inplaceOpVariant(node)) { |
247 | continue; |
248 | } |
249 | |
250 | Value* mutated_value = node->inputs().at(0); |
251 | if (!tryMakeCreationAndMutationAtomic(mutated_value, node) && |
252 | !tryMakeUnaliasedIfOutputAndMutationAtomic(mutated_value, node)) { |
253 | continue; |
254 | } |
255 | |
256 | // NOLINTNEXTLINE(cppcoreguidelines-init-variables) |
257 | Node* new_node; |
258 | if (isSpecialMappedOp(node)) { |
259 | new_node = createSpecialMappedOp(node); |
260 | } else { |
261 | auto schema_name = node->schema().name(); |
262 | auto new_schema = schema_name.substr(0, schema_name.size() - 1); |
263 | new_node = graph_->create(Symbol::fromQualString(new_schema), 1); |
264 | new_node->copyMetadata(node); |
265 | new_node->insertBefore(node); |
266 | for (Value* input : node->inputs()) { |
267 | new_node->addInput(input); |
268 | } |
269 | new_node->output()->setType(node->output()->type()); |
270 | |
271 | // weird case where there is an inplace op and an equivalent functional op |
272 | // of the same symbol, but they have different schemas |
273 | if (!new_node->maybeOperator()) { |
274 | new_node->destroy(); |
275 | continue; |
276 | } |
277 | } |
278 | |
279 | changed = true; |
280 | mutated_value->replaceAllUsesAfterNodeWith(node, new_node->output()); |
281 | node->output()->replaceAllUsesWith(new_node->output()); |
282 | |
283 | // We rewrite something like: |
284 | // x = torch.zeros() |
285 | // x.add_(1) |
286 | // x.add_(2) |
287 | // to: |
288 | // x = torch.zeros() |
289 | // x0 = x.add(1) |
290 | // x0.add_(2) |
291 | // For the remainder of the function, x0 will have the |
292 | // same aliasing relationships as the original x. |
293 | // To avoid rebuilding the entire alias db, we can replace |
294 | // the memory DAG element of x with x0. |
295 | getOrCreateAliasDb()->replaceWithNewValue( |
296 | mutated_value, new_node->output()); |
297 | |
298 | // it is an invariant that all mutable types have an element in the memory |
299 | // DAG so we must regive x an alias db element. We have already verified |
300 | // that the mutated value is a fresh alias with a single use. |
301 | getOrCreateAliasDb()->createValue(mutated_value); |
302 | |
303 | // We must erase the destroyed node from the AliasDb lists of writes |
304 | getOrCreateAliasDb()->writeIndex_->erase(node); |
305 | node->destroy(); |
306 | |
307 | // now that we have removed a mutating op, the write cache is stale |
308 | // TODO: don't strictly need to reset write cache, evaluate on models |
309 | getOrCreateAliasDb()->buildWrittenToLocationsIndex(); |
310 | } |
311 | |
312 | return changed; |
313 | } |
314 | |
315 | bool MutationRemover::inplaceOpVariant(Node* n) { |
316 | if (!n->kind().is_aten()) { |
317 | return false; |
318 | } |
319 | |
320 | if (isSpecialMappedOp(n)) { |
321 | return true; |
322 | } |
323 | |
324 | auto name = n->schema().name(); |
325 | bool inplace_op = name.at(name.size() - 1) == '_'; |
326 | if (!inplace_op) { |
327 | return false; |
328 | } |
329 | |
330 | // needs to have alias analysis by schema |
331 | auto op = n->maybeOperator(); |
332 | if (!op) { |
333 | return false; |
334 | } |
335 | if (op->aliasAnalysisKind() != AliasAnalysisKind::FROM_SCHEMA) { |
336 | return false; |
337 | } |
338 | |
339 | // all inplace ops at time of writing have a single input that is mutated |
340 | // and returned. check that this is true, anything else could have strange |
341 | // semantics, |
342 | if (n->outputs().size() != 1 || n->inputs().empty()) { |
343 | return false; |
344 | } |
345 | auto inputs = n->inputs(); |
346 | if (!getOrCreateAliasDb()->writesToAlias(n, {inputs.at(0)}) || |
347 | getOrCreateAliasDb()->writesToAlias( |
348 | n, {inputs.slice(1).begin(), inputs.slice(1).end()})) { |
349 | return false; |
350 | } |
351 | |
352 | auto new_schema = name.substr(0, name.size() - 1); |
353 | return !getAllOperatorsFor(Symbol::fromQualString(new_schema)).empty(); |
354 | } |
355 | |
356 | bool RemoveListMutation(const std::shared_ptr<Graph>& graph) { |
357 | MutationRemover mr(graph); |
358 | return mr.removeListMutation(); |
359 | } |
360 | |
361 | bool RemoveTensorMutation( |
362 | const std::shared_ptr<Graph>& graph, |
363 | c10::optional<std::function<bool(Node*)>> mutation_filter) { |
364 | MutationRemover mr(graph, std::move(mutation_filter)); |
365 | return mr.removeTensorMutation(); |
366 | } |
367 | |
368 | static const std::unordered_set<Symbol> activation_ops = []() { |
369 | std::unordered_set<Symbol> target_ops; |
370 | for (const auto& iter : activation_type_promotion_mapping) { |
371 | std::string name = std::string(iter.first.toQualString()) + "_" ; |
372 | target_ops.insert(Symbol::fromQualString(name)); |
373 | } |
374 | return target_ops; |
375 | }(); |
376 | |
377 | bool InplaceToFunctionalActivation(const std::shared_ptr<Graph>& graph) { |
378 | return RemoveTensorMutation(graph, [](Node* node) { |
379 | return activation_ops.count(node->kind()) != 0; |
380 | }); |
381 | } |
382 | |
383 | } // namespace jit |
384 | } // namespace torch |
385 | |