1#pragma once
2
3#include <torch/csrc/jit/ir/ir.h>
4#include <torch/csrc/jit/ir/irparser.h>
5#include <torch/csrc/jit/ir/subgraph_matcher.h>
6#include <torch/csrc/jit/passes/subgraph_rewrite.h>
7
8namespace torch {
9namespace jit {
10namespace graph_rewrite_helper {
11
12std::string getFuncName(Value* func_value);
13Value* getValue(
14 const std::string& name,
15 const std::unordered_map<const Value*, Value*>& match_vmap,
16 const std::unordered_map<std::string, Value*>& vmap);
17c10::optional<IValue> getIValue(
18 const std::string& name,
19 const std::unordered_map<const Value*, Value*>& match_vmap,
20 const std::unordered_map<std::string, Value*>& vmap);
21TORCH_API void replaceConvolutionWithAtenConv(std::shared_ptr<Graph>& graph);
22
23bool isClampFusable(
24 const Match& match,
25 const std::unordered_map<std::string, Value*>& vmap);
26
27// This struct contains a compiled IR patterns slated for use in the
28// findPatternMatches function. The struct encapsulates the common
29// information from parseIR that is used in conjunction with the
30// pattern matching facility. A const instance of this struct can
31// also be stored away to cache the compiled IR pattern and reduce
32// runtime cost
33struct PatternInfo {
34 std::string pattern_string;
35 std::unique_ptr<Graph> pattern_graph;
36 std::unordered_map<std::string, Value*> vmap;
37 std::vector<MatchFilter> filters;
38
39 static PatternInfo parse_from_str(
40 std::string pattern_string,
41 const std::vector<MatchFilter>& filters = {}) {
42 PatternInfo rv{
43 std::move(pattern_string),
44 std::make_unique<Graph>(),
45 decltype(vmap){},
46 filters};
47 parseIR(rv.pattern_string, rv.pattern_graph.get(), rv.vmap);
48 return rv;
49 }
50};
51
52} // namespace graph_rewrite_helper
53} // namespace jit
54} // namespace torch
55