1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/constant_folding.h" |
17 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | #include "tensorflow/core/graph/subgraph.h" |
20 | #include "tensorflow/core/platform/init_main.h" |
21 | #include "tensorflow/core/public/session.h" |
22 | #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" |
23 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace graph_transforms { |
27 | |
28 | Status FuseResizePadAndConv(const GraphDef& input_graph_def, |
29 | const TransformFuncContext& context, |
30 | GraphDef* output_graph_def) { |
31 | GraphDef replaced_graph_def; |
32 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
33 | input_graph_def, // clang-format off |
34 | {"Conv2D" , |
35 | { |
36 | {"MirrorPad" , |
37 | { |
38 | {"ResizeBilinear" }, |
39 | {"*" } |
40 | } |
41 | }, |
42 | {"*" } |
43 | } |
44 | }, // clang-format on |
45 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
46 | const std::set<string>& output_nodes, |
47 | std::vector<NodeDef>* new_nodes) { |
48 | // Find all the nodes we expect in the subgraph. |
49 | const NodeDef& conv_node = match.node; |
50 | const NodeDef& mirror_pad_node = match.inputs[0].node; |
51 | const NodeDef& weights_node = match.inputs[1].node; |
52 | const NodeDef& resize_node = match.inputs[0].inputs[0].node; |
53 | const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node; |
54 | |
55 | // We'll be reusing the old weights and pad dimensions. |
56 | new_nodes->push_back(weights_node); |
57 | new_nodes->push_back(pad_dims_node); |
58 | |
59 | // Set up the new fused version of the convolution op. |
60 | NodeDef fused_conv; |
61 | fused_conv.set_op("FusedResizeAndPadConv2D" ); |
62 | fused_conv.set_name(match.node.name()); |
63 | AddNodeInput(resize_node.input(0), &fused_conv); |
64 | AddNodeInput(resize_node.input(1), &fused_conv); |
65 | AddNodeInput(mirror_pad_node.input(1), &fused_conv); |
66 | AddNodeInput(conv_node.input(1), &fused_conv); |
67 | CopyNodeAttr(resize_node, "align_corners" , "resize_align_corners" , |
68 | &fused_conv); |
69 | CopyNodeAttr(mirror_pad_node, "mode" , "mode" , &fused_conv); |
70 | CopyNodeAttr(conv_node, "T" , "T" , &fused_conv); |
71 | CopyNodeAttr(conv_node, "padding" , "padding" , &fused_conv); |
72 | CopyNodeAttr(conv_node, "strides" , "strides" , &fused_conv); |
73 | new_nodes->push_back(fused_conv); |
74 | |
75 | return OkStatus(); |
76 | }, |
77 | {}, &replaced_graph_def)); |
78 | *output_graph_def = replaced_graph_def; |
79 | return OkStatus(); |
80 | } |
81 | |
82 | Status FuseResizeAndConv(const GraphDef& input_graph_def, |
83 | const TransformFuncContext& context, |
84 | GraphDef* output_graph_def) { |
85 | GraphDef replaced_graph_def; |
86 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
87 | input_graph_def, // clang-format off |
88 | {"Conv2D" , |
89 | { |
90 | {"ResizeBilinear" }, |
91 | {"*" } |
92 | } |
93 | }, // clang-format on |
94 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
95 | const std::set<string>& output_nodes, |
96 | std::vector<NodeDef>* new_nodes) { |
97 | // Find all the nodes we expect in the subgraph. |
98 | const NodeDef& conv_node = match.node; |
99 | const NodeDef& resize_node = match.inputs[0].node; |
100 | const NodeDef& weights_node = match.inputs[1].node; |
101 | |
102 | // We'll be reusing the old weights. |
103 | new_nodes->push_back(weights_node); |
104 | |
105 | // Create a 'no-op' mirror padding node that has no effect. |
106 | NodeDef pad_dims_node; |
107 | pad_dims_node.set_op("Const" ); |
108 | pad_dims_node.set_name(conv_node.name() + "_dummy_paddings" ); |
109 | SetNodeAttr("dtype" , DT_INT32, &pad_dims_node); |
110 | SetNodeTensorAttr<int32>("value" , {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0}, |
111 | &pad_dims_node); |
112 | new_nodes->push_back(pad_dims_node); |
113 | |
114 | // Set up the new fused version of the convolution op. |
115 | NodeDef fused_conv; |
116 | fused_conv.set_op("FusedResizeAndPadConv2D" ); |
117 | fused_conv.set_name(match.node.name()); |
118 | AddNodeInput(resize_node.input(0), &fused_conv); |
119 | AddNodeInput(resize_node.input(1), &fused_conv); |
120 | AddNodeInput(pad_dims_node.name(), &fused_conv); |
121 | AddNodeInput(conv_node.input(1), &fused_conv); |
122 | CopyNodeAttr(resize_node, "align_corners" , "resize_align_corners" , |
123 | &fused_conv); |
124 | SetNodeAttr("mode" , "REFLECT" , &fused_conv); |
125 | CopyNodeAttr(conv_node, "T" , "T" , &fused_conv); |
126 | CopyNodeAttr(conv_node, "padding" , "padding" , &fused_conv); |
127 | CopyNodeAttr(conv_node, "strides" , "strides" , &fused_conv); |
128 | new_nodes->push_back(fused_conv); |
129 | |
130 | return OkStatus(); |
131 | }, |
132 | {}, &replaced_graph_def)); |
133 | *output_graph_def = replaced_graph_def; |
134 | return OkStatus(); |
135 | } |
136 | |
137 | Status FusePadAndConv(const GraphDef& input_graph_def, |
138 | const TransformFuncContext& context, |
139 | GraphDef* output_graph_def) { |
140 | GraphDef replaced_graph_def; |
141 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
142 | input_graph_def, // clang-format off |
143 | {"Conv2D" , |
144 | { |
145 | {"MirrorPad" , |
146 | { |
147 | {"*" }, |
148 | {"*" }, |
149 | } |
150 | }, |
151 | {"*" } |
152 | } |
153 | }, // clang-format on |
154 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
155 | const std::set<string>& output_nodes, |
156 | std::vector<NodeDef>* new_nodes) { |
157 | // Find all the nodes we expect in the subgraph. |
158 | const NodeDef& conv_node = match.node; |
159 | CHECK_EQ("Conv2D" , conv_node.op()); |
160 | const NodeDef& mirror_pad_node = match.inputs[0].node; |
161 | CHECK_EQ("MirrorPad" , mirror_pad_node.op()); |
162 | const NodeDef& weights_node = match.inputs[1].node; |
163 | const NodeDef& input_node = match.inputs[0].inputs[0].node; |
164 | const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node; |
165 | |
166 | // We'll be reusing the old weights and pad dimensions. |
167 | new_nodes->push_back(weights_node); |
168 | new_nodes->push_back(input_node); |
169 | new_nodes->push_back(pad_dims_node); |
170 | |
171 | // Set up the new fused version of the convolution op. |
172 | NodeDef fused_conv; |
173 | fused_conv.set_op("FusedPadConv2D" ); |
174 | fused_conv.set_name(match.node.name()); |
175 | AddNodeInput(mirror_pad_node.input(0), &fused_conv); |
176 | AddNodeInput(mirror_pad_node.input(1), &fused_conv); |
177 | AddNodeInput(conv_node.input(1), &fused_conv); |
178 | CopyNodeAttr(mirror_pad_node, "mode" , "mode" , &fused_conv); |
179 | CopyNodeAttr(conv_node, "T" , "T" , &fused_conv); |
180 | CopyNodeAttr(conv_node, "padding" , "padding" , &fused_conv); |
181 | CopyNodeAttr(conv_node, "strides" , "strides" , &fused_conv); |
182 | new_nodes->push_back(fused_conv); |
183 | |
184 | return OkStatus(); |
185 | }, |
186 | {}, &replaced_graph_def)); |
187 | *output_graph_def = replaced_graph_def; |
188 | return OkStatus(); |
189 | } |
190 | |
191 | REGISTER_GRAPH_TRANSFORM("fuse_resize_pad_and_conv" , FuseResizePadAndConv); |
192 | |
193 | REGISTER_GRAPH_TRANSFORM("fuse_resize_and_conv" , FuseResizeAndConv); |
194 | |
195 | REGISTER_GRAPH_TRANSFORM("fuse_pad_and_conv" , FusePadAndConv); |
196 | |
197 | } // namespace graph_transforms |
198 | } // namespace tensorflow |
199 | |