1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/constant_folding.h"
17#include "tensorflow/core/common_runtime/graph_constructor.h"
18#include "tensorflow/core/graph/node_builder.h"
19#include "tensorflow/core/graph/subgraph.h"
20#include "tensorflow/core/platform/init_main.h"
21#include "tensorflow/core/public/session.h"
22#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
23#include "tensorflow/tools/graph_transforms/transform_utils.h"
24
25namespace tensorflow {
26namespace graph_transforms {
27
28Status FuseResizePadAndConv(const GraphDef& input_graph_def,
29 const TransformFuncContext& context,
30 GraphDef* output_graph_def) {
31 GraphDef replaced_graph_def;
32 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
33 input_graph_def, // clang-format off
34 {"Conv2D",
35 {
36 {"MirrorPad",
37 {
38 {"ResizeBilinear"},
39 {"*"}
40 }
41 },
42 {"*"}
43 }
44 }, // clang-format on
45 [](const NodeMatch& match, const std::set<string>& input_nodes,
46 const std::set<string>& output_nodes,
47 std::vector<NodeDef>* new_nodes) {
48 // Find all the nodes we expect in the subgraph.
49 const NodeDef& conv_node = match.node;
50 const NodeDef& mirror_pad_node = match.inputs[0].node;
51 const NodeDef& weights_node = match.inputs[1].node;
52 const NodeDef& resize_node = match.inputs[0].inputs[0].node;
53 const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node;
54
55 // We'll be reusing the old weights and pad dimensions.
56 new_nodes->push_back(weights_node);
57 new_nodes->push_back(pad_dims_node);
58
59 // Set up the new fused version of the convolution op.
60 NodeDef fused_conv;
61 fused_conv.set_op("FusedResizeAndPadConv2D");
62 fused_conv.set_name(match.node.name());
63 AddNodeInput(resize_node.input(0), &fused_conv);
64 AddNodeInput(resize_node.input(1), &fused_conv);
65 AddNodeInput(mirror_pad_node.input(1), &fused_conv);
66 AddNodeInput(conv_node.input(1), &fused_conv);
67 CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
68 &fused_conv);
69 CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv);
70 CopyNodeAttr(conv_node, "T", "T", &fused_conv);
71 CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
72 CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
73 new_nodes->push_back(fused_conv);
74
75 return OkStatus();
76 },
77 {}, &replaced_graph_def));
78 *output_graph_def = replaced_graph_def;
79 return OkStatus();
80}
81
82Status FuseResizeAndConv(const GraphDef& input_graph_def,
83 const TransformFuncContext& context,
84 GraphDef* output_graph_def) {
85 GraphDef replaced_graph_def;
86 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
87 input_graph_def, // clang-format off
88 {"Conv2D",
89 {
90 {"ResizeBilinear"},
91 {"*"}
92 }
93 }, // clang-format on
94 [](const NodeMatch& match, const std::set<string>& input_nodes,
95 const std::set<string>& output_nodes,
96 std::vector<NodeDef>* new_nodes) {
97 // Find all the nodes we expect in the subgraph.
98 const NodeDef& conv_node = match.node;
99 const NodeDef& resize_node = match.inputs[0].node;
100 const NodeDef& weights_node = match.inputs[1].node;
101
102 // We'll be reusing the old weights.
103 new_nodes->push_back(weights_node);
104
105 // Create a 'no-op' mirror padding node that has no effect.
106 NodeDef pad_dims_node;
107 pad_dims_node.set_op("Const");
108 pad_dims_node.set_name(conv_node.name() + "_dummy_paddings");
109 SetNodeAttr("dtype", DT_INT32, &pad_dims_node);
110 SetNodeTensorAttr<int32>("value", {4, 2}, {0, 0, 0, 0, 0, 0, 0, 0},
111 &pad_dims_node);
112 new_nodes->push_back(pad_dims_node);
113
114 // Set up the new fused version of the convolution op.
115 NodeDef fused_conv;
116 fused_conv.set_op("FusedResizeAndPadConv2D");
117 fused_conv.set_name(match.node.name());
118 AddNodeInput(resize_node.input(0), &fused_conv);
119 AddNodeInput(resize_node.input(1), &fused_conv);
120 AddNodeInput(pad_dims_node.name(), &fused_conv);
121 AddNodeInput(conv_node.input(1), &fused_conv);
122 CopyNodeAttr(resize_node, "align_corners", "resize_align_corners",
123 &fused_conv);
124 SetNodeAttr("mode", "REFLECT", &fused_conv);
125 CopyNodeAttr(conv_node, "T", "T", &fused_conv);
126 CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
127 CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
128 new_nodes->push_back(fused_conv);
129
130 return OkStatus();
131 },
132 {}, &replaced_graph_def));
133 *output_graph_def = replaced_graph_def;
134 return OkStatus();
135}
136
137Status FusePadAndConv(const GraphDef& input_graph_def,
138 const TransformFuncContext& context,
139 GraphDef* output_graph_def) {
140 GraphDef replaced_graph_def;
141 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
142 input_graph_def, // clang-format off
143 {"Conv2D",
144 {
145 {"MirrorPad",
146 {
147 {"*"},
148 {"*"},
149 }
150 },
151 {"*"}
152 }
153 }, // clang-format on
154 [](const NodeMatch& match, const std::set<string>& input_nodes,
155 const std::set<string>& output_nodes,
156 std::vector<NodeDef>* new_nodes) {
157 // Find all the nodes we expect in the subgraph.
158 const NodeDef& conv_node = match.node;
159 CHECK_EQ("Conv2D", conv_node.op());
160 const NodeDef& mirror_pad_node = match.inputs[0].node;
161 CHECK_EQ("MirrorPad", mirror_pad_node.op());
162 const NodeDef& weights_node = match.inputs[1].node;
163 const NodeDef& input_node = match.inputs[0].inputs[0].node;
164 const NodeDef& pad_dims_node = match.inputs[0].inputs[1].node;
165
166 // We'll be reusing the old weights and pad dimensions.
167 new_nodes->push_back(weights_node);
168 new_nodes->push_back(input_node);
169 new_nodes->push_back(pad_dims_node);
170
171 // Set up the new fused version of the convolution op.
172 NodeDef fused_conv;
173 fused_conv.set_op("FusedPadConv2D");
174 fused_conv.set_name(match.node.name());
175 AddNodeInput(mirror_pad_node.input(0), &fused_conv);
176 AddNodeInput(mirror_pad_node.input(1), &fused_conv);
177 AddNodeInput(conv_node.input(1), &fused_conv);
178 CopyNodeAttr(mirror_pad_node, "mode", "mode", &fused_conv);
179 CopyNodeAttr(conv_node, "T", "T", &fused_conv);
180 CopyNodeAttr(conv_node, "padding", "padding", &fused_conv);
181 CopyNodeAttr(conv_node, "strides", "strides", &fused_conv);
182 new_nodes->push_back(fused_conv);
183
184 return OkStatus();
185 },
186 {}, &replaced_graph_def));
187 *output_graph_def = replaced_graph_def;
188 return OkStatus();
189}
190
191REGISTER_GRAPH_TRANSFORM("fuse_resize_pad_and_conv", FuseResizePadAndConv);
192
193REGISTER_GRAPH_TRANSFORM("fuse_resize_and_conv", FuseResizeAndConv);
194
195REGISTER_GRAPH_TRANSFORM("fuse_pad_and_conv", FusePadAndConv);
196
197} // namespace graph_transforms
198} // namespace tensorflow
199