1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/constant_folding.h"
17#include "tensorflow/core/common_runtime/graph_constructor.h"
18#include "tensorflow/core/graph/node_builder.h"
19#include "tensorflow/core/graph/subgraph.h"
20#include "tensorflow/core/platform/init_main.h"
21#include "tensorflow/core/public/session.h"
22#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
23#include "tensorflow/tools/graph_transforms/transform_utils.h"
24
25namespace tensorflow {
26namespace graph_transforms {
27namespace {
28// Ensures the tensor is the expected shape.
29Status ErrorIfNotVector(const Tensor& input, const string& input_name,
30 int expected_width) {
31 if ((input.shape().dims() != 1) ||
32 (input.shape().dim_size(0) != expected_width)) {
33 return errors::InvalidArgument(
34 input_name,
35 " input to batch norm has bad shape: ", input.shape().DebugString());
36 }
37 return OkStatus();
38}
39
40Status GetScaleAndOffsetValues(const NodeMatch& match,
41 std::vector<float>* scale_values,
42 std::vector<float>* offset_values) {
43 // Find all the nodes we expect in the subgraph.
44 const NodeDef& batch_norm_node = match.node;
45 // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ
46 // by input order and attribute names.
47 CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" ||
48 batch_norm_node.op() == "FusedBatchNorm");
49 const bool is_fused = batch_norm_node.op() == "FusedBatchNorm";
50 const int mean_idx = is_fused ? 3 : 1;
51 const int var_idx = is_fused ? 4 : 2;
52 const int beta_idx = is_fused ? 2 : 3;
53 const int gamma_idx = is_fused ? 1 : 4;
54 const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon";
55 // FusedBatchNorm always scales after normalization.
56 const bool scale_after_normalization =
57 is_fused || batch_norm_node.attr().at("scale_after_normalization").b();
58
59 const NodeDef& mean_node = match.inputs[mean_idx].node;
60 CHECK_EQ("Const", mean_node.op());
61 const NodeDef& variance_node = match.inputs[var_idx].node;
62 CHECK_EQ("Const", variance_node.op());
63 const NodeDef& beta_node = match.inputs[beta_idx].node;
64 CHECK_EQ("Const", beta_node.op());
65 const NodeDef& gamma_node = match.inputs[gamma_idx].node;
66 CHECK_EQ("Const", gamma_node.op());
67
68 // We have a set of vectors that we want to combine into a vector of
69 // scale values and offset values.
70 Tensor mean = GetNodeTensorAttr(mean_node, "value");
71 Tensor variance = GetNodeTensorAttr(variance_node, "value");
72 Tensor beta = GetNodeTensorAttr(beta_node, "value");
73 Tensor gamma = GetNodeTensorAttr(gamma_node, "value");
74 const float variance_epsilon = batch_norm_node.attr().at(epsilon_attr).f();
75
76 // Make sure all the inputs really are vectors with the same shape.
77 const int64_t num_cols = mean.shape().dim_size(0);
78 TF_RETURN_IF_ERROR(ErrorIfNotVector(variance, "Variance", num_cols));
79 TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta", num_cols));
80 TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma", num_cols));
81
82 scale_values->resize(num_cols);
83 offset_values->resize(num_cols);
84
85 // Calculate the scale and offset values to apply.
86 if (scale_after_normalization) {
87 for (int i = 0; i < num_cols; ++i) {
88 (*scale_values)[i] =
89 (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) *
90 gamma.flat<float>()(i);
91 }
92 } else {
93 for (int i = 0; i < num_cols; ++i) {
94 (*scale_values)[i] =
95 (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon));
96 }
97 }
98 for (int i = 0; i < num_cols; ++i) {
99 (*offset_values)[i] =
100 (-mean.flat<float>()(i) * (*scale_values)[i]) + beta.flat<float>()(i);
101 }
102 return OkStatus();
103}
104
105Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values,
106 const std::vector<float>& offset_values,
107 const NodeMatch& conv_node_match,
108 const string& conv_output_name,
109 std::vector<NodeDef>* new_nodes) {
110 const NodeDef& conv_node = conv_node_match.node;
111 // CHECK_EQ("Conv2D", conv_node.op());
112 const NodeDef& input_node = conv_node_match.inputs[0].node;
113 const NodeDef& weights_node = conv_node_match.inputs[1].node;
114 CHECK_EQ("Const", weights_node.op());
115
116 Tensor weights = GetNodeTensorAttr(weights_node, "value");
117 int64_t weights_cols;
118 if (conv_node.op() == "Conv2D") {
119 weights_cols = weights.shape().dim_size(3);
120 } else if (conv_node.op() == "DepthwiseConv2dNative") {
121 weights_cols = weights.shape().dim_size(2) * weights.shape().dim_size(3);
122 } else {
123 weights_cols = weights.shape().dim_size(1);
124 }
125 CHECK_EQ(weights_cols, scale_values.size());
126
127 // Multiply the original weights by the scale vector.
128 auto weights_vector = weights.flat<float>();
129 Tensor scaled_weights(DT_FLOAT, weights.shape());
130 auto scaled_weights_vector = scaled_weights.flat<float>();
131 for (int64_t row = 0; row < weights_vector.dimension(0); ++row) {
132 scaled_weights_vector(row) =
133 weights_vector(row) * scale_values[row % weights_cols];
134 }
135 // Figure out the remaining bias to add on.
136 Tensor bias_offset(DT_FLOAT, {weights_cols});
137 auto bias_offset_vector = bias_offset.flat<float>();
138 for (int64_t col = 0; col < weights_cols; ++col) {
139 bias_offset_vector(col) = offset_values[col];
140 }
141
142 // Construct the new nodes.
143 NodeDef scaled_weights_node;
144 scaled_weights_node.set_op("Const");
145 scaled_weights_node.set_name(weights_node.name());
146 SetNodeAttr("dtype", DT_FLOAT, &scaled_weights_node);
147 SetNodeTensorAttr<float>("value", scaled_weights, &scaled_weights_node);
148 new_nodes->push_back(scaled_weights_node);
149
150 // The input and convolution can be copied straight over, since the
151 // name of the scaled weights constant is the same as the original.
152 new_nodes->push_back(input_node);
153 new_nodes->push_back(conv_node);
154
155 NodeDef bias_offset_node;
156 bias_offset_node.set_op("Const");
157 bias_offset_node.set_name(conv_node.name() + "_bn_offset");
158 SetNodeAttr("dtype", DT_FLOAT, &bias_offset_node);
159 SetNodeTensorAttr<float>("value", bias_offset, &bias_offset_node);
160 new_nodes->push_back(bias_offset_node);
161
162 NodeDef bias_add_node;
163 bias_add_node.set_op("BiasAdd");
164 bias_add_node.set_name(conv_output_name);
165 if (conv_node.attr().count("data_format")) {
166 CopyNodeAttr(conv_node, "data_format", "data_format", &bias_add_node);
167 }
168 CopyNodeAttr(conv_node, "T", "T", &bias_add_node);
169 AddNodeInput(conv_node.name(), &bias_add_node);
170 AddNodeInput(bias_offset_node.name(), &bias_add_node);
171 new_nodes->push_back(bias_add_node);
172 return OkStatus();
173}
174
175Status FuseBatchNormWithConv(const NodeMatch& match,
176 std::vector<NodeDef>* new_nodes) {
177 // Calculate the scale and offset values to apply.
178 std::vector<float> scale_values;
179 std::vector<float> offset_values;
180 TF_RETURN_IF_ERROR(
181 GetScaleAndOffsetValues(match, &scale_values, &offset_values));
182
183 // Fuse conv weights, and set the final output node name as batch_norm_node.
184 const NodeDef& batch_norm_node = match.node;
185 TF_RETURN_IF_ERROR(
186 FuseScaleOffsetToConvWeights(scale_values, offset_values, match.inputs[0],
187 batch_norm_node.name(), new_nodes));
188 return OkStatus();
189}
190
191Status FuseBatchNormWithBatchToSpace(const NodeMatch& match,
192 std::vector<NodeDef>* new_nodes) {
193 // Calculate the scale and offset values to apply.
194 std::vector<float> scale_values;
195 std::vector<float> offset_values;
196 TF_RETURN_IF_ERROR(
197 GetScaleAndOffsetValues(match, &scale_values, &offset_values));
198
199 // Fuse conv weights, and set the final output node name as batch_norm_node.
200 const NodeDef& batch_norm_node = match.node;
201 const NodeMatch& batch_to_space_node_match = match.inputs[0];
202 const NodeMatch& conv_node_match = batch_to_space_node_match.inputs[0];
203 const NodeDef& batch_to_space_node = batch_to_space_node_match.node;
204 const NodeDef& conv_node = conv_node_match.node;
205
206 string biasadd_name = conv_node.name() + "/biasadd";
207 TF_RETURN_IF_ERROR(FuseScaleOffsetToConvWeights(
208 scale_values, offset_values, conv_node_match, biasadd_name, new_nodes));
209
210 NodeDef new_batch_to_space_node = batch_to_space_node;
211 // reuse batch_norm node name
212 new_batch_to_space_node.set_name(batch_norm_node.name());
213 new_batch_to_space_node.set_input(0, biasadd_name);
214 new_nodes->push_back(batch_to_space_node_match.inputs[1].node);
215 new_nodes->push_back(batch_to_space_node_match.inputs[2].node);
216 new_nodes->push_back(new_batch_to_space_node);
217 return OkStatus();
218}
219
220Status FuseBatchNormWithConvConcat(const NodeMatch& match,
221 std::vector<NodeDef>* new_nodes) {
222 // Calculate the scale and offset values to apply.
223 std::vector<float> scale_values;
224 std::vector<float> offset_values;
225 TF_RETURN_IF_ERROR(
226 GetScaleAndOffsetValues(match, &scale_values, &offset_values));
227
228 // Find all the nodes we expect in the subgraph.
229 const NodeDef& batch_norm_node = match.node;
230 const NodeMatch& concat_node_match = match.inputs[0];
231 NodeDef concat_node = concat_node_match.node;
232 CHECK_EQ("ConcatV2", concat_node.op());
233
234 // First process the axis.
235 NodeDef axis_node = concat_node_match.inputs[2].node;
236 CHECK_EQ("Const", axis_node.op());
237 Tensor axis = GetNodeTensorAttr(axis_node, "value");
238 int32_t axis_scalar = (axis.scalar<int32>())();
239
240 // Set both conv0 and conv1 have the same scale and offset in default.
241 std::vector<float> scale0(scale_values);
242 std::vector<float> offset0(offset_values);
243 std::vector<float> scale1(scale_values);
244 std::vector<float> offset1(offset_values);
245 if (axis_scalar == 3) {
246 // If axis is 3, then scale and offset will be split into two halfs.
247 const NodeDef& weights0_node = concat_node_match.inputs[0].inputs[1].node;
248 Tensor weights0 = GetNodeTensorAttr(weights0_node, "value");
249 const int64_t split_cols = weights0.shape().dim_size(3);
250 // Only keep the first half for scale0/offset0.
251 scale0.erase(scale0.begin() + split_cols, scale0.end());
252 offset0.erase(offset0.begin() + split_cols, offset0.end());
253 // Only keep the second half for scale1/offset1.
254 scale1.erase(scale1.begin(), scale1.begin() + split_cols);
255 offset1.erase(offset1.begin(), offset1.begin() + split_cols);
256 }
257
258 // Fuse the weights for input0 of conv2d.
259 const string concat0_output_name = concat_node.name() + "_bn_in0";
260 TF_RETURN_IF_ERROR(
261 FuseScaleOffsetToConvWeights(scale0, offset0, concat_node_match.inputs[0],
262 concat0_output_name, new_nodes));
263
264 // Fuse the weights for input1 of conv2d.
265 const string concat1_output_name = concat_node.name() + "_bn_in1";
266 TF_RETURN_IF_ERROR(
267 FuseScaleOffsetToConvWeights(scale1, offset1, concat_node_match.inputs[1],
268 concat1_output_name, new_nodes));
269
270 // Push the shape node.
271 new_nodes->push_back(concat_node_match.inputs[2].node);
272
273 // Set the final output op name to batch_normal_node.
274 concat_node.set_name(batch_norm_node.name());
275 concat_node.set_input(0, concat0_output_name);
276 concat_node.set_input(1, concat1_output_name);
277 new_nodes->push_back(concat_node);
278 return OkStatus();
279}
280} // namespace
281
282// Finds monolithic batch norm ops (as used in early versions of TensorFlow) and
283// converts them into premultiplied weight inputs to convolutions.
284Status FoldOldBatchNorms(const GraphDef& input_graph_def,
285 const TransformFuncContext& context,
286 GraphDef* output_graph_def) {
287 GraphDef current_graph_def = input_graph_def;
288 // We have to do several passes to catch all the old BN nodes, since many of
289 // them may share inputs and so be excluded from replacement in one pass.
290 bool did_graph_change;
291 do {
292 did_graph_change = false;
293 GraphDef replaced_graph_def;
294 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
295 current_graph_def, // clang-format off
296 {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
297 {
298 {"Conv2D|DepthwiseConv2dNative", // conv_node
299 {
300 {"*"}, // input_node
301 {"Const"}, // weights_node
302 }
303 },
304 {"Const"}, // mean_node
305 {"Const"}, // variance_node
306 {"Const"}, // beta_node
307 {"Const"}, // gamma_node
308 }
309 }, // clang-format on
310 [&did_graph_change](const NodeMatch& match,
311 const std::set<string>& input_nodes,
312 const std::set<string>& output_nodes,
313 std::vector<NodeDef>* new_nodes) {
314 TF_RETURN_IF_ERROR(FuseBatchNormWithConv(match, new_nodes));
315 did_graph_change = true;
316 return OkStatus();
317 },
318 {}, &replaced_graph_def));
319 current_graph_def = replaced_graph_def;
320 } while (did_graph_change);
321
322 do {
323 did_graph_change = false;
324 GraphDef replaced_graph_def;
325 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
326 current_graph_def, // clang-format off
327 {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
328 {
329 {"BatchToSpaceND", // batch_to_space_node
330 {
331 {"Conv2D|DepthwiseConv2dNative", // conv_node
332 {
333 {"*"}, // input_node
334 {"Const"}, // weights_node
335 }
336 },
337 {"Const"}, // block_shape
338 {"Const"}, // crops
339 }
340 },
341 {"Const"}, // mean_node
342 {"Const"}, // variance_node
343 {"Const"}, // beta_node
344 {"Const"}, // gamma_node
345 }
346 }, // clang-format on
347 [&did_graph_change](const NodeMatch& match,
348 const std::set<string>& input_nodes,
349 const std::set<string>& output_nodes,
350 std::vector<NodeDef>* new_nodes) {
351 TF_RETURN_IF_ERROR(FuseBatchNormWithBatchToSpace(match, new_nodes));
352 did_graph_change = true;
353 return OkStatus();
354 },
355 {}, &replaced_graph_def));
356 current_graph_def = replaced_graph_def;
357 } while (did_graph_change);
358
359 do {
360 did_graph_change = false;
361 GraphDef replaced_graph_def;
362 // Replace BatchNorm with concat as input.
363 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
364 current_graph_def, // clang-format off
365 {"BatchNormWithGlobalNormalization|FusedBatchNorm", // batch_norm_node
366 {
367 {"ConcatV2|Concat", // concat two conv2d.
368 {
369 {"Conv2D|DepthwiseConv2dNative", // conv_node
370 {
371 {"*"}, // input_node
372 {"Const"}, // weights_node
373 }
374 },
375 {"Conv2D|DepthwiseConv2dNative", // conv_node
376 {
377 {"*"}, // input_node
378 {"Const"}, // weights_node
379 }
380 },
381 {"Const"}, // axis
382 },
383 },
384 {"Const"}, // mean_node
385 {"Const"}, // variance_node
386 {"Const"}, // beta_node
387 {"Const"}, // gamma_node
388 }
389 }, // clang-format on
390 [&did_graph_change](const NodeMatch& match,
391 const std::set<string>& input_nodes,
392 const std::set<string>& output_nodes,
393 std::vector<NodeDef>* new_nodes) {
394 TF_RETURN_IF_ERROR(FuseBatchNormWithConvConcat(match, new_nodes));
395 did_graph_change = true;
396 return OkStatus();
397 },
398 {}, &replaced_graph_def));
399 current_graph_def = replaced_graph_def;
400 } while (did_graph_change);
401
402 *output_graph_def = current_graph_def;
403 return OkStatus();
404}
405
406REGISTER_GRAPH_TRANSFORM("fold_old_batch_norms", FoldOldBatchNorms);
407
408} // namespace graph_transforms
409} // namespace tensorflow
410