1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/constant_folding.h" |
17 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | #include "tensorflow/core/graph/subgraph.h" |
20 | #include "tensorflow/core/platform/init_main.h" |
21 | #include "tensorflow/core/public/session.h" |
22 | #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" |
23 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace graph_transforms { |
27 | namespace { |
28 | // Ensures the tensor is the expected shape. |
29 | Status ErrorIfNotVector(const Tensor& input, const string& input_name, |
30 | int expected_width) { |
31 | if ((input.shape().dims() != 1) || |
32 | (input.shape().dim_size(0) != expected_width)) { |
33 | return errors::InvalidArgument( |
34 | input_name, |
35 | " input to batch norm has bad shape: " , input.shape().DebugString()); |
36 | } |
37 | return OkStatus(); |
38 | } |
39 | |
40 | Status GetScaleAndOffsetValues(const NodeMatch& match, |
41 | std::vector<float>* scale_values, |
42 | std::vector<float>* offset_values) { |
43 | // Find all the nodes we expect in the subgraph. |
44 | const NodeDef& batch_norm_node = match.node; |
45 | // BatchNormWithGlobalNormalization and FusedBatchNorm ops only differ |
46 | // by input order and attribute names. |
47 | CHECK(batch_norm_node.op() == "BatchNormWithGlobalNormalization" || |
48 | batch_norm_node.op() == "FusedBatchNorm" ); |
49 | const bool is_fused = batch_norm_node.op() == "FusedBatchNorm" ; |
50 | const int mean_idx = is_fused ? 3 : 1; |
51 | const int var_idx = is_fused ? 4 : 2; |
52 | const int beta_idx = is_fused ? 2 : 3; |
53 | const int gamma_idx = is_fused ? 1 : 4; |
54 | const string epsilon_attr = is_fused ? "epsilon" : "variance_epsilon" ; |
55 | // FusedBatchNorm always scales after normalization. |
56 | const bool scale_after_normalization = |
57 | is_fused || batch_norm_node.attr().at("scale_after_normalization" ).b(); |
58 | |
59 | const NodeDef& mean_node = match.inputs[mean_idx].node; |
60 | CHECK_EQ("Const" , mean_node.op()); |
61 | const NodeDef& variance_node = match.inputs[var_idx].node; |
62 | CHECK_EQ("Const" , variance_node.op()); |
63 | const NodeDef& beta_node = match.inputs[beta_idx].node; |
64 | CHECK_EQ("Const" , beta_node.op()); |
65 | const NodeDef& gamma_node = match.inputs[gamma_idx].node; |
66 | CHECK_EQ("Const" , gamma_node.op()); |
67 | |
68 | // We have a set of vectors that we want to combine into a vector of |
69 | // scale values and offset values. |
70 | Tensor mean = GetNodeTensorAttr(mean_node, "value" ); |
71 | Tensor variance = GetNodeTensorAttr(variance_node, "value" ); |
72 | Tensor beta = GetNodeTensorAttr(beta_node, "value" ); |
73 | Tensor gamma = GetNodeTensorAttr(gamma_node, "value" ); |
74 | const float variance_epsilon = batch_norm_node.attr().at(epsilon_attr).f(); |
75 | |
76 | // Make sure all the inputs really are vectors with the same shape. |
77 | const int64_t num_cols = mean.shape().dim_size(0); |
78 | TF_RETURN_IF_ERROR(ErrorIfNotVector(variance, "Variance" , num_cols)); |
79 | TF_RETURN_IF_ERROR(ErrorIfNotVector(beta, "Beta" , num_cols)); |
80 | TF_RETURN_IF_ERROR(ErrorIfNotVector(gamma, "gamma" , num_cols)); |
81 | |
82 | scale_values->resize(num_cols); |
83 | offset_values->resize(num_cols); |
84 | |
85 | // Calculate the scale and offset values to apply. |
86 | if (scale_after_normalization) { |
87 | for (int i = 0; i < num_cols; ++i) { |
88 | (*scale_values)[i] = |
89 | (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)) * |
90 | gamma.flat<float>()(i); |
91 | } |
92 | } else { |
93 | for (int i = 0; i < num_cols; ++i) { |
94 | (*scale_values)[i] = |
95 | (1.0f / sqrtf(variance.flat<float>()(i) + variance_epsilon)); |
96 | } |
97 | } |
98 | for (int i = 0; i < num_cols; ++i) { |
99 | (*offset_values)[i] = |
100 | (-mean.flat<float>()(i) * (*scale_values)[i]) + beta.flat<float>()(i); |
101 | } |
102 | return OkStatus(); |
103 | } |
104 | |
105 | Status FuseScaleOffsetToConvWeights(const std::vector<float>& scale_values, |
106 | const std::vector<float>& offset_values, |
107 | const NodeMatch& conv_node_match, |
108 | const string& conv_output_name, |
109 | std::vector<NodeDef>* new_nodes) { |
110 | const NodeDef& conv_node = conv_node_match.node; |
111 | // CHECK_EQ("Conv2D", conv_node.op()); |
112 | const NodeDef& input_node = conv_node_match.inputs[0].node; |
113 | const NodeDef& weights_node = conv_node_match.inputs[1].node; |
114 | CHECK_EQ("Const" , weights_node.op()); |
115 | |
116 | Tensor weights = GetNodeTensorAttr(weights_node, "value" ); |
117 | int64_t weights_cols; |
118 | if (conv_node.op() == "Conv2D" ) { |
119 | weights_cols = weights.shape().dim_size(3); |
120 | } else if (conv_node.op() == "DepthwiseConv2dNative" ) { |
121 | weights_cols = weights.shape().dim_size(2) * weights.shape().dim_size(3); |
122 | } else { |
123 | weights_cols = weights.shape().dim_size(1); |
124 | } |
125 | CHECK_EQ(weights_cols, scale_values.size()); |
126 | |
127 | // Multiply the original weights by the scale vector. |
128 | auto weights_vector = weights.flat<float>(); |
129 | Tensor scaled_weights(DT_FLOAT, weights.shape()); |
130 | auto scaled_weights_vector = scaled_weights.flat<float>(); |
131 | for (int64_t row = 0; row < weights_vector.dimension(0); ++row) { |
132 | scaled_weights_vector(row) = |
133 | weights_vector(row) * scale_values[row % weights_cols]; |
134 | } |
135 | // Figure out the remaining bias to add on. |
136 | Tensor bias_offset(DT_FLOAT, {weights_cols}); |
137 | auto bias_offset_vector = bias_offset.flat<float>(); |
138 | for (int64_t col = 0; col < weights_cols; ++col) { |
139 | bias_offset_vector(col) = offset_values[col]; |
140 | } |
141 | |
142 | // Construct the new nodes. |
143 | NodeDef scaled_weights_node; |
144 | scaled_weights_node.set_op("Const" ); |
145 | scaled_weights_node.set_name(weights_node.name()); |
146 | SetNodeAttr("dtype" , DT_FLOAT, &scaled_weights_node); |
147 | SetNodeTensorAttr<float>("value" , scaled_weights, &scaled_weights_node); |
148 | new_nodes->push_back(scaled_weights_node); |
149 | |
150 | // The input and convolution can be copied straight over, since the |
151 | // name of the scaled weights constant is the same as the original. |
152 | new_nodes->push_back(input_node); |
153 | new_nodes->push_back(conv_node); |
154 | |
155 | NodeDef bias_offset_node; |
156 | bias_offset_node.set_op("Const" ); |
157 | bias_offset_node.set_name(conv_node.name() + "_bn_offset" ); |
158 | SetNodeAttr("dtype" , DT_FLOAT, &bias_offset_node); |
159 | SetNodeTensorAttr<float>("value" , bias_offset, &bias_offset_node); |
160 | new_nodes->push_back(bias_offset_node); |
161 | |
162 | NodeDef bias_add_node; |
163 | bias_add_node.set_op("BiasAdd" ); |
164 | bias_add_node.set_name(conv_output_name); |
165 | if (conv_node.attr().count("data_format" )) { |
166 | CopyNodeAttr(conv_node, "data_format" , "data_format" , &bias_add_node); |
167 | } |
168 | CopyNodeAttr(conv_node, "T" , "T" , &bias_add_node); |
169 | AddNodeInput(conv_node.name(), &bias_add_node); |
170 | AddNodeInput(bias_offset_node.name(), &bias_add_node); |
171 | new_nodes->push_back(bias_add_node); |
172 | return OkStatus(); |
173 | } |
174 | |
175 | Status FuseBatchNormWithConv(const NodeMatch& match, |
176 | std::vector<NodeDef>* new_nodes) { |
177 | // Calculate the scale and offset values to apply. |
178 | std::vector<float> scale_values; |
179 | std::vector<float> offset_values; |
180 | TF_RETURN_IF_ERROR( |
181 | GetScaleAndOffsetValues(match, &scale_values, &offset_values)); |
182 | |
183 | // Fuse conv weights, and set the final output node name as batch_norm_node. |
184 | const NodeDef& batch_norm_node = match.node; |
185 | TF_RETURN_IF_ERROR( |
186 | FuseScaleOffsetToConvWeights(scale_values, offset_values, match.inputs[0], |
187 | batch_norm_node.name(), new_nodes)); |
188 | return OkStatus(); |
189 | } |
190 | |
191 | Status FuseBatchNormWithBatchToSpace(const NodeMatch& match, |
192 | std::vector<NodeDef>* new_nodes) { |
193 | // Calculate the scale and offset values to apply. |
194 | std::vector<float> scale_values; |
195 | std::vector<float> offset_values; |
196 | TF_RETURN_IF_ERROR( |
197 | GetScaleAndOffsetValues(match, &scale_values, &offset_values)); |
198 | |
199 | // Fuse conv weights, and set the final output node name as batch_norm_node. |
200 | const NodeDef& batch_norm_node = match.node; |
201 | const NodeMatch& batch_to_space_node_match = match.inputs[0]; |
202 | const NodeMatch& conv_node_match = batch_to_space_node_match.inputs[0]; |
203 | const NodeDef& batch_to_space_node = batch_to_space_node_match.node; |
204 | const NodeDef& conv_node = conv_node_match.node; |
205 | |
206 | string biasadd_name = conv_node.name() + "/biasadd" ; |
207 | TF_RETURN_IF_ERROR(FuseScaleOffsetToConvWeights( |
208 | scale_values, offset_values, conv_node_match, biasadd_name, new_nodes)); |
209 | |
210 | NodeDef new_batch_to_space_node = batch_to_space_node; |
211 | // reuse batch_norm node name |
212 | new_batch_to_space_node.set_name(batch_norm_node.name()); |
213 | new_batch_to_space_node.set_input(0, biasadd_name); |
214 | new_nodes->push_back(batch_to_space_node_match.inputs[1].node); |
215 | new_nodes->push_back(batch_to_space_node_match.inputs[2].node); |
216 | new_nodes->push_back(new_batch_to_space_node); |
217 | return OkStatus(); |
218 | } |
219 | |
220 | Status FuseBatchNormWithConvConcat(const NodeMatch& match, |
221 | std::vector<NodeDef>* new_nodes) { |
222 | // Calculate the scale and offset values to apply. |
223 | std::vector<float> scale_values; |
224 | std::vector<float> offset_values; |
225 | TF_RETURN_IF_ERROR( |
226 | GetScaleAndOffsetValues(match, &scale_values, &offset_values)); |
227 | |
228 | // Find all the nodes we expect in the subgraph. |
229 | const NodeDef& batch_norm_node = match.node; |
230 | const NodeMatch& concat_node_match = match.inputs[0]; |
231 | NodeDef concat_node = concat_node_match.node; |
232 | CHECK_EQ("ConcatV2" , concat_node.op()); |
233 | |
234 | // First process the axis. |
235 | NodeDef axis_node = concat_node_match.inputs[2].node; |
236 | CHECK_EQ("Const" , axis_node.op()); |
237 | Tensor axis = GetNodeTensorAttr(axis_node, "value" ); |
238 | int32_t axis_scalar = (axis.scalar<int32>())(); |
239 | |
240 | // Set both conv0 and conv1 have the same scale and offset in default. |
241 | std::vector<float> scale0(scale_values); |
242 | std::vector<float> offset0(offset_values); |
243 | std::vector<float> scale1(scale_values); |
244 | std::vector<float> offset1(offset_values); |
245 | if (axis_scalar == 3) { |
246 | // If axis is 3, then scale and offset will be split into two halfs. |
247 | const NodeDef& weights0_node = concat_node_match.inputs[0].inputs[1].node; |
248 | Tensor weights0 = GetNodeTensorAttr(weights0_node, "value" ); |
249 | const int64_t split_cols = weights0.shape().dim_size(3); |
250 | // Only keep the first half for scale0/offset0. |
251 | scale0.erase(scale0.begin() + split_cols, scale0.end()); |
252 | offset0.erase(offset0.begin() + split_cols, offset0.end()); |
253 | // Only keep the second half for scale1/offset1. |
254 | scale1.erase(scale1.begin(), scale1.begin() + split_cols); |
255 | offset1.erase(offset1.begin(), offset1.begin() + split_cols); |
256 | } |
257 | |
258 | // Fuse the weights for input0 of conv2d. |
259 | const string concat0_output_name = concat_node.name() + "_bn_in0" ; |
260 | TF_RETURN_IF_ERROR( |
261 | FuseScaleOffsetToConvWeights(scale0, offset0, concat_node_match.inputs[0], |
262 | concat0_output_name, new_nodes)); |
263 | |
264 | // Fuse the weights for input1 of conv2d. |
265 | const string concat1_output_name = concat_node.name() + "_bn_in1" ; |
266 | TF_RETURN_IF_ERROR( |
267 | FuseScaleOffsetToConvWeights(scale1, offset1, concat_node_match.inputs[1], |
268 | concat1_output_name, new_nodes)); |
269 | |
270 | // Push the shape node. |
271 | new_nodes->push_back(concat_node_match.inputs[2].node); |
272 | |
273 | // Set the final output op name to batch_normal_node. |
274 | concat_node.set_name(batch_norm_node.name()); |
275 | concat_node.set_input(0, concat0_output_name); |
276 | concat_node.set_input(1, concat1_output_name); |
277 | new_nodes->push_back(concat_node); |
278 | return OkStatus(); |
279 | } |
280 | } // namespace |
281 | |
282 | // Finds monolithic batch norm ops (as used in early versions of TensorFlow) and |
283 | // converts them into premultiplied weight inputs to convolutions. |
284 | Status FoldOldBatchNorms(const GraphDef& input_graph_def, |
285 | const TransformFuncContext& context, |
286 | GraphDef* output_graph_def) { |
287 | GraphDef current_graph_def = input_graph_def; |
288 | // We have to do several passes to catch all the old BN nodes, since many of |
289 | // them may share inputs and so be excluded from replacement in one pass. |
290 | bool did_graph_change; |
291 | do { |
292 | did_graph_change = false; |
293 | GraphDef replaced_graph_def; |
294 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
295 | current_graph_def, // clang-format off |
296 | {"BatchNormWithGlobalNormalization|FusedBatchNorm" , // batch_norm_node |
297 | { |
298 | {"Conv2D|DepthwiseConv2dNative" , // conv_node |
299 | { |
300 | {"*" }, // input_node |
301 | {"Const" }, // weights_node |
302 | } |
303 | }, |
304 | {"Const" }, // mean_node |
305 | {"Const" }, // variance_node |
306 | {"Const" }, // beta_node |
307 | {"Const" }, // gamma_node |
308 | } |
309 | }, // clang-format on |
310 | [&did_graph_change](const NodeMatch& match, |
311 | const std::set<string>& input_nodes, |
312 | const std::set<string>& output_nodes, |
313 | std::vector<NodeDef>* new_nodes) { |
314 | TF_RETURN_IF_ERROR(FuseBatchNormWithConv(match, new_nodes)); |
315 | did_graph_change = true; |
316 | return OkStatus(); |
317 | }, |
318 | {}, &replaced_graph_def)); |
319 | current_graph_def = replaced_graph_def; |
320 | } while (did_graph_change); |
321 | |
322 | do { |
323 | did_graph_change = false; |
324 | GraphDef replaced_graph_def; |
325 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
326 | current_graph_def, // clang-format off |
327 | {"BatchNormWithGlobalNormalization|FusedBatchNorm" , // batch_norm_node |
328 | { |
329 | {"BatchToSpaceND" , // batch_to_space_node |
330 | { |
331 | {"Conv2D|DepthwiseConv2dNative" , // conv_node |
332 | { |
333 | {"*" }, // input_node |
334 | {"Const" }, // weights_node |
335 | } |
336 | }, |
337 | {"Const" }, // block_shape |
338 | {"Const" }, // crops |
339 | } |
340 | }, |
341 | {"Const" }, // mean_node |
342 | {"Const" }, // variance_node |
343 | {"Const" }, // beta_node |
344 | {"Const" }, // gamma_node |
345 | } |
346 | }, // clang-format on |
347 | [&did_graph_change](const NodeMatch& match, |
348 | const std::set<string>& input_nodes, |
349 | const std::set<string>& output_nodes, |
350 | std::vector<NodeDef>* new_nodes) { |
351 | TF_RETURN_IF_ERROR(FuseBatchNormWithBatchToSpace(match, new_nodes)); |
352 | did_graph_change = true; |
353 | return OkStatus(); |
354 | }, |
355 | {}, &replaced_graph_def)); |
356 | current_graph_def = replaced_graph_def; |
357 | } while (did_graph_change); |
358 | |
359 | do { |
360 | did_graph_change = false; |
361 | GraphDef replaced_graph_def; |
362 | // Replace BatchNorm with concat as input. |
363 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
364 | current_graph_def, // clang-format off |
365 | {"BatchNormWithGlobalNormalization|FusedBatchNorm" , // batch_norm_node |
366 | { |
367 | {"ConcatV2|Concat" , // concat two conv2d. |
368 | { |
369 | {"Conv2D|DepthwiseConv2dNative" , // conv_node |
370 | { |
371 | {"*" }, // input_node |
372 | {"Const" }, // weights_node |
373 | } |
374 | }, |
375 | {"Conv2D|DepthwiseConv2dNative" , // conv_node |
376 | { |
377 | {"*" }, // input_node |
378 | {"Const" }, // weights_node |
379 | } |
380 | }, |
381 | {"Const" }, // axis |
382 | }, |
383 | }, |
384 | {"Const" }, // mean_node |
385 | {"Const" }, // variance_node |
386 | {"Const" }, // beta_node |
387 | {"Const" }, // gamma_node |
388 | } |
389 | }, // clang-format on |
390 | [&did_graph_change](const NodeMatch& match, |
391 | const std::set<string>& input_nodes, |
392 | const std::set<string>& output_nodes, |
393 | std::vector<NodeDef>* new_nodes) { |
394 | TF_RETURN_IF_ERROR(FuseBatchNormWithConvConcat(match, new_nodes)); |
395 | did_graph_change = true; |
396 | return OkStatus(); |
397 | }, |
398 | {}, &replaced_graph_def)); |
399 | current_graph_def = replaced_graph_def; |
400 | } while (did_graph_change); |
401 | |
402 | *output_graph_def = current_graph_def; |
403 | return OkStatus(); |
404 | } |
405 | |
406 | REGISTER_GRAPH_TRANSFORM("fold_old_batch_norms" , FoldOldBatchNorms); |
407 | |
408 | } // namespace graph_transforms |
409 | } // namespace tensorflow |
410 | |