1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/example/example_parser_configuration.h" |
16 | |
17 | #include <vector> |
18 | |
19 | #include "tensorflow/core/example/feature.pb.h" |
20 | #include "tensorflow/core/framework/attr_value.pb.h" |
21 | #include "tensorflow/core/framework/node_def.pb.h" |
22 | #include "tensorflow/core/framework/numeric_op.h" |
23 | #include "tensorflow/core/framework/register_types.h" |
24 | #include "tensorflow/core/framework/tensor.pb.h" |
25 | #include "tensorflow/core/lib/core/errors.h" |
26 | #include "tensorflow/core/lib/strings/strcat.h" |
27 | #include "tensorflow/core/platform/logging.h" |
28 | #include "tensorflow/core/platform/protobuf.h" |
29 | |
30 | namespace tensorflow { |
31 | |
32 | Status FindNodeIndexByName(const tensorflow::GraphDef& graph, |
33 | const string& node_name, int* node_idx) { |
34 | for (int i = 0; i < graph.node_size(); ++i) { |
35 | const auto& node = graph.node(i); |
36 | if (node.name() == node_name) { |
37 | *node_idx = i; |
38 | return OkStatus(); |
39 | } |
40 | } |
41 | return errors::InvalidArgument(node_name, " not found in GraphDef" ); |
42 | } |
43 | |
44 | Status ( |
45 | const tensorflow::GraphDef& graph, const string& node_name, |
46 | tensorflow::Session* session, |
47 | std::vector<FixedLenFeature>* fixed_len_features, |
48 | std::vector<VarLenFeature>* var_len_features) { |
49 | int node_idx; |
50 | TF_RETURN_IF_ERROR(FindNodeIndexByName(graph, node_name, &node_idx)); |
51 | |
52 | const auto& node = graph.node(node_idx); |
53 | if (node.op() != "ParseExample" ) { |
54 | return errors::InvalidArgument(node_name, " node is not a ParseExample op" ); |
55 | } |
56 | |
57 | auto& attr_map = node.attr(); |
58 | auto num_sparse = attr_map.at("Nsparse" ).i(); |
59 | auto num_dense = attr_map.at("Ndense" ).i(); |
60 | fixed_len_features->resize(num_dense); |
61 | var_len_features->resize(num_sparse); |
62 | |
63 | auto tdense = attr_map.at("Tdense" ); |
64 | auto dense_shapes = attr_map.at("dense_shapes" ); |
65 | auto sparse_types = attr_map.at("sparse_types" ); |
66 | |
67 | // Consistency check attributes. |
68 | if (tdense.list().type_size() != num_dense) { |
69 | return errors::InvalidArgument("Node attr Tdense has " , |
70 | tdense.list().type_size(), |
71 | " elements != Ndense attr: " , num_dense); |
72 | } |
73 | |
74 | if (dense_shapes.list().shape_size() != num_dense) { |
75 | return errors::InvalidArgument("Node attr dense_shapes has " , |
76 | dense_shapes.list().shape_size(), |
77 | " elements != Ndense attr: " , num_dense); |
78 | } |
79 | |
80 | if (sparse_types.list().type_size() != num_sparse) { |
81 | return errors::InvalidArgument("Node attr sparse_types has " , |
82 | sparse_types.list().type_size(), |
83 | " elements != NSparse attr: " , num_sparse); |
84 | } |
85 | |
86 | for (int i = 0; i < tdense.list().type_size(); ++i) { |
87 | (*fixed_len_features)[i].dtype = tdense.list().type(i); |
88 | // Convert TensorShapeProto to TensorShape. |
89 | (*fixed_len_features)[i].shape = TensorShape(dense_shapes.list().shape(i)); |
90 | } |
91 | |
92 | for (int i = 0; i < sparse_types.list().type_size(); ++i) { |
93 | (*var_len_features)[i].dtype = sparse_types.list().type(i); |
94 | } |
95 | |
96 | // We must fetch the configuration input tensors to the ParseExample op. |
97 | // Skipping index = 0, which is the serialized proto input. |
98 | std::vector<string> fetch_names(node.input_size() - 1); |
99 | for (int i = 1; i < node.input_size(); ++i) { |
100 | fetch_names[i - 1] = node.input(i); |
101 | } |
102 | |
103 | std::vector<Tensor> op_input_tensors; |
104 | |
105 | TF_RETURN_IF_ERROR(session->Run({}, // no_inputs, |
106 | fetch_names, {}, // no target_node_names, |
107 | &op_input_tensors)); |
108 | |
109 | // The input tensors are laid out sequentially in a flat manner. |
110 | // Here are the various start offsets. |
111 | int sparse_keys_start = 1; |
112 | int dense_keys_start = sparse_keys_start + num_sparse; |
113 | int dense_defaults_start = dense_keys_start + num_dense; |
114 | |
115 | for (int i = 0; i < num_sparse; ++i) { |
116 | int input_idx = sparse_keys_start + i; |
117 | (*var_len_features)[i].key = |
118 | op_input_tensors[input_idx].scalar<tstring>()(); |
119 | } |
120 | |
121 | for (int i = 0; i < num_dense; ++i) { |
122 | FixedLenFeature& config = (*fixed_len_features)[i]; |
123 | int dense_keys_offset = dense_keys_start + i; |
124 | config.key = op_input_tensors[dense_keys_offset].scalar<tstring>()(); |
125 | |
126 | int defaults_offset = dense_defaults_start + i; |
127 | config.default_value = op_input_tensors[defaults_offset]; |
128 | } |
129 | |
130 | // The output tensors are laid out sequentially in a flat manner. |
131 | // Here are the various start offsets. |
132 | int sparse_indices_output_start = 0; |
133 | int sparse_values_output_start = sparse_indices_output_start + num_sparse; |
134 | int sparse_shapes_output_start = sparse_values_output_start + num_sparse; |
135 | int dense_values_output_start = sparse_shapes_output_start + num_sparse; |
136 | |
137 | string node_output_prefix = strings::StrCat(node_name, ":" ); |
138 | |
139 | for (int i = 0; i < num_sparse; ++i) { |
140 | VarLenFeature& config = (*var_len_features)[i]; |
141 | |
142 | int indices_offset = sparse_indices_output_start + i; |
143 | config.indices_output_tensor_name = |
144 | strings::StrCat(node_output_prefix, indices_offset); |
145 | |
146 | int values_offset = sparse_values_output_start + i; |
147 | config.values_output_tensor_name = |
148 | strings::StrCat(node_output_prefix, values_offset); |
149 | |
150 | int shapes_offset = sparse_shapes_output_start + i; |
151 | config.shapes_output_tensor_name = |
152 | strings::StrCat(node_output_prefix, shapes_offset); |
153 | } |
154 | |
155 | for (int i = 0; i < num_dense; ++i) { |
156 | int output_idx = dense_values_output_start + i; |
157 | (*fixed_len_features)[i].values_output_tensor_name = |
158 | strings::StrCat(node_output_prefix, output_idx); |
159 | } |
160 | return OkStatus(); |
161 | } |
162 | |
163 | Status ExampleParserConfigurationProtoToFeatureVectors( |
164 | const ExampleParserConfiguration& config_proto, |
165 | std::vector<FixedLenFeature>* fixed_len_features, |
166 | std::vector<VarLenFeature>* var_len_features) { |
167 | const auto& feature_map = config_proto.feature_map(); |
168 | for (auto it = feature_map.cbegin(); it != feature_map.cend(); ++it) { |
169 | string key = it->first; |
170 | const auto& config = it->second; |
171 | if (config.has_fixed_len_feature()) { |
172 | const auto& fixed_config = config.fixed_len_feature(); |
173 | FixedLenFeature f; |
174 | f.key = key; |
175 | f.dtype = fixed_config.dtype(); |
176 | f.shape = TensorShape(fixed_config.shape()); |
177 | Tensor default_value(f.dtype, f.shape); |
178 | if (!default_value.FromProto(fixed_config.default_value())) { |
179 | return errors::InvalidArgument( |
180 | "Invalid default_value in config proto " , |
181 | fixed_config.default_value().DebugString()); |
182 | } |
183 | f.default_value = default_value; |
184 | f.values_output_tensor_name = fixed_config.values_output_tensor_name(); |
185 | fixed_len_features->push_back(f); |
186 | } else { |
187 | const auto& var_len_config = config.var_len_feature(); |
188 | VarLenFeature v; |
189 | v.key = key; |
190 | v.dtype = var_len_config.dtype(); |
191 | v.values_output_tensor_name = var_len_config.values_output_tensor_name(); |
192 | v.indices_output_tensor_name = |
193 | var_len_config.indices_output_tensor_name(); |
194 | v.shapes_output_tensor_name = var_len_config.shapes_output_tensor_name(); |
195 | var_len_features->push_back(v); |
196 | } |
197 | } |
198 | return OkStatus(); |
199 | } |
200 | |
201 | } // namespace tensorflow |
202 | |