1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/parsing_ops.cc. |
17 | |
18 | #include <numeric> |
19 | #include <unordered_set> |
20 | #include <vector> |
21 | |
22 | #include "absl/base/call_once.h" |
23 | #include "tensorflow/core/example/example.pb.h" |
24 | #include "tensorflow/core/example/feature.pb.h" |
25 | #include "tensorflow/core/framework/common_shape_fns.h" |
26 | #include "tensorflow/core/framework/metrics.h" |
27 | #include "tensorflow/core/framework/numeric_op.h" |
28 | #include "tensorflow/core/framework/register_types.h" |
29 | #include "tensorflow/core/lib/core/errors.h" |
30 | #include "tensorflow/core/lib/gtl/array_slice.h" |
31 | #include "tensorflow/core/platform/logging.h" |
32 | #include "tensorflow/core/platform/protobuf.h" |
33 | #include "tensorflow/core/util/example_proto_fast_parsing.h" |
34 | #include "tensorflow/core/util/example_proto_helper.h" |
35 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
36 | #include "tensorflow/core/util/work_sharder.h" |
37 | |
38 | namespace tensorflow { |
39 | |
40 | namespace { |
41 | constexpr char kParseExampleV2[] = "ParseExampleV2" ; |
42 | constexpr char kParseSequenceExampleV2[] = "ParseSequenceExampleV2" ; |
43 | } // namespace |
44 | |
45 | // Note: this kernel is used by both the ParseExample op and the ParseExampleV2 |
46 | // op. It automatically determines which op was used by checking if the |
47 | // "ragged_value_types" attribute exists. |
48 | class ParseExampleOp : public OpKernel { |
49 | public: |
50 | explicit ParseExampleOp(OpKernelConstruction* ctx) |
51 | : OpKernel(ctx), op_version_(ctx->def().op() == kParseExampleV2 ? 2 : 1) { |
52 | OP_REQUIRES_OK(ctx, attrs_.Init(ctx, op_version_)); |
53 | } |
54 | |
55 | void Compute(OpKernelContext* ctx) override { |
56 | const Tensor* names; |
57 | const Tensor* serialized; |
58 | std::vector<StringPiece> dense_keys_t; |
59 | std::vector<StringPiece> sparse_keys_t; |
60 | std::vector<StringPiece> ragged_keys_t; |
61 | OpInputList dense_defaults; |
62 | |
63 | // Grab the inputs. |
64 | OP_REQUIRES_OK(ctx, ctx->input("serialized" , &serialized)); |
65 | OP_REQUIRES_OK(ctx, ctx->input("names" , &names)); |
66 | if (op_version_ == 2) { |
67 | OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "dense_keys" , &dense_keys_t)); |
68 | OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "sparse_keys" , &sparse_keys_t)); |
69 | OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "ragged_keys" , &ragged_keys_t)); |
70 | } else { |
71 | OP_REQUIRES_OK(ctx, GetInputListKeys(ctx, "dense_keys" , &dense_keys_t)); |
72 | OP_REQUIRES_OK(ctx, GetInputListKeys(ctx, "sparse_keys" , &sparse_keys_t)); |
73 | } |
74 | absl::call_once(flag_, [&dense_keys_t, &sparse_keys_t, &ragged_keys_t]() { |
75 | metrics::RecordParseDenseFeature(dense_keys_t.size()); |
76 | metrics::RecordParseSparseFeature(sparse_keys_t.size()); |
77 | metrics::RecordParseRaggedFeature(ragged_keys_t.size()); |
78 | }); |
79 | OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults" , &dense_defaults)); |
80 | |
81 | // Validate input tensor shapes. |
82 | OP_REQUIRES_OK( |
83 | ctx, CheckInputShapes(serialized, names, dense_defaults, dense_keys_t, |
84 | sparse_keys_t, ragged_keys_t)); |
85 | |
86 | example::FastParseExampleConfig config = |
87 | MakeConfig(dense_keys_t, sparse_keys_t, ragged_keys_t, dense_defaults); |
88 | |
89 | example::Result result; |
90 | if (TensorShapeUtils::IsVector(serialized->shape())) { |
91 | OP_REQUIRES_OK( |
92 | ctx, ParseExampleVector(config, serialized, names, ctx, &result)); |
93 | } else { |
94 | OP_REQUIRES_OK(ctx, ParseExampleScalar(config, serialized, ctx, &result)); |
95 | } |
96 | OP_REQUIRES_OK(ctx, WriteOutput(result, ctx)); |
97 | } |
98 | |
99 | protected: |
100 | // Copies keys from tensor to std::vector<string>. |
101 | Status GetTensorKeys(OpKernelContext* ctx, StringPiece input_name, |
102 | std::vector<StringPiece>* keys) const { |
103 | const Tensor* key_t; |
104 | TF_RETURN_IF_ERROR(ctx->input(input_name, &key_t)); |
105 | keys->reserve(key_t->NumElements()); |
106 | auto keys_flat = key_t->flat<tstring>(); |
107 | for (int i = 0; i < keys_flat.size(); ++i) { |
108 | keys->push_back(keys_flat(i)); |
109 | } |
110 | return OkStatus(); |
111 | } |
112 | |
113 | // Copies keys from OpInputList of scalar to std::vector<string>. |
114 | Status GetInputListKeys(OpKernelContext* ctx, StringPiece input_name, |
115 | std::vector<StringPiece>* keys) const { |
116 | OpInputList key_list; |
117 | TF_RETURN_IF_ERROR(ctx->input_list(input_name, &key_list)); |
118 | keys->reserve(key_list.size()); |
119 | for (const auto& key : key_list) { |
120 | keys->push_back(key.scalar<tstring>()()); |
121 | } |
122 | return OkStatus(); |
123 | } |
124 | |
125 | // Validates the shapes of input tensors. |
126 | Status CheckInputShapes(const Tensor* serialized, const Tensor* names, |
127 | const OpInputList& dense_defaults, |
128 | const std::vector<StringPiece>& dense_keys_t, |
129 | const std::vector<StringPiece>& sparse_keys_t, |
130 | const std::vector<StringPiece>& ragged_keys_t) const { |
131 | if (op_version_ == 2) { |
132 | if (TensorShapeUtils::IsMatrixOrHigher(serialized->shape())) { |
133 | return errors::InvalidArgument( |
134 | "Expected serialized to be a scalar or vector, got shape: " , |
135 | serialized->shape().DebugString()); |
136 | } |
137 | } else { |
138 | if (!TensorShapeUtils::IsVector(serialized->shape())) { |
139 | return errors::InvalidArgument( |
140 | "Expected serialized to be a vector, got shape: " , |
141 | serialized->shape().DebugString()); |
142 | } |
143 | } |
144 | if (names->NumElements() > 0 && names->shape() != serialized->shape()) { |
145 | return errors::InvalidArgument( |
146 | "Expected names have the same shape as serialized: name.shape=" , |
147 | names->shape().DebugString(), |
148 | ", serialized.shape=" , serialized->shape().DebugString()); |
149 | } |
150 | if (op_version_ == 2) { |
151 | if (dense_keys_t.size() != attrs_.num_dense) { |
152 | return errors::InvalidArgument( |
153 | "Expected len(dense_keys) == len(dense_types) but got: " , |
154 | dense_keys_t.size(), " vs. " , attrs_.num_dense); |
155 | } |
156 | if (sparse_keys_t.size() != attrs_.num_sparse) { |
157 | return errors::InvalidArgument( |
158 | "Expected len(sparse_keys) == num_sparse but got: " , |
159 | sparse_keys_t.size(), " vs. " , attrs_.num_sparse); |
160 | } |
161 | if (ragged_keys_t.size() != attrs_.num_ragged) { |
162 | return errors::InvalidArgument( |
163 | "Expected len(ragged_keys) == len(ragged_value_types) but got: " , |
164 | ragged_keys_t.size(), " vs. " , attrs_.num_ragged); |
165 | } |
166 | } |
167 | |
168 | if (dense_defaults.size() != attrs_.num_dense) { |
169 | return errors::InvalidArgument( |
170 | "Expected len(dense_defaults) == len(dense_keys) but got: " , |
171 | dense_defaults.size(), " vs. " , attrs_.num_dense); |
172 | } |
173 | |
174 | for (int d = 0; d < static_cast<int>(attrs_.num_dense); ++d) { |
175 | const Tensor& def_value = dense_defaults[d]; |
176 | if (attrs_.variable_length[d]) { |
177 | if (def_value.NumElements() != 1) { |
178 | return errors::InvalidArgument( |
179 | "dense_shape[" , d, "] is a variable length shape: " , |
180 | attrs_.dense_shapes[d].DebugString(), |
181 | ", therefore " |
182 | "def_value[" , |
183 | d, |
184 | "] must contain a single element (" |
185 | "the padding element). But its shape is: " , |
186 | def_value.shape().DebugString()); |
187 | } |
188 | } else if (def_value.NumElements() > 0) { |
189 | if (!attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape())) { |
190 | return errors::InvalidArgument( |
191 | "def_value[" , d, "].shape() == " , def_value.shape().DebugString(), |
192 | " is not compatible with dense_shapes_[" , d, |
193 | "] == " , attrs_.dense_shapes[d].DebugString()); |
194 | } |
195 | } |
196 | if (def_value.dtype() != attrs_.dense_types[d]) { |
197 | return errors::InvalidArgument( |
198 | "dense_defaults[" , d, |
199 | "].dtype() == " , DataTypeString(def_value.dtype()), |
200 | " != dense_types_[" , d, |
201 | "] == " , DataTypeString(attrs_.dense_types[d])); |
202 | } |
203 | } |
204 | return OkStatus(); |
205 | } |
206 | |
207 | // Populates the FastParseExampleConfig from keys & defaults. |
208 | example::FastParseExampleConfig MakeConfig( |
209 | const std::vector<StringPiece>& dense_keys_t, |
210 | const std::vector<StringPiece>& sparse_keys_t, |
211 | const std::vector<StringPiece>& ragged_keys_t, |
212 | const OpInputList& dense_defaults) const { |
213 | example::FastParseExampleConfig config; |
214 | config.dense.reserve(attrs_.num_dense); |
215 | for (int d = 0; d < attrs_.num_dense; ++d) { |
216 | config.dense.emplace_back(dense_keys_t[d], attrs_.dense_types[d], |
217 | attrs_.dense_shapes[d], dense_defaults[d], |
218 | attrs_.variable_length[d], |
219 | attrs_.elements_per_stride[d]); |
220 | } |
221 | config.sparse.reserve(attrs_.num_sparse); |
222 | for (int d = 0; d < attrs_.num_sparse; ++d) { |
223 | config.sparse.emplace_back(sparse_keys_t[d], attrs_.sparse_types[d]); |
224 | } |
225 | config.ragged.reserve(attrs_.num_ragged); |
226 | for (int d = 0; d < attrs_.num_ragged; ++d) { |
227 | config.ragged.emplace_back(ragged_keys_t[d], attrs_.ragged_value_types[d], |
228 | attrs_.ragged_split_types[d]); |
229 | } |
230 | return config; |
231 | } |
232 | |
233 | // Parses a single example. |
234 | Status ParseExampleScalar(const example::FastParseExampleConfig& config, |
235 | const Tensor* serialized, OpKernelContext* ctx, |
236 | example::Result* result) const { |
237 | const tstring& serialized_proto = serialized->scalar<tstring>()(); |
238 | return FastParseSingleExample(config, serialized_proto, result); |
239 | } |
240 | |
241 | // Parses a vector of examples. |
242 | Status ParseExampleVector(const example::FastParseExampleConfig& config, |
243 | const Tensor* serialized, const Tensor* names, |
244 | OpKernelContext* ctx, |
245 | example::Result* result) const { |
246 | auto serialized_t = serialized->flat<tstring>(); |
247 | auto names_t = names->flat<tstring>(); |
248 | gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size()); |
249 | gtl::ArraySlice<tstring> names_slice(names_t.data(), names_t.size()); |
250 | return FastParseExample( |
251 | config, slice, names_slice, |
252 | ctx->device()->tensorflow_cpu_worker_threads()->workers, result); |
253 | } |
254 | |
255 | Status WriteOutput(const example::Result& result, |
256 | OpKernelContext* ctx) const { |
257 | OpOutputList dense_values; |
258 | OpOutputList sparse_indices; |
259 | OpOutputList sparse_values; |
260 | OpOutputList sparse_shapes; |
261 | TF_RETURN_IF_ERROR(ctx->output_list("dense_values" , &dense_values)); |
262 | TF_RETURN_IF_ERROR(ctx->output_list("sparse_indices" , &sparse_indices)); |
263 | TF_RETURN_IF_ERROR(ctx->output_list("sparse_values" , &sparse_values)); |
264 | TF_RETURN_IF_ERROR(ctx->output_list("sparse_shapes" , &sparse_shapes)); |
265 | for (int d = 0; d < attrs_.num_dense; ++d) { |
266 | dense_values.set(d, result.dense_values[d]); |
267 | } |
268 | for (int d = 0; d < attrs_.num_sparse; ++d) { |
269 | sparse_indices.set(d, result.sparse_indices[d]); |
270 | sparse_values.set(d, result.sparse_values[d]); |
271 | sparse_shapes.set(d, result.sparse_shapes[d]); |
272 | } |
273 | if (op_version_ == 2) { |
274 | OpOutputList ragged_values; |
275 | OpOutputList ragged_splits; |
276 | TF_RETURN_IF_ERROR(ctx->output_list("ragged_values" , &ragged_values)); |
277 | TF_RETURN_IF_ERROR(ctx->output_list("ragged_row_splits" , &ragged_splits)); |
278 | for (int d = 0; d < attrs_.num_ragged; ++d) { |
279 | ragged_values.set(d, result.ragged_values[d]); |
280 | ragged_splits.set(d, result.ragged_splits[d]); |
281 | } |
282 | } |
283 | return OkStatus(); |
284 | } |
285 | |
286 | ParseExampleAttrs attrs_; |
287 | int op_version_; |
288 | absl::once_flag flag_; |
289 | }; |
290 | |
291 | REGISTER_KERNEL_BUILDER(Name("ParseExample" ).Device(DEVICE_CPU), |
292 | ParseExampleOp); |
293 | REGISTER_KERNEL_BUILDER(Name("ParseExampleV2" ).Device(DEVICE_CPU), |
294 | ParseExampleOp); |
295 | |
296 | class ParseSingleExampleOp : public OpKernel { |
297 | public: |
298 | explicit ParseSingleExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
299 | OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); |
300 | metrics::RecordParseDenseFeature(attrs_.dense_keys.size()); |
301 | metrics::RecordParseSparseFeature(attrs_.sparse_keys.size()); |
302 | } |
303 | |
304 | void Compute(OpKernelContext* ctx) override { |
305 | const Tensor* serialized; |
306 | OpInputList dense_defaults; |
307 | |
308 | // Grab the input list arguments. |
309 | OP_REQUIRES_OK(ctx, ctx->input("serialized" , &serialized)); |
310 | OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults" , &dense_defaults)); |
311 | |
312 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()), |
313 | errors::InvalidArgument( |
314 | "Expected serialized to be a scalar, got shape: " , |
315 | serialized->shape().DebugString())); |
316 | OP_REQUIRES(ctx, dense_defaults.size() == attrs_.dense_keys.size(), |
317 | errors::InvalidArgument( |
318 | "Expected len(dense_defaults) == len(dense_keys) but got: " , |
319 | dense_defaults.size(), " vs. " , attrs_.dense_keys.size())); |
320 | |
321 | for (size_t d = 0; d < attrs_.dense_keys.size(); ++d) { |
322 | const Tensor& def_value = dense_defaults[d]; |
323 | if (attrs_.variable_length[d]) { |
324 | OP_REQUIRES(ctx, def_value.NumElements() == 1, |
325 | errors::InvalidArgument( |
326 | "dense_shape[" , d, "] is a variable length shape: " , |
327 | attrs_.dense_shapes[d].DebugString(), |
328 | ", therefore " |
329 | "def_value[" , |
330 | d, |
331 | "] must contain a single element (" |
332 | "the padding element). But its shape is: " , |
333 | def_value.shape().DebugString())); |
334 | } else if (def_value.NumElements() > 0) { |
335 | OP_REQUIRES(ctx, |
336 | attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()), |
337 | errors::InvalidArgument( |
338 | "def_value[" , d, |
339 | "].shape() == " , def_value.shape().DebugString(), |
340 | " is not compatible with dense_shapes_[" , d, |
341 | "] == " , attrs_.dense_shapes[d].DebugString())); |
342 | } |
343 | OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d], |
344 | errors::InvalidArgument( |
345 | "dense_defaults[" , d, "].dtype() == " , |
346 | DataTypeString(def_value.dtype()), " != dense_types_[" , d, |
347 | "] == " , DataTypeString(attrs_.dense_types[d]))); |
348 | } |
349 | |
350 | example::Result result; |
351 | |
352 | // TODO(mrry): Build the configuration once and cache it. |
353 | example::FastParseExampleConfig config; |
354 | for (int d = 0; d < attrs_.dense_keys.size(); ++d) { |
355 | config.dense.push_back({attrs_.dense_keys[d], attrs_.dense_types[d], |
356 | attrs_.dense_shapes[d], dense_defaults[d], |
357 | attrs_.variable_length[d], |
358 | attrs_.elements_per_stride[d]}); |
359 | } |
360 | for (int d = 0; d < attrs_.sparse_keys.size(); ++d) { |
361 | config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]}); |
362 | } |
363 | |
364 | const tstring& serialized_proto = serialized->scalar<tstring>()(); |
365 | |
366 | OP_REQUIRES_OK(ctx, |
367 | FastParseSingleExample(config, serialized_proto, &result)); |
368 | |
369 | OpOutputList dense_values; |
370 | OpOutputList sparse_indices; |
371 | OpOutputList sparse_values; |
372 | OpOutputList sparse_shapes; |
373 | OP_REQUIRES_OK(ctx, ctx->output_list("dense_values" , &dense_values)); |
374 | OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices" , &sparse_indices)); |
375 | OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values" , &sparse_values)); |
376 | OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes" , &sparse_shapes)); |
377 | for (int d = 0; d < attrs_.dense_keys.size(); ++d) { |
378 | dense_values.set(d, result.dense_values[d]); |
379 | } |
380 | for (int d = 0; d < attrs_.sparse_keys.size(); ++d) { |
381 | sparse_indices.set(d, result.sparse_indices[d]); |
382 | sparse_values.set(d, result.sparse_values[d]); |
383 | sparse_shapes.set(d, result.sparse_shapes[d]); |
384 | } |
385 | } |
386 | |
387 | protected: |
388 | ParseSingleExampleAttrs attrs_; |
389 | }; |
390 | |
391 | REGISTER_KERNEL_BUILDER(Name("ParseSingleExample" ).Device(DEVICE_CPU), |
392 | ParseSingleExampleOp); |
393 | |
394 | class ParseSequenceExampleOp : public OpKernel { |
395 | public: |
396 | explicit ParseSequenceExampleOp(OpKernelConstruction* ctx) |
397 | : OpKernel(ctx), |
398 | op_version_(ctx->def().op() == kParseSequenceExampleV2 ? 2 : 1) { |
399 | OP_REQUIRES_OK(ctx, attrs_.Init(ctx, op_version_)); |
400 | metrics::RecordParseDenseFeature(attrs_.context_dense_keys.size() + |
401 | attrs_.feature_list_dense_keys.size()); |
402 | metrics::RecordParseSparseFeature(attrs_.context_sparse_keys.size() + |
403 | attrs_.feature_list_sparse_keys.size()); |
404 | } |
405 | |
406 | void Compute(OpKernelContext* ctx) override { |
407 | const Tensor* debug_name; |
408 | const Tensor* serialized; |
409 | OpInputList context_dense_defaults; |
410 | |
411 | OP_REQUIRES_OK(ctx, ctx->input("debug_name" , &debug_name)); |
412 | OP_REQUIRES_OK(ctx, ctx->input("serialized" , &serialized)); |
413 | OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults" , |
414 | &context_dense_defaults)); |
415 | const Tensor* context_dense_keys = nullptr; |
416 | const Tensor* context_sparse_keys = nullptr; |
417 | const Tensor* context_ragged_keys = nullptr; |
418 | const Tensor* feature_list_dense_keys = nullptr; |
419 | const Tensor* feature_list_sparse_keys = nullptr; |
420 | const Tensor* feature_list_ragged_keys = nullptr; |
421 | const Tensor* feature_list_dense_missing_assumed_empty = nullptr; |
422 | if (op_version_ == 2) { |
423 | OP_REQUIRES_OK(ctx, |
424 | ctx->input("feature_list_dense_missing_assumed_empty" , |
425 | &feature_list_dense_missing_assumed_empty)); |
426 | OP_REQUIRES_OK(ctx, |
427 | ctx->input("context_dense_keys" , &context_dense_keys)); |
428 | OP_REQUIRES_OK(ctx, |
429 | ctx->input("context_sparse_keys" , &context_sparse_keys)); |
430 | OP_REQUIRES_OK(ctx, |
431 | ctx->input("context_ragged_keys" , &context_ragged_keys)); |
432 | OP_REQUIRES_OK( |
433 | ctx, ctx->input("feature_list_dense_keys" , &feature_list_dense_keys)); |
434 | OP_REQUIRES_OK(ctx, ctx->input("feature_list_sparse_keys" , |
435 | &feature_list_sparse_keys)); |
436 | OP_REQUIRES_OK(ctx, ctx->input("feature_list_ragged_keys" , |
437 | &feature_list_ragged_keys)); |
438 | absl::call_once(flag_, [&]() { |
439 | metrics::RecordParseDenseFeature( |
440 | context_dense_keys->NumElements() + |
441 | feature_list_dense_keys->NumElements()); |
442 | metrics::RecordParseSparseFeature( |
443 | context_sparse_keys->NumElements() + |
444 | feature_list_sparse_keys->NumElements()); |
445 | metrics::RecordParseRaggedFeature( |
446 | context_ragged_keys->NumElements() + |
447 | feature_list_ragged_keys->NumElements()); |
448 | }); |
449 | } |
450 | |
451 | // Validate input tensor shapes. |
452 | OP_REQUIRES_OK(ctx, CheckInputShapes( |
453 | serialized, debug_name, context_dense_defaults, |
454 | context_dense_keys, context_sparse_keys, |
455 | context_ragged_keys, feature_list_dense_keys, |
456 | feature_list_sparse_keys, feature_list_ragged_keys, |
457 | feature_list_dense_missing_assumed_empty)); |
458 | |
459 | example::FastParseExampleConfig context_config = |
460 | MakeContextConfig(context_dense_keys, context_sparse_keys, |
461 | context_ragged_keys, context_dense_defaults); |
462 | example::FastParseExampleConfig feature_list_config = MakeFeatureListConfig( |
463 | feature_list_dense_keys, feature_list_sparse_keys, |
464 | feature_list_ragged_keys, feature_list_dense_missing_assumed_empty); |
465 | |
466 | bool is_batch = TensorShapeUtils::IsVector(serialized->shape()); |
467 | auto serialized_t = serialized->flat<tstring>(); |
468 | auto debug_name_t = debug_name->flat<tstring>(); |
469 | gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size()); |
470 | gtl::ArraySlice<tstring> names_slice(debug_name_t.data(), |
471 | debug_name_t.size()); |
472 | |
473 | example::Result context_result, feature_list_result; |
474 | std::vector<Tensor> dense_feature_lengths; |
475 | OP_REQUIRES_OK( |
476 | ctx, FastParseSequenceExample( |
477 | context_config, feature_list_config, slice, names_slice, |
478 | ctx->device()->tensorflow_cpu_worker_threads()->workers, |
479 | &context_result, &feature_list_result, &dense_feature_lengths, |
480 | is_batch)); |
481 | |
482 | OP_REQUIRES_OK(ctx, WriteOutput(context_result, feature_list_result, |
483 | dense_feature_lengths, ctx)); |
484 | } |
485 | |
486 | protected: |
487 | Status CheckInputShapes( |
488 | const Tensor* serialized, const Tensor* names, |
489 | const OpInputList& context_dense_defaults, |
490 | |
491 | const Tensor* context_dense_keys, const Tensor* context_sparse_keys, |
492 | const Tensor* context_ragged_keys, const Tensor* feature_list_dense_keys, |
493 | const Tensor* feature_list_sparse_keys, |
494 | const Tensor* feature_list_ragged_keys, |
495 | const Tensor* feature_list_dense_missing_assumed_empty) const { |
496 | if (TensorShapeUtils::IsMatrixOrHigher(serialized->shape())) { |
497 | return errors::InvalidArgument( |
498 | "Expected serialized to be a scalar or vector, got shape: " , |
499 | serialized->shape().DebugString()); |
500 | } |
501 | if (op_version_ > 1) { |
502 | if (context_dense_keys->NumElements() != attrs_.num_context_dense) { |
503 | return errors::InvalidArgument( |
504 | "Expected len(context_dense_keys) to match len(Tcontext_dense)" ); |
505 | } |
506 | if (context_sparse_keys->NumElements() != attrs_.num_context_sparse) { |
507 | return errors::InvalidArgument( |
508 | "Expected len(context_sparse_keys) to match Ncontext_sparse" ); |
509 | } |
510 | if (context_ragged_keys->NumElements() != attrs_.num_context_ragged) { |
511 | return errors::InvalidArgument( |
512 | "Expected len(context_ragged_keys) to match " |
513 | "len(context_ragged_value_types)" ); |
514 | } |
515 | if (feature_list_dense_keys->NumElements() != |
516 | attrs_.num_feature_list_dense) { |
517 | return errors::InvalidArgument( |
518 | "Expected len(feature_list_dense_keys) to match " |
519 | "Nfeature_list_dense" ); |
520 | } |
521 | if (feature_list_dense_missing_assumed_empty->NumElements() != |
522 | attrs_.num_feature_list_dense) { |
523 | return errors::InvalidArgument( |
524 | "Expected len(feature_list_dense_missing_assumed_empty to match " |
525 | "Nfeature_list_dense" ); |
526 | } |
527 | if (feature_list_sparse_keys->NumElements() != |
528 | attrs_.num_feature_list_sparse) { |
529 | return errors::InvalidArgument( |
530 | "Expected len(feature_list_sparse_keys) to match " |
531 | "Nfeature_list_sparse" ); |
532 | } |
533 | if (feature_list_ragged_keys->NumElements() != |
534 | attrs_.num_feature_list_ragged) { |
535 | return errors::InvalidArgument( |
536 | "Expected len(feature_list_ragged_keys) to match " |
537 | "len(feature_list_ragged_value_types)" ); |
538 | } |
539 | } |
540 | if (context_dense_defaults.size() != attrs_.num_context_dense) { |
541 | return errors::InvalidArgument( |
542 | "Expected len(context_dense_defaults) " |
543 | "== len(context_dense_keys) but got: " , |
544 | context_dense_defaults.size(), " vs. " , attrs_.num_context_dense); |
545 | } |
546 | for (int d = 0; d < attrs_.num_context_dense; ++d) { |
547 | const Tensor& def_value = context_dense_defaults[d]; |
548 | if (def_value.NumElements() > 0) { |
549 | if (def_value.shape() != attrs_.context_dense_shapes[d]) { |
550 | return errors::InvalidArgument( |
551 | "default_value[" , d, |
552 | "].shape() == " , def_value.shape().DebugString(), |
553 | " != context_dense_shapes[" , d, |
554 | "] == " , attrs_.context_dense_shapes[d].DebugString()); |
555 | } |
556 | if (def_value.dtype() != attrs_.context_dense_types[d]) { |
557 | return errors::InvalidArgument( |
558 | "context_dense_defaults[" , d, |
559 | "].dtype() == " , DataTypeString(def_value.dtype()), |
560 | " != context_dense_types[" , d, |
561 | "] == " , DataTypeString(attrs_.context_dense_types[d])); |
562 | } |
563 | } |
564 | } |
565 | return OkStatus(); |
566 | } |
567 | |
568 | example::FastParseExampleConfig MakeContextConfig( |
569 | const Tensor* dense_keys, const Tensor* sparse_keys, |
570 | const Tensor* ragged_keys, |
571 | const OpInputList& context_dense_defaults) const { |
572 | // Convert the tensors/attrs to ArraySlices once, instead of re-evaluating |
573 | // them in each loop iteration. |
574 | gtl::ArraySlice<tstring> dense_keys_slice = |
575 | dense_keys |
576 | ? gtl::ArraySlice<tstring>(dense_keys->flat<tstring>().data(), |
577 | attrs_.num_context_dense) |
578 | : attrs_.context_dense_keys; |
579 | gtl::ArraySlice<tstring> sparse_keys_slice = |
580 | sparse_keys |
581 | ? gtl::ArraySlice<tstring>(sparse_keys->flat<tstring>().data(), |
582 | attrs_.num_context_sparse) |
583 | : attrs_.context_sparse_keys; |
584 | gtl::ArraySlice<tstring> ragged_keys_slice = |
585 | ragged_keys |
586 | ? gtl::ArraySlice<tstring>(ragged_keys->flat<tstring>().data(), |
587 | attrs_.num_context_ragged) |
588 | : gtl::ArraySlice<tstring>(nullptr, 0); |
589 | |
590 | example::FastParseExampleConfig config; |
591 | config.dense.reserve(attrs_.num_context_dense); |
592 | for (int d = 0; d < attrs_.num_context_dense; ++d) { |
593 | const tstring& key = dense_keys_slice[d]; |
594 | config.dense.emplace_back(key, attrs_.context_dense_types[d], |
595 | attrs_.context_dense_shapes[d], |
596 | context_dense_defaults[d], |
597 | false /* attrs_.context_variable_length[d] */, |
598 | 0 /*attrs_.context_elements_per_stride[d] */); |
599 | } |
600 | config.sparse.reserve(attrs_.num_context_sparse); |
601 | for (int d = 0; d < attrs_.num_context_sparse; ++d) { |
602 | const tstring& key = sparse_keys_slice[d]; |
603 | config.sparse.emplace_back(key, attrs_.context_sparse_types[d]); |
604 | } |
605 | config.ragged.reserve(attrs_.num_context_ragged); |
606 | for (int d = 0; d < attrs_.num_context_ragged; ++d) { |
607 | config.ragged.emplace_back(ragged_keys_slice[d], |
608 | attrs_.context_ragged_value_types[d], |
609 | attrs_.context_ragged_split_types[d]); |
610 | } |
611 | return config; |
612 | } |
613 | |
614 | static Tensor ConstructDefaultScalar(DataType dtype) { |
615 | switch (dtype) { |
616 | case DT_INT64: |
617 | return Tensor(static_cast<int64_t>(0)); |
618 | case DT_FLOAT: |
619 | return Tensor(static_cast<float>(0.0)); |
620 | case DT_STRING: |
621 | return Tensor("" ); |
622 | default: |
623 | return Tensor(DT_INVALID); |
624 | } |
625 | } |
626 | |
627 | example::FastParseExampleConfig MakeFeatureListConfig( |
628 | const Tensor* dense_keys, const Tensor* sparse_keys, |
629 | const Tensor* ragged_keys, |
630 | const Tensor* feature_list_dense_missing_assumed_empty) const { |
631 | // Convert the tensors/attrs to ArraySlices once, instead of re-evaluating |
632 | // them in each loop iteration. |
633 | gtl::ArraySlice<tstring> dense_keys_slice = |
634 | dense_keys |
635 | ? gtl::ArraySlice<tstring>(dense_keys->flat<tstring>().data(), |
636 | attrs_.num_feature_list_dense) |
637 | : attrs_.feature_list_dense_keys; |
638 | gtl::ArraySlice<tstring> sparse_keys_slice = |
639 | sparse_keys |
640 | ? gtl::ArraySlice<tstring>(sparse_keys->flat<tstring>().data(), |
641 | attrs_.num_feature_list_sparse) |
642 | : attrs_.feature_list_sparse_keys; |
643 | gtl::ArraySlice<tstring> ragged_keys_slice = |
644 | ragged_keys |
645 | ? gtl::ArraySlice<tstring>(ragged_keys->flat<tstring>().data(), |
646 | attrs_.num_feature_list_ragged) |
647 | : gtl::ArraySlice<tstring>(nullptr, 0); |
648 | // Use an empty slice to indicate that the map in attrs_ should be used |
649 | // instead. |
650 | gtl::ArraySlice<bool> feature_list_dense_missing_assumed_empty_slice = |
651 | feature_list_dense_missing_assumed_empty |
652 | ? gtl::ArraySlice<bool>( |
653 | feature_list_dense_missing_assumed_empty->flat<bool>().data(), |
654 | attrs_.num_feature_list_dense) |
655 | : gtl::ArraySlice<bool>(nullptr, 0); |
656 | |
657 | example::FastParseExampleConfig config; |
658 | config.dense.reserve(attrs_.num_feature_list_dense); |
659 | for (int d = 0; d < attrs_.num_feature_list_dense; ++d) { |
660 | const tstring& key = dense_keys_slice[d]; |
661 | bool missing_assumed_empty = |
662 | !feature_list_dense_missing_assumed_empty_slice.empty() |
663 | ? feature_list_dense_missing_assumed_empty_slice[d] |
664 | : attrs_.feature_list_dense_missing_assumed_empty.count(key) > 0; |
665 | DataType dtype = attrs_.feature_list_dense_types[d]; |
666 | config.dense.emplace_back( |
667 | key, dtype, attrs_.feature_list_dense_shapes[d], |
668 | ConstructDefaultScalar(dtype), missing_assumed_empty, |
669 | 0 /*attrs_.feature_list_elements_per_stride[d] */); |
670 | } |
671 | config.sparse.reserve(attrs_.num_feature_list_sparse); |
672 | for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) { |
673 | const tstring& key = sparse_keys_slice[d]; |
674 | config.sparse.emplace_back(key, attrs_.feature_list_sparse_types[d]); |
675 | } |
676 | config.ragged.reserve(attrs_.num_feature_list_ragged); |
677 | for (int d = 0; d < attrs_.num_feature_list_ragged; ++d) { |
678 | config.ragged.emplace_back(ragged_keys_slice[d], |
679 | attrs_.feature_list_ragged_value_types[d], |
680 | attrs_.feature_list_ragged_split_types[d]); |
681 | } |
682 | return config; |
683 | } |
684 | |
685 | Status WriteOutput(const example::Result& context_result, |
686 | const example::Result& feature_list_result, |
687 | const std::vector<Tensor>& dense_feature_lengths, |
688 | OpKernelContext* ctx) const { |
689 | OpOutputList context_sparse_indices; |
690 | OpOutputList context_sparse_values; |
691 | OpOutputList context_sparse_shapes; |
692 | OpOutputList context_dense_values; |
693 | OpOutputList feature_list_sparse_indices; |
694 | OpOutputList feature_list_sparse_values; |
695 | OpOutputList feature_list_sparse_shapes; |
696 | OpOutputList feature_list_dense_values; |
697 | OpOutputList feature_list_dense_lengths; |
698 | |
699 | TF_RETURN_IF_ERROR( |
700 | ctx->output_list("context_sparse_indices" , &context_sparse_indices)); |
701 | TF_RETURN_IF_ERROR( |
702 | ctx->output_list("context_sparse_values" , &context_sparse_values)); |
703 | TF_RETURN_IF_ERROR( |
704 | ctx->output_list("context_sparse_shapes" , &context_sparse_shapes)); |
705 | TF_RETURN_IF_ERROR( |
706 | ctx->output_list("context_dense_values" , &context_dense_values)); |
707 | TF_RETURN_IF_ERROR( |
708 | ctx->output_list("context_sparse_indices" , &context_sparse_indices)); |
709 | TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_indices" , |
710 | &feature_list_sparse_indices)); |
711 | TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_values" , |
712 | &feature_list_sparse_values)); |
713 | TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_shapes" , |
714 | &feature_list_sparse_shapes)); |
715 | TF_RETURN_IF_ERROR(ctx->output_list("feature_list_dense_values" , |
716 | &feature_list_dense_values)); |
717 | TF_RETURN_IF_ERROR(ctx->output_list("feature_list_dense_lengths" , |
718 | &feature_list_dense_lengths)); |
719 | for (int d = 0; d < attrs_.num_context_dense; ++d) { |
720 | context_dense_values.set(d, context_result.dense_values[d]); |
721 | } |
722 | for (int d = 0; d < attrs_.num_feature_list_dense; ++d) { |
723 | feature_list_dense_values.set(d, feature_list_result.dense_values[d]); |
724 | feature_list_dense_lengths.set(d, dense_feature_lengths[d]); |
725 | } |
726 | for (int d = 0; d < attrs_.num_context_sparse; ++d) { |
727 | context_sparse_indices.set(d, context_result.sparse_indices[d]); |
728 | context_sparse_values.set(d, context_result.sparse_values[d]); |
729 | context_sparse_shapes.set(d, context_result.sparse_shapes[d]); |
730 | } |
731 | for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) { |
732 | feature_list_sparse_indices.set(d, feature_list_result.sparse_indices[d]); |
733 | feature_list_sparse_values.set(d, feature_list_result.sparse_values[d]); |
734 | feature_list_sparse_shapes.set(d, feature_list_result.sparse_shapes[d]); |
735 | } |
736 | if (op_version_ == 2) { |
737 | OpOutputList context_ragged_values; |
738 | OpOutputList context_ragged_splits; |
739 | OpOutputList feature_list_ragged_values; |
740 | OpOutputList feature_list_ragged_inner_splits; |
741 | OpOutputList feature_list_ragged_outer_splits; |
742 | TF_RETURN_IF_ERROR( |
743 | ctx->output_list("context_ragged_values" , &context_ragged_values)); |
744 | TF_RETURN_IF_ERROR(ctx->output_list("context_ragged_row_splits" , |
745 | &context_ragged_splits)); |
746 | TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_values" , |
747 | &feature_list_ragged_values)); |
748 | TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_inner_splits" , |
749 | &feature_list_ragged_inner_splits)); |
750 | TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_outer_splits" , |
751 | &feature_list_ragged_outer_splits)); |
752 | for (int d = 0; d < attrs_.num_context_ragged; ++d) { |
753 | context_ragged_values.set(d, context_result.ragged_values[d]); |
754 | context_ragged_splits.set(d, context_result.ragged_splits[d]); |
755 | } |
756 | for (int d = 0; d < attrs_.num_feature_list_ragged; ++d) { |
757 | feature_list_ragged_values.set(d, feature_list_result.ragged_values[d]); |
758 | feature_list_ragged_outer_splits.set( |
759 | d, feature_list_result.ragged_outer_splits[d]); |
760 | feature_list_ragged_inner_splits.set( |
761 | d, feature_list_result.ragged_splits[d]); |
762 | } |
763 | } |
764 | return OkStatus(); |
765 | } |
766 | |
767 | ParseSequenceExampleAttrs attrs_; |
768 | int op_version_; |
769 | absl::once_flag flag_; |
770 | }; |
771 | |
772 | REGISTER_KERNEL_BUILDER(Name("ParseSequenceExample" ).Device(DEVICE_CPU), |
773 | ParseSequenceExampleOp); |
774 | REGISTER_KERNEL_BUILDER(Name("ParseSequenceExampleV2" ).Device(DEVICE_CPU), |
775 | ParseSequenceExampleOp); |
776 | |
777 | class ParseSingleSequenceExampleOp : public OpKernel { |
778 | public: |
779 | explicit ParseSingleSequenceExampleOp(OpKernelConstruction* ctx) |
780 | : OpKernel(ctx) { |
781 | OP_REQUIRES_OK(ctx, attrs_.Init(ctx)); |
782 | } |
783 | |
784 | void Compute(OpKernelContext* ctx) override { |
785 | const Tensor* debug_name; |
786 | const Tensor* serialized; |
787 | OpInputList context_dense_keys; |
788 | OpInputList context_sparse_keys; |
789 | OpInputList context_dense_defaults; |
790 | OpInputList feature_list_dense_keys; |
791 | OpInputList feature_list_sparse_keys; |
792 | const Tensor* feature_list_dense_missing_assumed_empty; |
793 | |
794 | OP_REQUIRES_OK(ctx, ctx->input("debug_name" , &debug_name)); |
795 | OP_REQUIRES_OK(ctx, ctx->input("serialized" , &serialized)); |
796 | OP_REQUIRES_OK(ctx, ctx->input("feature_list_dense_missing_assumed_empty" , |
797 | &feature_list_dense_missing_assumed_empty)); |
798 | OP_REQUIRES_OK(ctx, |
799 | ctx->input_list("context_dense_keys" , &context_dense_keys)); |
800 | OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_dense_keys" , |
801 | &feature_list_dense_keys)); |
802 | OP_REQUIRES_OK( |
803 | ctx, ctx->input_list("context_sparse_keys" , &context_sparse_keys)); |
804 | OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys" , |
805 | &feature_list_sparse_keys)); |
806 | OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults" , |
807 | &context_dense_defaults)); |
808 | |
809 | std::vector<string> context_dense_keys_t(attrs_.num_context_dense); |
810 | std::vector<string> context_sparse_keys_t(attrs_.num_context_sparse); |
811 | std::vector<string> feature_list_dense_keys_t( |
812 | attrs_.num_feature_list_dense); |
813 | std::vector<string> feature_list_sparse_keys_t( |
814 | attrs_.num_feature_list_sparse); |
815 | absl::call_once( |
816 | flag_, [&context_dense_keys_t, &context_sparse_keys_t, |
817 | &feature_list_dense_keys_t, &feature_list_sparse_keys_t]() { |
818 | metrics::RecordParseDenseFeature(context_dense_keys_t.size() + |
819 | feature_list_dense_keys_t.size()); |
820 | metrics::RecordParseSparseFeature(context_sparse_keys_t.size() + |
821 | feature_list_sparse_keys_t.size()); |
822 | }); |
823 | std::unordered_set<string> feature_list_dense_missing_assumed_empty_set; |
824 | CHECK_EQ(context_dense_keys.size(), attrs_.num_context_dense); |
825 | CHECK_EQ(context_sparse_keys.size(), attrs_.num_context_sparse); |
826 | CHECK_EQ(feature_list_dense_keys.size(), attrs_.num_feature_list_dense); |
827 | CHECK_EQ(feature_list_sparse_keys.size(), attrs_.num_feature_list_sparse); |
828 | for (int di = 0; di < attrs_.num_context_dense; ++di) { |
829 | OP_REQUIRES(ctx, |
830 | TensorShapeUtils::IsScalar(context_dense_keys[di].shape()), |
831 | errors::InvalidArgument( |
832 | "Expected context_dense_keys[" , di, |
833 | "] to be a scalar, got shape: " , |
834 | context_dense_keys[di].shape().DebugString())); |
835 | context_dense_keys_t[di] = context_dense_keys[di].scalar<tstring>()(); |
836 | } |
837 | for (int di = 0; di < attrs_.num_context_sparse; ++di) { |
838 | OP_REQUIRES(ctx, |
839 | TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()), |
840 | errors::InvalidArgument( |
841 | "Expected context_sparse_keys[" , di, |
842 | "] to be a scalar, got shape: " , |
843 | context_sparse_keys[di].shape().DebugString())); |
844 | context_sparse_keys_t[di] = context_sparse_keys[di].scalar<tstring>()(); |
845 | } |
846 | for (int di = 0; di < attrs_.num_feature_list_dense; ++di) { |
847 | OP_REQUIRES( |
848 | ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()), |
849 | errors::InvalidArgument( |
850 | "Expected feature_list_dense_keys[" , di, |
851 | "] to be a scalar, got shape: " , |
852 | feature_list_dense_keys[di].shape().DebugString())); |
853 | feature_list_dense_keys_t[di] = |
854 | feature_list_dense_keys[di].scalar<tstring>()(); |
855 | } |
856 | for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) { |
857 | OP_REQUIRES( |
858 | ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()), |
859 | errors::InvalidArgument( |
860 | "Expected feature_list_sparse_keys[" , di, |
861 | "] to be a scalar, got shape: " , |
862 | feature_list_sparse_keys[di].shape().DebugString())); |
863 | feature_list_sparse_keys_t[di] = |
864 | feature_list_sparse_keys[di].scalar<tstring>()(); |
865 | } |
866 | OP_REQUIRES( |
867 | ctx, |
868 | TensorShapeUtils::IsVector( |
869 | feature_list_dense_missing_assumed_empty->shape()), |
870 | errors::InvalidArgument( |
871 | "Expected feature_list_dense_missing_assumed_empty " , |
872 | "to be a vector, got shape: " , |
873 | feature_list_dense_missing_assumed_empty->shape().DebugString())); |
874 | auto feature_list_dense_missing_assumped_empty_t = |
875 | feature_list_dense_missing_assumed_empty->vec<tstring>(); |
876 | for (int de = 0; |
877 | de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) { |
878 | feature_list_dense_missing_assumed_empty_set.insert( |
879 | feature_list_dense_missing_assumped_empty_t(de)); |
880 | } |
881 | |
882 | bool has_debug_name = (debug_name->NumElements() > 0); |
883 | if (has_debug_name) { |
884 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(debug_name->shape()), |
885 | errors::InvalidArgument( |
886 | "Expected debug_name to be a scalar, got shape: " , |
887 | debug_name->shape().DebugString())); |
888 | } |
889 | auto debug_name_t = debug_name->scalar<tstring>(); |
890 | |
891 | OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()), |
892 | errors::InvalidArgument( |
893 | "Expected serialized to be a scalar, got shape: " , |
894 | serialized->shape().DebugString())); |
895 | |
896 | OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense, |
897 | errors::InvalidArgument("Expected len(context_dense_defaults) " |
898 | "== len(context_dense_keys) but got: " , |
899 | context_dense_defaults.size(), " vs. " , |
900 | attrs_.num_context_dense)); |
901 | |
902 | std::vector<bool> required(attrs_.num_context_dense); |
903 | for (int d = 0; d < attrs_.num_context_dense; ++d) { |
904 | const Tensor& def_value = context_dense_defaults[d]; |
905 | required[d] = (def_value.NumElements() == 0); // No default provided. |
906 | |
907 | if (def_value.NumElements() > 0) { |
908 | OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d], |
909 | errors::InvalidArgument( |
910 | "def_value[" , d, |
911 | "].shape() == " , def_value.shape().DebugString(), |
912 | " != context_dense_shapes_[" , d, |
913 | "] == " , attrs_.context_dense_shapes[d].DebugString())); |
914 | OP_REQUIRES( |
915 | ctx, def_value.dtype() == attrs_.context_dense_types[d], |
916 | errors::InvalidArgument( |
917 | "context_dense_defaults[" , d, "].dtype() == " , |
918 | DataTypeString(def_value.dtype()), " != context_dense_types_[" , |
919 | d, "] == " , DataTypeString(attrs_.context_dense_types[d]))); |
920 | } |
921 | } |
922 | |
923 | auto serialized_t = serialized->scalar<tstring>(); |
924 | |
925 | OpOutputList context_sparse_indices; |
926 | OpOutputList context_sparse_values; |
927 | OpOutputList context_sparse_shapes; |
928 | OpOutputList context_dense_values; |
929 | OpOutputList feature_list_sparse_indices; |
930 | OpOutputList feature_list_sparse_values; |
931 | OpOutputList feature_list_sparse_shapes; |
932 | OpOutputList feature_list_dense_values; |
933 | |
934 | OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices" , |
935 | &context_sparse_indices)); |
936 | OP_REQUIRES_OK( |
937 | ctx, ctx->output_list("context_sparse_values" , &context_sparse_values)); |
938 | OP_REQUIRES_OK( |
939 | ctx, ctx->output_list("context_sparse_shapes" , &context_sparse_shapes)); |
940 | OP_REQUIRES_OK( |
941 | ctx, ctx->output_list("context_dense_values" , &context_dense_values)); |
942 | OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices" , |
943 | &context_sparse_indices)); |
944 | OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices" , |
945 | &feature_list_sparse_indices)); |
946 | OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values" , |
947 | &feature_list_sparse_values)); |
948 | OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes" , |
949 | &feature_list_sparse_shapes)); |
950 | OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values" , |
951 | &feature_list_dense_values)); |
952 | |
953 | // Allocate the SequenceExample on an arena. Provides better memory locality |
954 | // and greatly speeds up destruction. |
955 | protobuf::ArenaOptions options; |
956 | // We have some hint of what the final proto size will be based on the size |
957 | // of the serialized bytes- use this to set a custom allocation strategy. |
958 | // Note that the default allocation strategy is quite conservative (min |
959 | // block size of 256 bytes, and a max of 8 kilobytes). |
960 | const size_t block_size = serialized_t().size() * 1.1; |
961 | options.start_block_size = std::max(options.start_block_size, block_size); |
962 | options.max_block_size = std::max(options.max_block_size, block_size); |
963 | protobuf::Arena arena(options); |
964 | auto& ex = *protobuf::Arena::CreateMessage<SequenceExample>(&arena); |
965 | |
966 | OP_REQUIRES( |
967 | ctx, ParseProtoUnlimited(&ex, serialized_t()), |
968 | errors::InvalidArgument("Could not parse example input, value: '" , |
969 | serialized_t(), "'" )); |
970 | |
971 | const tstring& name = (has_debug_name) ? debug_name_t() : "<unknown>" ; |
972 | const Features& context = ex.context(); |
973 | const auto& context_dict = context.feature(); |
974 | |
975 | // Context Dense ----------------------------------------------------------- |
976 | |
977 | // Preallocate context_dense_values, since we know their sizes |
978 | for (int d = 0; d < attrs_.num_context_dense; ++d) { |
979 | TensorShape out_shape; |
980 | for (const int dim : attrs_.context_dense_shapes[d].dim_sizes()) |
981 | out_shape.AddDim(dim); |
982 | Tensor* out = nullptr; |
983 | OP_REQUIRES_OK(ctx, context_dense_values.allocate(d, out_shape, &out)); |
984 | } |
985 | |
986 | for (int d = 0; d < attrs_.num_context_dense; ++d) { |
987 | const tstring& key = context_dense_keys_t[d]; |
988 | const DataType& dtype = attrs_.context_dense_types[d]; |
989 | const TensorShape& shape = attrs_.context_dense_shapes[d]; |
990 | |
991 | const auto& feature_found = context_dict.find(key); |
992 | OP_REQUIRES( |
993 | ctx, (feature_found != context_dict.end()) || !required[d], |
994 | errors::InvalidArgument("Name: " , name, ", Context feature '" , key, |
995 | "' is required but could not be found." )); |
996 | if (feature_found != context_dict.end()) { |
997 | const Feature& f = feature_found->second; |
998 | bool types_match; |
999 | OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); |
1000 | OP_REQUIRES( |
1001 | ctx, types_match, |
1002 | errors::InvalidArgument("Name: " , name, ", Context feature: " , key, |
1003 | ". Data types don't match. " , |
1004 | "Expected type: " , DataTypeString(dtype), |
1005 | " Feature is: " , f.DebugString())); |
1006 | |
1007 | OP_REQUIRES_OK(ctx, FeatureDenseCopy(0, name, key, dtype, shape, f, |
1008 | context_dense_values[d])); |
1009 | } else { |
1010 | RowDenseCopy(0, dtype, context_dense_defaults[d], |
1011 | context_dense_values[d]); |
1012 | } |
1013 | } |
1014 | |
1015 | // Context Sparse ---------------------------------------------------------- |
1016 | for (int d = 0; d < attrs_.num_context_sparse; ++d) { |
1017 | const tstring& key = context_sparse_keys_t[d]; |
1018 | const DataType& dtype = attrs_.context_sparse_types[d]; |
1019 | |
1020 | const auto& feature_found = context_dict.find(key); |
1021 | bool feature_has_data = // Found key & data type is set |
1022 | (feature_found != context_dict.end() && |
1023 | (feature_found->second.kind_case() != Feature::KIND_NOT_SET)); |
1024 | |
1025 | if (feature_has_data) { |
1026 | const Feature& f = feature_found->second; |
1027 | bool types_match; |
1028 | OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); |
1029 | OP_REQUIRES( |
1030 | ctx, types_match, |
1031 | errors::InvalidArgument("Name: " , name, ", Context feature: " , key, |
1032 | ". Data types don't match. " , |
1033 | "Expected type: " , DataTypeString(dtype), |
1034 | " Feature is: " , f.DebugString())); |
1035 | |
1036 | Tensor feature_values = FeatureSparseCopy(0, key, dtype, f); |
1037 | const int64_t num_elements = feature_values.NumElements(); |
1038 | TensorShape indices_shape({num_elements, 1}); |
1039 | Tensor* sp_indices_d = nullptr; |
1040 | Tensor* sp_shape_d = nullptr; |
1041 | OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape, |
1042 | &sp_indices_d)); |
1043 | context_sparse_values.set(d, feature_values); |
1044 | OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}), |
1045 | &sp_shape_d)); |
1046 | auto shape_t = sp_shape_d->vec<int64_t>(); |
1047 | shape_t(0) = num_elements; |
1048 | auto indices_t = sp_indices_d->matrix<int64_t>(); |
1049 | std::iota(indices_t.data(), indices_t.data() + num_elements, 0); |
1050 | } else { |
1051 | TensorShape indices_shape({0, 1}); |
1052 | TensorShape values_shape({0}); |
1053 | Tensor* sp_indices_d = nullptr; |
1054 | Tensor* sp_values_d = nullptr; |
1055 | Tensor* sp_shape_d = nullptr; |
1056 | OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape, |
1057 | &sp_indices_d)); |
1058 | OP_REQUIRES_OK( |
1059 | ctx, context_sparse_values.allocate(d, values_shape, &sp_values_d)); |
1060 | OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}), |
1061 | &sp_shape_d)); |
1062 | auto shape_t = sp_shape_d->vec<int64_t>(); |
1063 | shape_t(0) = 0; |
1064 | } |
1065 | } |
1066 | |
1067 | // Feature List Dense ------------------------------------------------------ |
1068 | |
1069 | // Preallocate context_dense_values, since we can infer their |
1070 | // sizes |
1071 | const FeatureLists& feature_lists = ex.feature_lists(); |
1072 | const auto& feature_list_dict = feature_lists.feature_list(); |
1073 | FeatureList empty_feature_list; // Placeholder for missing FLs |
1074 | |
1075 | for (int d = 0; d < attrs_.num_feature_list_dense; ++d) { |
1076 | const tstring& key = feature_list_dense_keys_t[d]; |
1077 | const DataType& dtype = attrs_.feature_list_dense_types[d]; |
1078 | const TensorShape& shape = attrs_.feature_list_dense_shapes[d]; |
1079 | |
1080 | const auto& feature_list_found = feature_list_dict.find(key); |
1081 | bool feature_list_missing = |
1082 | (feature_list_found == feature_list_dict.end()); |
1083 | bool feature_list_allowed_missing = |
1084 | (feature_list_dense_missing_assumed_empty_set.count(key) > 0); |
1085 | |
1086 | OP_REQUIRES( |
1087 | ctx, !feature_list_missing || feature_list_allowed_missing, |
1088 | errors::InvalidArgument("Name: " , name, ", Feature list '" , key, |
1089 | "' is required but could not be found. " |
1090 | "Did you mean to include it in " |
1091 | "feature_list_dense_missing_assumed_empty or " |
1092 | "feature_list_dense_defaults?" )); |
1093 | |
1094 | TensorShape out_shape; |
1095 | const FeatureList& fl = (feature_list_missing) |
1096 | ? empty_feature_list |
1097 | : feature_list_found->second; |
1098 | out_shape.AddDim(fl.feature_size()); |
1099 | for (const int dim : attrs_.feature_list_dense_shapes[d].dim_sizes()) { |
1100 | out_shape.AddDim(dim); |
1101 | } |
1102 | Tensor* out = nullptr; |
1103 | OP_REQUIRES_OK(ctx, |
1104 | feature_list_dense_values.allocate(d, out_shape, &out)); |
1105 | |
1106 | for (int64_t t = 0; t < fl.feature_size(); ++t) { |
1107 | const Feature& f = fl.feature(t); |
1108 | bool types_match; |
1109 | OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); |
1110 | OP_REQUIRES(ctx, types_match, |
1111 | errors::InvalidArgument( |
1112 | "Name: " , name, ", Feature list: " , key, ", Index: " , t, |
1113 | ". Data types don't match. " , |
1114 | "Expected type: " , DataTypeString(dtype), |
1115 | " Feature is: " , f.DebugString())); |
1116 | OP_REQUIRES_OK(ctx, FeatureDenseCopy(t, name, key, dtype, shape, f, |
1117 | feature_list_dense_values[d])); |
1118 | } |
1119 | } |
1120 | |
1121 | // Feature List Sparse ----------------------------------------------------- |
1122 | for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) { |
1123 | const tstring& key = feature_list_sparse_keys_t[d]; |
1124 | const DataType& dtype = attrs_.feature_list_sparse_types[d]; |
1125 | |
1126 | const auto& feature_list_found = feature_list_dict.find(key); |
1127 | bool feature_list_has_data = // Found key |
1128 | (feature_list_found != feature_list_dict.end()); |
1129 | |
1130 | std::vector<Tensor> sparse_values_tmp; |
1131 | int64_t feature_list_size = 0; |
1132 | if (feature_list_has_data) { |
1133 | const FeatureList& fl = feature_list_found->second; |
1134 | feature_list_size = fl.feature_size(); |
1135 | for (int64_t t = 0; t < feature_list_size; ++t) { |
1136 | const Feature& f = fl.feature(t); |
1137 | bool types_match; |
1138 | OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match)); |
1139 | OP_REQUIRES( |
1140 | ctx, f.kind_case() == Feature::KIND_NOT_SET || types_match, |
1141 | errors::InvalidArgument("Name: " , name, ", Feature List: " , key, |
1142 | ", Index: " , t, |
1143 | ". Data types don't match. " , |
1144 | "Expected type: " , DataTypeString(dtype), |
1145 | " Feature is: " , f.DebugString())); |
1146 | sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f)); |
1147 | } |
1148 | } else { |
1149 | sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0}))); |
1150 | } |
1151 | |
1152 | int64_t total_num_features = 0; |
1153 | int64_t max_num_features = 0; |
1154 | for (int t = 0; t < feature_list_size; ++t) { |
1155 | const Tensor& v = sparse_values_tmp[t]; |
1156 | const int64_t num_elements = v.shape().num_elements(); |
1157 | total_num_features += num_elements; |
1158 | max_num_features = std::max(max_num_features, num_elements); |
1159 | } |
1160 | |
1161 | TensorShape indices_shape({total_num_features, 2}); |
1162 | TensorShape values_shape({total_num_features}); |
1163 | Tensor* sp_indices_d = nullptr; |
1164 | Tensor* sp_values_d = nullptr; |
1165 | Tensor* sp_shape_d = nullptr; |
1166 | OP_REQUIRES_OK(ctx, feature_list_sparse_indices.allocate(d, indices_shape, |
1167 | &sp_indices_d)); |
1168 | OP_REQUIRES_OK(ctx, feature_list_sparse_values.allocate(d, values_shape, |
1169 | &sp_values_d)); |
1170 | OP_REQUIRES_OK(ctx, feature_list_sparse_shapes.allocate( |
1171 | d, TensorShape({2}), &sp_shape_d)); |
1172 | auto shape_t = sp_shape_d->vec<int64_t>(); |
1173 | shape_t(0) = feature_list_size; |
1174 | shape_t(1) = max_num_features; |
1175 | |
1176 | int64_t offset = 0; |
1177 | |
1178 | for (int t = 0; t < feature_list_size; ++t) { |
1179 | const int64_t num_elements = CopyIntoSparseTensor( |
1180 | sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d); |
1181 | offset += num_elements; |
1182 | } |
1183 | } |
1184 | } |
1185 | |
1186 | protected: |
1187 | ParseSingleSequenceExampleAttrs attrs_; |
1188 | absl::once_flag flag_; |
1189 | }; |
1190 | |
1191 | REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample" ).Device(DEVICE_CPU), |
1192 | ParseSingleSequenceExampleOp); |
1193 | |
1194 | #ifndef IS_MOBILE_PLATFORM |
1195 | // when using lite protos on mobile, decoding JSON is not available. |
1196 | |
1197 | class DecodeJSONExampleOp : public OpKernel { |
1198 | public: |
1199 | explicit DecodeJSONExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) { |
1200 | resolver_.reset(protobuf::util::NewTypeResolverForDescriptorPool( |
1201 | "type.googleapis.com" , protobuf::DescriptorPool::generated_pool())); |
1202 | } |
1203 | |
1204 | void Compute(OpKernelContext* ctx) override { |
1205 | const Tensor* json_examples; |
1206 | OP_REQUIRES_OK(ctx, ctx->input("json_examples" , &json_examples)); |
1207 | Tensor* binary_examples; |
1208 | OP_REQUIRES_OK( |
1209 | ctx, ctx->allocate_output("binary_examples" , json_examples->shape(), |
1210 | &binary_examples)); |
1211 | |
1212 | for (int i = 0; i < json_examples->NumElements(); ++i) { |
1213 | const tstring& json_example = json_examples->flat<tstring>()(i); |
1214 | protobuf::io::ArrayInputStream in(json_example.data(), |
1215 | json_example.size()); |
1216 | TStringOutputStream out(&binary_examples->flat<tstring>()(i)); |
1217 | auto status = protobuf::util::JsonToBinaryStream( |
1218 | resolver_.get(), "type.googleapis.com/tensorflow.Example" , &in, &out); |
1219 | OP_REQUIRES(ctx, status.ok(), |
1220 | errors::InvalidArgument("Error while parsing JSON: " , |
1221 | string(status.message()))); |
1222 | } |
1223 | } |
1224 | |
1225 | private: |
1226 | std::unique_ptr<protobuf::util::TypeResolver> resolver_; |
1227 | }; |
1228 | |
1229 | REGISTER_KERNEL_BUILDER(Name("DecodeJSONExample" ).Device(DEVICE_CPU), |
1230 | DecodeJSONExampleOp); |
1231 | #endif |
1232 | |
1233 | } // namespace tensorflow |
1234 | |