1#include <ATen/core/jit_type.h>
2#include <ATen/core/symbol.h>
3#include <torch/csrc/jit/passes/remove_mutation.h>
4#include <torch/csrc/jit/passes/restore_mutation.h>
5
6namespace torch {
7namespace jit {
8
9FunctionalToInplaceRewriter::FunctionalToInplaceRewriter(
10 std::shared_ptr<Graph> graph)
11 : aliasDb_(nullptr), graph_(std::move(graph)) {}
12
13bool FunctionalToInplaceRewriter::CanBeInplace(Node* node) {
14 if (activation_type_promotion_mapping.find(node->kind()) ==
15 activation_type_promotion_mapping.end()) {
16 return false;
17 }
18
19 Symbol inplace_op =
20 Symbol::fromQualString(std::string(node->kind().toQualString()) + "_");
21 if (!inplace_op) {
22 return false;
23 }
24
25 // If type promotion is allowed, then perform dtype check
26 bool check_dtype = activation_type_promotion_mapping.at(node->kind());
27
28 Value* input = node->inputs().at(0);
29 Value* output = node->outputs().at(0);
30 auto inputDtype = input->type()->expect<TensorType>()->scalarType();
31 auto outputDtype = output->type()->expect<TensorType>()->scalarType();
32
33 // In general, we don't need to check shape for activation ops as they
34 // element-wise. But for those where type promotion could happen, we need to
35 // make sure the dtype of input and output are the same. For now the dtype
36 // checking will always fail until the type inference is ready.
37 if (check_dtype &&
38 (!inputDtype || !outputDtype ||
39 inputDtype.value() != outputDtype.value())) {
40 return false;
41 }
42
43 // Skip if input's def node has side effect or input has alias
44 if (MutationRemover::hasSideEffectOrAlias(input, getOrCreateAliasDb())) {
45 return false;
46 }
47
48 // If x has more than one use, skip the converson.
49 // TODO: Use liveness analysis to catch more general scenario
50 return (input->uses().size() == 1);
51}
52
53bool FunctionalToInplaceRewriter::FunctionalToInplace(Block* block) {
54 bool changed = false;
55 for (auto it = block->nodes().begin(); it != block->nodes().end();) {
56 auto* node = *it;
57 it++;
58
59 for (Block* sub_block : node->blocks()) {
60 changed |= FunctionalToInplace(sub_block);
61 }
62
63 if (!CanBeInplace(node)) {
64 continue;
65 }
66
67 changed = true;
68 Node* inplace_node = node->replaceWithNewSymbol(
69 Symbol::fromQualString(node->schema().name() + "_"));
70 inplace_node->output()->replaceAllUsesWith(node->inputs().at(0));
71 getOrCreateAliasDb()->replaceWithNewValue(
72 node->output(), inplace_node->output());
73
74 node->destroy();
75 }
76 return changed;
77}
78
79bool FunctionalToInplaceActivation(const std::shared_ptr<Graph>& graph) {
80 FunctionalToInplaceRewriter rewriter(graph);
81 return rewriter.FunctionalToInplace(graph->block());
82}
83
84} // namespace jit
85} // namespace torch
86