1/** This file defines API for pattern-based subgraph rewrites.
2 *
3 * The API can be used for finding concrete patterns in the model and replacing
4 * the corresponding subgraphs with another subgraph. A special case of such
5 * rewrites is fusion, where the new subgraph consists of just a single node.
6 *
7 * There is a default set of the most common patterns that everyone could use.
8 * Alternatively, an arbitrary pattern can be registered.
9 */
10#pragma once
11
12#include <torch/csrc/jit/api/module.h>
13#include <torch/csrc/jit/ir/ir.h>
14
15#include <functional>
16#include <unordered_set>
17#include <vector>
18
19namespace torch {
20namespace jit {
21
22// Forward declarations.
23struct RewritePatternDescr;
24struct Match;
25
26using MatchFilter = std::function<
27 bool(const Match&, const std::unordered_map<std::string, Value*>&)>;
28
29/** Run pattern-based subgraph rewrites on all methods in the module.
30 *
31 * This pass will go through all methods in the module and try to replace all
32 * recognized patterns (see SubgraphRewriter::RegisterDefaultPatterns for the
33 * list of these patterns).
34 */
35TORCH_API Module PatternBasedRewrite(const Module& module);
36
37/** A class implementing API for pattern-based subgraph rewrites.
38 *
39 * To perform pattern-based subgraph rewrites on a module using this API, one
40 * needs to create an object of such class, register rewrite patterns and run
41 * the transformation pass (`runOnModule`).
42 *
43 * To use standard patterns, one could use `RegisterDefaultPatterns`.
44 *
45 * To enable rewrites of custom patterns, the custom patterns must be registered
46 * with `RegisterRewritePattern`.
47 */
48class TORCH_API SubgraphRewriter {
49 public:
50 // Run pattern-based subgraph rewrite pass on the module.
51 Module runOnModule(const Module& module);
52
53 // Run pattern-based subgraph rewrite pass on the graph (used in testing).
54 // `filter` is a function that does extra filtering on the match. If it
55 // returns false for a given Match, we'll skip the Match. The filter
56 // function's arguments consist of a Match and a value map from parsing the
57 // pattern graph. Both the Match and the value map are necessary because we
58 // need to 1) do extra filtering on the matched result as well as 2) refer to
59 // the values in the matched result through the values in the pattern graph.
60 void runOnGraph(
61 std::shared_ptr<Graph>& graph,
62 const std::vector<MatchFilter>& filters);
63
64 void runOnGraph(
65 std::shared_ptr<Graph>& graph,
66 const MatchFilter& filter =
67 [](const Match&, const std::unordered_map<std::string, Value*>&) {
68 return true;
69 }) {
70 runOnGraph(graph, std::vector<MatchFilter>({filter}));
71 }
72
73 // Register standard rewrite patterns.
74 void RegisterDefaultPatterns();
75
76 /** Register a custom rewrite pattern.
77 *
78 * The method takes two parameters specifying the pattern:
79 * \p PATTERN - IR string representing the pattern subgraph.
80 * \p REPLACEMENT - IR string representing the replacement subgraph.
81 * \p value name map - vector of pairs mapping values in the replacement graph
82 * to the values in the pattern graph. Used for preserving source range info
83 * across graph rewrite.
84 *
85 * See examples of pattern registering in `RegisterDefaultPatterns`.
86 */
87 void RegisterRewritePattern(
88 const std::string& pattern,
89 const std::string& replacement,
90 const std::vector<std::pair<std::string, std::string>>& value_name_pair =
91 {});
92
93 private:
94 std::vector<RewritePatternDescr> patterns_;
95 std::unordered_set<Node*> nodes_to_delete_;
96
97 void rewriteSinglePatternOnGraph(
98 std::shared_ptr<Graph>& graph,
99 const RewritePatternDescr& pattern,
100 const std::vector<MatchFilter>& filters);
101
102 bool overlapsWithPreviousMatches(const Match* match);
103};
104
105/** Rewrite pattern descriptor.
106 *
107 * This structure is used in the implementation of `SubgraphRewriter` and
108 * is not supposed to be used externally.
109 */
110struct RewritePatternDescr {
111 std::string pattern;
112 std::string replacement;
113 std::unordered_map<std::string, std::string> value_name_map;
114};
115
116} // namespace jit
117} // namespace torch
118