1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ |
17 | #define TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ |
18 | |
19 | #include <string> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/example/example.pb.h" |
24 | #include "tensorflow/core/example/feature.pb.h" |
25 | #include "tensorflow/core/framework/allocator.h" |
26 | #include "tensorflow/core/framework/graph.pb.h" |
27 | #include "tensorflow/core/framework/partial_tensor_shape.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/lib/core/errors.h" |
31 | #include "tensorflow/core/platform/types.h" |
32 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
33 | |
34 | // This is a set of helper methods that will make it possible to share |
35 | // tensorflow::Example proto Tensor conversion code inside the ExampleParserOp |
36 | // OpKernel as well as in external code. |
37 | namespace tensorflow { |
38 | |
39 | // "Dense" feature configuration. |
40 | struct FixedLenFeature { |
41 | string key; |
42 | DataType dtype; |
43 | TensorShape shape; |
44 | Tensor default_value; |
45 | string values_output_tensor_name; |
46 | }; |
47 | |
48 | // "Sparse" feature configuration. |
49 | struct VarLenFeature { |
50 | string key; |
51 | DataType dtype; |
52 | string values_output_tensor_name; |
53 | string indices_output_tensor_name; |
54 | string shapes_output_tensor_name; |
55 | }; |
56 | |
57 | // Given a single tensorflow::Example, with an optional example name |
58 | // at a particular index within a batch, and dense and sparse feature |
59 | // configurations from fixed_len_features, var_len_features, this method |
60 | // updates the dense value tensor and the sparse values temporary vector |
61 | // of tensors. The indexing of the output vectors correspond 1:1 to the |
62 | // indexing of the feature configuration vectors. |
63 | // |
64 | // The fixed_len_features and var_len_features maps are assume to be |
65 | // have disjoint key fields from the Feature map in the tensorflow.Example |
66 | // proto. |
67 | // |
68 | // For each sparse feature, the sparse values temporary vector holds a |
69 | // tensor for each Example. Each tensor is either empty or filled, depending |
70 | // on if the sparse feature value is set for the Example. This |
71 | // temporary structure is needed because we need to know the total number |
72 | // of filled elements in the batch to get the proper final sparse tensor |
73 | // shapes allocated. After the entire batch is processed, |
74 | // GetSparseTensorShape can be used to calculate the final shapes and |
75 | // CopyIntoSparseTensor can be used to copy from the temporary vector |
76 | // into the final allocated tensors. |
77 | Status SingleExampleProtoToTensors( |
78 | const Example& example, const string& name, const int batch_index, |
79 | const std::vector<FixedLenFeature>& fixed_len_features, |
80 | const std::vector<VarLenFeature>& var_len_features, |
81 | std::vector<Tensor*>* dense_values, |
82 | std::vector<std::vector<Tensor>>* sparse_values_temporary_vector); |
83 | |
84 | // The shape of the indices and values tensors associated with a SparseTensor |
85 | // are dependent on the contents of the batch. |
86 | struct VarLenFeatureBatchShapes { |
87 | TensorShape indices_shape; |
88 | TensorShape values_shape; |
89 | int max_num_features; |
90 | }; |
91 | |
92 | // Get the shape of the sparse values and indices tensors for the batch, |
93 | // given how many of the tensors in the temporary sparse values vector |
94 | // are actually filled. |
95 | Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, |
96 | const std::vector<Tensor>& sparse_values_tmp, |
97 | const int batch_size, |
98 | VarLenFeatureBatchShapes* output_shapes); |
99 | |
100 | // A method to convert a batch of tensorflow::Example protos into output |
101 | // tensors. This method is useful if there already is a batch of deserialized |
102 | // Example protos in memory (such as a serving use-case) and we do not wish |
103 | // to incur an extraneous serialize/deserialize. It is intended |
104 | // as an outside of OpKernel compatible replacement for the functionality of |
105 | // ExampleParserOp. In a serving setting, this method could be used to produce |
106 | // a feed_dict of Tensors that could bypass the ExampleParserOp. |
107 | // |
108 | // Note that unlike SingleExampleProtoToTensors, output tensors are |
109 | // allocated using a provided Allocator within this method. |
110 | Status BatchExampleProtoToTensors( |
111 | const std::vector<const Example*>& examples, |
112 | const std::vector<string>& names, |
113 | const std::vector<FixedLenFeature>& fixed_len_features, |
114 | const std::vector<VarLenFeature>& var_len_features, Allocator* allocator, |
115 | std::vector<Tensor>* output_dense_values_tensor, |
116 | std::vector<Tensor>* output_sparse_indices_tensor, |
117 | std::vector<Tensor>* output_sparse_values_tensor, |
118 | std::vector<Tensor>* output_sparse_shapes_tensor); |
119 | |
120 | // Check that the given dtype is one that is compatible with |
121 | // tensorflow::Example protocol buffer feature values. |
122 | Status CheckValidType(const DataType& dtype); |
123 | |
124 | // Check that the provided Feature proto message's oneof value |
125 | // matches that of the provided dtype. |
126 | Status CheckTypesMatch(const Feature& feature, const DataType& dtype, |
127 | bool* match); |
128 | |
129 | // For a single Example, copy a dense feature value into an output |
130 | // dense value tensor Out at the provided out_index offset. |
131 | Status FeatureDenseCopy(const std::size_t out_index, const string& name, |
132 | const string& key, const DataType& dtype, |
133 | const TensorShape& shape, const Feature& feature, |
134 | Tensor* out); |
135 | |
136 | // Copy the value a provided Tensor into an output dense_value tensor Out |
137 | // at the provided out_index offset. |
138 | void RowDenseCopy(const std::size_t& out_index, const DataType& dtype, |
139 | const Tensor& in, Tensor* out); |
140 | |
141 | // For a single Example, and given sparse feature return a temporary output |
142 | // Tensor suitable for being collected in the temporary sparse value vector. |
143 | Tensor FeatureSparseCopy(const std::size_t batch, const string& key, |
144 | const DataType& dtype, const Feature& feature); |
145 | |
146 | // Copy a temporary Tensor into the final sparse indices and values |
147 | // tensor at a given batch index and element offset. This method |
148 | // assumes that the indices/values Tensors have been properly allocated |
149 | // for the batch. |
150 | int64_t CopyIntoSparseTensor(const Tensor& in, const int batch, |
151 | const int64_t offset, Tensor* indices, |
152 | Tensor* values); |
153 | |
154 | // Check that each dense_shape has known rank and inner dimensions; and |
155 | // update variable_length (whether the outer dimension is None) and |
156 | // elements_per_stride for each denes_shape. |
157 | Status GetDenseShapes(const std::vector<PartialTensorShape>& dense_shapes, |
158 | std::vector<bool>* variable_length, |
159 | std::vector<std::size_t>* elements_per_stride); |
160 | |
161 | // Parses the attributes passed to ParseExample. |
162 | // REQUIRES: Init must be called after construction. |
163 | struct ParseExampleAttrs { |
164 | public: |
165 | template <typename ContextType> |
166 | Status Init(ContextType* ctx, int op_version = 1) { |
167 | TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_types" , &sparse_types)); |
168 | TF_RETURN_IF_ERROR(ctx->GetAttr("Tdense" , &dense_types)); |
169 | TF_RETURN_IF_ERROR(ctx->GetAttr("dense_shapes" , &dense_shapes)); |
170 | TF_RETURN_IF_ERROR( |
171 | GetDenseShapes(dense_shapes, &variable_length, &elements_per_stride)); |
172 | switch (op_version) { |
173 | case 1: |
174 | TF_RETURN_IF_ERROR(ctx->GetAttr("Nsparse" , &num_sparse)); |
175 | TF_RETURN_IF_ERROR(ctx->GetAttr("Ndense" , &num_dense)); |
176 | break; |
177 | case 2: |
178 | TF_RETURN_IF_ERROR( |
179 | ctx->GetAttr("ragged_value_types" , &ragged_value_types)); |
180 | TF_RETURN_IF_ERROR(ctx->GetAttr("num_sparse" , &num_sparse)); |
181 | TF_RETURN_IF_ERROR( |
182 | ctx->GetAttr("ragged_split_types" , &ragged_split_types)); |
183 | break; |
184 | default: |
185 | return errors::InvalidArgument("Unexpected op_version" , op_version); |
186 | } |
187 | return FinishInit(op_version); |
188 | } |
189 | |
190 | int64_t num_sparse; |
191 | int64_t num_dense; |
192 | int64_t num_ragged; |
193 | std::vector<DataType> sparse_types; |
194 | std::vector<DataType> dense_types; |
195 | std::vector<DataType> ragged_value_types; |
196 | std::vector<DataType> ragged_split_types; |
197 | std::vector<PartialTensorShape> dense_shapes; |
198 | std::vector<bool> variable_length; |
199 | std::vector<std::size_t> elements_per_stride; |
200 | |
201 | private: |
202 | Status FinishInit(int op_version); // for context-independent parts of Init. |
203 | }; |
204 | |
205 | // Parses the attributes passed to ParseSingleExample. |
206 | // REQUIRES: Init must be called after construction. |
207 | struct ParseSingleExampleAttrs { |
208 | public: |
209 | template <typename ContextType> |
210 | Status Init(ContextType* ctx) { |
211 | TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_keys" , &sparse_keys)); |
212 | TF_RETURN_IF_ERROR(ctx->GetAttr("sparse_types" , &sparse_types)); |
213 | TF_RETURN_IF_ERROR(ctx->GetAttr("dense_keys" , &dense_keys)); |
214 | TF_RETURN_IF_ERROR(ctx->GetAttr("Tdense" , &dense_types)); |
215 | TF_RETURN_IF_ERROR(ctx->GetAttr("dense_shapes" , &dense_shapes)); |
216 | |
217 | int num_sparse; |
218 | TF_RETURN_IF_ERROR(ctx->GetAttr("num_sparse" , &num_sparse)); |
219 | if (num_sparse != sparse_keys.size() || num_sparse != sparse_types.size()) { |
220 | return errors::InvalidArgument( |
221 | "num_sparse (" , num_sparse, ") must match the size of sparse_keys (" , |
222 | sparse_keys.size(), ") and sparse_types (" , sparse_types.size(), ")" ); |
223 | } |
224 | |
225 | TF_RETURN_IF_ERROR( |
226 | GetDenseShapes(dense_shapes, &variable_length, &elements_per_stride)); |
227 | return FinishInit(); |
228 | } |
229 | |
230 | std::vector<tstring> sparse_keys; |
231 | std::vector<DataType> sparse_types; |
232 | std::vector<tstring> dense_keys; |
233 | std::vector<DataType> dense_types; |
234 | std::vector<PartialTensorShape> dense_shapes; |
235 | std::vector<bool> variable_length; |
236 | std::vector<std::size_t> elements_per_stride; |
237 | |
238 | private: |
239 | Status FinishInit(); // for context-independent parts of Init. |
240 | }; |
241 | |
242 | // Parses the attributes passed to ParseSequenceExample. |
243 | // REQUIRES: Init must be called after construction. |
244 | struct ParseSequenceExampleAttrs { |
245 | public: |
246 | template <typename ContextType> |
247 | Status Init(ContextType* ctx, int op_version = 1) { |
248 | switch (op_version) { |
249 | case 1: { |
250 | std::vector<string> missing_empty_vector; |
251 | TF_RETURN_IF_ERROR(ctx->GetAttr( |
252 | "feature_list_dense_missing_assumed_empty" , &missing_empty_vector)); |
253 | for (const string& feature : missing_empty_vector) { |
254 | feature_list_dense_missing_assumed_empty.insert(feature); |
255 | } |
256 | } |
257 | TF_RETURN_IF_ERROR( |
258 | ctx->GetAttr("context_sparse_keys" , &context_sparse_keys)); |
259 | TF_RETURN_IF_ERROR( |
260 | ctx->GetAttr("context_dense_keys" , &context_dense_keys)); |
261 | TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_sparse_keys" , |
262 | &feature_list_sparse_keys)); |
263 | TF_RETURN_IF_ERROR( |
264 | ctx->GetAttr("feature_list_dense_keys" , &feature_list_dense_keys)); |
265 | TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense" , &num_context_dense)); |
266 | break; |
267 | case 2: |
268 | TF_RETURN_IF_ERROR(ctx->GetAttr("context_ragged_value_types" , |
269 | &context_ragged_value_types)); |
270 | TF_RETURN_IF_ERROR(ctx->GetAttr("context_ragged_split_types" , |
271 | &context_ragged_split_types)); |
272 | TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_ragged_value_types" , |
273 | &feature_list_ragged_value_types)); |
274 | TF_RETURN_IF_ERROR(ctx->GetAttr("feature_list_ragged_split_types" , |
275 | &feature_list_ragged_split_types)); |
276 | break; |
277 | default: |
278 | return errors::InvalidArgument("Unexpected op_version" , op_version); |
279 | } |
280 | TF_RETURN_IF_ERROR( |
281 | ctx->GetAttr("context_sparse_types" , &context_sparse_types)); |
282 | TF_RETURN_IF_ERROR( |
283 | ctx->GetAttr("Nfeature_list_dense" , &num_feature_list_dense)); |
284 | TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse" , &num_context_sparse)); |
285 | TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense" , &context_dense_types)); |
286 | TF_RETURN_IF_ERROR( |
287 | ctx->GetAttr("feature_list_sparse_types" , &feature_list_sparse_types)); |
288 | TF_RETURN_IF_ERROR( |
289 | ctx->GetAttr("feature_list_dense_types" , &feature_list_dense_types)); |
290 | TF_RETURN_IF_ERROR( |
291 | ctx->GetAttr("Nfeature_list_sparse" , &num_feature_list_sparse)); |
292 | TF_RETURN_IF_ERROR( |
293 | ctx->GetAttr("context_dense_shapes" , &context_dense_shapes)); |
294 | TF_RETURN_IF_ERROR( |
295 | ctx->GetAttr("feature_list_dense_shapes" , &feature_list_dense_shapes)); |
296 | return FinishInit(op_version); |
297 | } |
298 | |
299 | std::unordered_set<string> feature_list_dense_missing_assumed_empty; |
300 | int64_t num_context_sparse; |
301 | int64_t num_context_dense; |
302 | int64_t num_context_ragged; |
303 | int64_t num_feature_list_sparse; |
304 | int64_t num_feature_list_dense; |
305 | int64_t num_feature_list_ragged; |
306 | std::vector<tstring> context_sparse_keys; |
307 | std::vector<tstring> context_dense_keys; |
308 | std::vector<tstring> feature_list_sparse_keys; |
309 | std::vector<tstring> feature_list_dense_keys; |
310 | std::vector<DataType> context_sparse_types; |
311 | std::vector<DataType> context_dense_types; |
312 | std::vector<TensorShape> context_dense_shapes; |
313 | std::vector<DataType> feature_list_sparse_types; |
314 | std::vector<DataType> feature_list_dense_types; |
315 | std::vector<TensorShape> feature_list_dense_shapes; |
316 | std::vector<DataType> context_ragged_value_types; |
317 | std::vector<DataType> context_ragged_split_types; |
318 | std::vector<DataType> feature_list_ragged_value_types; |
319 | std::vector<DataType> feature_list_ragged_split_types; |
320 | |
321 | private: |
322 | Status FinishInit(int op_version); // for context-independent parts of Init. |
323 | }; |
324 | |
325 | // Parses the attributes passed to ParseSingleSequenceExample. |
326 | // REQUIRES: Init must be called after construction. |
327 | struct ParseSingleSequenceExampleAttrs { |
328 | public: |
329 | template <typename ContextType> |
330 | Status Init(ContextType* ctx) { |
331 | TF_RETURN_IF_ERROR( |
332 | ctx->GetAttr("context_sparse_types" , &context_sparse_types)); |
333 | TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_dense" , &num_context_dense)); |
334 | TF_RETURN_IF_ERROR( |
335 | ctx->GetAttr("Nfeature_list_dense" , &num_feature_list_dense)); |
336 | TF_RETURN_IF_ERROR(ctx->GetAttr("Ncontext_sparse" , &num_context_sparse)); |
337 | TF_RETURN_IF_ERROR(ctx->GetAttr("Tcontext_dense" , &context_dense_types)); |
338 | TF_RETURN_IF_ERROR( |
339 | ctx->GetAttr("feature_list_sparse_types" , &feature_list_sparse_types)); |
340 | TF_RETURN_IF_ERROR( |
341 | ctx->GetAttr("feature_list_dense_types" , &feature_list_dense_types)); |
342 | TF_RETURN_IF_ERROR( |
343 | ctx->GetAttr("Nfeature_list_sparse" , &num_feature_list_sparse)); |
344 | TF_RETURN_IF_ERROR( |
345 | ctx->GetAttr("context_dense_shapes" , &context_dense_shapes)); |
346 | TF_RETURN_IF_ERROR( |
347 | ctx->GetAttr("feature_list_dense_shapes" , &feature_list_dense_shapes)); |
348 | return FinishInit(); |
349 | } |
350 | |
351 | int64_t num_context_sparse; |
352 | int64_t num_context_dense; |
353 | int64_t num_feature_list_sparse; |
354 | int64_t num_feature_list_dense; |
355 | std::vector<DataType> context_sparse_types; |
356 | std::vector<DataType> context_dense_types; |
357 | std::vector<TensorShape> context_dense_shapes; |
358 | std::vector<DataType> feature_list_sparse_types; |
359 | std::vector<DataType> feature_list_dense_types; |
360 | std::vector<TensorShape> feature_list_dense_shapes; |
361 | |
362 | private: |
363 | Status FinishInit(); // for context-independent parts of Init. |
364 | }; |
365 | |
366 | } // namespace tensorflow |
367 | |
368 | #endif // TENSORFLOW_CORE_UTIL_EXAMPLE_PROTO_HELPER_H_ |
369 | |