1 | #include <ATen/core/jit_type.h> |
---|---|
2 | #include <ATen/core/symbol.h> |
3 | #include <torch/csrc/jit/passes/remove_mutation.h> |
4 | #include <torch/csrc/jit/passes/restore_mutation.h> |
5 | |
6 | namespace torch { |
7 | namespace jit { |
8 | |
9 | FunctionalToInplaceRewriter::FunctionalToInplaceRewriter( |
10 | std::shared_ptr<Graph> graph) |
11 | : aliasDb_(nullptr), graph_(std::move(graph)) {} |
12 | |
13 | bool FunctionalToInplaceRewriter::CanBeInplace(Node* node) { |
14 | if (activation_type_promotion_mapping.find(node->kind()) == |
15 | activation_type_promotion_mapping.end()) { |
16 | return false; |
17 | } |
18 | |
19 | Symbol inplace_op = |
20 | Symbol::fromQualString(std::string(node->kind().toQualString()) + "_"); |
21 | if (!inplace_op) { |
22 | return false; |
23 | } |
24 | |
25 | // If type promotion is allowed, then perform dtype check |
26 | bool check_dtype = activation_type_promotion_mapping.at(node->kind()); |
27 | |
28 | Value* input = node->inputs().at(0); |
29 | Value* output = node->outputs().at(0); |
30 | auto inputDtype = input->type()->expect<TensorType>()->scalarType(); |
31 | auto outputDtype = output->type()->expect<TensorType>()->scalarType(); |
32 | |
33 | // In general, we don't need to check shape for activation ops as they |
34 | // element-wise. But for those where type promotion could happen, we need to |
35 | // make sure the dtype of input and output are the same. For now the dtype |
36 | // checking will always fail until the type inference is ready. |
37 | if (check_dtype && |
38 | (!inputDtype || !outputDtype || |
39 | inputDtype.value() != outputDtype.value())) { |
40 | return false; |
41 | } |
42 | |
43 | // Skip if input's def node has side effect or input has alias |
44 | if (MutationRemover::hasSideEffectOrAlias(input, getOrCreateAliasDb())) { |
45 | return false; |
46 | } |
47 | |
48 | // If x has more than one use, skip the converson. |
49 | // TODO: Use liveness analysis to catch more general scenario |
50 | return (input->uses().size() == 1); |
51 | } |
52 | |
53 | bool FunctionalToInplaceRewriter::FunctionalToInplace(Block* block) { |
54 | bool changed = false; |
55 | for (auto it = block->nodes().begin(); it != block->nodes().end();) { |
56 | auto* node = *it; |
57 | it++; |
58 | |
59 | for (Block* sub_block : node->blocks()) { |
60 | changed |= FunctionalToInplace(sub_block); |
61 | } |
62 | |
63 | if (!CanBeInplace(node)) { |
64 | continue; |
65 | } |
66 | |
67 | changed = true; |
68 | Node* inplace_node = node->replaceWithNewSymbol( |
69 | Symbol::fromQualString(node->schema().name() + "_")); |
70 | inplace_node->output()->replaceAllUsesWith(node->inputs().at(0)); |
71 | getOrCreateAliasDb()->replaceWithNewValue( |
72 | node->output(), inplace_node->output()); |
73 | |
74 | node->destroy(); |
75 | } |
76 | return changed; |
77 | } |
78 | |
79 | bool FunctionalToInplaceActivation(const std::shared_ptr<Graph>& graph) { |
80 | FunctionalToInplaceRewriter rewriter(graph); |
81 | return rewriter.FunctionalToInplace(graph->block()); |
82 | } |
83 | |
84 | } // namespace jit |
85 | } // namespace torch |
86 |