1 | #include <torch/csrc/jit/passes/fold_linear_bn.h> |
---|---|
2 | |
3 | #include <ATen/TensorOperators.h> |
4 | |
5 | #ifndef AT_PER_OPERATOR_HEADERS |
6 | #include <ATen/Functions.h> |
7 | #else |
8 | #include <ATen/ops/rsqrt.h> |
9 | #endif |
10 | |
11 | namespace torch { |
12 | namespace jit { |
13 | |
14 | std::tuple<at::Tensor, at::Tensor> computeUpdatedLinearWeightAndBias( |
15 | const LinearBNParameters& p) { |
16 | at::Tensor bn_scale = p.bn_w * at::rsqrt(p.bn_rv + p.bn_eps); |
17 | at::Tensor fused_w = p.linear_w * bn_scale.unsqueeze(-1); |
18 | at::Tensor fused_b = (p.linear_b - p.bn_rm) * bn_scale + p.bn_b; |
19 | |
20 | auto linear_w_dtype = p.linear_w.dtype(); |
21 | auto linear_b_dtype = p.linear_b.dtype(); |
22 | |
23 | return std::make_tuple( |
24 | fused_w.to(linear_w_dtype), fused_b.to(linear_b_dtype)); |
25 | } |
26 | |
27 | } // namespace jit |
28 | } // namespace torch |
29 |