1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/common_runtime/constant_folding.h"
17#include "tensorflow/core/common_runtime/graph_constructor.h"
18#include "tensorflow/core/graph/node_builder.h"
19#include "tensorflow/core/graph/subgraph.h"
20#include "tensorflow/core/platform/init_main.h"
21#include "tensorflow/core/public/session.h"
22#include "tensorflow/tools/graph_transforms/fold_constants_lib.h"
23#include "tensorflow/tools/graph_transforms/transform_utils.h"
24
25namespace tensorflow {
26namespace graph_transforms {
27
28// Switch any ConcatV2 nodes to the v1 version, swapping the input order.
29Status BackportConcatV2Transform(const GraphDef& input_graph_def,
30 const TransformFuncContext& context,
31 GraphDef* output_graph_def) {
32 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
33 input_graph_def, {"ConcatV2"},
34 [](const NodeMatch& match, const std::set<string>& input_nodes,
35 const std::set<string>& output_nodes,
36 std::vector<NodeDef>* new_nodes) {
37 const NodeDef& concat_v2_node = match.node;
38 NodeDef concat_node = concat_v2_node;
39 concat_node.set_op("Concat");
40 // The last input is inserted at the head of the inputs, because Concat
41 // expects the dimension as the first input (not the last as in
42 // ConcatV2).
43 concat_node.mutable_input()->Clear();
44 const string& dim_input =
45 concat_v2_node.input(concat_v2_node.input_size() - 1);
46 concat_node.add_input(dim_input);
47 for (int i = 0; i < (concat_v2_node.input_size() - 1); ++i) {
48 concat_node.add_input(concat_v2_node.input(i));
49 }
50 // Tidx attribute must be deleted because it's not used in Concat.
51 concat_node.mutable_attr()->erase("Tidx");
52 new_nodes->push_back(concat_node);
53 return OkStatus();
54 },
55 {true}, output_graph_def));
56
57 return OkStatus();
58}
59
60REGISTER_GRAPH_TRANSFORM("backport_concatv2", BackportConcatV2Transform);
61
62// Switch any TensorArrayV3 nodes to the v2 version, removing the second output.
63Status BackportTensorArrayV3Transform(const GraphDef& input_graph_def,
64 const TransformFuncContext& context,
65 GraphDef* output_graph_def) {
66 std::map<string, string> inputs_to_rename;
67 GraphDef replaced_graph_def;
68 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
69 input_graph_def, {"TensorArrayV3|TensorArrayGradV3"},
70 [&inputs_to_rename](const NodeMatch& match,
71 const std::set<string>& input_nodes,
72 const std::set<string>& output_nodes,
73 std::vector<NodeDef>* new_nodes) {
74 const NodeDef& tensor_array_v3_node = match.node;
75
76 // All we need to do here is rename the op type, since the attributes
77 // remain the same.
78 NodeDef tensor_array_v2_node = tensor_array_v3_node;
79 if (tensor_array_v3_node.op() == "TensorArrayV3") {
80 tensor_array_v2_node.set_op("TensorArrayV2");
81 } else {
82 tensor_array_v2_node.set_op("TensorArrayGradV2");
83 }
84
85 // The v3 version has a second 'flow' output that's not present in v2,
86 // so substitute a dummy constant instead in any places that use it.
87 NodeDef replacement_flow_node;
88 replacement_flow_node.set_op("Const");
89 SetNodeAttr("dtype", DT_FLOAT, &replacement_flow_node);
90 replacement_flow_node.set_name(tensor_array_v3_node.name() +
91 "/replacement_flow_node");
92 Tensor replacement_flow_tensor(DT_FLOAT, {});
93 // I'm picking an arbitrary value for the gradient flow here, for lack
94 // of a better alternative.
95 replacement_flow_tensor.flat<float>()(0) = 1.0f;
96 SetNodeTensorAttr<float>("value", replacement_flow_tensor,
97 &replacement_flow_node);
98 inputs_to_rename[tensor_array_v3_node.name() + ":1"] =
99 replacement_flow_node.name();
100
101 new_nodes->push_back(tensor_array_v2_node);
102 new_nodes->push_back(replacement_flow_node);
103 return OkStatus();
104 },
105 {true}, &replaced_graph_def));
106 // Update the graph so that any nodes that referred to removed inputs now
107 // pull from the substitute constants we've added.
108 GraphDef renamed_graph_def;
109 TF_RETURN_IF_ERROR(RenameNodeInputs(replaced_graph_def, inputs_to_rename,
110 std::unordered_set<string>(),
111 &renamed_graph_def));
112 TF_RETURN_IF_ERROR(ReplaceMatchingOpTypes(
113 renamed_graph_def,
114 {"TensorArrayWriteV3|TensorArrayReadV3|TensorArrayGatherV3|"
115 "TensorArrayScatterV3|TensorArrayConcatV3|TensorArraySplitV3|"
116 "TensorArraySizeV3|TensorArrayCloseV3"},
117 [](const NodeMatch& match, const std::set<string>& input_nodes,
118 const std::set<string>& output_nodes,
119 std::vector<NodeDef>* new_nodes) {
120 const NodeDef& v3_node = match.node;
121 NodeDef v2_node = v3_node;
122 v2_node.set_op(v3_node.op().substr(0, v3_node.op().size() - 1) + "2");
123 new_nodes->push_back(v2_node);
124 return OkStatus();
125 },
126 {true}, output_graph_def));
127 return OkStatus();
128}
129
130REGISTER_GRAPH_TRANSFORM("backport_tensor_array_v3",
131 BackportTensorArrayV3Transform);
132
133} // namespace graph_transforms
134} // namespace tensorflow
135