1#include <ATen/Utils.h>
2
3#include <torch/csrc/jit/ir/constants.h>
4#include <torch/csrc/jit/ir/ir.h>
5#include <torch/csrc/jit/ir/subgraph_matcher.h>
6#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
7#include <torch/csrc/jit/passes/graph_rewrite_helper.h>
8#include <torch/csrc/jit/passes/remove_mutation.h>
9#include <torch/csrc/jit/passes/subgraph_rewrite.h>
10#ifdef USE_CUDA
11#include <ATen/cuda/CUDAConfig.h>
12#endif
13
14namespace torch {
15namespace jit {
16
17std::function<void(std::shared_ptr<Graph>&)>& getFuseFrozenConvAddReluImpl() {
18 static std::function<void(std::shared_ptr<Graph>&)> impl;
19 return impl;
20}
21
22// Implementation is in frozen_conv_add_relu_fusion.cpp; at runtime the
23// implementation is registered in _fuseFrozenConvAddReluImpl. This allows
24// the GPU code to be built separately from CPU-only code. If you're
25// expecting conv-add-relu fusion to occur but it's not happening, it's
26// possible that the GPU code isn't being built or linked properly.
27void FuseFrozenConvAddRelu(std::shared_ptr<Graph>& graph) {
28 if (getFuseFrozenConvAddReluImpl()) {
29 getFuseFrozenConvAddReluImpl()(graph);
30 }
31}
32
33} // namespace jit
34} // namespace torch
35