1#pragma once
2
3#include <ATen/Config.h>
4#include <torch/csrc/jit/api/module.h>
5#include <torch/csrc/jit/ir/ir.h>
6#include <torch/csrc/jit/passes/subgraph_rewrite.h>
7
8#if AT_MKLDNN_ENABLED()
9
10#include <ideep/tensor.hpp>
11
12#endif // AT_MKLDNN_ENABLED()
13
14namespace torch {
15namespace jit {
16
17#if AT_MKLDNN_ENABLED()
18
19namespace mkldnn {
20
21const static std::map<std::string, std::vector<torch::jit::MatchFilter>>
22 fusion_rewrite_map = {
23 {"none", {}},
24 {"relu", {}},
25};
26
27} // namespace mkldnn
28
29#endif // AT_MKLDNN_ENABLED()
30
31void FuseConvWithEltwise(std::shared_ptr<Graph>& graph);
32
33} // namespace jit
34} // namespace torch
35