1 | #pragma once |
---|---|
2 | |
3 | #include <torch/csrc/jit/api/module.h> |
4 | |
5 | namespace torch { |
6 | namespace jit { |
7 | |
8 | struct TORCH_API LinearBNParameters { |
9 | at::Tensor linear_w; |
10 | at::Tensor linear_b; |
11 | at::Tensor bn_rm; |
12 | at::Tensor bn_rv; |
13 | double bn_eps = 0.0; |
14 | at::Tensor bn_w; |
15 | at::Tensor bn_b; |
16 | }; |
17 | |
18 | /** |
19 | * Given the current weight and bias tensors of a Linear module and parameters |
20 | * of the BatchNorm module we're folding with, compute the updated values |
21 | * for the weight and bias. |
22 | * |
23 | * The function is basically copied from torch/nn/utils/fusion.py |
24 | */ |
25 | TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias( |
26 | const LinearBNParameters& p); |
27 | |
28 | } // namespace jit |
29 | } // namespace torch |
30 |