1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/ir/ir.h> |
4 | |
5 | #include <unordered_map> |
6 | #include <vector> |
7 | |
8 | namespace torch { |
9 | namespace jit { |
10 | |
11 | /** |
12 | * \brief A structure describing a match of a pattern in a graph. |
13 | * |
14 | * The structure contains an anchor node, from which the match was found, and |
15 | * match-maps for nodes and values. A match-map specifies the correspondance |
16 | * between nodes in the pattern graph (match-map keys) with nodes in the actual |
17 | * graph (match-map values). We keep such maps for both nodes and values. |
18 | */ |
19 | struct Match { |
20 | Node* anchor; |
21 | std::unordered_map<const Node*, Node*> nodes_map; |
22 | std::unordered_map<const Value*, Value*> values_map; |
23 | }; |
24 | |
25 | /** |
26 | * \brief Find all matches of a \p PATTERN in a \p GRAPH. |
27 | * |
28 | * The function returns a vector of match-descriptors (see description of |
29 | * `struct Match`). |
30 | * |
31 | * Matching rules: |
32 | * - Pattern graph must contain a single block. |
33 | * - Matched subgraphs do not span across different blocks. |
34 | * - No uses outside the match are allowed, except for Param and Return nodes. |
35 | * Basically, we're matching hammocks, not arbitrary subgraphs. |
36 | * - The pattern graph must return only one value (i.e. it must have a single |
37 | * node leading to return). |
38 | * - Nodes that are not used in computation of the return value in the pattern |
39 | * graph are ignored during matching (IOW, we're essentially performing DCE on |
40 | * the pattern). |
41 | * - Pattern graph nodes cannot alias. TODO: the check not implemented yet. |
42 | * - Aliasing nodes in the graph cannot consitute a match (i.e. through all |
43 | * found matches, no nodes in the subgraph alias with each other). TODO: check |
44 | * not implemented yet. |
45 | * - The matcher will not mutate either the pattern graph or the matched graph. |
46 | * The matched graph is taken as non-const so that Match may contain non-const |
47 | * pointers. This enables clients of this API to use Match to drive mutations. |
48 | * |
49 | * Note [Multi-output Patterns] |
50 | * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
51 | * Subgraph matcher provides limited support for multi-output patterns. With a |
52 | * single output pattern, a single scan through the graph is sufficient to |
53 | * find all the matches: given a starting node (an "anchor"), we can |
54 | * deterministically check whether a pattern matches a subgraph corresponding to |
55 | * this anchor node. For a general case of multi-output patterns, we would have |
56 | * N anchors, which would result in M^N comparisons (M is the size of the |
57 | * graph). Clearly this is computationally prohibitive. |
58 | * |
59 | * To overcome this, we impose some constraints on the multi-output patterns |
60 | * that we accept. We require that checking whether the pattern matches a |
61 | * subgraph would still be fully determined by a single node in the graph. To |
62 | * achieve this, we designate the first output in the pattern as the "main" |
63 | * output and assume that we can traverse up from this node to match the |
64 | * entire pattern. |
65 | * |
66 | * Corrolary 1: the order of outputs in the pattern matters! |
67 | * Corollary 2: patterns cannot contain any nodes not participating in the main |
68 | * output computation. |
69 | */ |
70 | std::vector<Match> TORCH_API |
71 | findPatternMatches(const Graph& pattern, Graph& graph); |
72 | |
73 | } // namespace jit |
74 | } // namespace torch |
75 | |