1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/constant_folding.h" |
17 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | #include "tensorflow/core/graph/subgraph.h" |
20 | #include "tensorflow/core/platform/init_main.h" |
21 | #include "tensorflow/core/public/session.h" |
22 | #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" |
23 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace graph_transforms { |
27 | |
28 | // Converts Conv2D or MatMul ops followed by column-wise Muls into equivalent |
29 | // ops with the Mul baked into the convolution weights, to save computation |
30 | // during inference. |
31 | Status FoldBatchNorms(const GraphDef& input_graph_def, |
32 | const TransformFuncContext& context, |
33 | GraphDef* output_graph_def) { |
34 | GraphDef replaced_graph_def; |
35 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
36 | input_graph_def, // clang-format off |
37 | {"Mul" , // mul_node |
38 | { |
39 | {"Conv2D|MatMul|DepthwiseConv2dNative" , // conv_node |
40 | { |
41 | {"*" }, // input_node |
42 | {"Const" }, // weights_node |
43 | } |
44 | }, |
45 | {"Const" }, // mul_values_node |
46 | } |
47 | }, // clang-format on |
48 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
49 | const std::set<string>& output_nodes, |
50 | std::vector<NodeDef>* new_nodes) { |
51 | // Find all the nodes we expect in the subgraph. |
52 | const NodeDef& mul_node = match.node; |
53 | const NodeDef& conv_node = match.inputs[0].node; |
54 | const NodeDef& input_node = match.inputs[0].inputs[0].node; |
55 | const NodeDef& weights_node = match.inputs[0].inputs[1].node; |
56 | const NodeDef& mul_values_node = match.inputs[1].node; |
57 | |
58 | // Check that nodes that we use are not used somewhere else. |
59 | for (const auto& node : {conv_node, weights_node, mul_values_node}) { |
60 | if (output_nodes.count(node.name())) { |
61 | // Return original nodes. |
62 | new_nodes->insert(new_nodes->end(), |
63 | {mul_node, conv_node, input_node, weights_node, |
64 | mul_values_node}); |
65 | return OkStatus(); |
66 | } |
67 | } |
68 | |
69 | Tensor weights = GetNodeTensorAttr(weights_node, "value" ); |
70 | Tensor mul_values = GetNodeTensorAttr(mul_values_node, "value" ); |
71 | |
72 | // Make sure all the inputs really are vectors, with as many entries as |
73 | // there are columns in the weights. |
74 | int64_t weights_cols; |
75 | if (conv_node.op() == "Conv2D" ) { |
76 | weights_cols = weights.shape().dim_size(3); |
77 | } else if (conv_node.op() == "DepthwiseConv2dNative" ) { |
78 | weights_cols = |
79 | weights.shape().dim_size(2) * weights.shape().dim_size(3); |
80 | } else { |
81 | weights_cols = weights.shape().dim_size(1); |
82 | } |
83 | if ((mul_values.shape().dims() != 1) || |
84 | (mul_values.shape().dim_size(0) != weights_cols)) { |
85 | return errors::InvalidArgument( |
86 | "Mul constant input to batch norm has bad shape: " , |
87 | mul_values.shape().DebugString()); |
88 | } |
89 | |
90 | // Multiply the original weights by the scale vector. |
91 | auto weights_vector = weights.flat<float>(); |
92 | Tensor scaled_weights(DT_FLOAT, weights.shape()); |
93 | auto scaled_weights_vector = scaled_weights.flat<float>(); |
94 | for (int64_t row = 0; row < weights_vector.dimension(0); ++row) { |
95 | scaled_weights_vector(row) = |
96 | weights_vector(row) * |
97 | mul_values.flat<float>()(row % weights_cols); |
98 | } |
99 | |
100 | // Construct the new nodes. |
101 | NodeDef scaled_weights_node; |
102 | scaled_weights_node.set_op("Const" ); |
103 | scaled_weights_node.set_name(weights_node.name()); |
104 | SetNodeAttr("dtype" , DT_FLOAT, &scaled_weights_node); |
105 | SetNodeTensorAttr<float>("value" , scaled_weights, &scaled_weights_node); |
106 | new_nodes->push_back(scaled_weights_node); |
107 | |
108 | new_nodes->push_back(input_node); |
109 | |
110 | NodeDef new_conv_node; |
111 | new_conv_node = conv_node; |
112 | new_conv_node.set_name(mul_node.name()); |
113 | new_nodes->push_back(new_conv_node); |
114 | |
115 | return OkStatus(); |
116 | }, |
117 | {}, &replaced_graph_def)); |
118 | *output_graph_def = replaced_graph_def; |
119 | return OkStatus(); |
120 | } |
121 | |
122 | REGISTER_GRAPH_TRANSFORM("fold_batch_norms" , FoldBatchNorms); |
123 | |
124 | } // namespace graph_transforms |
125 | } // namespace tensorflow |
126 | |