1 | #pragma once |
---|---|
2 | |
3 | #include <ATen/Config.h> |
4 | #include <torch/csrc/jit/api/module.h> |
5 | #include <torch/csrc/jit/ir/ir.h> |
6 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
7 | |
8 | #if AT_MKLDNN_ENABLED() |
9 | |
10 | #include <ideep/tensor.hpp> |
11 | |
12 | #endif // AT_MKLDNN_ENABLED() |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | #if AT_MKLDNN_ENABLED() |
18 | |
19 | namespace mkldnn { |
20 | |
21 | const static std::map<std::string, std::vector<torch::jit::MatchFilter>> |
22 | fusion_rewrite_map = { |
23 | {"none", {}}, |
24 | {"relu", {}}, |
25 | }; |
26 | |
27 | } // namespace mkldnn |
28 | |
29 | #endif // AT_MKLDNN_ENABLED() |
30 | |
31 | void FuseConvWithEltwise(std::shared_ptr<Graph>& graph); |
32 | |
33 | } // namespace jit |
34 | } // namespace torch |
35 |