1#include <torch/csrc/jit/ir/constants.h>
2#include <torch/csrc/jit/ir/ir.h>
3#include <torch/csrc/jit/passes/dead_code_elimination.h>
4#include <torch/csrc/jit/passes/fold_linear_bn.h>
5#include <torch/csrc/jit/passes/frozen_linear_folding.h>
6#include <torch/csrc/jit/passes/utils/optimization_utils.h>
7
8#ifndef AT_PER_OPERATOR_HEADERS
9#include <ATen/Functions.h>
10#else
11#include <ATen/ops/ones_like.h>
12#include <ATen/ops/zeros_like.h>
13#endif
14
15namespace torch {
16namespace jit {
17
18namespace {
19
20using Tensor = at::Tensor;
21
22bool supportedLinearNode(Node* n) {
23 if (n->kind() == aten::linear) {
24 return true;
25 } else {
26 return false;
27 }
28}
29
30bool FoldFrozenLinearBatchnorm(Block* b) {
31 bool graph_modified = false;
32 for (Node* n : b->nodes()) {
33 for (Block* block : n->blocks()) {
34 graph_modified |= FoldFrozenLinearBatchnorm(block);
35 }
36
37 if (n->kind() == aten::batch_norm &&
38 supportedLinearNode(n->inputs().at(0)->node())) {
39 auto linear = n->inputs().at(0)->node();
40 auto bn = n;
41
42 if (nonConstantParameters(linear) || nonConstantParameters(bn)) {
43 continue;
44 }
45
46 auto bn_rm_ivalue = bn->namedInput("running_mean");
47 auto bn_rv_ivalue = bn->namedInput("running_var");
48
49 // check running_mean and running_var has value, if they are
50 // None(track_running_stats=False), skiping the folding path.
51 if (bn_rm_ivalue->type() == NoneType::get() &&
52 bn_rv_ivalue->type() == NoneType::get()) {
53 continue;
54 }
55
56 auto bn_rm = constant_as<Tensor>(bn->namedInput("running_mean")).value();
57 auto bn_rv = constant_as<Tensor>(bn->namedInput("running_var")).value();
58 auto bn_eps = constant_as<double>(bn->namedInput("eps")).value();
59 auto linear_w = constant_as<Tensor>(linear->namedInput("weight")).value();
60
61 // implementation taken from torch/nn/utils/fusion.py
62 Tensor linear_b;
63 if (linear->namedInput("bias")->type() == NoneType::get()) {
64 at::ScalarType bias_dtype = bn_rm.scalar_type();
65 at::ScalarType weight_dtype = linear_w.scalar_type();
66 at::DeviceType weight_device = linear_w.device().type();
67 if (weight_device == at::kCUDA &&
68 (weight_dtype == at::kHalf || weight_dtype == at::kBFloat16) &&
69 bias_dtype == at::kFloat) {
70 bias_dtype = weight_dtype;
71 }
72 linear_b = at::zeros_like(bn_rm, at::TensorOptions().dtype(bias_dtype));
73 } else {
74 linear_b = constant_as<Tensor>(linear->namedInput("bias")).value();
75 }
76 Tensor bn_w;
77 if (bn->namedInput("weight")->type() == NoneType::get()) {
78 bn_w = at::ones_like(bn_rm);
79 } else {
80 bn_w = constant_as<Tensor>(bn->namedInput("weight")).value();
81 }
82 Tensor bn_b;
83 if (n->namedInput("bias")->type() == NoneType::get()) {
84 bn_b = at::zeros_like(bn_rm);
85 } else {
86 bn_b = constant_as<Tensor>(bn->namedInput("bias")).value();
87 }
88
89 LinearBNParameters params;
90 params.linear_w = linear_w;
91 params.linear_b = linear_b;
92 params.bn_rm = bn_rm;
93 params.bn_rv = bn_rv;
94 params.bn_eps = bn_eps;
95 params.bn_w = bn_w;
96 params.bn_b = bn_b;
97 std::tuple<Tensor, Tensor> out =
98 computeUpdatedLinearWeightAndBias(params);
99 WithInsertPoint guard(linear);
100 auto fused_linear_w = b->owningGraph()->insertConstant(std::get<0>(out));
101 auto fused_linear_b = b->owningGraph()->insertConstant(std::get<1>(out));
102 auto linear_w_value = linear->namedInput("weight");
103 auto linear_b_value = linear->namedInput("bias");
104
105 fused_linear_w->setDebugName(linear_w_value->debugName() + "_fused_bn");
106 fused_linear_b->setDebugName(linear_b_value->debugName() + "_fused_bn");
107
108 linear->replaceInputWith(linear_w_value, fused_linear_w);
109 linear->replaceInputWith(linear_b_value, fused_linear_b);
110
111 bn->output()->replaceAllUsesWith(linear->output());
112 graph_modified = true;
113 }
114 }
115 return graph_modified;
116}
117
118} // namespace
119
120bool FoldFrozenLinearBatchnorm(std::shared_ptr<Graph>& graph) {
121 bool graph_modified = FoldFrozenLinearBatchnorm(graph->block());
122 EliminateDeadCode(graph);
123 return graph_modified;
124}
125
126} // namespace jit
127} // namespace torch
128