1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4
5#include <unordered_map>
6#include <vector>
7
8namespace torch {
9namespace jit {
10
11/**
12 * \brief A structure describing a match of a pattern in a graph.
13 *
14 * The structure contains an anchor node, from which the match was found, and
15 * match-maps for nodes and values. A match-map specifies the correspondance
16 * between nodes in the pattern graph (match-map keys) with nodes in the actual
17 * graph (match-map values). We keep such maps for both nodes and values.
18 */
19struct Match {
20 Node* anchor;
21 std::unordered_map<const Node*, Node*> nodes_map;
22 std::unordered_map<const Value*, Value*> values_map;
23};
24
25/**
26 * \brief Find all matches of a \p PATTERN in a \p GRAPH.
27 *
28 * The function returns a vector of match-descriptors (see description of
29 * `struct Match`).
30 *
31 * Matching rules:
32 * - Pattern graph must contain a single block.
33 * - Matched subgraphs do not span across different blocks.
34 * - No uses outside the match are allowed, except for Param and Return nodes.
35 * Basically, we're matching hammocks, not arbitrary subgraphs.
36 * - The pattern graph must return only one value (i.e. it must have a single
37 * node leading to return).
38 * - Nodes that are not used in computation of the return value in the pattern
39 * graph are ignored during matching (IOW, we're essentially performing DCE on
40 * the pattern).
41 * - Pattern graph nodes cannot alias. TODO: the check not implemented yet.
42 * - Aliasing nodes in the graph cannot consitute a match (i.e. through all
43 * found matches, no nodes in the subgraph alias with each other). TODO: check
44 * not implemented yet.
45 * - The matcher will not mutate either the pattern graph or the matched graph.
46 * The matched graph is taken as non-const so that Match may contain non-const
47 * pointers. This enables clients of this API to use Match to drive mutations.
48 *
49 * Note [Multi-output Patterns]
50 * ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
51 * Subgraph matcher provides limited support for multi-output patterns. With a
52 * single output pattern, a single scan through the graph is sufficient to
53 * find all the matches: given a starting node (an "anchor"), we can
54 * deterministically check whether a pattern matches a subgraph corresponding to
55 * this anchor node. For a general case of multi-output patterns, we would have
56 * N anchors, which would result in M^N comparisons (M is the size of the
57 * graph). Clearly this is computationally prohibitive.
58 *
59 * To overcome this, we impose some constraints on the multi-output patterns
60 * that we accept. We require that checking whether the pattern matches a
61 * subgraph would still be fully determined by a single node in the graph. To
62 * achieve this, we designate the first output in the pattern as the "main"
63 * output and assume that we can traverse up from this node to match the
64 * entire pattern.
65 *
66 * Corrolary 1: the order of outputs in the pattern matters!
67 * Corollary 2: patterns cannot contain any nodes not participating in the main
68 * output computation.
69 */
70std::vector<Match> TORCH_API
71findPatternMatches(const Graph& pattern, Graph& graph);
72
73} // namespace jit
74} // namespace torch
75