1 | #pragma once |
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | /** \brief Fold Conv2d-BatchNorm2d into Conv2d in all methods of this |
9 | * module and all its submodules, forward is included by default. |
10 | * |
11 | * The weight and bias of the Conv2d are correspondingly updated. Should only be |
12 | * used on modules in eval mode. |
13 | */ |
14 | TORCH_API Module FoldConvBatchNorm(const Module& module); |
15 | |
16 | struct TORCH_API ConvBNParameters { |
17 | at::Tensor conv_w; |
18 | at::Tensor conv_b; |
19 | at::Tensor bn_rm; |
20 | at::Tensor bn_rv; |
21 | double bn_eps = 0.0; |
22 | at::Tensor bn_w; |
23 | at::Tensor bn_b; |
24 | }; |
25 | |
26 | /** |
27 | * Given the current weight and bias tensors of a Conv module and parameters |
28 | * of the BatchNorm module we're folding with, compute the updated values |
29 | * for the weight and bias. |
30 | * |
31 | * The function is basically copied from torch/nn/utils/fusion.py |
32 | */ |
33 | TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedConvWeightAndBias( |
34 | const ConvBNParameters& p); |
35 | |
36 | } // namespace jit |
37 | } // namespace torch |
38 | |