1 | #include <ATen/Utils.h> |
---|---|
2 | |
3 | #include <torch/csrc/jit/ir/constants.h> |
4 | #include <torch/csrc/jit/ir/ir.h> |
5 | #include <torch/csrc/jit/ir/subgraph_matcher.h> |
6 | #include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h> |
7 | #include <torch/csrc/jit/passes/graph_rewrite_helper.h> |
8 | #include <torch/csrc/jit/passes/remove_mutation.h> |
9 | #include <torch/csrc/jit/passes/subgraph_rewrite.h> |
10 | #ifdef USE_CUDA |
11 | #include <ATen/cuda/CUDAConfig.h> |
12 | #endif |
13 | |
14 | namespace torch { |
15 | namespace jit { |
16 | |
17 | std::function<void(std::shared_ptr<Graph>&)>& getFuseFrozenConvAddReluImpl() { |
18 | static std::function<void(std::shared_ptr<Graph>&)> impl; |
19 | return impl; |
20 | } |
21 | |
22 | // Implementation is in frozen_conv_add_relu_fusion.cpp; at runtime the |
23 | // implementation is registered in _fuseFrozenConvAddReluImpl. This allows |
24 | // the GPU code to be built separately from CPU-only code. If you're |
25 | // expecting conv-add-relu fusion to occur but it's not happening, it's |
26 | // possible that the GPU code isn't being built or linked properly. |
27 | void FuseFrozenConvAddRelu(std::shared_ptr<Graph>& graph) { |
28 | if (getFuseFrozenConvAddReluImpl()) { |
29 | getFuseFrozenConvAddReluImpl()(graph); |
30 | } |
31 | } |
32 | |
33 | } // namespace jit |
34 | } // namespace torch |
35 |