1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include "tensorflow/core/util/example_proto_helper.h"
16
17#include <vector>
18
19#include "tensorflow/core/example/example.pb.h"
20#include "tensorflow/core/example/feature.pb.h"
21#include "tensorflow/core/framework/numeric_op.h"
22#include "tensorflow/core/framework/register_types.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/platform/logging.h"
25#include "tensorflow/core/platform/protobuf.h"
26#include "tensorflow/core/util/sparse/sparse_tensor.h"
27
28namespace tensorflow {
29
30Status CheckValidType(const DataType& dtype) {
31 switch (dtype) {
32 case DT_INT64:
33 case DT_FLOAT:
34 case DT_STRING:
35 return OkStatus();
36 default:
37 return errors::InvalidArgument("Received input dtype: ",
38 DataTypeString(dtype));
39 }
40}
41
42Status CheckTypesMatch(const Feature& feature, const DataType& dtype,
43 bool* match) {
44 switch (dtype) {
45 case DT_INT64:
46 *match = (feature.kind_case() == Feature::kInt64List);
47 break;
48 case DT_FLOAT:
49 *match = (feature.kind_case() == Feature::kFloatList);
50 break;
51 case DT_STRING:
52 *match = (feature.kind_case() == Feature::kBytesList);
53 break;
54 default:
55 return errors::InvalidArgument("Invalid input dtype: ",
56 DataTypeString(dtype));
57 }
58 return OkStatus();
59}
60
61Status FeatureDenseCopy(const std::size_t out_index, const string& name,
62 const string& key, const DataType& dtype,
63 const TensorShape& shape, const Feature& feature,
64 Tensor* out) {
65 const std::size_t num_elements = shape.num_elements();
66 const std::size_t offset = out_index * num_elements;
67
68 switch (dtype) {
69 case DT_INT64: {
70 const Int64List& values = feature.int64_list();
71 if (static_cast<size_t>(values.value_size()) != num_elements) {
72 return errors::InvalidArgument(
73 "Name: ", name, ", Key: ", key, ", Index: ", out_index,
74 ". Number of int64 values != expected. "
75 "values size: ",
76 values.value_size(), " but output shape: ", shape.DebugString());
77 }
78 auto out_p = out->flat<int64_t>().data() + offset;
79 std::copy_n(values.value().data(), num_elements, out_p);
80 return OkStatus();
81 }
82 case DT_FLOAT: {
83 const FloatList& values = feature.float_list();
84 if (static_cast<size_t>(values.value_size()) != num_elements) {
85 return errors::InvalidArgument(
86 "Name: ", name, ", Key: ", key, ", Index: ", out_index,
87 ". Number of float values != expected. "
88 "values size: ",
89 values.value_size(), " but output shape: ", shape.DebugString());
90 }
91 auto out_p = out->flat<float>().data() + offset;
92 std::copy_n(values.value().data(), num_elements, out_p);
93 return OkStatus();
94 }
95 case DT_STRING: {
96 const BytesList& values = feature.bytes_list();
97 if (static_cast<size_t>(values.value_size()) != num_elements) {
98 return errors::InvalidArgument(
99 "Name: ", name, ", Key ", key, ", Index: ", out_index,
100 ". Number of bytes values != expected. "
101 "Values size: ",
102 values.value_size(), " but output shape: ", shape.DebugString());
103 }
104 auto out_p = out->flat<tstring>().data() + offset;
105 std::transform(values.value().data(),
106 values.value().data() + num_elements, out_p,
107 [](const string* s) { return *s; });
108 return OkStatus();
109 }
110 default:
111 return errors::InvalidArgument("Invalid input dtype: ",
112 DataTypeString(dtype));
113 }
114}
115
116Tensor FeatureSparseCopy(const std::size_t batch, const string& key,
117 const DataType& dtype, const Feature& feature) {
118 switch (dtype) {
119 case DT_INT64: {
120 const Int64List& values = feature.int64_list();
121 const int64_t num_elements = values.value_size();
122 Tensor out(dtype, TensorShape({num_elements}));
123 auto out_p = out.flat<int64_t>().data();
124 std::copy_n(values.value().data(), num_elements, out_p);
125 return out;
126 }
127 case DT_FLOAT: {
128 const FloatList& values = feature.float_list();
129 const int64_t num_elements = values.value_size();
130 Tensor out(dtype, TensorShape({num_elements}));
131 auto out_p = out.flat<float>().data();
132 std::copy_n(values.value().data(), num_elements, out_p);
133 return out;
134 }
135 case DT_STRING: {
136 const BytesList& values = feature.bytes_list();
137 const int64_t num_elements = values.value_size();
138 Tensor out(dtype, TensorShape({num_elements}));
139 auto out_p = out.flat<tstring>().data();
140 std::transform(values.value().data(),
141 values.value().data() + num_elements, out_p,
142 [](const string* s) { return *s; });
143 return out;
144 }
145 default:
146 LOG(FATAL) << "not supposed to be here. dtype requested: " << dtype;
147 }
148}
149
150int64_t CopyIntoSparseTensor(const Tensor& in, const int batch,
151 const int64_t offset, Tensor* indices,
152 Tensor* values) {
153 const int64_t num_elements = in.shape().num_elements();
154 const DataType& dtype = in.dtype();
155 CHECK_EQ(dtype, values->dtype());
156
157 // Update indices.
158 if (num_elements > 0) {
159 auto ix_t = indices->matrix<int64_t>();
160 int64_t* ix_p = &ix_t(offset, 0);
161 for (int64_t i = 0; i < num_elements; ++i, ix_p += 2) {
162 *ix_p = batch; // Column 0 stores the batch entry
163 *(ix_p + 1) = i; // Column 1 stores the index in the batch
164 }
165 }
166
167 // Copy values over.
168 switch (dtype) {
169 case DT_INT64: {
170 std::copy_n(in.flat<int64_t>().data(), num_elements,
171 values->flat<int64_t>().data() + offset);
172 break;
173 }
174 case DT_FLOAT: {
175 std::copy_n(in.flat<float>().data(), num_elements,
176 values->flat<float>().data() + offset);
177 break;
178 }
179 case DT_STRING: {
180 std::copy_n(in.flat<tstring>().data(), num_elements,
181 values->flat<tstring>().data() + offset);
182 break;
183 }
184 default:
185 LOG(FATAL) << "Not supposed to be here. Saw dtype: " << dtype;
186 }
187
188 return num_elements;
189}
190
191void RowDenseCopy(const std::size_t& out_index, const DataType& dtype,
192 const Tensor& in, Tensor* out) {
193 const std::size_t num_elements = in.shape().num_elements();
194 const std::size_t offset = out_index * num_elements;
195
196 switch (dtype) {
197 case DT_INT64: {
198 std::copy_n(in.flat<int64_t>().data(), num_elements,
199 out->flat<int64_t>().data() + offset);
200 break;
201 }
202 case DT_FLOAT: {
203 std::copy_n(in.flat<float>().data(), num_elements,
204 out->flat<float>().data() + offset);
205 break;
206 }
207 case DT_STRING: {
208 // TODO(dero): verify.
209 std::copy_n(in.flat<tstring>().data(), num_elements,
210 out->flat<tstring>().data() + offset);
211 break;
212 }
213 default:
214 LOG(FATAL) << "Not supposed to be here. Saw dtype: " << dtype;
215 }
216}
217
218Status SingleExampleProtoToTensors(
219 const Example& example, const string& example_name, const int batch_index,
220 const std::vector<FixedLenFeature>& fixed_len_features,
221 const std::vector<VarLenFeature>& var_len_features,
222 std::vector<Tensor*>* output_dense_values_tensor,
223 std::vector<std::vector<Tensor>>* output_sparse_values_tmp) {
224 const Features& features = example.features();
225 const auto& feature_dict = features.feature();
226
227 // Handle dense features.
228 for (size_t d = 0; d < fixed_len_features.size(); ++d) {
229 const FixedLenFeature& feature_config = fixed_len_features[d];
230 const string& key = feature_config.key;
231 const DataType& dtype = feature_config.dtype;
232 const TensorShape& shape = feature_config.shape;
233 const Tensor& default_value = feature_config.default_value;
234 bool required = (default_value.NumElements() == 0);
235 const auto& feature_found = feature_dict.find(key);
236 const bool feature_has_data = // Found key & data type is set
237 (feature_found != feature_dict.end() &&
238 (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
239
240 const bool required_ok = feature_has_data || !required;
241 if (!required_ok) {
242 return errors::InvalidArgument("Name: ", example_name, ", Feature: ", key,
243 " is required but could not be found.");
244 }
245
246 // Perform the FeatureDenseCopy into the output dense_values tensor (if
247 // the value is present).
248 if (feature_has_data) {
249 const Feature& f = feature_found->second;
250 bool types_match;
251 TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match));
252 if (!types_match) {
253 return errors::InvalidArgument("Name: ", example_name,
254 ", Feature: ", key,
255 ". Data types don't match. ",
256 "Expected type: ", DataTypeString(dtype),
257 " Feature is: ", f.DebugString());
258 }
259 TF_RETURN_IF_ERROR(FeatureDenseCopy(batch_index, example_name, key, dtype,
260 shape, f,
261 (*output_dense_values_tensor)[d]));
262 } else {
263 // If the value is missing, RowDenseCopy the default value.
264 RowDenseCopy(batch_index, dtype, default_value,
265 (*output_dense_values_tensor)[d]);
266 }
267 }
268
269 // Handle sparse features.
270 for (size_t d = 0; d < var_len_features.size(); ++d) {
271 const VarLenFeature& feature_config = var_len_features[d];
272 const string& key = feature_config.key;
273 const DataType& dtype = feature_config.dtype;
274 const auto& feature_found = feature_dict.find(key);
275
276 const bool feature_has_data = // Found key & data type is set
277 (feature_found != feature_dict.end() &&
278 (feature_found->second.kind_case() != Feature::KIND_NOT_SET));
279
280 if (feature_has_data) {
281 const Feature& f = feature_found->second;
282 bool types_match;
283 TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match));
284 if (!types_match) {
285 return errors::InvalidArgument("Name: ", example_name,
286 ", Feature: ", key,
287 ". Data types don't match. ",
288 "Expected type: ", DataTypeString(dtype),
289 " Feature is: ", f.DebugString());
290 }
291 (*output_sparse_values_tmp)[d][batch_index] =
292 FeatureSparseCopy(batch_index, key, dtype, f);
293 } else {
294 (*output_sparse_values_tmp)[d][batch_index] =
295 Tensor(dtype, TensorShape({0}));
296 }
297 }
298 return OkStatus();
299}
300
301Status GetSparseTensorShapes(const VarLenFeature& var_len_feature,
302 const std::vector<Tensor>& sparse_values_tmp,
303 const int batch_size,
304 VarLenFeatureBatchShapes* output_shapes) {
305 int64_t total_num_features = 0;
306 int64_t max_num_features = 0;
307 for (int b = 0; b < batch_size; ++b) {
308 const Tensor& t = sparse_values_tmp[b];
309 const int64_t num_elements = t.shape().num_elements();
310 total_num_features += num_elements;
311 max_num_features = std::max(max_num_features, num_elements);
312 }
313 output_shapes->indices_shape.AddDim(total_num_features);
314 output_shapes->indices_shape.AddDim(2);
315 output_shapes->values_shape.AddDim(total_num_features);
316 output_shapes->max_num_features = max_num_features;
317 return OkStatus();
318}
319
320Status BatchExampleProtoToTensors(
321 const std::vector<const Example*>& examples,
322 const std::vector<string>& names,
323 const std::vector<FixedLenFeature>& fixed_len_features,
324 const std::vector<VarLenFeature>& var_len_features, Allocator* allocator,
325 std::vector<Tensor>* output_dense_values_tensor,
326 std::vector<Tensor>* output_sparse_indices_tensor,
327 std::vector<Tensor>* output_sparse_values_tensor,
328 std::vector<Tensor>* output_sparse_shapes_tensor) {
329 const int batch_size = examples.size();
330
331 const bool has_names = (!names.empty());
332 if (has_names) {
333 if (names.size() != examples.size()) {
334 return errors::InvalidArgument(
335 "Expected len(names) == len(examples), but got: ", names.size(),
336 " vs. ", examples.size());
337 }
338 }
339
340 // We also need a map of Tensor pointers for the SingleExampleProtoToTensors
341 // call. (Is there a better solution here?)
342 std::vector<Tensor*> output_dense_values_tensor_ptrs(
343 fixed_len_features.size());
344
345 // Preallocate dense_values, since we know their sizes.
346 for (size_t d = 0; d < fixed_len_features.size(); ++d) {
347 const FixedLenFeature& config = fixed_len_features[d];
348 TensorShape out_shape;
349 out_shape.AddDim(batch_size);
350 const TensorShape& shape = config.shape;
351 const DataType& dtype = config.dtype;
352 for (const int dim : shape.dim_sizes()) out_shape.AddDim(dim);
353 (*output_dense_values_tensor)[d] = Tensor(allocator, dtype, out_shape);
354 output_dense_values_tensor_ptrs[d] = &(*output_dense_values_tensor)[d];
355 }
356
357 // Temporary vector to hold sparse values.
358 std::vector<std::vector<Tensor>> sparse_values_tmp(var_len_features.size());
359
360 for (size_t d = 0; d < var_len_features.size(); ++d) {
361 sparse_values_tmp[d] = std::vector<Tensor>(batch_size);
362 }
363
364 for (size_t b = 0; b < examples.size(); ++b) {
365 const Example& ex = *(examples[b]);
366 const string& example_name = (has_names) ? names[b] : "<unknown>";
367 TF_RETURN_IF_ERROR(SingleExampleProtoToTensors(
368 ex, example_name, b, fixed_len_features, var_len_features,
369 &output_dense_values_tensor_ptrs, &sparse_values_tmp));
370 }
371
372 for (size_t d = 0; d < var_len_features.size(); ++d) {
373 const VarLenFeature& feature_config = var_len_features[d];
374 const DataType& dtype = feature_config.dtype;
375 const std::vector<Tensor>& sparse_values_tensor = sparse_values_tmp[d];
376
377 VarLenFeatureBatchShapes sparse_tensor_batch_shapes;
378 TF_RETURN_IF_ERROR(GetSparseTensorShapes(feature_config,
379 sparse_values_tensor, batch_size,
380 &sparse_tensor_batch_shapes));
381 const TensorShape& indices_shape = sparse_tensor_batch_shapes.indices_shape;
382 const TensorShape& values_shape = sparse_tensor_batch_shapes.values_shape;
383
384 // Allocate the sparse indices here.
385 (*output_sparse_indices_tensor)[d] =
386 Tensor(allocator, DT_INT64, indices_shape);
387 (*output_sparse_values_tensor)[d] = Tensor(allocator, dtype, values_shape);
388 (*output_sparse_shapes_tensor)[d] =
389 Tensor(allocator, DT_INT64, TensorShape({2}));
390
391 auto shape_t = (*output_sparse_shapes_tensor)[d].vec<int64_t>();
392 shape_t(0) = batch_size;
393 shape_t(1) = sparse_tensor_batch_shapes.max_num_features;
394
395 Tensor* sp_indices_d = &(*output_sparse_indices_tensor)[d];
396 Tensor* sp_values_d = &(*output_sparse_values_tensor)[d];
397
398 int64_t offset = 0;
399 for (int b = 0; b < batch_size; ++b) {
400 const int64_t num_elements = CopyIntoSparseTensor(
401 sparse_values_tensor[b], b, offset, sp_indices_d, sp_values_d);
402 offset += num_elements;
403 }
404 }
405 return OkStatus();
406}
407
408Status ParseExampleAttrs::FinishInit(int op_version) {
409 switch (op_version) {
410 case 1:
411 num_ragged = 0;
412 break;
413 case 2:
414 num_dense = dense_types.size();
415 num_ragged = ragged_value_types.size();
416 break;
417 default:
418 return errors::InvalidArgument("Unexpected op_version", op_version);
419 }
420 if (static_cast<size_t>(num_sparse) != sparse_types.size()) {
421 return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)");
422 }
423 if (static_cast<size_t>(num_dense) != dense_types.size()) {
424 return errors::InvalidArgument("len(dense_keys) != len(dense_types)");
425 }
426 if (static_cast<size_t>(num_dense) != dense_shapes.size()) {
427 return errors::InvalidArgument("len(dense_keys) != len(dense_shapes)");
428 }
429 if (static_cast<size_t>(num_ragged) != ragged_value_types.size()) {
430 return errors::InvalidArgument(
431 "len(ragged_keys) != len(ragged_value_types)");
432 }
433 if (static_cast<size_t>(num_ragged) != ragged_split_types.size()) {
434 return errors::InvalidArgument(
435 "len(ragged_keys) != len(ragged_split_types)");
436 }
437 if (num_dense > std::numeric_limits<int32>::max()) {
438 return errors::InvalidArgument("num_dense_ too large");
439 }
440 for (const DataType& type : dense_types) {
441 TF_RETURN_IF_ERROR(CheckValidType(type));
442 }
443 for (const DataType& type : sparse_types) {
444 TF_RETURN_IF_ERROR(CheckValidType(type));
445 }
446 for (const DataType& type : ragged_value_types) {
447 TF_RETURN_IF_ERROR(CheckValidType(type));
448 }
449 for (const DataType& type : ragged_split_types) {
450 if (!(type == DT_INT64 || type == DT_INT32)) {
451 return errors::InvalidArgument("Invalid ragged_split_type: ",
452 DataTypeString(type));
453 }
454 }
455 return OkStatus();
456}
457
458Status ParseSingleExampleAttrs::FinishInit() {
459 if (sparse_keys.size() != sparse_types.size()) {
460 return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)");
461 }
462 if (dense_keys.size() != dense_types.size()) {
463 return errors::InvalidArgument("len(dense_keys) != len(dense_types)");
464 }
465 if (dense_keys.size() != dense_shapes.size()) {
466 return errors::InvalidArgument("len(dense_keys) != len(dense_shapes)");
467 }
468 for (const DataType& type : dense_types) {
469 TF_RETURN_IF_ERROR(CheckValidType(type));
470 }
471 for (const DataType& type : sparse_types) {
472 TF_RETURN_IF_ERROR(CheckValidType(type));
473 }
474 return OkStatus();
475}
476
477Status ParseSequenceExampleAttrs::FinishInit(int op_version) {
478 switch (op_version) {
479 case 1:
480 num_context_ragged = 0;
481 num_feature_list_ragged = 0;
482 if (num_context_sparse != context_sparse_keys.size()) {
483 return errors::InvalidArgument(
484 "num_context_sparse (", num_context_sparse,
485 ") must match the size of context_sparse_keys (",
486 context_sparse_keys.size(), ")");
487 }
488 if (num_context_dense != context_dense_keys.size()) {
489 return errors::InvalidArgument(
490 "num_context_dense (", num_context_dense,
491 ") must match the size of context_dense_keys (",
492 context_dense_keys.size(), ")");
493 }
494 if (num_feature_list_sparse != feature_list_sparse_keys.size()) {
495 return errors::InvalidArgument(
496 "num_feature_list_sparse (", num_feature_list_sparse,
497 ") must match the size of feature_list_sparse_keys (",
498 feature_list_sparse_keys.size(), ")");
499 }
500 if (num_feature_list_dense != feature_list_dense_keys.size()) {
501 return errors::InvalidArgument(
502 "num_feature_list_dense (", num_feature_list_dense,
503 ") must match the size of feature_list_dense_keys (",
504 feature_list_dense_keys.size(), ")");
505 }
506 break;
507 case 2:
508 num_context_dense = context_dense_types.size();
509 num_context_ragged = context_ragged_value_types.size();
510 num_feature_list_ragged = feature_list_ragged_value_types.size();
511 break;
512 default:
513 return errors::InvalidArgument("Unexpected op_version", op_version);
514 }
515 if (num_context_sparse != context_sparse_types.size()) {
516 return errors::InvalidArgument(
517 "num_context_sparse (", num_context_sparse,
518 ") must match the size of context_sparse_types (",
519 context_sparse_types.size(), ")");
520 }
521 if (num_context_dense != context_dense_types.size() ||
522 num_context_dense != context_dense_shapes.size()) {
523 return errors::InvalidArgument(
524 "num_context_dense (", num_context_dense,
525 ") must match the size of context_dense_types (",
526 context_dense_types.size(), ") and context_dense_shapes (",
527 context_dense_shapes.size(), ")");
528 }
529 if ((num_context_ragged != context_ragged_value_types.size()) ||
530 (num_context_ragged != context_ragged_split_types.size())) {
531 return errors::InvalidArgument(
532 "num_context_ragged (", num_context_ragged,
533 ") must match the size of context_ragged_value_types (",
534 context_ragged_value_types.size(), ") and context_ragged_split_types (",
535 context_ragged_split_types.size(), ")");
536 }
537 if (num_feature_list_sparse != feature_list_sparse_types.size()) {
538 return errors::InvalidArgument(
539 "num_feature_list_sparse (", num_feature_list_sparse,
540 ") must match the size of feature_list_sparse_types (",
541 feature_list_sparse_types.size(), ")");
542 }
543 if (num_feature_list_dense != feature_list_dense_types.size() ||
544 num_feature_list_dense != feature_list_dense_shapes.size()) {
545 return errors::InvalidArgument(
546 "num_feature_list_dense (", num_feature_list_dense,
547 ") must match the size of feature_list_dense_types (",
548 feature_list_dense_types.size(), ") and feature_list_dense_shapes (",
549 feature_list_dense_shapes.size(), ")");
550 }
551 if ((num_feature_list_ragged != feature_list_ragged_value_types.size()) ||
552 (num_feature_list_ragged != feature_list_ragged_split_types.size())) {
553 return errors::InvalidArgument(
554 "num_feature_list_ragged (", num_feature_list_ragged,
555 ") must match the size of feature_list_ragged_value_types (",
556 feature_list_ragged_value_types.size(),
557 ") and feature_list_ragged_split_types (",
558 feature_list_ragged_split_types.size(), ")");
559 }
560 for (const DataType& type : context_dense_types) {
561 TF_RETURN_IF_ERROR(CheckValidType(type));
562 }
563 for (const DataType& type : context_sparse_types) {
564 TF_RETURN_IF_ERROR(CheckValidType(type));
565 }
566 for (const DataType& type : feature_list_dense_types) {
567 TF_RETURN_IF_ERROR(CheckValidType(type));
568 }
569 for (const DataType& type : feature_list_sparse_types) {
570 TF_RETURN_IF_ERROR(CheckValidType(type));
571 }
572 for (const DataType& type : context_ragged_value_types) {
573 TF_RETURN_IF_ERROR(CheckValidType(type));
574 }
575 for (const DataType& type : context_ragged_split_types) {
576 if (!(type == DT_INT64 || type == DT_INT32)) {
577 return errors::InvalidArgument("Invalid context_ragged_split_type: ",
578 DataTypeString(type));
579 }
580 }
581 for (const DataType& type : feature_list_ragged_value_types) {
582 TF_RETURN_IF_ERROR(CheckValidType(type));
583 }
584 for (const DataType& type : feature_list_ragged_split_types) {
585 if (!(type == DT_INT64 || type == DT_INT32)) {
586 return errors::InvalidArgument("Invalid feature_list_ragged_split_type: ",
587 DataTypeString(type));
588 }
589 }
590
591 return OkStatus();
592}
593
594Status ParseSingleSequenceExampleAttrs::FinishInit() {
595 if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) {
596 return errors::InvalidArgument(
597 "len(context_sparse_keys) != len(context_sparse_types)");
598 }
599 if (static_cast<size_t>(num_context_dense) != context_dense_types.size()) {
600 return errors::InvalidArgument(
601 "len(context_dense_keys) != len(context_dense_types)");
602 }
603 if (static_cast<size_t>(num_context_dense) != context_dense_shapes.size()) {
604 return errors::InvalidArgument(
605 "len(context_dense_keys) != len(context_dense_shapes)");
606 }
607 if (static_cast<size_t>(num_feature_list_sparse) !=
608 feature_list_sparse_types.size()) {
609 return errors::InvalidArgument(
610 "len(feature_list_sparse_keys) != len(feature_list_sparse_types)");
611 }
612 if (static_cast<size_t>(num_feature_list_dense) !=
613 feature_list_dense_types.size()) {
614 return errors::InvalidArgument(
615 "len(feature_list_dense_keys) != "
616 "len(feature_list_dense_types)");
617 }
618 for (const DataType& type : context_dense_types) {
619 TF_RETURN_IF_ERROR(CheckValidType(type));
620 }
621 for (const DataType& type : context_sparse_types) {
622 TF_RETURN_IF_ERROR(CheckValidType(type));
623 }
624 for (const DataType& type : feature_list_dense_types) {
625 TF_RETURN_IF_ERROR(CheckValidType(type));
626 }
627 for (const DataType& type : feature_list_sparse_types) {
628 TF_RETURN_IF_ERROR(CheckValidType(type));
629 }
630 return OkStatus();
631}
632
633Status GetDenseShapes(const std::vector<PartialTensorShape>& dense_shapes,
634 std::vector<bool>* variable_length,
635 std::vector<std::size_t>* elements_per_stride) {
636 // Temporary check until we start allowing a variable length outer
637 // dimension.
638 for (int i = 0; i < dense_shapes.size(); ++i) {
639 bool shape_ok = true;
640 if (dense_shapes[i].dims() == -1) {
641 shape_ok = false;
642 } else {
643 for (int d = 1; d < dense_shapes[i].dims(); ++d) {
644 if (dense_shapes[i].dim_size(d) == -1) {
645 shape_ok = false;
646 }
647 }
648 }
649 if (!shape_ok) {
650 return errors::InvalidArgument(
651 "dense_shapes[", i,
652 "] has unknown rank or unknown inner dimensions: ",
653 dense_shapes[i].DebugString());
654 }
655 TensorShape dense_shape;
656 if (dense_shapes[i].dims() > 0 && dense_shapes[i].dim_size(0) == -1) {
657 variable_length->push_back(true);
658 for (int d = 1; d < dense_shapes[i].dims(); ++d) {
659 dense_shape.AddDim(dense_shapes[i].dim_size(d));
660 }
661 } else {
662 variable_length->push_back(false);
663 dense_shapes[i].AsTensorShape(&dense_shape);
664 }
665 elements_per_stride->push_back(dense_shape.num_elements());
666 }
667 return OkStatus();
668}
669
670} // namespace tensorflow
671