1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/parsing_ops.cc.
17
18#include <numeric>
19#include <unordered_set>
20#include <vector>
21
22#include "absl/base/call_once.h"
23#include "tensorflow/core/example/example.pb.h"
24#include "tensorflow/core/example/feature.pb.h"
25#include "tensorflow/core/framework/common_shape_fns.h"
26#include "tensorflow/core/framework/metrics.h"
27#include "tensorflow/core/framework/numeric_op.h"
28#include "tensorflow/core/framework/register_types.h"
29#include "tensorflow/core/lib/core/errors.h"
30#include "tensorflow/core/lib/gtl/array_slice.h"
31#include "tensorflow/core/platform/logging.h"
32#include "tensorflow/core/platform/protobuf.h"
33#include "tensorflow/core/util/example_proto_fast_parsing.h"
34#include "tensorflow/core/util/example_proto_helper.h"
35#include "tensorflow/core/util/sparse/sparse_tensor.h"
36#include "tensorflow/core/util/work_sharder.h"
37
38namespace tensorflow {
39
40namespace {
41constexpr char kParseExampleV2[] = "ParseExampleV2";
42constexpr char kParseSequenceExampleV2[] = "ParseSequenceExampleV2";
43} // namespace
44
45// Note: this kernel is used by both the ParseExample op and the ParseExampleV2
46// op. It automatically determines which op was used by checking if the
47// "ragged_value_types" attribute exists.
48class ParseExampleOp : public OpKernel {
49 public:
50 explicit ParseExampleOp(OpKernelConstruction* ctx)
51 : OpKernel(ctx), op_version_(ctx->def().op() == kParseExampleV2 ? 2 : 1) {
52 OP_REQUIRES_OK(ctx, attrs_.Init(ctx, op_version_));
53 }
54
55 void Compute(OpKernelContext* ctx) override {
56 const Tensor* names;
57 const Tensor* serialized;
58 std::vector<StringPiece> dense_keys_t;
59 std::vector<StringPiece> sparse_keys_t;
60 std::vector<StringPiece> ragged_keys_t;
61 OpInputList dense_defaults;
62
63 // Grab the inputs.
64 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
65 OP_REQUIRES_OK(ctx, ctx->input("names", &names));
66 if (op_version_ == 2) {
67 OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "dense_keys", &dense_keys_t));
68 OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "sparse_keys", &sparse_keys_t));
69 OP_REQUIRES_OK(ctx, GetTensorKeys(ctx, "ragged_keys", &ragged_keys_t));
70 } else {
71 OP_REQUIRES_OK(ctx, GetInputListKeys(ctx, "dense_keys", &dense_keys_t));
72 OP_REQUIRES_OK(ctx, GetInputListKeys(ctx, "sparse_keys", &sparse_keys_t));
73 }
74 absl::call_once(flag_, [&dense_keys_t, &sparse_keys_t, &ragged_keys_t]() {
75 metrics::RecordParseDenseFeature(dense_keys_t.size());
76 metrics::RecordParseSparseFeature(sparse_keys_t.size());
77 metrics::RecordParseRaggedFeature(ragged_keys_t.size());
78 });
79 OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
80
81 // Validate input tensor shapes.
82 OP_REQUIRES_OK(
83 ctx, CheckInputShapes(serialized, names, dense_defaults, dense_keys_t,
84 sparse_keys_t, ragged_keys_t));
85
86 example::FastParseExampleConfig config =
87 MakeConfig(dense_keys_t, sparse_keys_t, ragged_keys_t, dense_defaults);
88
89 example::Result result;
90 if (TensorShapeUtils::IsVector(serialized->shape())) {
91 OP_REQUIRES_OK(
92 ctx, ParseExampleVector(config, serialized, names, ctx, &result));
93 } else {
94 OP_REQUIRES_OK(ctx, ParseExampleScalar(config, serialized, ctx, &result));
95 }
96 OP_REQUIRES_OK(ctx, WriteOutput(result, ctx));
97 }
98
99 protected:
100 // Copies keys from tensor to std::vector<string>.
101 Status GetTensorKeys(OpKernelContext* ctx, StringPiece input_name,
102 std::vector<StringPiece>* keys) const {
103 const Tensor* key_t;
104 TF_RETURN_IF_ERROR(ctx->input(input_name, &key_t));
105 keys->reserve(key_t->NumElements());
106 auto keys_flat = key_t->flat<tstring>();
107 for (int i = 0; i < keys_flat.size(); ++i) {
108 keys->push_back(keys_flat(i));
109 }
110 return OkStatus();
111 }
112
113 // Copies keys from OpInputList of scalar to std::vector<string>.
114 Status GetInputListKeys(OpKernelContext* ctx, StringPiece input_name,
115 std::vector<StringPiece>* keys) const {
116 OpInputList key_list;
117 TF_RETURN_IF_ERROR(ctx->input_list(input_name, &key_list));
118 keys->reserve(key_list.size());
119 for (const auto& key : key_list) {
120 keys->push_back(key.scalar<tstring>()());
121 }
122 return OkStatus();
123 }
124
125 // Validates the shapes of input tensors.
126 Status CheckInputShapes(const Tensor* serialized, const Tensor* names,
127 const OpInputList& dense_defaults,
128 const std::vector<StringPiece>& dense_keys_t,
129 const std::vector<StringPiece>& sparse_keys_t,
130 const std::vector<StringPiece>& ragged_keys_t) const {
131 if (op_version_ == 2) {
132 if (TensorShapeUtils::IsMatrixOrHigher(serialized->shape())) {
133 return errors::InvalidArgument(
134 "Expected serialized to be a scalar or vector, got shape: ",
135 serialized->shape().DebugString());
136 }
137 } else {
138 if (!TensorShapeUtils::IsVector(serialized->shape())) {
139 return errors::InvalidArgument(
140 "Expected serialized to be a vector, got shape: ",
141 serialized->shape().DebugString());
142 }
143 }
144 if (names->NumElements() > 0 && names->shape() != serialized->shape()) {
145 return errors::InvalidArgument(
146 "Expected names have the same shape as serialized: name.shape=",
147 names->shape().DebugString(),
148 ", serialized.shape=", serialized->shape().DebugString());
149 }
150 if (op_version_ == 2) {
151 if (dense_keys_t.size() != attrs_.num_dense) {
152 return errors::InvalidArgument(
153 "Expected len(dense_keys) == len(dense_types) but got: ",
154 dense_keys_t.size(), " vs. ", attrs_.num_dense);
155 }
156 if (sparse_keys_t.size() != attrs_.num_sparse) {
157 return errors::InvalidArgument(
158 "Expected len(sparse_keys) == num_sparse but got: ",
159 sparse_keys_t.size(), " vs. ", attrs_.num_sparse);
160 }
161 if (ragged_keys_t.size() != attrs_.num_ragged) {
162 return errors::InvalidArgument(
163 "Expected len(ragged_keys) == len(ragged_value_types) but got: ",
164 ragged_keys_t.size(), " vs. ", attrs_.num_ragged);
165 }
166 }
167
168 if (dense_defaults.size() != attrs_.num_dense) {
169 return errors::InvalidArgument(
170 "Expected len(dense_defaults) == len(dense_keys) but got: ",
171 dense_defaults.size(), " vs. ", attrs_.num_dense);
172 }
173
174 for (int d = 0; d < static_cast<int>(attrs_.num_dense); ++d) {
175 const Tensor& def_value = dense_defaults[d];
176 if (attrs_.variable_length[d]) {
177 if (def_value.NumElements() != 1) {
178 return errors::InvalidArgument(
179 "dense_shape[", d, "] is a variable length shape: ",
180 attrs_.dense_shapes[d].DebugString(),
181 ", therefore "
182 "def_value[",
183 d,
184 "] must contain a single element ("
185 "the padding element). But its shape is: ",
186 def_value.shape().DebugString());
187 }
188 } else if (def_value.NumElements() > 0) {
189 if (!attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape())) {
190 return errors::InvalidArgument(
191 "def_value[", d, "].shape() == ", def_value.shape().DebugString(),
192 " is not compatible with dense_shapes_[", d,
193 "] == ", attrs_.dense_shapes[d].DebugString());
194 }
195 }
196 if (def_value.dtype() != attrs_.dense_types[d]) {
197 return errors::InvalidArgument(
198 "dense_defaults[", d,
199 "].dtype() == ", DataTypeString(def_value.dtype()),
200 " != dense_types_[", d,
201 "] == ", DataTypeString(attrs_.dense_types[d]));
202 }
203 }
204 return OkStatus();
205 }
206
207 // Populates the FastParseExampleConfig from keys & defaults.
208 example::FastParseExampleConfig MakeConfig(
209 const std::vector<StringPiece>& dense_keys_t,
210 const std::vector<StringPiece>& sparse_keys_t,
211 const std::vector<StringPiece>& ragged_keys_t,
212 const OpInputList& dense_defaults) const {
213 example::FastParseExampleConfig config;
214 config.dense.reserve(attrs_.num_dense);
215 for (int d = 0; d < attrs_.num_dense; ++d) {
216 config.dense.emplace_back(dense_keys_t[d], attrs_.dense_types[d],
217 attrs_.dense_shapes[d], dense_defaults[d],
218 attrs_.variable_length[d],
219 attrs_.elements_per_stride[d]);
220 }
221 config.sparse.reserve(attrs_.num_sparse);
222 for (int d = 0; d < attrs_.num_sparse; ++d) {
223 config.sparse.emplace_back(sparse_keys_t[d], attrs_.sparse_types[d]);
224 }
225 config.ragged.reserve(attrs_.num_ragged);
226 for (int d = 0; d < attrs_.num_ragged; ++d) {
227 config.ragged.emplace_back(ragged_keys_t[d], attrs_.ragged_value_types[d],
228 attrs_.ragged_split_types[d]);
229 }
230 return config;
231 }
232
233 // Parses a single example.
234 Status ParseExampleScalar(const example::FastParseExampleConfig& config,
235 const Tensor* serialized, OpKernelContext* ctx,
236 example::Result* result) const {
237 const tstring& serialized_proto = serialized->scalar<tstring>()();
238 return FastParseSingleExample(config, serialized_proto, result);
239 }
240
241 // Parses a vector of examples.
242 Status ParseExampleVector(const example::FastParseExampleConfig& config,
243 const Tensor* serialized, const Tensor* names,
244 OpKernelContext* ctx,
245 example::Result* result) const {
246 auto serialized_t = serialized->flat<tstring>();
247 auto names_t = names->flat<tstring>();
248 gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size());
249 gtl::ArraySlice<tstring> names_slice(names_t.data(), names_t.size());
250 return FastParseExample(
251 config, slice, names_slice,
252 ctx->device()->tensorflow_cpu_worker_threads()->workers, result);
253 }
254
255 Status WriteOutput(const example::Result& result,
256 OpKernelContext* ctx) const {
257 OpOutputList dense_values;
258 OpOutputList sparse_indices;
259 OpOutputList sparse_values;
260 OpOutputList sparse_shapes;
261 TF_RETURN_IF_ERROR(ctx->output_list("dense_values", &dense_values));
262 TF_RETURN_IF_ERROR(ctx->output_list("sparse_indices", &sparse_indices));
263 TF_RETURN_IF_ERROR(ctx->output_list("sparse_values", &sparse_values));
264 TF_RETURN_IF_ERROR(ctx->output_list("sparse_shapes", &sparse_shapes));
265 for (int d = 0; d < attrs_.num_dense; ++d) {
266 dense_values.set(d, result.dense_values[d]);
267 }
268 for (int d = 0; d < attrs_.num_sparse; ++d) {
269 sparse_indices.set(d, result.sparse_indices[d]);
270 sparse_values.set(d, result.sparse_values[d]);
271 sparse_shapes.set(d, result.sparse_shapes[d]);
272 }
273 if (op_version_ == 2) {
274 OpOutputList ragged_values;
275 OpOutputList ragged_splits;
276 TF_RETURN_IF_ERROR(ctx->output_list("ragged_values", &ragged_values));
277 TF_RETURN_IF_ERROR(ctx->output_list("ragged_row_splits", &ragged_splits));
278 for (int d = 0; d < attrs_.num_ragged; ++d) {
279 ragged_values.set(d, result.ragged_values[d]);
280 ragged_splits.set(d, result.ragged_splits[d]);
281 }
282 }
283 return OkStatus();
284 }
285
286 ParseExampleAttrs attrs_;
287 int op_version_;
288 absl::once_flag flag_;
289};
290
291REGISTER_KERNEL_BUILDER(Name("ParseExample").Device(DEVICE_CPU),
292 ParseExampleOp);
293REGISTER_KERNEL_BUILDER(Name("ParseExampleV2").Device(DEVICE_CPU),
294 ParseExampleOp);
295
296class ParseSingleExampleOp : public OpKernel {
297 public:
298 explicit ParseSingleExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
299 OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
300 metrics::RecordParseDenseFeature(attrs_.dense_keys.size());
301 metrics::RecordParseSparseFeature(attrs_.sparse_keys.size());
302 }
303
304 void Compute(OpKernelContext* ctx) override {
305 const Tensor* serialized;
306 OpInputList dense_defaults;
307
308 // Grab the input list arguments.
309 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
310 OP_REQUIRES_OK(ctx, ctx->input_list("dense_defaults", &dense_defaults));
311
312 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
313 errors::InvalidArgument(
314 "Expected serialized to be a scalar, got shape: ",
315 serialized->shape().DebugString()));
316 OP_REQUIRES(ctx, dense_defaults.size() == attrs_.dense_keys.size(),
317 errors::InvalidArgument(
318 "Expected len(dense_defaults) == len(dense_keys) but got: ",
319 dense_defaults.size(), " vs. ", attrs_.dense_keys.size()));
320
321 for (size_t d = 0; d < attrs_.dense_keys.size(); ++d) {
322 const Tensor& def_value = dense_defaults[d];
323 if (attrs_.variable_length[d]) {
324 OP_REQUIRES(ctx, def_value.NumElements() == 1,
325 errors::InvalidArgument(
326 "dense_shape[", d, "] is a variable length shape: ",
327 attrs_.dense_shapes[d].DebugString(),
328 ", therefore "
329 "def_value[",
330 d,
331 "] must contain a single element ("
332 "the padding element). But its shape is: ",
333 def_value.shape().DebugString()));
334 } else if (def_value.NumElements() > 0) {
335 OP_REQUIRES(ctx,
336 attrs_.dense_shapes[d].IsCompatibleWith(def_value.shape()),
337 errors::InvalidArgument(
338 "def_value[", d,
339 "].shape() == ", def_value.shape().DebugString(),
340 " is not compatible with dense_shapes_[", d,
341 "] == ", attrs_.dense_shapes[d].DebugString()));
342 }
343 OP_REQUIRES(ctx, def_value.dtype() == attrs_.dense_types[d],
344 errors::InvalidArgument(
345 "dense_defaults[", d, "].dtype() == ",
346 DataTypeString(def_value.dtype()), " != dense_types_[", d,
347 "] == ", DataTypeString(attrs_.dense_types[d])));
348 }
349
350 example::Result result;
351
352 // TODO(mrry): Build the configuration once and cache it.
353 example::FastParseExampleConfig config;
354 for (int d = 0; d < attrs_.dense_keys.size(); ++d) {
355 config.dense.push_back({attrs_.dense_keys[d], attrs_.dense_types[d],
356 attrs_.dense_shapes[d], dense_defaults[d],
357 attrs_.variable_length[d],
358 attrs_.elements_per_stride[d]});
359 }
360 for (int d = 0; d < attrs_.sparse_keys.size(); ++d) {
361 config.sparse.push_back({attrs_.sparse_keys[d], attrs_.sparse_types[d]});
362 }
363
364 const tstring& serialized_proto = serialized->scalar<tstring>()();
365
366 OP_REQUIRES_OK(ctx,
367 FastParseSingleExample(config, serialized_proto, &result));
368
369 OpOutputList dense_values;
370 OpOutputList sparse_indices;
371 OpOutputList sparse_values;
372 OpOutputList sparse_shapes;
373 OP_REQUIRES_OK(ctx, ctx->output_list("dense_values", &dense_values));
374 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_indices", &sparse_indices));
375 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_values", &sparse_values));
376 OP_REQUIRES_OK(ctx, ctx->output_list("sparse_shapes", &sparse_shapes));
377 for (int d = 0; d < attrs_.dense_keys.size(); ++d) {
378 dense_values.set(d, result.dense_values[d]);
379 }
380 for (int d = 0; d < attrs_.sparse_keys.size(); ++d) {
381 sparse_indices.set(d, result.sparse_indices[d]);
382 sparse_values.set(d, result.sparse_values[d]);
383 sparse_shapes.set(d, result.sparse_shapes[d]);
384 }
385 }
386
387 protected:
388 ParseSingleExampleAttrs attrs_;
389};
390
391REGISTER_KERNEL_BUILDER(Name("ParseSingleExample").Device(DEVICE_CPU),
392 ParseSingleExampleOp);
393
394class ParseSequenceExampleOp : public OpKernel {
395 public:
396 explicit ParseSequenceExampleOp(OpKernelConstruction* ctx)
397 : OpKernel(ctx),
398 op_version_(ctx->def().op() == kParseSequenceExampleV2 ? 2 : 1) {
399 OP_REQUIRES_OK(ctx, attrs_.Init(ctx, op_version_));
400 metrics::RecordParseDenseFeature(attrs_.context_dense_keys.size() +
401 attrs_.feature_list_dense_keys.size());
402 metrics::RecordParseSparseFeature(attrs_.context_sparse_keys.size() +
403 attrs_.feature_list_sparse_keys.size());
404 }
405
406 void Compute(OpKernelContext* ctx) override {
407 const Tensor* debug_name;
408 const Tensor* serialized;
409 OpInputList context_dense_defaults;
410
411 OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
412 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
413 OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
414 &context_dense_defaults));
415 const Tensor* context_dense_keys = nullptr;
416 const Tensor* context_sparse_keys = nullptr;
417 const Tensor* context_ragged_keys = nullptr;
418 const Tensor* feature_list_dense_keys = nullptr;
419 const Tensor* feature_list_sparse_keys = nullptr;
420 const Tensor* feature_list_ragged_keys = nullptr;
421 const Tensor* feature_list_dense_missing_assumed_empty = nullptr;
422 if (op_version_ == 2) {
423 OP_REQUIRES_OK(ctx,
424 ctx->input("feature_list_dense_missing_assumed_empty",
425 &feature_list_dense_missing_assumed_empty));
426 OP_REQUIRES_OK(ctx,
427 ctx->input("context_dense_keys", &context_dense_keys));
428 OP_REQUIRES_OK(ctx,
429 ctx->input("context_sparse_keys", &context_sparse_keys));
430 OP_REQUIRES_OK(ctx,
431 ctx->input("context_ragged_keys", &context_ragged_keys));
432 OP_REQUIRES_OK(
433 ctx, ctx->input("feature_list_dense_keys", &feature_list_dense_keys));
434 OP_REQUIRES_OK(ctx, ctx->input("feature_list_sparse_keys",
435 &feature_list_sparse_keys));
436 OP_REQUIRES_OK(ctx, ctx->input("feature_list_ragged_keys",
437 &feature_list_ragged_keys));
438 absl::call_once(flag_, [&]() {
439 metrics::RecordParseDenseFeature(
440 context_dense_keys->NumElements() +
441 feature_list_dense_keys->NumElements());
442 metrics::RecordParseSparseFeature(
443 context_sparse_keys->NumElements() +
444 feature_list_sparse_keys->NumElements());
445 metrics::RecordParseRaggedFeature(
446 context_ragged_keys->NumElements() +
447 feature_list_ragged_keys->NumElements());
448 });
449 }
450
451 // Validate input tensor shapes.
452 OP_REQUIRES_OK(ctx, CheckInputShapes(
453 serialized, debug_name, context_dense_defaults,
454 context_dense_keys, context_sparse_keys,
455 context_ragged_keys, feature_list_dense_keys,
456 feature_list_sparse_keys, feature_list_ragged_keys,
457 feature_list_dense_missing_assumed_empty));
458
459 example::FastParseExampleConfig context_config =
460 MakeContextConfig(context_dense_keys, context_sparse_keys,
461 context_ragged_keys, context_dense_defaults);
462 example::FastParseExampleConfig feature_list_config = MakeFeatureListConfig(
463 feature_list_dense_keys, feature_list_sparse_keys,
464 feature_list_ragged_keys, feature_list_dense_missing_assumed_empty);
465
466 bool is_batch = TensorShapeUtils::IsVector(serialized->shape());
467 auto serialized_t = serialized->flat<tstring>();
468 auto debug_name_t = debug_name->flat<tstring>();
469 gtl::ArraySlice<tstring> slice(serialized_t.data(), serialized_t.size());
470 gtl::ArraySlice<tstring> names_slice(debug_name_t.data(),
471 debug_name_t.size());
472
473 example::Result context_result, feature_list_result;
474 std::vector<Tensor> dense_feature_lengths;
475 OP_REQUIRES_OK(
476 ctx, FastParseSequenceExample(
477 context_config, feature_list_config, slice, names_slice,
478 ctx->device()->tensorflow_cpu_worker_threads()->workers,
479 &context_result, &feature_list_result, &dense_feature_lengths,
480 is_batch));
481
482 OP_REQUIRES_OK(ctx, WriteOutput(context_result, feature_list_result,
483 dense_feature_lengths, ctx));
484 }
485
486 protected:
487 Status CheckInputShapes(
488 const Tensor* serialized, const Tensor* names,
489 const OpInputList& context_dense_defaults,
490
491 const Tensor* context_dense_keys, const Tensor* context_sparse_keys,
492 const Tensor* context_ragged_keys, const Tensor* feature_list_dense_keys,
493 const Tensor* feature_list_sparse_keys,
494 const Tensor* feature_list_ragged_keys,
495 const Tensor* feature_list_dense_missing_assumed_empty) const {
496 if (TensorShapeUtils::IsMatrixOrHigher(serialized->shape())) {
497 return errors::InvalidArgument(
498 "Expected serialized to be a scalar or vector, got shape: ",
499 serialized->shape().DebugString());
500 }
501 if (op_version_ > 1) {
502 if (context_dense_keys->NumElements() != attrs_.num_context_dense) {
503 return errors::InvalidArgument(
504 "Expected len(context_dense_keys) to match len(Tcontext_dense)");
505 }
506 if (context_sparse_keys->NumElements() != attrs_.num_context_sparse) {
507 return errors::InvalidArgument(
508 "Expected len(context_sparse_keys) to match Ncontext_sparse");
509 }
510 if (context_ragged_keys->NumElements() != attrs_.num_context_ragged) {
511 return errors::InvalidArgument(
512 "Expected len(context_ragged_keys) to match "
513 "len(context_ragged_value_types)");
514 }
515 if (feature_list_dense_keys->NumElements() !=
516 attrs_.num_feature_list_dense) {
517 return errors::InvalidArgument(
518 "Expected len(feature_list_dense_keys) to match "
519 "Nfeature_list_dense");
520 }
521 if (feature_list_dense_missing_assumed_empty->NumElements() !=
522 attrs_.num_feature_list_dense) {
523 return errors::InvalidArgument(
524 "Expected len(feature_list_dense_missing_assumed_empty to match "
525 "Nfeature_list_dense");
526 }
527 if (feature_list_sparse_keys->NumElements() !=
528 attrs_.num_feature_list_sparse) {
529 return errors::InvalidArgument(
530 "Expected len(feature_list_sparse_keys) to match "
531 "Nfeature_list_sparse");
532 }
533 if (feature_list_ragged_keys->NumElements() !=
534 attrs_.num_feature_list_ragged) {
535 return errors::InvalidArgument(
536 "Expected len(feature_list_ragged_keys) to match "
537 "len(feature_list_ragged_value_types)");
538 }
539 }
540 if (context_dense_defaults.size() != attrs_.num_context_dense) {
541 return errors::InvalidArgument(
542 "Expected len(context_dense_defaults) "
543 "== len(context_dense_keys) but got: ",
544 context_dense_defaults.size(), " vs. ", attrs_.num_context_dense);
545 }
546 for (int d = 0; d < attrs_.num_context_dense; ++d) {
547 const Tensor& def_value = context_dense_defaults[d];
548 if (def_value.NumElements() > 0) {
549 if (def_value.shape() != attrs_.context_dense_shapes[d]) {
550 return errors::InvalidArgument(
551 "default_value[", d,
552 "].shape() == ", def_value.shape().DebugString(),
553 " != context_dense_shapes[", d,
554 "] == ", attrs_.context_dense_shapes[d].DebugString());
555 }
556 if (def_value.dtype() != attrs_.context_dense_types[d]) {
557 return errors::InvalidArgument(
558 "context_dense_defaults[", d,
559 "].dtype() == ", DataTypeString(def_value.dtype()),
560 " != context_dense_types[", d,
561 "] == ", DataTypeString(attrs_.context_dense_types[d]));
562 }
563 }
564 }
565 return OkStatus();
566 }
567
568 example::FastParseExampleConfig MakeContextConfig(
569 const Tensor* dense_keys, const Tensor* sparse_keys,
570 const Tensor* ragged_keys,
571 const OpInputList& context_dense_defaults) const {
572 // Convert the tensors/attrs to ArraySlices once, instead of re-evaluating
573 // them in each loop iteration.
574 gtl::ArraySlice<tstring> dense_keys_slice =
575 dense_keys
576 ? gtl::ArraySlice<tstring>(dense_keys->flat<tstring>().data(),
577 attrs_.num_context_dense)
578 : attrs_.context_dense_keys;
579 gtl::ArraySlice<tstring> sparse_keys_slice =
580 sparse_keys
581 ? gtl::ArraySlice<tstring>(sparse_keys->flat<tstring>().data(),
582 attrs_.num_context_sparse)
583 : attrs_.context_sparse_keys;
584 gtl::ArraySlice<tstring> ragged_keys_slice =
585 ragged_keys
586 ? gtl::ArraySlice<tstring>(ragged_keys->flat<tstring>().data(),
587 attrs_.num_context_ragged)
588 : gtl::ArraySlice<tstring>(nullptr, 0);
589
590 example::FastParseExampleConfig config;
591 config.dense.reserve(attrs_.num_context_dense);
592 for (int d = 0; d < attrs_.num_context_dense; ++d) {
593 const tstring& key = dense_keys_slice[d];
594 config.dense.emplace_back(key, attrs_.context_dense_types[d],
595 attrs_.context_dense_shapes[d],
596 context_dense_defaults[d],
597 false /* attrs_.context_variable_length[d] */,
598 0 /*attrs_.context_elements_per_stride[d] */);
599 }
600 config.sparse.reserve(attrs_.num_context_sparse);
601 for (int d = 0; d < attrs_.num_context_sparse; ++d) {
602 const tstring& key = sparse_keys_slice[d];
603 config.sparse.emplace_back(key, attrs_.context_sparse_types[d]);
604 }
605 config.ragged.reserve(attrs_.num_context_ragged);
606 for (int d = 0; d < attrs_.num_context_ragged; ++d) {
607 config.ragged.emplace_back(ragged_keys_slice[d],
608 attrs_.context_ragged_value_types[d],
609 attrs_.context_ragged_split_types[d]);
610 }
611 return config;
612 }
613
614 static Tensor ConstructDefaultScalar(DataType dtype) {
615 switch (dtype) {
616 case DT_INT64:
617 return Tensor(static_cast<int64_t>(0));
618 case DT_FLOAT:
619 return Tensor(static_cast<float>(0.0));
620 case DT_STRING:
621 return Tensor("");
622 default:
623 return Tensor(DT_INVALID);
624 }
625 }
626
627 example::FastParseExampleConfig MakeFeatureListConfig(
628 const Tensor* dense_keys, const Tensor* sparse_keys,
629 const Tensor* ragged_keys,
630 const Tensor* feature_list_dense_missing_assumed_empty) const {
631 // Convert the tensors/attrs to ArraySlices once, instead of re-evaluating
632 // them in each loop iteration.
633 gtl::ArraySlice<tstring> dense_keys_slice =
634 dense_keys
635 ? gtl::ArraySlice<tstring>(dense_keys->flat<tstring>().data(),
636 attrs_.num_feature_list_dense)
637 : attrs_.feature_list_dense_keys;
638 gtl::ArraySlice<tstring> sparse_keys_slice =
639 sparse_keys
640 ? gtl::ArraySlice<tstring>(sparse_keys->flat<tstring>().data(),
641 attrs_.num_feature_list_sparse)
642 : attrs_.feature_list_sparse_keys;
643 gtl::ArraySlice<tstring> ragged_keys_slice =
644 ragged_keys
645 ? gtl::ArraySlice<tstring>(ragged_keys->flat<tstring>().data(),
646 attrs_.num_feature_list_ragged)
647 : gtl::ArraySlice<tstring>(nullptr, 0);
648 // Use an empty slice to indicate that the map in attrs_ should be used
649 // instead.
650 gtl::ArraySlice<bool> feature_list_dense_missing_assumed_empty_slice =
651 feature_list_dense_missing_assumed_empty
652 ? gtl::ArraySlice<bool>(
653 feature_list_dense_missing_assumed_empty->flat<bool>().data(),
654 attrs_.num_feature_list_dense)
655 : gtl::ArraySlice<bool>(nullptr, 0);
656
657 example::FastParseExampleConfig config;
658 config.dense.reserve(attrs_.num_feature_list_dense);
659 for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
660 const tstring& key = dense_keys_slice[d];
661 bool missing_assumed_empty =
662 !feature_list_dense_missing_assumed_empty_slice.empty()
663 ? feature_list_dense_missing_assumed_empty_slice[d]
664 : attrs_.feature_list_dense_missing_assumed_empty.count(key) > 0;
665 DataType dtype = attrs_.feature_list_dense_types[d];
666 config.dense.emplace_back(
667 key, dtype, attrs_.feature_list_dense_shapes[d],
668 ConstructDefaultScalar(dtype), missing_assumed_empty,
669 0 /*attrs_.feature_list_elements_per_stride[d] */);
670 }
671 config.sparse.reserve(attrs_.num_feature_list_sparse);
672 for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
673 const tstring& key = sparse_keys_slice[d];
674 config.sparse.emplace_back(key, attrs_.feature_list_sparse_types[d]);
675 }
676 config.ragged.reserve(attrs_.num_feature_list_ragged);
677 for (int d = 0; d < attrs_.num_feature_list_ragged; ++d) {
678 config.ragged.emplace_back(ragged_keys_slice[d],
679 attrs_.feature_list_ragged_value_types[d],
680 attrs_.feature_list_ragged_split_types[d]);
681 }
682 return config;
683 }
684
685 Status WriteOutput(const example::Result& context_result,
686 const example::Result& feature_list_result,
687 const std::vector<Tensor>& dense_feature_lengths,
688 OpKernelContext* ctx) const {
689 OpOutputList context_sparse_indices;
690 OpOutputList context_sparse_values;
691 OpOutputList context_sparse_shapes;
692 OpOutputList context_dense_values;
693 OpOutputList feature_list_sparse_indices;
694 OpOutputList feature_list_sparse_values;
695 OpOutputList feature_list_sparse_shapes;
696 OpOutputList feature_list_dense_values;
697 OpOutputList feature_list_dense_lengths;
698
699 TF_RETURN_IF_ERROR(
700 ctx->output_list("context_sparse_indices", &context_sparse_indices));
701 TF_RETURN_IF_ERROR(
702 ctx->output_list("context_sparse_values", &context_sparse_values));
703 TF_RETURN_IF_ERROR(
704 ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
705 TF_RETURN_IF_ERROR(
706 ctx->output_list("context_dense_values", &context_dense_values));
707 TF_RETURN_IF_ERROR(
708 ctx->output_list("context_sparse_indices", &context_sparse_indices));
709 TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_indices",
710 &feature_list_sparse_indices));
711 TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_values",
712 &feature_list_sparse_values));
713 TF_RETURN_IF_ERROR(ctx->output_list("feature_list_sparse_shapes",
714 &feature_list_sparse_shapes));
715 TF_RETURN_IF_ERROR(ctx->output_list("feature_list_dense_values",
716 &feature_list_dense_values));
717 TF_RETURN_IF_ERROR(ctx->output_list("feature_list_dense_lengths",
718 &feature_list_dense_lengths));
719 for (int d = 0; d < attrs_.num_context_dense; ++d) {
720 context_dense_values.set(d, context_result.dense_values[d]);
721 }
722 for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
723 feature_list_dense_values.set(d, feature_list_result.dense_values[d]);
724 feature_list_dense_lengths.set(d, dense_feature_lengths[d]);
725 }
726 for (int d = 0; d < attrs_.num_context_sparse; ++d) {
727 context_sparse_indices.set(d, context_result.sparse_indices[d]);
728 context_sparse_values.set(d, context_result.sparse_values[d]);
729 context_sparse_shapes.set(d, context_result.sparse_shapes[d]);
730 }
731 for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
732 feature_list_sparse_indices.set(d, feature_list_result.sparse_indices[d]);
733 feature_list_sparse_values.set(d, feature_list_result.sparse_values[d]);
734 feature_list_sparse_shapes.set(d, feature_list_result.sparse_shapes[d]);
735 }
736 if (op_version_ == 2) {
737 OpOutputList context_ragged_values;
738 OpOutputList context_ragged_splits;
739 OpOutputList feature_list_ragged_values;
740 OpOutputList feature_list_ragged_inner_splits;
741 OpOutputList feature_list_ragged_outer_splits;
742 TF_RETURN_IF_ERROR(
743 ctx->output_list("context_ragged_values", &context_ragged_values));
744 TF_RETURN_IF_ERROR(ctx->output_list("context_ragged_row_splits",
745 &context_ragged_splits));
746 TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_values",
747 &feature_list_ragged_values));
748 TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_inner_splits",
749 &feature_list_ragged_inner_splits));
750 TF_RETURN_IF_ERROR(ctx->output_list("feature_list_ragged_outer_splits",
751 &feature_list_ragged_outer_splits));
752 for (int d = 0; d < attrs_.num_context_ragged; ++d) {
753 context_ragged_values.set(d, context_result.ragged_values[d]);
754 context_ragged_splits.set(d, context_result.ragged_splits[d]);
755 }
756 for (int d = 0; d < attrs_.num_feature_list_ragged; ++d) {
757 feature_list_ragged_values.set(d, feature_list_result.ragged_values[d]);
758 feature_list_ragged_outer_splits.set(
759 d, feature_list_result.ragged_outer_splits[d]);
760 feature_list_ragged_inner_splits.set(
761 d, feature_list_result.ragged_splits[d]);
762 }
763 }
764 return OkStatus();
765 }
766
767 ParseSequenceExampleAttrs attrs_;
768 int op_version_;
769 absl::once_flag flag_;
770};
771
772REGISTER_KERNEL_BUILDER(Name("ParseSequenceExample").Device(DEVICE_CPU),
773 ParseSequenceExampleOp);
774REGISTER_KERNEL_BUILDER(Name("ParseSequenceExampleV2").Device(DEVICE_CPU),
775 ParseSequenceExampleOp);
776
777class ParseSingleSequenceExampleOp : public OpKernel {
778 public:
779 explicit ParseSingleSequenceExampleOp(OpKernelConstruction* ctx)
780 : OpKernel(ctx) {
781 OP_REQUIRES_OK(ctx, attrs_.Init(ctx));
782 }
783
784 void Compute(OpKernelContext* ctx) override {
785 const Tensor* debug_name;
786 const Tensor* serialized;
787 OpInputList context_dense_keys;
788 OpInputList context_sparse_keys;
789 OpInputList context_dense_defaults;
790 OpInputList feature_list_dense_keys;
791 OpInputList feature_list_sparse_keys;
792 const Tensor* feature_list_dense_missing_assumed_empty;
793
794 OP_REQUIRES_OK(ctx, ctx->input("debug_name", &debug_name));
795 OP_REQUIRES_OK(ctx, ctx->input("serialized", &serialized));
796 OP_REQUIRES_OK(ctx, ctx->input("feature_list_dense_missing_assumed_empty",
797 &feature_list_dense_missing_assumed_empty));
798 OP_REQUIRES_OK(ctx,
799 ctx->input_list("context_dense_keys", &context_dense_keys));
800 OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_dense_keys",
801 &feature_list_dense_keys));
802 OP_REQUIRES_OK(
803 ctx, ctx->input_list("context_sparse_keys", &context_sparse_keys));
804 OP_REQUIRES_OK(ctx, ctx->input_list("feature_list_sparse_keys",
805 &feature_list_sparse_keys));
806 OP_REQUIRES_OK(ctx, ctx->input_list("context_dense_defaults",
807 &context_dense_defaults));
808
809 std::vector<string> context_dense_keys_t(attrs_.num_context_dense);
810 std::vector<string> context_sparse_keys_t(attrs_.num_context_sparse);
811 std::vector<string> feature_list_dense_keys_t(
812 attrs_.num_feature_list_dense);
813 std::vector<string> feature_list_sparse_keys_t(
814 attrs_.num_feature_list_sparse);
815 absl::call_once(
816 flag_, [&context_dense_keys_t, &context_sparse_keys_t,
817 &feature_list_dense_keys_t, &feature_list_sparse_keys_t]() {
818 metrics::RecordParseDenseFeature(context_dense_keys_t.size() +
819 feature_list_dense_keys_t.size());
820 metrics::RecordParseSparseFeature(context_sparse_keys_t.size() +
821 feature_list_sparse_keys_t.size());
822 });
823 std::unordered_set<string> feature_list_dense_missing_assumed_empty_set;
824 CHECK_EQ(context_dense_keys.size(), attrs_.num_context_dense);
825 CHECK_EQ(context_sparse_keys.size(), attrs_.num_context_sparse);
826 CHECK_EQ(feature_list_dense_keys.size(), attrs_.num_feature_list_dense);
827 CHECK_EQ(feature_list_sparse_keys.size(), attrs_.num_feature_list_sparse);
828 for (int di = 0; di < attrs_.num_context_dense; ++di) {
829 OP_REQUIRES(ctx,
830 TensorShapeUtils::IsScalar(context_dense_keys[di].shape()),
831 errors::InvalidArgument(
832 "Expected context_dense_keys[", di,
833 "] to be a scalar, got shape: ",
834 context_dense_keys[di].shape().DebugString()));
835 context_dense_keys_t[di] = context_dense_keys[di].scalar<tstring>()();
836 }
837 for (int di = 0; di < attrs_.num_context_sparse; ++di) {
838 OP_REQUIRES(ctx,
839 TensorShapeUtils::IsScalar(context_sparse_keys[di].shape()),
840 errors::InvalidArgument(
841 "Expected context_sparse_keys[", di,
842 "] to be a scalar, got shape: ",
843 context_sparse_keys[di].shape().DebugString()));
844 context_sparse_keys_t[di] = context_sparse_keys[di].scalar<tstring>()();
845 }
846 for (int di = 0; di < attrs_.num_feature_list_dense; ++di) {
847 OP_REQUIRES(
848 ctx, TensorShapeUtils::IsScalar(feature_list_dense_keys[di].shape()),
849 errors::InvalidArgument(
850 "Expected feature_list_dense_keys[", di,
851 "] to be a scalar, got shape: ",
852 feature_list_dense_keys[di].shape().DebugString()));
853 feature_list_dense_keys_t[di] =
854 feature_list_dense_keys[di].scalar<tstring>()();
855 }
856 for (int di = 0; di < attrs_.num_feature_list_sparse; ++di) {
857 OP_REQUIRES(
858 ctx, TensorShapeUtils::IsScalar(feature_list_sparse_keys[di].shape()),
859 errors::InvalidArgument(
860 "Expected feature_list_sparse_keys[", di,
861 "] to be a scalar, got shape: ",
862 feature_list_sparse_keys[di].shape().DebugString()));
863 feature_list_sparse_keys_t[di] =
864 feature_list_sparse_keys[di].scalar<tstring>()();
865 }
866 OP_REQUIRES(
867 ctx,
868 TensorShapeUtils::IsVector(
869 feature_list_dense_missing_assumed_empty->shape()),
870 errors::InvalidArgument(
871 "Expected feature_list_dense_missing_assumed_empty ",
872 "to be a vector, got shape: ",
873 feature_list_dense_missing_assumed_empty->shape().DebugString()));
874 auto feature_list_dense_missing_assumped_empty_t =
875 feature_list_dense_missing_assumed_empty->vec<tstring>();
876 for (int de = 0;
877 de < feature_list_dense_missing_assumed_empty->NumElements(); ++de) {
878 feature_list_dense_missing_assumed_empty_set.insert(
879 feature_list_dense_missing_assumped_empty_t(de));
880 }
881
882 bool has_debug_name = (debug_name->NumElements() > 0);
883 if (has_debug_name) {
884 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(debug_name->shape()),
885 errors::InvalidArgument(
886 "Expected debug_name to be a scalar, got shape: ",
887 debug_name->shape().DebugString()));
888 }
889 auto debug_name_t = debug_name->scalar<tstring>();
890
891 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(serialized->shape()),
892 errors::InvalidArgument(
893 "Expected serialized to be a scalar, got shape: ",
894 serialized->shape().DebugString()));
895
896 OP_REQUIRES(ctx, context_dense_defaults.size() == attrs_.num_context_dense,
897 errors::InvalidArgument("Expected len(context_dense_defaults) "
898 "== len(context_dense_keys) but got: ",
899 context_dense_defaults.size(), " vs. ",
900 attrs_.num_context_dense));
901
902 std::vector<bool> required(attrs_.num_context_dense);
903 for (int d = 0; d < attrs_.num_context_dense; ++d) {
904 const Tensor& def_value = context_dense_defaults[d];
905 required[d] = (def_value.NumElements() == 0); // No default provided.
906
907 if (def_value.NumElements() > 0) {
908 OP_REQUIRES(ctx, def_value.shape() == attrs_.context_dense_shapes[d],
909 errors::InvalidArgument(
910 "def_value[", d,
911 "].shape() == ", def_value.shape().DebugString(),
912 " != context_dense_shapes_[", d,
913 "] == ", attrs_.context_dense_shapes[d].DebugString()));
914 OP_REQUIRES(
915 ctx, def_value.dtype() == attrs_.context_dense_types[d],
916 errors::InvalidArgument(
917 "context_dense_defaults[", d, "].dtype() == ",
918 DataTypeString(def_value.dtype()), " != context_dense_types_[",
919 d, "] == ", DataTypeString(attrs_.context_dense_types[d])));
920 }
921 }
922
923 auto serialized_t = serialized->scalar<tstring>();
924
925 OpOutputList context_sparse_indices;
926 OpOutputList context_sparse_values;
927 OpOutputList context_sparse_shapes;
928 OpOutputList context_dense_values;
929 OpOutputList feature_list_sparse_indices;
930 OpOutputList feature_list_sparse_values;
931 OpOutputList feature_list_sparse_shapes;
932 OpOutputList feature_list_dense_values;
933
934 OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
935 &context_sparse_indices));
936 OP_REQUIRES_OK(
937 ctx, ctx->output_list("context_sparse_values", &context_sparse_values));
938 OP_REQUIRES_OK(
939 ctx, ctx->output_list("context_sparse_shapes", &context_sparse_shapes));
940 OP_REQUIRES_OK(
941 ctx, ctx->output_list("context_dense_values", &context_dense_values));
942 OP_REQUIRES_OK(ctx, ctx->output_list("context_sparse_indices",
943 &context_sparse_indices));
944 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_indices",
945 &feature_list_sparse_indices));
946 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_values",
947 &feature_list_sparse_values));
948 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_sparse_shapes",
949 &feature_list_sparse_shapes));
950 OP_REQUIRES_OK(ctx, ctx->output_list("feature_list_dense_values",
951 &feature_list_dense_values));
952
953 // Allocate the SequenceExample on an arena. Provides better memory locality
954 // and greatly speeds up destruction.
955 protobuf::ArenaOptions options;
956 // We have some hint of what the final proto size will be based on the size
957 // of the serialized bytes- use this to set a custom allocation strategy.
958 // Note that the default allocation strategy is quite conservative (min
959 // block size of 256 bytes, and a max of 8 kilobytes).
960 const size_t block_size = serialized_t().size() * 1.1;
961 options.start_block_size = std::max(options.start_block_size, block_size);
962 options.max_block_size = std::max(options.max_block_size, block_size);
963 protobuf::Arena arena(options);
964 auto& ex = *protobuf::Arena::CreateMessage<SequenceExample>(&arena);
965
966 OP_REQUIRES(
967 ctx, ParseProtoUnlimited(&ex, serialized_t()),
968 errors::InvalidArgument("Could not parse example input, value: '",
969 serialized_t(), "'"));
970
971 const tstring& name = (has_debug_name) ? debug_name_t() : "<unknown>";
972 const Features& context = ex.context();
973 const auto& context_dict = context.feature();
974
975 // Context Dense -----------------------------------------------------------
976
977 // Preallocate context_dense_values, since we know their sizes
978 for (int d = 0; d < attrs_.num_context_dense; ++d) {
979 TensorShape out_shape;
980 for (const int dim : attrs_.context_dense_shapes[d].dim_sizes())
981 out_shape.AddDim(dim);
982 Tensor* out = nullptr;
983 OP_REQUIRES_OK(ctx, context_dense_values.allocate(d, out_shape, &out));
984 }
985
986 for (int d = 0; d < attrs_.num_context_dense; ++d) {
987 const tstring& key = context_dense_keys_t[d];
988 const DataType& dtype = attrs_.context_dense_types[d];
989 const TensorShape& shape = attrs_.context_dense_shapes[d];
990
991 const auto& feature_found = context_dict.find(key);
992 OP_REQUIRES(
993 ctx, (feature_found != context_dict.end()) || !required[d],
994 errors::InvalidArgument("Name: ", name, ", Context feature '", key,
995 "' is required but could not be found."));
996 if (feature_found != context_dict.end()) {
997 const Feature& f = feature_found->second;
998 bool types_match;
999 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
1000 OP_REQUIRES(
1001 ctx, types_match,
1002 errors::InvalidArgument("Name: ", name, ", Context feature: ", key,
1003 ". Data types don't match. ",
1004 "Expected type: ", DataTypeString(dtype),
1005 " Feature is: ", f.DebugString()));
1006
1007 OP_REQUIRES_OK(ctx, FeatureDenseCopy(0, name, key, dtype, shape, f,
1008 context_dense_values[d]));
1009 } else {
1010 RowDenseCopy(0, dtype, context_dense_defaults[d],
1011 context_dense_values[d]);
1012 }
1013 }
1014
1015 // Context Sparse ----------------------------------------------------------
1016 for (int d = 0; d < attrs_.num_context_sparse; ++d) {
1017 const tstring& key = context_sparse_keys_t[d];
1018 const DataType& dtype = attrs_.context_sparse_types[d];
1019
1020 const auto& feature_found = context_dict.find(key);
1021 bool feature_has_data = // Found key & data type is set
1022 (feature_found != context_dict.end() &&
1023 (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
1024
1025 if (feature_has_data) {
1026 const Feature& f = feature_found->second;
1027 bool types_match;
1028 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
1029 OP_REQUIRES(
1030 ctx, types_match,
1031 errors::InvalidArgument("Name: ", name, ", Context feature: ", key,
1032 ". Data types don't match. ",
1033 "Expected type: ", DataTypeString(dtype),
1034 " Feature is: ", f.DebugString()));
1035
1036 Tensor feature_values = FeatureSparseCopy(0, key, dtype, f);
1037 const int64_t num_elements = feature_values.NumElements();
1038 TensorShape indices_shape({num_elements, 1});
1039 Tensor* sp_indices_d = nullptr;
1040 Tensor* sp_shape_d = nullptr;
1041 OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape,
1042 &sp_indices_d));
1043 context_sparse_values.set(d, feature_values);
1044 OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}),
1045 &sp_shape_d));
1046 auto shape_t = sp_shape_d->vec<int64_t>();
1047 shape_t(0) = num_elements;
1048 auto indices_t = sp_indices_d->matrix<int64_t>();
1049 std::iota(indices_t.data(), indices_t.data() + num_elements, 0);
1050 } else {
1051 TensorShape indices_shape({0, 1});
1052 TensorShape values_shape({0});
1053 Tensor* sp_indices_d = nullptr;
1054 Tensor* sp_values_d = nullptr;
1055 Tensor* sp_shape_d = nullptr;
1056 OP_REQUIRES_OK(ctx, context_sparse_indices.allocate(d, indices_shape,
1057 &sp_indices_d));
1058 OP_REQUIRES_OK(
1059 ctx, context_sparse_values.allocate(d, values_shape, &sp_values_d));
1060 OP_REQUIRES_OK(ctx, context_sparse_shapes.allocate(d, TensorShape({1}),
1061 &sp_shape_d));
1062 auto shape_t = sp_shape_d->vec<int64_t>();
1063 shape_t(0) = 0;
1064 }
1065 }
1066
1067 // Feature List Dense ------------------------------------------------------
1068
1069 // Preallocate context_dense_values, since we can infer their
1070 // sizes
1071 const FeatureLists& feature_lists = ex.feature_lists();
1072 const auto& feature_list_dict = feature_lists.feature_list();
1073 FeatureList empty_feature_list; // Placeholder for missing FLs
1074
1075 for (int d = 0; d < attrs_.num_feature_list_dense; ++d) {
1076 const tstring& key = feature_list_dense_keys_t[d];
1077 const DataType& dtype = attrs_.feature_list_dense_types[d];
1078 const TensorShape& shape = attrs_.feature_list_dense_shapes[d];
1079
1080 const auto& feature_list_found = feature_list_dict.find(key);
1081 bool feature_list_missing =
1082 (feature_list_found == feature_list_dict.end());
1083 bool feature_list_allowed_missing =
1084 (feature_list_dense_missing_assumed_empty_set.count(key) > 0);
1085
1086 OP_REQUIRES(
1087 ctx, !feature_list_missing || feature_list_allowed_missing,
1088 errors::InvalidArgument("Name: ", name, ", Feature list '", key,
1089 "' is required but could not be found. "
1090 "Did you mean to include it in "
1091 "feature_list_dense_missing_assumed_empty or "
1092 "feature_list_dense_defaults?"));
1093
1094 TensorShape out_shape;
1095 const FeatureList& fl = (feature_list_missing)
1096 ? empty_feature_list
1097 : feature_list_found->second;
1098 out_shape.AddDim(fl.feature_size());
1099 for (const int dim : attrs_.feature_list_dense_shapes[d].dim_sizes()) {
1100 out_shape.AddDim(dim);
1101 }
1102 Tensor* out = nullptr;
1103 OP_REQUIRES_OK(ctx,
1104 feature_list_dense_values.allocate(d, out_shape, &out));
1105
1106 for (int64_t t = 0; t < fl.feature_size(); ++t) {
1107 const Feature& f = fl.feature(t);
1108 bool types_match;
1109 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
1110 OP_REQUIRES(ctx, types_match,
1111 errors::InvalidArgument(
1112 "Name: ", name, ", Feature list: ", key, ", Index: ", t,
1113 ". Data types don't match. ",
1114 "Expected type: ", DataTypeString(dtype),
1115 " Feature is: ", f.DebugString()));
1116 OP_REQUIRES_OK(ctx, FeatureDenseCopy(t, name, key, dtype, shape, f,
1117 feature_list_dense_values[d]));
1118 }
1119 }
1120
1121 // Feature List Sparse -----------------------------------------------------
1122 for (int d = 0; d < attrs_.num_feature_list_sparse; ++d) {
1123 const tstring& key = feature_list_sparse_keys_t[d];
1124 const DataType& dtype = attrs_.feature_list_sparse_types[d];
1125
1126 const auto& feature_list_found = feature_list_dict.find(key);
1127 bool feature_list_has_data = // Found key
1128 (feature_list_found != feature_list_dict.end());
1129
1130 std::vector<Tensor> sparse_values_tmp;
1131 int64_t feature_list_size = 0;
1132 if (feature_list_has_data) {
1133 const FeatureList& fl = feature_list_found->second;
1134 feature_list_size = fl.feature_size();
1135 for (int64_t t = 0; t < feature_list_size; ++t) {
1136 const Feature& f = fl.feature(t);
1137 bool types_match;
1138 OP_REQUIRES_OK(ctx, CheckTypesMatch(f, dtype, &types_match));
1139 OP_REQUIRES(
1140 ctx, f.kind_case() == Feature::KIND_NOT_SET || types_match,
1141 errors::InvalidArgument("Name: ", name, ", Feature List: ", key,
1142 ", Index: ", t,
1143 ". Data types don't match. ",
1144 "Expected type: ", DataTypeString(dtype),
1145 " Feature is: ", f.DebugString()));
1146 sparse_values_tmp.push_back(FeatureSparseCopy(t, key, dtype, f));
1147 }
1148 } else {
1149 sparse_values_tmp.push_back(Tensor(dtype, TensorShape({0})));
1150 }
1151
1152 int64_t total_num_features = 0;
1153 int64_t max_num_features = 0;
1154 for (int t = 0; t < feature_list_size; ++t) {
1155 const Tensor& v = sparse_values_tmp[t];
1156 const int64_t num_elements = v.shape().num_elements();
1157 total_num_features += num_elements;
1158 max_num_features = std::max(max_num_features, num_elements);
1159 }
1160
1161 TensorShape indices_shape({total_num_features, 2});
1162 TensorShape values_shape({total_num_features});
1163 Tensor* sp_indices_d = nullptr;
1164 Tensor* sp_values_d = nullptr;
1165 Tensor* sp_shape_d = nullptr;
1166 OP_REQUIRES_OK(ctx, feature_list_sparse_indices.allocate(d, indices_shape,
1167 &sp_indices_d));
1168 OP_REQUIRES_OK(ctx, feature_list_sparse_values.allocate(d, values_shape,
1169 &sp_values_d));
1170 OP_REQUIRES_OK(ctx, feature_list_sparse_shapes.allocate(
1171 d, TensorShape({2}), &sp_shape_d));
1172 auto shape_t = sp_shape_d->vec<int64_t>();
1173 shape_t(0) = feature_list_size;
1174 shape_t(1) = max_num_features;
1175
1176 int64_t offset = 0;
1177
1178 for (int t = 0; t < feature_list_size; ++t) {
1179 const int64_t num_elements = CopyIntoSparseTensor(
1180 sparse_values_tmp[t], t, offset, sp_indices_d, sp_values_d);
1181 offset += num_elements;
1182 }
1183 }
1184 }
1185
1186 protected:
1187 ParseSingleSequenceExampleAttrs attrs_;
1188 absl::once_flag flag_;
1189};
1190
1191REGISTER_KERNEL_BUILDER(Name("ParseSingleSequenceExample").Device(DEVICE_CPU),
1192 ParseSingleSequenceExampleOp);
1193
1194#ifndef IS_MOBILE_PLATFORM
1195// when using lite protos on mobile, decoding JSON is not available.
1196
1197class DecodeJSONExampleOp : public OpKernel {
1198 public:
1199 explicit DecodeJSONExampleOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
1200 resolver_.reset(protobuf::util::NewTypeResolverForDescriptorPool(
1201 "type.googleapis.com", protobuf::DescriptorPool::generated_pool()));
1202 }
1203
1204 void Compute(OpKernelContext* ctx) override {
1205 const Tensor* json_examples;
1206 OP_REQUIRES_OK(ctx, ctx->input("json_examples", &json_examples));
1207 Tensor* binary_examples;
1208 OP_REQUIRES_OK(
1209 ctx, ctx->allocate_output("binary_examples", json_examples->shape(),
1210 &binary_examples));
1211
1212 for (int i = 0; i < json_examples->NumElements(); ++i) {
1213 const tstring& json_example = json_examples->flat<tstring>()(i);
1214 protobuf::io::ArrayInputStream in(json_example.data(),
1215 json_example.size());
1216 TStringOutputStream out(&binary_examples->flat<tstring>()(i));
1217 auto status = protobuf::util::JsonToBinaryStream(
1218 resolver_.get(), "type.googleapis.com/tensorflow.Example", &in, &out);
1219 OP_REQUIRES(ctx, status.ok(),
1220 errors::InvalidArgument("Error while parsing JSON: ",
1221 string(status.message())));
1222 }
1223 }
1224
1225 private:
1226 std::unique_ptr<protobuf::util::TypeResolver> resolver_;
1227};
1228
1229REGISTER_KERNEL_BUILDER(Name("DecodeJSONExample").Device(DEVICE_CPU),
1230 DecodeJSONExampleOp);
1231#endif
1232
1233} // namespace tensorflow
1234