1 | /** This file defines API for pattern-based subgraph rewrites. |
2 | * |
3 | * The API can be used for finding concrete patterns in the model and replacing |
4 | * the corresponding subgraphs with another subgraph. A special case of such |
5 | * rewrites is fusion, where the new subgraph consists of just a single node. |
6 | * |
7 | * There is a default set of the most common patterns that everyone could use. |
8 | * Alternatively, an arbitrary pattern can be registered. |
9 | */ |
10 | #pragma once |
11 | |
12 | #include <torch/csrc/jit/api/module.h> |
13 | #include <torch/csrc/jit/ir/ir.h> |
14 | |
15 | #include <functional> |
16 | #include <unordered_set> |
17 | #include <vector> |
18 | |
19 | namespace torch { |
20 | namespace jit { |
21 | |
22 | // Forward declarations. |
23 | struct RewritePatternDescr; |
24 | struct Match; |
25 | |
26 | using MatchFilter = std::function< |
27 | bool(const Match&, const std::unordered_map<std::string, Value*>&)>; |
28 | |
29 | /** Run pattern-based subgraph rewrites on all methods in the module. |
30 | * |
31 | * This pass will go through all methods in the module and try to replace all |
32 | * recognized patterns (see SubgraphRewriter::RegisterDefaultPatterns for the |
33 | * list of these patterns). |
34 | */ |
35 | TORCH_API Module PatternBasedRewrite(const Module& module); |
36 | |
37 | /** A class implementing API for pattern-based subgraph rewrites. |
38 | * |
39 | * To perform pattern-based subgraph rewrites on a module using this API, one |
40 | * needs to create an object of such class, register rewrite patterns and run |
41 | * the transformation pass (`runOnModule`). |
42 | * |
43 | * To use standard patterns, one could use `RegisterDefaultPatterns`. |
44 | * |
45 | * To enable rewrites of custom patterns, the custom patterns must be registered |
46 | * with `RegisterRewritePattern`. |
47 | */ |
48 | class TORCH_API SubgraphRewriter { |
49 | public: |
50 | // Run pattern-based subgraph rewrite pass on the module. |
51 | Module runOnModule(const Module& module); |
52 | |
53 | // Run pattern-based subgraph rewrite pass on the graph (used in testing). |
54 | // `filter` is a function that does extra filtering on the match. If it |
55 | // returns false for a given Match, we'll skip the Match. The filter |
56 | // function's arguments consist of a Match and a value map from parsing the |
57 | // pattern graph. Both the Match and the value map are necessary because we |
58 | // need to 1) do extra filtering on the matched result as well as 2) refer to |
59 | // the values in the matched result through the values in the pattern graph. |
60 | void runOnGraph( |
61 | std::shared_ptr<Graph>& graph, |
62 | const std::vector<MatchFilter>& filters); |
63 | |
64 | void runOnGraph( |
65 | std::shared_ptr<Graph>& graph, |
66 | const MatchFilter& filter = |
67 | [](const Match&, const std::unordered_map<std::string, Value*>&) { |
68 | return true; |
69 | }) { |
70 | runOnGraph(graph, std::vector<MatchFilter>({filter})); |
71 | } |
72 | |
73 | // Register standard rewrite patterns. |
74 | void RegisterDefaultPatterns(); |
75 | |
76 | /** Register a custom rewrite pattern. |
77 | * |
78 | * The method takes two parameters specifying the pattern: |
79 | * \p PATTERN - IR string representing the pattern subgraph. |
80 | * \p REPLACEMENT - IR string representing the replacement subgraph. |
81 | * \p value name map - vector of pairs mapping values in the replacement graph |
82 | * to the values in the pattern graph. Used for preserving source range info |
83 | * across graph rewrite. |
84 | * |
85 | * See examples of pattern registering in `RegisterDefaultPatterns`. |
86 | */ |
87 | void RegisterRewritePattern( |
88 | const std::string& pattern, |
89 | const std::string& replacement, |
90 | const std::vector<std::pair<std::string, std::string>>& value_name_pair = |
91 | {}); |
92 | |
93 | private: |
94 | std::vector<RewritePatternDescr> patterns_; |
95 | std::unordered_set<Node*> nodes_to_delete_; |
96 | |
97 | void rewriteSinglePatternOnGraph( |
98 | std::shared_ptr<Graph>& graph, |
99 | const RewritePatternDescr& pattern, |
100 | const std::vector<MatchFilter>& filters); |
101 | |
102 | bool overlapsWithPreviousMatches(const Match* match); |
103 | }; |
104 | |
105 | /** Rewrite pattern descriptor. |
106 | * |
107 | * This structure is used in the implementation of `SubgraphRewriter` and |
108 | * is not supposed to be used externally. |
109 | */ |
110 | struct RewritePatternDescr { |
111 | std::string pattern; |
112 | std::string replacement; |
113 | std::unordered_map<std::string, std::string> value_name_map; |
114 | }; |
115 | |
116 | } // namespace jit |
117 | } // namespace torch |
118 | |