1#pragma once
2
3#include <torch/csrc/jit/api/module.h>
4
5namespace torch {
6namespace jit {
7
8/** \brief Fold Conv2d-BatchNorm2d into Conv2d in all methods of this
9 * module and all its submodules, forward is included by default.
10 *
11 * The weight and bias of the Conv2d are correspondingly updated. Should only be
12 * used on modules in eval mode.
13 */
14TORCH_API Module FoldConvBatchNorm(const Module& module);
15
16struct TORCH_API ConvBNParameters {
17 at::Tensor conv_w;
18 at::Tensor conv_b;
19 at::Tensor bn_rm;
20 at::Tensor bn_rv;
21 double bn_eps = 0.0;
22 at::Tensor bn_w;
23 at::Tensor bn_b;
24};
25
26/**
27 * Given the current weight and bias tensors of a Conv module and parameters
28 * of the BatchNorm module we're folding with, compute the updated values
29 * for the weight and bias.
30 *
31 * The function is basically copied from torch/nn/utils/fusion.py
32 */
33TORCH_API std::tuple<at::Tensor, at::Tensor> computeUpdatedConvWeightAndBias(
34 const ConvBNParameters& p);
35
36} // namespace jit
37} // namespace torch
38