1 | #pragma once |
2 | |
3 | #include <ATen/core/symbol.h> |
4 | #include <c10/util/Exception.h> |
5 | #include <torch/csrc/Export.h> |
6 | #include <torch/csrc/jit/ir/alias_analysis.h> |
7 | #include <torch/csrc/jit/ir/ir.h> |
8 | #include <torch/csrc/utils/memory.h> |
9 | |
10 | namespace torch { |
11 | namespace jit { |
12 | |
13 | // A map which stores if an activation operator can perform type promotion |
14 | const std::unordered_map<Symbol, bool> activation_type_promotion_mapping = { |
15 | {aten::sigmoid, true}, |
16 | {aten::tanh, true}, |
17 | {aten::celu, false}, |
18 | {aten::elu, false}, |
19 | {aten::gelu, false}, |
20 | {aten::glu, false}, |
21 | {aten::hardshrink, false}, |
22 | {aten::hardsigmoid, false}, |
23 | {aten::hardswish, false}, |
24 | {aten::hardtanh, false}, |
25 | {aten::leaky_relu, false}, |
26 | {aten::prelu, false}, |
27 | {aten::relu6, false}, |
28 | {aten::relu, false}, |
29 | {aten::rrelu, false}, |
30 | {aten::selu, false}, |
31 | {aten::silu, false}}; |
32 | |
33 | class FunctionalToInplaceRewriter { |
34 | public: |
35 | FunctionalToInplaceRewriter(std::shared_ptr<Graph> graph); |
36 | |
37 | bool FunctionalToInplace(Block* block); |
38 | |
39 | private: |
40 | AliasDb* getOrCreateAliasDb() { |
41 | if (!aliasDb_) { |
42 | aliasDb_ = std::make_unique<AliasDb>(graph_); |
43 | } |
44 | return aliasDb_.get(); |
45 | } |
46 | |
47 | bool CanBeInplace(Node* node); |
48 | |
49 | std::unique_ptr<AliasDb> aliasDb_ = nullptr; |
50 | std::shared_ptr<Graph> graph_; |
51 | }; |
52 | |
53 | // A common application scenario is to apply InplaceToFunctionalActivation |
54 | // before some JIT optimization passes, so that those passes are less |
55 | // constrained by in-place ops. After those passes are done, we can call |
56 | // FunctionalToInplaceActivation to recover in-place activation ops, |
57 | // so that we won't lose the performance benefit coming from memory reduction. |
58 | |
59 | // Replaces functional aten activation ops with their in-place equivalents |
60 | TORCH_API bool FunctionalToInplaceActivation( |
61 | const std::shared_ptr<Graph>& graph); |
62 | |
63 | } // namespace jit |
64 | } // namespace torch |
65 | |