1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/constant_folding.h" |
17 | #include "tensorflow/core/common_runtime/graph_constructor.h" |
18 | #include "tensorflow/core/graph/node_builder.h" |
19 | #include "tensorflow/core/graph/subgraph.h" |
20 | #include "tensorflow/core/platform/init_main.h" |
21 | #include "tensorflow/core/public/session.h" |
22 | #include "tensorflow/tools/graph_transforms/fold_constants_lib.h" |
23 | #include "tensorflow/tools/graph_transforms/transform_utils.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace graph_transforms { |
27 | |
28 | // Switch any ConcatV2 nodes to the v1 version, swapping the input order. |
29 | Status BackportConcatV2Transform(const GraphDef& input_graph_def, |
30 | const TransformFuncContext& context, |
31 | GraphDef* output_graph_def) { |
32 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
33 | input_graph_def, {"ConcatV2" }, |
34 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
35 | const std::set<string>& output_nodes, |
36 | std::vector<NodeDef>* new_nodes) { |
37 | const NodeDef& concat_v2_node = match.node; |
38 | NodeDef concat_node = concat_v2_node; |
39 | concat_node.set_op("Concat" ); |
40 | // The last input is inserted at the head of the inputs, because Concat |
41 | // expects the dimension as the first input (not the last as in |
42 | // ConcatV2). |
43 | concat_node.mutable_input()->Clear(); |
44 | const string& dim_input = |
45 | concat_v2_node.input(concat_v2_node.input_size() - 1); |
46 | concat_node.add_input(dim_input); |
47 | for (int i = 0; i < (concat_v2_node.input_size() - 1); ++i) { |
48 | concat_node.add_input(concat_v2_node.input(i)); |
49 | } |
50 | // Tidx attribute must be deleted because it's not used in Concat. |
51 | concat_node.mutable_attr()->erase("Tidx" ); |
52 | new_nodes->push_back(concat_node); |
53 | return OkStatus(); |
54 | }, |
55 | {true}, output_graph_def)); |
56 | |
57 | return OkStatus(); |
58 | } |
59 | |
60 | REGISTER_GRAPH_TRANSFORM("backport_concatv2" , BackportConcatV2Transform); |
61 | |
62 | // Switch any TensorArrayV3 nodes to the v2 version, removing the second output. |
63 | Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def, |
64 | const TransformFuncContext& context, |
65 | GraphDef* output_graph_def) { |
66 | std::map<string, string> inputs_to_rename; |
67 | GraphDef replaced_graph_def; |
68 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
69 | input_graph_def, {"TensorArrayV3|TensorArrayGradV3" }, |
70 | [&inputs_to_rename](const NodeMatch& match, |
71 | const std::set<string>& input_nodes, |
72 | const std::set<string>& output_nodes, |
73 | std::vector<NodeDef>* new_nodes) { |
74 | const NodeDef& tensor_array_v3_node = match.node; |
75 | |
76 | // All we need to do here is rename the op type, since the attributes |
77 | // remain the same. |
78 | NodeDef tensor_array_v2_node = tensor_array_v3_node; |
79 | if (tensor_array_v3_node.op() == "TensorArrayV3" ) { |
80 | tensor_array_v2_node.set_op("TensorArrayV2" ); |
81 | } else { |
82 | tensor_array_v2_node.set_op("TensorArrayGradV2" ); |
83 | } |
84 | |
85 | // The v3 version has a second 'flow' output that's not present in v2, |
86 | // so substitute a dummy constant instead in any places that use it. |
87 | NodeDef replacement_flow_node; |
88 | replacement_flow_node.set_op("Const" ); |
89 | SetNodeAttr("dtype" , DT_FLOAT, &replacement_flow_node); |
90 | replacement_flow_node.set_name(tensor_array_v3_node.name() + |
91 | "/replacement_flow_node" ); |
92 | Tensor replacement_flow_tensor(DT_FLOAT, {}); |
93 | // I'm picking an arbitrary value for the gradient flow here, for lack |
94 | // of a better alternative. |
95 | replacement_flow_tensor.flat<float>()(0) = 1.0f; |
96 | SetNodeTensorAttr<float>("value" , replacement_flow_tensor, |
97 | &replacement_flow_node); |
98 | inputs_to_rename[tensor_array_v3_node.name() + ":1" ] = |
99 | replacement_flow_node.name(); |
100 | |
101 | new_nodes->push_back(tensor_array_v2_node); |
102 | new_nodes->push_back(replacement_flow_node); |
103 | return OkStatus(); |
104 | }, |
105 | {true}, &replaced_graph_def)); |
106 | // Update the graph so that any nodes that referred to removed inputs now |
107 | // pull from the substitute constants we've added. |
108 | GraphDef renamed_graph_def; |
109 | TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename, |
110 | std::unordered_set<string>(), |
111 | &renamed_graph_def)); |
112 | TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes( |
113 | renamed_graph_def, |
114 | {"TensorArrayWriteV3|TensorArrayReadV3|TensorArrayGatherV3|" |
115 | "TensorArrayScatterV3|TensorArrayConcatV3|TensorArraySplitV3|" |
116 | "TensorArraySizeV3|TensorArrayCloseV3" }, |
117 | [](const NodeMatch& match, const std::set<string>& input_nodes, |
118 | const std::set<string>& output_nodes, |
119 | std::vector<NodeDef>* new_nodes) { |
120 | const NodeDef& v3_node = match.node; |
121 | NodeDef v2_node = v3_node; |
122 | v2_node.set_op(v3_node.op().substr(0, v3_node.op().size() - 1) + "2" ); |
123 | new_nodes->push_back(v2_node); |
124 | return OkStatus(); |
125 | }, |
126 | {true}, output_graph_def)); |
127 | return OkStatus(); |
128 | } |
129 | |
130 | REGISTER_GRAPH_TRANSFORM("backport_tensor_array_v3" , |
131 | BackportTensorArrayV3Transform); |
132 | |
133 | } // namespace graph_transforms |
134 | } // namespace tensorflow |
135 | |