1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/example/example_parser_configuration.h"
16
17#include <vector>
18
19#include "tensorflow/core/example/feature.pb.h"
20#include "tensorflow/core/framework/attr_value.pb.h"
21#include "tensorflow/core/framework/node_def.pb.h"
22#include "tensorflow/core/framework/numeric_op.h"
23#include "tensorflow/core/framework/register_types.h"
24#include "tensorflow/core/framework/tensor.pb.h"
25#include "tensorflow/core/lib/core/errors.h"
26#include "tensorflow/core/lib/strings/strcat.h"
27#include "tensorflow/core/platform/logging.h"
28#include "tensorflow/core/platform/protobuf.h"
29
30namespace tensorflow {
31
32Status FindNodeIndexByName(const tensorflow::GraphDef& graph,
33 const string& node_name, int* node_idx) {
34 for (int i = 0; i < graph.node_size(); ++i) {
35 const auto& node = graph.node(i);
36 if (node.name() == node_name) {
37 *node_idx = i;
38 return OkStatus();
39 }
40 }
41 return errors::InvalidArgument(node_name, " not found in GraphDef");
42}
43
44Status ExtractExampleParserConfiguration(
45 const tensorflow::GraphDef& graph, const string& node_name,
46 tensorflow::Session* session,
47 std::vector<FixedLenFeature>* fixed_len_features,
48 std::vector<VarLenFeature>* var_len_features) {
49 int node_idx;
50 TF_RETURN_IF_ERROR(FindNodeIndexByName(graph, node_name, &node_idx));
51
52 const auto& node = graph.node(node_idx);
53 if (node.op() != "ParseExample") {
54 return errors::InvalidArgument(node_name, " node is not a ParseExample op");
55 }
56
57 auto& attr_map = node.attr();
58 auto num_sparse = attr_map.at("Nsparse").i();
59 auto num_dense = attr_map.at("Ndense").i();
60 fixed_len_features->resize(num_dense);
61 var_len_features->resize(num_sparse);
62
63 auto tdense = attr_map.at("Tdense");
64 auto dense_shapes = attr_map.at("dense_shapes");
65 auto sparse_types = attr_map.at("sparse_types");
66
67 // Consistency check attributes.
68 if (tdense.list().type_size() != num_dense) {
69 return errors::InvalidArgument("Node attr Tdense has ",
70 tdense.list().type_size(),
71 " elements != Ndense attr: ", num_dense);
72 }
73
74 if (dense_shapes.list().shape_size() != num_dense) {
75 return errors::InvalidArgument("Node attr dense_shapes has ",
76 dense_shapes.list().shape_size(),
77 " elements != Ndense attr: ", num_dense);
78 }
79
80 if (sparse_types.list().type_size() != num_sparse) {
81 return errors::InvalidArgument("Node attr sparse_types has ",
82 sparse_types.list().type_size(),
83 " elements != NSparse attr: ", num_sparse);
84 }
85
86 for (int i = 0; i < tdense.list().type_size(); ++i) {
87 (*fixed_len_features)[i].dtype = tdense.list().type(i);
88 // Convert TensorShapeProto to TensorShape.
89 (*fixed_len_features)[i].shape = TensorShape(dense_shapes.list().shape(i));
90 }
91
92 for (int i = 0; i < sparse_types.list().type_size(); ++i) {
93 (*var_len_features)[i].dtype = sparse_types.list().type(i);
94 }
95
96 // We must fetch the configuration input tensors to the ParseExample op.
97 // Skipping index = 0, which is the serialized proto input.
98 std::vector<string> fetch_names(node.input_size() - 1);
99 for (int i = 1; i < node.input_size(); ++i) {
100 fetch_names[i - 1] = node.input(i);
101 }
102
103 std::vector<Tensor> op_input_tensors;
104
105 TF_RETURN_IF_ERROR(session->Run({}, // no_inputs,
106 fetch_names, {}, // no target_node_names,
107 &op_input_tensors));
108
109 // The input tensors are laid out sequentially in a flat manner.
110 // Here are the various start offsets.
111 int sparse_keys_start = 1;
112 int dense_keys_start = sparse_keys_start + num_sparse;
113 int dense_defaults_start = dense_keys_start + num_dense;
114
115 for (int i = 0; i < num_sparse; ++i) {
116 int input_idx = sparse_keys_start + i;
117 (*var_len_features)[i].key =
118 op_input_tensors[input_idx].scalar<tstring>()();
119 }
120
121 for (int i = 0; i < num_dense; ++i) {
122 FixedLenFeature& config = (*fixed_len_features)[i];
123 int dense_keys_offset = dense_keys_start + i;
124 config.key = op_input_tensors[dense_keys_offset].scalar<tstring>()();
125
126 int defaults_offset = dense_defaults_start + i;
127 config.default_value = op_input_tensors[defaults_offset];
128 }
129
130 // The output tensors are laid out sequentially in a flat manner.
131 // Here are the various start offsets.
132 int sparse_indices_output_start = 0;
133 int sparse_values_output_start = sparse_indices_output_start + num_sparse;
134 int sparse_shapes_output_start = sparse_values_output_start + num_sparse;
135 int dense_values_output_start = sparse_shapes_output_start + num_sparse;
136
137 string node_output_prefix = strings::StrCat(node_name, ":");
138
139 for (int i = 0; i < num_sparse; ++i) {
140 VarLenFeature& config = (*var_len_features)[i];
141
142 int indices_offset = sparse_indices_output_start + i;
143 config.indices_output_tensor_name =
144 strings::StrCat(node_output_prefix, indices_offset);
145
146 int values_offset = sparse_values_output_start + i;
147 config.values_output_tensor_name =
148 strings::StrCat(node_output_prefix, values_offset);
149
150 int shapes_offset = sparse_shapes_output_start + i;
151 config.shapes_output_tensor_name =
152 strings::StrCat(node_output_prefix, shapes_offset);
153 }
154
155 for (int i = 0; i < num_dense; ++i) {
156 int output_idx = dense_values_output_start + i;
157 (*fixed_len_features)[i].values_output_tensor_name =
158 strings::StrCat(node_output_prefix, output_idx);
159 }
160 return OkStatus();
161}
162
163Status ExampleParserConfigurationProtoToFeatureVectors(
164 const ExampleParserConfiguration& config_proto,
165 std::vector<FixedLenFeature>* fixed_len_features,
166 std::vector<VarLenFeature>* var_len_features) {
167 const auto& feature_map = config_proto.feature_map();
168 for (auto it = feature_map.cbegin(); it != feature_map.cend(); ++it) {
169 string key = it->first;
170 const auto& config = it->second;
171 if (config.has_fixed_len_feature()) {
172 const auto& fixed_config = config.fixed_len_feature();
173 FixedLenFeature f;
174 f.key = key;
175 f.dtype = fixed_config.dtype();
176 f.shape = TensorShape(fixed_config.shape());
177 Tensor default_value(f.dtype, f.shape);
178 if (!default_value.FromProto(fixed_config.default_value())) {
179 return errors::InvalidArgument(
180 "Invalid default_value in config proto ",
181 fixed_config.default_value().DebugString());
182 }
183 f.default_value = default_value;
184 f.values_output_tensor_name = fixed_config.values_output_tensor_name();
185 fixed_len_features->push_back(f);
186 } else {
187 const auto& var_len_config = config.var_len_feature();
188 VarLenFeature v;
189 v.key = key;
190 v.dtype = var_len_config.dtype();
191 v.values_output_tensor_name = var_len_config.values_output_tensor_name();
192 v.indices_output_tensor_name =
193 var_len_config.indices_output_tensor_name();
194 v.shapes_output_tensor_name = var_len_config.shapes_output_tensor_name();
195 var_len_features->push_back(v);
196 }
197 }
198 return OkStatus();
199}
200
201} // namespace tensorflow
202