1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
17 | #include "tensorflow/core/graph/node_builder.h" |
18 | #include "tensorflow/core/graph/subgraph.h" |
19 | #include "tensorflow/core/platform/init_main.h" |
20 | #include "tensorflow/core/public/session.h" |
21 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
22 | |
23 | namespace tensorflow { |
24 | namespace graph_transforms { |
25 | |
26 | Status FlattenAtrousConv(const GraphDef& input_graph_def, |
27 | const TransformFuncContext& context, |
28 | GraphDef* output_graph_def) { |
29 | GraphDef replaced_graph_def; |
30 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
31 | input_graph_def, // clang-format off |
32 | {"BatchToSpaceND" , |
33 | { |
34 | {"Conv2D|DepthwiseConv2dNative" , |
35 | { |
36 | {"SpaceToBatchND" , |
37 | { |
38 | {"*" }, // Input to the flattened op. |
39 | {"*" }, // block_shape |
40 | {"*" } // paddings |
41 | } |
42 | }, |
43 | {"*" } // filter |
44 | } |
45 | }, |
46 | {"*" }, // block_shape |
47 | {"*" } // crops |
48 | } |
49 | }, // clang-format on |
50 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
51 | const std::set<string>& output_nodes, |
52 | std::vector<NodeDef>* new_nodes) { |
53 | // Find all the nodes we expect in the subgraph. |
54 | const NodeDef& batch_to_space_node = match.node; |
55 | const NodeDef& conv_node = match.inputs[0].node; |
56 | const NodeDef& filter_node = match.inputs[0].inputs[1].node; |
57 | const NodeDef& input_node = match.inputs[0].inputs[0].inputs[0].node; |
58 | const NodeDef& space_to_batch_block_shape_node = |
59 | match.inputs[0].inputs[0].inputs[1].node; |
60 | |
61 | // The atrous rate value is inferred from the block shape. |
62 | Tensor block_shape = |
63 | GetNodeTensorAttr(space_to_batch_block_shape_node, "value" ); |
64 | const int32_t block_height = block_shape.flat<int32>()(0); |
65 | const int32_t block_width = block_shape.flat<int32>()(1); |
66 | |
67 | // Compute the upsampled filter. |
68 | const Tensor& filter = GetNodeTensorAttr(filter_node, "value" ); |
69 | const int32_t filter_height = filter.dim_size(0); |
70 | const int32_t filter_width = filter.dim_size(1); |
71 | const int32_t in_channels = filter.dim_size(2); |
72 | const int32_t out_channels = filter.dim_size(3); |
73 | |
74 | const int32_t upsampled_filter_height = |
75 | (filter_height - 1) * block_height + 1; |
76 | const int32_t upsampled_filter_width = |
77 | (filter_width - 1) * block_width + 1; |
78 | Tensor upsampled_filter( |
79 | DT_FLOAT, |
80 | TensorShape({upsampled_filter_height, upsampled_filter_width, |
81 | in_channels, out_channels})); |
82 | |
83 | auto filter_eigen = filter.tensor<float, 4>(); |
84 | auto upsampled_filter_eigen = upsampled_filter.tensor<float, 4>(); |
85 | |
86 | upsampled_filter_eigen.setZero(); |
87 | for (int h = 0; h < filter_height; ++h) { |
88 | for (int w = 0; w < filter_width; ++w) { |
89 | for (int c_in = 0; c_in < in_channels; ++c_in) { |
90 | for (int c_out = 0; c_out < out_channels; ++c_out) { |
91 | upsampled_filter_eigen(block_height * h, block_width * w, c_in, |
92 | c_out) = filter_eigen(h, w, c_in, c_out); |
93 | } |
94 | } |
95 | } |
96 | } |
97 | |
98 | NodeDef upsampled_filter_node; |
99 | upsampled_filter_node.set_op("Const" ); |
100 | upsampled_filter_node.set_name(filter_node.name()); |
101 | SetNodeAttr("dtype" , DT_FLOAT, &upsampled_filter_node); |
102 | SetNodeTensorAttr<float>("value" , upsampled_filter, |
103 | &upsampled_filter_node); |
104 | |
105 | // Set up the new flattened version of the convolution op. |
106 | NodeDef flattened_conv_node; |
107 | |
108 | flattened_conv_node.set_name(batch_to_space_node.name()); |
109 | flattened_conv_node.set_op(conv_node.op()); |
110 | flattened_conv_node.set_device(conv_node.device()); |
111 | |
112 | AddNodeInput(input_node.name(), &flattened_conv_node); |
113 | AddNodeInput(upsampled_filter_node.name(), &flattened_conv_node); |
114 | |
115 | CopyNodeAttr(conv_node, "T" , "T" , &flattened_conv_node); |
116 | CopyNodeAttr(conv_node, "strides" , "strides" , &flattened_conv_node); |
117 | SetNodeAttr("padding" , "SAME" , &flattened_conv_node); |
118 | CopyNodeAttr(conv_node, "data_format" , "data_format" , |
119 | &flattened_conv_node); |
120 | |
121 | if (conv_node.op() == "Conv2D" ) { |
122 | CopyNodeAttr(conv_node, "use_cudnn_on_gpu" , "use_cudnn_on_gpu" , |
123 | &flattened_conv_node); |
124 | } |
125 | |
126 | new_nodes->push_back(input_node); |
127 | new_nodes->push_back(upsampled_filter_node); |
128 | new_nodes->push_back(flattened_conv_node); |
129 | |
130 | return OkStatus(); |
131 | }, |
132 | {}, &replaced_graph_def)); |
133 | *output_graph_def = replaced_graph_def; |
134 | return OkStatus(); |
135 | } |
136 | |
137 | REGISTER_GRAPH_TRANSFORM("flatten_atrous_conv" , FlattenAtrousConv); |
138 | |
139 | } // namespace graph_transforms |
140 | } // namespace tensorflow |
141 | |