1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/util/example_proto_helper.h" |
16 | |
17 | #include <vector> |
18 | |
19 | #include "tensorflow/core/example/example.pb.h" |
20 | #include "tensorflow/core/example/feature.pb.h" |
21 | #include "tensorflow/core/framework/numeric_op.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/lib/core/errors.h" |
24 | #include "tensorflow/core/platform/logging.h" |
25 | #include "tensorflow/core/platform/protobuf.h" |
26 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | Status CheckValidType(const DataType& dtype) { |
31 | switch (dtype) { |
32 | case DT_INT64: |
33 | case DT_FLOAT: |
34 | case DT_STRING: |
35 | return OkStatus(); |
36 | default: |
37 | return errors::InvalidArgument("Received input dtype: " , |
38 | DataTypeString(dtype)); |
39 | } |
40 | } |
41 | |
42 | Status CheckTypesMatch(const Feature& feature, const DataType& dtype, |
43 | bool* match) { |
44 | switch (dtype) { |
45 | case DT_INT64: |
46 | *match = (feature.kind_case() == Feature::kInt64List); |
47 | break; |
48 | case DT_FLOAT: |
49 | *match = (feature.kind_case() == Feature::kFloatList); |
50 | break; |
51 | case DT_STRING: |
52 | *match = (feature.kind_case() == Feature::kBytesList); |
53 | break; |
54 | default: |
55 | return errors::InvalidArgument("Invalid input dtype: " , |
56 | DataTypeString(dtype)); |
57 | } |
58 | return OkStatus(); |
59 | } |
60 | |
61 | Status FeatureDenseCopy(const std::size_t out_index, const string& name, |
62 | const string& key, const DataType& dtype, |
63 | const TensorShape& shape, const Feature& feature, |
64 | Tensor* out) { |
65 | const std::size_t num_elements = shape.num_elements(); |
66 | const std::size_t offset = out_index * num_elements; |
67 | |
68 | switch (dtype) { |
69 | case DT_INT64: { |
70 | const Int64List& values = feature.int64_list(); |
71 | if (static_cast<size_t>(values.value_size()) != num_elements) { |
72 | return errors::InvalidArgument( |
73 | "Name: " , name, ", Key: " , key, ", Index: " , out_index, |
74 | ". Number of int64 values != expected. " |
75 | "values size: " , |
76 | values.value_size(), " but output shape: " , shape.DebugString()); |
77 | } |
78 | auto out_p = out->flat<int64_t>().data() + offset; |
79 | std::copy_n(values.value().data(), num_elements, out_p); |
80 | return OkStatus(); |
81 | } |
82 | case DT_FLOAT: { |
83 | const FloatList& values = feature.float_list(); |
84 | if (static_cast<size_t>(values.value_size()) != num_elements) { |
85 | return errors::InvalidArgument( |
86 | "Name: " , name, ", Key: " , key, ", Index: " , out_index, |
87 | ". Number of float values != expected. " |
88 | "values size: " , |
89 | values.value_size(), " but output shape: " , shape.DebugString()); |
90 | } |
91 | auto out_p = out->flat<float>().data() + offset; |
92 | std::copy_n(values.value().data(), num_elements, out_p); |
93 | return OkStatus(); |
94 | } |
95 | case DT_STRING: { |
96 | const BytesList& values = feature.bytes_list(); |
97 | if (static_cast<size_t>(values.value_size()) != num_elements) { |
98 | return errors::InvalidArgument( |
99 | "Name: " , name, ", Key " , key, ", Index: " , out_index, |
100 | ". Number of bytes values != expected. " |
101 | "Values size: " , |
102 | values.value_size(), " but output shape: " , shape.DebugString()); |
103 | } |
104 | auto out_p = out->flat<tstring>().data() + offset; |
105 | std::transform(values.value().data(), |
106 | values.value().data() + num_elements, out_p, |
107 | [](const string* s) { return *s; }); |
108 | return OkStatus(); |
109 | } |
110 | default: |
111 | return errors::InvalidArgument("Invalid input dtype: " , |
112 | DataTypeString(dtype)); |
113 | } |
114 | } |
115 | |
116 | Tensor FeatureSparseCopy(const std::size_t batch, const string& key, |
117 | const DataType& dtype, const Feature& feature) { |
118 | switch (dtype) { |
119 | case DT_INT64: { |
120 | const Int64List& values = feature.int64_list(); |
121 | const int64_t num_elements = values.value_size(); |
122 | Tensor out(dtype, TensorShape({num_elements})); |
123 | auto out_p = out.flat<int64_t>().data(); |
124 | std::copy_n(values.value().data(), num_elements, out_p); |
125 | return out; |
126 | } |
127 | case DT_FLOAT: { |
128 | const FloatList& values = feature.float_list(); |
129 | const int64_t num_elements = values.value_size(); |
130 | Tensor out(dtype, TensorShape({num_elements})); |
131 | auto out_p = out.flat<float>().data(); |
132 | std::copy_n(values.value().data(), num_elements, out_p); |
133 | return out; |
134 | } |
135 | case DT_STRING: { |
136 | const BytesList& values = feature.bytes_list(); |
137 | const int64_t num_elements = values.value_size(); |
138 | Tensor out(dtype, TensorShape({num_elements})); |
139 | auto out_p = out.flat<tstring>().data(); |
140 | std::transform(values.value().data(), |
141 | values.value().data() + num_elements, out_p, |
142 | [](const string* s) { return *s; }); |
143 | return out; |
144 | } |
145 | default: |
146 | LOG(FATAL) << "not supposed to be here. dtype requested: " << dtype; |
147 | } |
148 | } |
149 | |
150 | int64_t CopyIntoSparseTensor(const Tensor& in, const int batch, |
151 | const int64_t offset, Tensor* indices, |
152 | Tensor* values) { |
153 | const int64_t num_elements = in.shape().num_elements(); |
154 | const DataType& dtype = in.dtype(); |
155 | CHECK_EQ(dtype, values->dtype()); |
156 | |
157 | // Update indices. |
158 | if (num_elements > 0) { |
159 | auto ix_t = indices->matrix<int64_t>(); |
160 | int64_t* ix_p = &ix_t(offset, 0); |
161 | for (int64_t i = 0; i < num_elements; ++i, ix_p += 2) { |
162 | *ix_p = batch; // Column 0 stores the batch entry |
163 | *(ix_p + 1) = i; // Column 1 stores the index in the batch |
164 | } |
165 | } |
166 | |
167 | // Copy values over. |
168 | switch (dtype) { |
169 | case DT_INT64: { |
170 | std::copy_n(in.flat<int64_t>().data(), num_elements, |
171 | values->flat<int64_t>().data() + offset); |
172 | break; |
173 | } |
174 | case DT_FLOAT: { |
175 | std::copy_n(in.flat<float>().data(), num_elements, |
176 | values->flat<float>().data() + offset); |
177 | break; |
178 | } |
179 | case DT_STRING: { |
180 | std::copy_n(in.flat<tstring>().data(), num_elements, |
181 | values->flat<tstring>().data() + offset); |
182 | break; |
183 | } |
184 | default: |
185 | LOG(FATAL) << "Not supposed to be here. Saw dtype: " << dtype; |
186 | } |
187 | |
188 | return num_elements; |
189 | } |
190 | |
191 | void RowDenseCopy(const std::size_t& out_index, const DataType& dtype, |
192 | const Tensor& in, Tensor* out) { |
193 | const std::size_t num_elements = in.shape().num_elements(); |
194 | const std::size_t offset = out_index * num_elements; |
195 | |
196 | switch (dtype) { |
197 | case DT_INT64: { |
198 | std::copy_n(in.flat<int64_t>().data(), num_elements, |
199 | out->flat<int64_t>().data() + offset); |
200 | break; |
201 | } |
202 | case DT_FLOAT: { |
203 | std::copy_n(in.flat<float>().data(), num_elements, |
204 | out->flat<float>().data() + offset); |
205 | break; |
206 | } |
207 | case DT_STRING: { |
208 | // TODO(dero): verify. |
209 | std::copy_n(in.flat<tstring>().data(), num_elements, |
210 | out->flat<tstring>().data() + offset); |
211 | break; |
212 | } |
213 | default: |
214 | LOG(FATAL) << "Not supposed to be here. Saw dtype: " << dtype; |
215 | } |
216 | } |
217 | |
218 | Status SingleExampleProtoToTensors( |
219 | const Example& example, const string& example_name, const int batch_index, |
220 | const std::vector<FixedLenFeature>& fixed_len_features, |
221 | const std::vector<VarLenFeature>& var_len_features, |
222 | std::vector<Tensor*>* output_dense_values_tensor, |
223 | std::vector<std::vector<Tensor>>* output_sparse_values_tmp) { |
224 | const Features& features = example.features(); |
225 | const auto& feature_dict = features.feature(); |
226 | |
227 | // Handle dense features. |
228 | for (size_t d = 0; d < fixed_len_features.size(); ++d) { |
229 | const FixedLenFeature& feature_config = fixed_len_features[d]; |
230 | const string& key = feature_config.key; |
231 | const DataType& dtype = feature_config.dtype; |
232 | const TensorShape& shape = feature_config.shape; |
233 | const Tensor& default_value = feature_config.default_value; |
234 | bool required = (default_value.NumElements() == 0); |
235 | const auto& feature_found = feature_dict.find(key); |
236 | const bool feature_has_data = // Found key & data type is set |
237 | (feature_found != feature_dict.end() && |
238 | (feature_found->second.kind_case() != Feature::KIND_NOT_SET)); |
239 | |
240 | const bool required_ok = feature_has_data || !required; |
241 | if (!required_ok) { |
242 | return errors::InvalidArgument("Name: " , example_name, ", Feature: " , key, |
243 | " is required but could not be found." ); |
244 | } |
245 | |
246 | // Perform the FeatureDenseCopy into the output dense_values tensor (if |
247 | // the value is present). |
248 | if (feature_has_data) { |
249 | const Feature& f = feature_found->second; |
250 | bool types_match; |
251 | TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match)); |
252 | if (!types_match) { |
253 | return errors::InvalidArgument("Name: " , example_name, |
254 | ", Feature: " , key, |
255 | ". Data types don't match. " , |
256 | "Expected type: " , DataTypeString(dtype), |
257 | " Feature is: " , f.DebugString()); |
258 | } |
259 | TF_RETURN_IF_ERROR(FeatureDenseCopy(batch_index, example_name, key, dtype, |
260 | shape, f, |
261 | (*output_dense_values_tensor)[d])); |
262 | } else { |
263 | // If the value is missing, RowDenseCopy the default value. |
264 | RowDenseCopy(batch_index, dtype, default_value, |
265 | (*output_dense_values_tensor)[d]); |
266 | } |
267 | } |
268 | |
269 | // Handle sparse features. |
270 | for (size_t d = 0; d < var_len_features.size(); ++d) { |
271 | const VarLenFeature& feature_config = var_len_features[d]; |
272 | const string& key = feature_config.key; |
273 | const DataType& dtype = feature_config.dtype; |
274 | const auto& feature_found = feature_dict.find(key); |
275 | |
276 | const bool feature_has_data = // Found key & data type is set |
277 | (feature_found != feature_dict.end() && |
278 | (feature_found->second.kind_case() != Feature::KIND_NOT_SET)); |
279 | |
280 | if (feature_has_data) { |
281 | const Feature& f = feature_found->second; |
282 | bool types_match; |
283 | TF_RETURN_IF_ERROR(CheckTypesMatch(f, dtype, &types_match)); |
284 | if (!types_match) { |
285 | return errors::InvalidArgument("Name: " , example_name, |
286 | ", Feature: " , key, |
287 | ". Data types don't match. " , |
288 | "Expected type: " , DataTypeString(dtype), |
289 | " Feature is: " , f.DebugString()); |
290 | } |
291 | (*output_sparse_values_tmp)[d][batch_index] = |
292 | FeatureSparseCopy(batch_index, key, dtype, f); |
293 | } else { |
294 | (*output_sparse_values_tmp)[d][batch_index] = |
295 | Tensor(dtype, TensorShape({0})); |
296 | } |
297 | } |
298 | return OkStatus(); |
299 | } |
300 | |
301 | Status GetSparseTensorShapes(const VarLenFeature& var_len_feature, |
302 | const std::vector<Tensor>& sparse_values_tmp, |
303 | const int batch_size, |
304 | VarLenFeatureBatchShapes* output_shapes) { |
305 | int64_t total_num_features = 0; |
306 | int64_t max_num_features = 0; |
307 | for (int b = 0; b < batch_size; ++b) { |
308 | const Tensor& t = sparse_values_tmp[b]; |
309 | const int64_t num_elements = t.shape().num_elements(); |
310 | total_num_features += num_elements; |
311 | max_num_features = std::max(max_num_features, num_elements); |
312 | } |
313 | output_shapes->indices_shape.AddDim(total_num_features); |
314 | output_shapes->indices_shape.AddDim(2); |
315 | output_shapes->values_shape.AddDim(total_num_features); |
316 | output_shapes->max_num_features = max_num_features; |
317 | return OkStatus(); |
318 | } |
319 | |
320 | Status BatchExampleProtoToTensors( |
321 | const std::vector<const Example*>& examples, |
322 | const std::vector<string>& names, |
323 | const std::vector<FixedLenFeature>& fixed_len_features, |
324 | const std::vector<VarLenFeature>& var_len_features, Allocator* allocator, |
325 | std::vector<Tensor>* output_dense_values_tensor, |
326 | std::vector<Tensor>* output_sparse_indices_tensor, |
327 | std::vector<Tensor>* output_sparse_values_tensor, |
328 | std::vector<Tensor>* output_sparse_shapes_tensor) { |
329 | const int batch_size = examples.size(); |
330 | |
331 | const bool has_names = (!names.empty()); |
332 | if (has_names) { |
333 | if (names.size() != examples.size()) { |
334 | return errors::InvalidArgument( |
335 | "Expected len(names) == len(examples), but got: " , names.size(), |
336 | " vs. " , examples.size()); |
337 | } |
338 | } |
339 | |
340 | // We also need a map of Tensor pointers for the SingleExampleProtoToTensors |
341 | // call. (Is there a better solution here?) |
342 | std::vector<Tensor*> output_dense_values_tensor_ptrs( |
343 | fixed_len_features.size()); |
344 | |
345 | // Preallocate dense_values, since we know their sizes. |
346 | for (size_t d = 0; d < fixed_len_features.size(); ++d) { |
347 | const FixedLenFeature& config = fixed_len_features[d]; |
348 | TensorShape out_shape; |
349 | out_shape.AddDim(batch_size); |
350 | const TensorShape& shape = config.shape; |
351 | const DataType& dtype = config.dtype; |
352 | for (const int dim : shape.dim_sizes()) out_shape.AddDim(dim); |
353 | (*output_dense_values_tensor)[d] = Tensor(allocator, dtype, out_shape); |
354 | output_dense_values_tensor_ptrs[d] = &(*output_dense_values_tensor)[d]; |
355 | } |
356 | |
357 | // Temporary vector to hold sparse values. |
358 | std::vector<std::vector<Tensor>> sparse_values_tmp(var_len_features.size()); |
359 | |
360 | for (size_t d = 0; d < var_len_features.size(); ++d) { |
361 | sparse_values_tmp[d] = std::vector<Tensor>(batch_size); |
362 | } |
363 | |
364 | for (size_t b = 0; b < examples.size(); ++b) { |
365 | const Example& ex = *(examples[b]); |
366 | const string& example_name = (has_names) ? names[b] : "<unknown>" ; |
367 | TF_RETURN_IF_ERROR(SingleExampleProtoToTensors( |
368 | ex, example_name, b, fixed_len_features, var_len_features, |
369 | &output_dense_values_tensor_ptrs, &sparse_values_tmp)); |
370 | } |
371 | |
372 | for (size_t d = 0; d < var_len_features.size(); ++d) { |
373 | const VarLenFeature& feature_config = var_len_features[d]; |
374 | const DataType& dtype = feature_config.dtype; |
375 | const std::vector<Tensor>& sparse_values_tensor = sparse_values_tmp[d]; |
376 | |
377 | VarLenFeatureBatchShapes sparse_tensor_batch_shapes; |
378 | TF_RETURN_IF_ERROR(GetSparseTensorShapes(feature_config, |
379 | sparse_values_tensor, batch_size, |
380 | &sparse_tensor_batch_shapes)); |
381 | const TensorShape& indices_shape = sparse_tensor_batch_shapes.indices_shape; |
382 | const TensorShape& values_shape = sparse_tensor_batch_shapes.values_shape; |
383 | |
384 | // Allocate the sparse indices here. |
385 | (*output_sparse_indices_tensor)[d] = |
386 | Tensor(allocator, DT_INT64, indices_shape); |
387 | (*output_sparse_values_tensor)[d] = Tensor(allocator, dtype, values_shape); |
388 | (*output_sparse_shapes_tensor)[d] = |
389 | Tensor(allocator, DT_INT64, TensorShape({2})); |
390 | |
391 | auto shape_t = (*output_sparse_shapes_tensor)[d].vec<int64_t>(); |
392 | shape_t(0) = batch_size; |
393 | shape_t(1) = sparse_tensor_batch_shapes.max_num_features; |
394 | |
395 | Tensor* sp_indices_d = &(*output_sparse_indices_tensor)[d]; |
396 | Tensor* sp_values_d = &(*output_sparse_values_tensor)[d]; |
397 | |
398 | int64_t offset = 0; |
399 | for (int b = 0; b < batch_size; ++b) { |
400 | const int64_t num_elements = CopyIntoSparseTensor( |
401 | sparse_values_tensor[b], b, offset, sp_indices_d, sp_values_d); |
402 | offset += num_elements; |
403 | } |
404 | } |
405 | return OkStatus(); |
406 | } |
407 | |
408 | Status ParseExampleAttrs::FinishInit(int op_version) { |
409 | switch (op_version) { |
410 | case 1: |
411 | num_ragged = 0; |
412 | break; |
413 | case 2: |
414 | num_dense = dense_types.size(); |
415 | num_ragged = ragged_value_types.size(); |
416 | break; |
417 | default: |
418 | return errors::InvalidArgument("Unexpected op_version" , op_version); |
419 | } |
420 | if (static_cast<size_t>(num_sparse) != sparse_types.size()) { |
421 | return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)" ); |
422 | } |
423 | if (static_cast<size_t>(num_dense) != dense_types.size()) { |
424 | return errors::InvalidArgument("len(dense_keys) != len(dense_types)" ); |
425 | } |
426 | if (static_cast<size_t>(num_dense) != dense_shapes.size()) { |
427 | return errors::InvalidArgument("len(dense_keys) != len(dense_shapes)" ); |
428 | } |
429 | if (static_cast<size_t>(num_ragged) != ragged_value_types.size()) { |
430 | return errors::InvalidArgument( |
431 | "len(ragged_keys) != len(ragged_value_types)" ); |
432 | } |
433 | if (static_cast<size_t>(num_ragged) != ragged_split_types.size()) { |
434 | return errors::InvalidArgument( |
435 | "len(ragged_keys) != len(ragged_split_types)" ); |
436 | } |
437 | if (num_dense > std::numeric_limits<int32>::max()) { |
438 | return errors::InvalidArgument("num_dense_ too large" ); |
439 | } |
440 | for (const DataType& type : dense_types) { |
441 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
442 | } |
443 | for (const DataType& type : sparse_types) { |
444 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
445 | } |
446 | for (const DataType& type : ragged_value_types) { |
447 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
448 | } |
449 | for (const DataType& type : ragged_split_types) { |
450 | if (!(type == DT_INT64 || type == DT_INT32)) { |
451 | return errors::InvalidArgument("Invalid ragged_split_type: " , |
452 | DataTypeString(type)); |
453 | } |
454 | } |
455 | return OkStatus(); |
456 | } |
457 | |
458 | Status ParseSingleExampleAttrs::FinishInit() { |
459 | if (sparse_keys.size() != sparse_types.size()) { |
460 | return errors::InvalidArgument("len(sparse_keys) != len(sparse_types)" ); |
461 | } |
462 | if (dense_keys.size() != dense_types.size()) { |
463 | return errors::InvalidArgument("len(dense_keys) != len(dense_types)" ); |
464 | } |
465 | if (dense_keys.size() != dense_shapes.size()) { |
466 | return errors::InvalidArgument("len(dense_keys) != len(dense_shapes)" ); |
467 | } |
468 | for (const DataType& type : dense_types) { |
469 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
470 | } |
471 | for (const DataType& type : sparse_types) { |
472 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
473 | } |
474 | return OkStatus(); |
475 | } |
476 | |
477 | Status ParseSequenceExampleAttrs::FinishInit(int op_version) { |
478 | switch (op_version) { |
479 | case 1: |
480 | num_context_ragged = 0; |
481 | num_feature_list_ragged = 0; |
482 | if (num_context_sparse != context_sparse_keys.size()) { |
483 | return errors::InvalidArgument( |
484 | "num_context_sparse (" , num_context_sparse, |
485 | ") must match the size of context_sparse_keys (" , |
486 | context_sparse_keys.size(), ")" ); |
487 | } |
488 | if (num_context_dense != context_dense_keys.size()) { |
489 | return errors::InvalidArgument( |
490 | "num_context_dense (" , num_context_dense, |
491 | ") must match the size of context_dense_keys (" , |
492 | context_dense_keys.size(), ")" ); |
493 | } |
494 | if (num_feature_list_sparse != feature_list_sparse_keys.size()) { |
495 | return errors::InvalidArgument( |
496 | "num_feature_list_sparse (" , num_feature_list_sparse, |
497 | ") must match the size of feature_list_sparse_keys (" , |
498 | feature_list_sparse_keys.size(), ")" ); |
499 | } |
500 | if (num_feature_list_dense != feature_list_dense_keys.size()) { |
501 | return errors::InvalidArgument( |
502 | "num_feature_list_dense (" , num_feature_list_dense, |
503 | ") must match the size of feature_list_dense_keys (" , |
504 | feature_list_dense_keys.size(), ")" ); |
505 | } |
506 | break; |
507 | case 2: |
508 | num_context_dense = context_dense_types.size(); |
509 | num_context_ragged = context_ragged_value_types.size(); |
510 | num_feature_list_ragged = feature_list_ragged_value_types.size(); |
511 | break; |
512 | default: |
513 | return errors::InvalidArgument("Unexpected op_version" , op_version); |
514 | } |
515 | if (num_context_sparse != context_sparse_types.size()) { |
516 | return errors::InvalidArgument( |
517 | "num_context_sparse (" , num_context_sparse, |
518 | ") must match the size of context_sparse_types (" , |
519 | context_sparse_types.size(), ")" ); |
520 | } |
521 | if (num_context_dense != context_dense_types.size() || |
522 | num_context_dense != context_dense_shapes.size()) { |
523 | return errors::InvalidArgument( |
524 | "num_context_dense (" , num_context_dense, |
525 | ") must match the size of context_dense_types (" , |
526 | context_dense_types.size(), ") and context_dense_shapes (" , |
527 | context_dense_shapes.size(), ")" ); |
528 | } |
529 | if ((num_context_ragged != context_ragged_value_types.size()) || |
530 | (num_context_ragged != context_ragged_split_types.size())) { |
531 | return errors::InvalidArgument( |
532 | "num_context_ragged (" , num_context_ragged, |
533 | ") must match the size of context_ragged_value_types (" , |
534 | context_ragged_value_types.size(), ") and context_ragged_split_types (" , |
535 | context_ragged_split_types.size(), ")" ); |
536 | } |
537 | if (num_feature_list_sparse != feature_list_sparse_types.size()) { |
538 | return errors::InvalidArgument( |
539 | "num_feature_list_sparse (" , num_feature_list_sparse, |
540 | ") must match the size of feature_list_sparse_types (" , |
541 | feature_list_sparse_types.size(), ")" ); |
542 | } |
543 | if (num_feature_list_dense != feature_list_dense_types.size() || |
544 | num_feature_list_dense != feature_list_dense_shapes.size()) { |
545 | return errors::InvalidArgument( |
546 | "num_feature_list_dense (" , num_feature_list_dense, |
547 | ") must match the size of feature_list_dense_types (" , |
548 | feature_list_dense_types.size(), ") and feature_list_dense_shapes (" , |
549 | feature_list_dense_shapes.size(), ")" ); |
550 | } |
551 | if ((num_feature_list_ragged != feature_list_ragged_value_types.size()) || |
552 | (num_feature_list_ragged != feature_list_ragged_split_types.size())) { |
553 | return errors::InvalidArgument( |
554 | "num_feature_list_ragged (" , num_feature_list_ragged, |
555 | ") must match the size of feature_list_ragged_value_types (" , |
556 | feature_list_ragged_value_types.size(), |
557 | ") and feature_list_ragged_split_types (" , |
558 | feature_list_ragged_split_types.size(), ")" ); |
559 | } |
560 | for (const DataType& type : context_dense_types) { |
561 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
562 | } |
563 | for (const DataType& type : context_sparse_types) { |
564 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
565 | } |
566 | for (const DataType& type : feature_list_dense_types) { |
567 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
568 | } |
569 | for (const DataType& type : feature_list_sparse_types) { |
570 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
571 | } |
572 | for (const DataType& type : context_ragged_value_types) { |
573 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
574 | } |
575 | for (const DataType& type : context_ragged_split_types) { |
576 | if (!(type == DT_INT64 || type == DT_INT32)) { |
577 | return errors::InvalidArgument("Invalid context_ragged_split_type: " , |
578 | DataTypeString(type)); |
579 | } |
580 | } |
581 | for (const DataType& type : feature_list_ragged_value_types) { |
582 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
583 | } |
584 | for (const DataType& type : feature_list_ragged_split_types) { |
585 | if (!(type == DT_INT64 || type == DT_INT32)) { |
586 | return errors::InvalidArgument("Invalid feature_list_ragged_split_type: " , |
587 | DataTypeString(type)); |
588 | } |
589 | } |
590 | |
591 | return OkStatus(); |
592 | } |
593 | |
594 | Status ParseSingleSequenceExampleAttrs::FinishInit() { |
595 | if (static_cast<size_t>(num_context_sparse) != context_sparse_types.size()) { |
596 | return errors::InvalidArgument( |
597 | "len(context_sparse_keys) != len(context_sparse_types)" ); |
598 | } |
599 | if (static_cast<size_t>(num_context_dense) != context_dense_types.size()) { |
600 | return errors::InvalidArgument( |
601 | "len(context_dense_keys) != len(context_dense_types)" ); |
602 | } |
603 | if (static_cast<size_t>(num_context_dense) != context_dense_shapes.size()) { |
604 | return errors::InvalidArgument( |
605 | "len(context_dense_keys) != len(context_dense_shapes)" ); |
606 | } |
607 | if (static_cast<size_t>(num_feature_list_sparse) != |
608 | feature_list_sparse_types.size()) { |
609 | return errors::InvalidArgument( |
610 | "len(feature_list_sparse_keys) != len(feature_list_sparse_types)" ); |
611 | } |
612 | if (static_cast<size_t>(num_feature_list_dense) != |
613 | feature_list_dense_types.size()) { |
614 | return errors::InvalidArgument( |
615 | "len(feature_list_dense_keys) != " |
616 | "len(feature_list_dense_types)" ); |
617 | } |
618 | for (const DataType& type : context_dense_types) { |
619 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
620 | } |
621 | for (const DataType& type : context_sparse_types) { |
622 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
623 | } |
624 | for (const DataType& type : feature_list_dense_types) { |
625 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
626 | } |
627 | for (const DataType& type : feature_list_sparse_types) { |
628 | TF_RETURN_IF_ERROR(CheckValidType(type)); |
629 | } |
630 | return OkStatus(); |
631 | } |
632 | |
633 | Status GetDenseShapes(const std::vector<PartialTensorShape>& dense_shapes, |
634 | std::vector<bool>* variable_length, |
635 | std::vector<std::size_t>* elements_per_stride) { |
636 | // Temporary check until we start allowing a variable length outer |
637 | // dimension. |
638 | for (int i = 0; i < dense_shapes.size(); ++i) { |
639 | bool shape_ok = true; |
640 | if (dense_shapes[i].dims() == -1) { |
641 | shape_ok = false; |
642 | } else { |
643 | for (int d = 1; d < dense_shapes[i].dims(); ++d) { |
644 | if (dense_shapes[i].dim_size(d) == -1) { |
645 | shape_ok = false; |
646 | } |
647 | } |
648 | } |
649 | if (!shape_ok) { |
650 | return errors::InvalidArgument( |
651 | "dense_shapes[" , i, |
652 | "] has unknown rank or unknown inner dimensions: " , |
653 | dense_shapes[i].DebugString()); |
654 | } |
655 | TensorShape dense_shape; |
656 | if (dense_shapes[i].dims() > 0 && dense_shapes[i].dim_size(0) == -1) { |
657 | variable_length->push_back(true); |
658 | for (int d = 1; d < dense_shapes[i].dims(); ++d) { |
659 | dense_shape.AddDim(dense_shapes[i].dim_size(d)); |
660 | } |
661 | } else { |
662 | variable_length->push_back(false); |
663 | dense_shapes[i].AsTensorShape(&dense_shape); |
664 | } |
665 | elements_per_stride->push_back(dense_shape.num_elements()); |
666 | } |
667 | return OkStatus(); |
668 | } |
669 | |
670 | } // namespace tensorflow |
671 | |