1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include "tensorflow/core/util/example_proto_fast_parsing.h" |
16 | |
17 | #include <vector> |
18 | |
19 | #include "absl/base/casts.h" |
20 | #include "absl/container/flat_hash_map.h" |
21 | #include "tensorflow/core/example/example.pb.h" |
22 | #include "tensorflow/core/example/feature.pb.h" |
23 | #include "tensorflow/core/framework/allocator.h" |
24 | #include "tensorflow/core/framework/numeric_op.h" |
25 | #include "tensorflow/core/framework/op_kernel.h" |
26 | #include "tensorflow/core/framework/register_types.h" |
27 | #include "tensorflow/core/framework/types.pb.h" |
28 | #include "tensorflow/core/lib/core/errors.h" |
29 | #include "tensorflow/core/lib/core/threadpool.h" |
30 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
31 | #include "tensorflow/core/lib/monitoring/counter.h" |
32 | #include "tensorflow/core/platform/blocking_counter.h" |
33 | #include "tensorflow/core/platform/byte_order.h" |
34 | #include "tensorflow/core/platform/logging.h" |
35 | #include "tensorflow/core/platform/protobuf.h" |
36 | #include "tensorflow/core/util/presized_cuckoo_map.h" |
37 | #include "tensorflow/core/util/sparse/sparse_tensor.h" |
38 | |
39 | namespace tensorflow { |
40 | namespace example { |
41 | |
42 | namespace { |
43 | |
44 | template <typename T> |
45 | using SmallVector = gtl::InlinedVector<T, 4>; |
46 | |
47 | template <typename T> |
48 | class LimitedArraySlice { |
49 | public: |
50 | using value_type = T; |
51 | |
52 | LimitedArraySlice(T* begin, size_t num_elements) |
53 | : current_(begin), begin_(begin), end_(begin + num_elements) {} |
54 | |
55 | // May return negative if there were push_back calls after slice was filled. |
56 | int64_t EndDistance() const { return end_ - current_; } |
57 | |
58 | // Attempts to push value to the back of this. If the slice has |
59 | // already been filled, this method has no effect on the underlying data, but |
60 | // it changes the number returned by EndDistance into negative values. |
61 | void push_back(T&& value) { |
62 | if (EndDistance() > 0) *current_ = std::move(value); |
63 | ++current_; |
64 | } |
65 | |
66 | // "Constructs" an element at the back of this by resizing the slice, and |
67 | // returns a mutable reference to the new last element. |
68 | // REQUIRES: EndDistance() > 0. |
69 | T& construct_at_end() { |
70 | DCHECK_GT(EndDistance(), 0); |
71 | return *(current_++); |
72 | } |
73 | |
74 | // Returns a mutable reference to the last element in the slice. |
75 | // REQUIRES: size() > 0. |
76 | T& back() { return *(current_ - 1); } |
77 | |
78 | // Returns the number of elements in the slice. |
79 | size_t size() const { return std::min(current_ - begin_, end_ - begin_); } |
80 | |
81 | // Attempts to resize the vector to the given size. It does so by advancing |
82 | // the pointer to the current element, possibly beyond the end of the slice. |
83 | // As a consequence, calling `size()` after `resize(x)` was called might |
84 | // return a value less than `x`. |
85 | void resize(size_t size) { current_ = begin_ + size; } |
86 | |
87 | // Returns the pointer to the underlying data buffer. |
88 | T* data() { return begin_; } |
89 | |
90 | private: |
91 | T* current_; |
92 | T* begin_; |
93 | T* end_; |
94 | }; |
95 | |
96 | template <typename A> |
97 | auto EnableAliasing(A* a) -> decltype(a->EnableAliasing(true), void()) { |
98 | a->EnableAliasing(true); |
99 | } |
100 | |
101 | template <typename A> |
102 | void EnableAliasing(A&& a) {} |
103 | |
104 | uint8 PeekTag(protobuf::io::CodedInputStream* stream) { |
105 | DCHECK(stream != nullptr); |
106 | const void* ptr; |
107 | int size; |
108 | if (!stream->GetDirectBufferPointer(&ptr, &size)) return 0; |
109 | return *static_cast<const uint8*>(ptr); |
110 | } |
111 | |
112 | constexpr uint8 kVarintTag(uint32 tag) { return (tag << 3) | 0; } |
113 | constexpr uint8 kDelimitedTag(uint32 tag) { return (tag << 3) | 2; } |
114 | constexpr uint8 kFixed32Tag(uint32 tag) { return (tag << 3) | 5; } |
115 | |
116 | namespace parsed { |
117 | |
118 | // ParseDataType has to be called first, then appropriate ParseZzzzList. |
119 | class Feature { |
120 | public: |
121 | Feature() {} |
122 | explicit Feature(StringPiece serialized) : serialized_(serialized) {} |
123 | |
124 | Status ParseDataType(DataType* dtype) { |
125 | DCHECK(dtype != nullptr); |
126 | if (serialized_.empty()) { |
127 | *dtype = DT_INVALID; |
128 | return OkStatus(); |
129 | } |
130 | uint8 oneof_tag = static_cast<uint8>(*serialized_.data()); |
131 | serialized_.remove_prefix(1); |
132 | switch (oneof_tag) { |
133 | case kDelimitedTag(1): |
134 | *dtype = DT_STRING; |
135 | break; |
136 | case kDelimitedTag(2): |
137 | *dtype = DT_FLOAT; |
138 | break; |
139 | case kDelimitedTag(3): |
140 | *dtype = DT_INT64; |
141 | break; |
142 | default: |
143 | // Initialize variable to avoid compiler warning |
144 | *dtype = DT_INVALID; |
145 | return errors::InvalidArgument("Unsupported datatype." ); |
146 | } |
147 | return OkStatus(); |
148 | } |
149 | |
150 | bool GetNumElementsInBytesList(int* num_elements) { |
151 | protobuf::io::CodedInputStream stream( |
152 | reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); |
153 | EnableAliasing(&stream); |
154 | uint32 length = 0; |
155 | if (!stream.ReadVarint32(&length)) return false; |
156 | auto limit = stream.PushLimit(length); |
157 | *num_elements = 0; |
158 | while (!stream.ExpectAtEnd()) { |
159 | if (!stream.ExpectTag(kDelimitedTag(1))) return false; |
160 | uint32 bytes_length = 0; |
161 | if (!stream.ReadVarint32(&bytes_length)) return false; |
162 | if (!stream.Skip(bytes_length)) return false; |
163 | ++*num_elements; |
164 | } |
165 | stream.PopLimit(limit); |
166 | return true; |
167 | } |
168 | |
169 | // Helper methods |
170 | tstring* construct_at_end(LimitedArraySlice<tstring>* bytes_list) { |
171 | if (bytes_list->EndDistance() <= 0) { |
172 | return nullptr; |
173 | } |
174 | return &bytes_list->construct_at_end(); |
175 | } |
176 | tstring* construct_at_end(SmallVector<tstring>* bytes_list) { |
177 | return &bytes_list->emplace_back(); |
178 | } |
179 | |
180 | template <typename Result> |
181 | bool ParseBytesList(Result* bytes_list) { |
182 | DCHECK(bytes_list != nullptr); |
183 | |
184 | protobuf::io::CodedInputStream stream( |
185 | reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); |
186 | |
187 | EnableAliasing(&stream); |
188 | |
189 | uint32 length; |
190 | if (!stream.ReadVarint32(&length)) return false; |
191 | auto limit = stream.PushLimit(length); |
192 | |
193 | while (!stream.ExpectAtEnd()) { |
194 | if (!stream.ExpectTag(kDelimitedTag(1))) return false; |
195 | // parse string |
196 | uint32 bytes_length; |
197 | if (!stream.ReadVarint32(&bytes_length)) return false; |
198 | tstring* bytes = construct_at_end(bytes_list); |
199 | if (bytes == nullptr) return false; |
200 | bytes->resize_uninitialized(bytes_length); |
201 | if (!stream.ReadRaw(bytes->data(), bytes_length)) return false; |
202 | } |
203 | stream.PopLimit(limit); |
204 | return true; |
205 | } |
206 | |
207 | template <typename Result> |
208 | bool ParseFloatList(Result* float_list) { |
209 | DCHECK(float_list != nullptr); |
210 | protobuf::io::CodedInputStream stream( |
211 | reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); |
212 | EnableAliasing(&stream); |
213 | uint32 length; |
214 | if (!stream.ReadVarint32(&length)) return false; |
215 | auto limit = stream.PushLimit(length); |
216 | |
217 | if (!stream.ExpectAtEnd()) { |
218 | uint8 peek_tag = PeekTag(&stream); |
219 | if (peek_tag != kDelimitedTag(1) && peek_tag != kFixed32Tag(1)) { |
220 | return false; |
221 | } |
222 | |
223 | constexpr int32_t kNumFloatBytes = 4; |
224 | if (peek_tag == kDelimitedTag(1)) { // packed |
225 | if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag |
226 | uint32 packed_length; |
227 | if (!stream.ReadVarint32(&packed_length)) return false; |
228 | auto packed_limit = stream.PushLimit(packed_length); |
229 | |
230 | // Store the initial size to know the offset we have to start writing |
231 | // data from before resizing the output "vector". |
232 | const size_t initial_size = float_list->size(); |
233 | float_list->resize(initial_size + packed_length / kNumFloatBytes); |
234 | |
235 | // If the result data type is float and we are on a little endian |
236 | // machine then we can simply memcpy the data from the proto into the |
237 | // result vector. |
238 | if (port::kLittleEndian && |
239 | sizeof(typename Result::value_type) == kNumFloatBytes) { |
240 | // Calculate the length of the buffer available what can be less than |
241 | // what we requested in resize in case of a LimitedArraySlice. |
242 | const uint32 bytes_to_copy = |
243 | std::min(static_cast<uint32>((float_list->size() - initial_size) * |
244 | kNumFloatBytes), |
245 | packed_length); |
246 | if (!stream.ReadRaw(float_list->data() + initial_size, bytes_to_copy)) |
247 | return false; |
248 | } else { |
249 | int64_t index = initial_size; |
250 | while (!stream.ExpectAtEnd()) { |
251 | uint32 buffer32; |
252 | if (!stream.ReadLittleEndian32(&buffer32)) return false; |
253 | if (index < float_list->size()) { |
254 | float_list->data()[index] = absl::bit_cast<float>(buffer32); |
255 | ++index; |
256 | } |
257 | } |
258 | } |
259 | |
260 | stream.PopLimit(packed_limit); |
261 | } else { // non-packed |
262 | const size_t initial_size = float_list->size(); |
263 | // 1 byte for the tag (`1` encoded as Variant32) and kNumFloatBytes for |
264 | // the value. |
265 | const int64_t num_elements = |
266 | stream.BytesUntilLimit() / (1 + kNumFloatBytes); |
267 | float_list->resize(initial_size + num_elements); |
268 | int64_t index = initial_size; |
269 | while (!stream.ExpectAtEnd()) { |
270 | if (!stream.ExpectTag(kFixed32Tag(1))) return false; |
271 | uint32 buffer32; |
272 | if (!stream.ReadLittleEndian32(&buffer32)) return false; |
273 | float_list->data()[index] = absl::bit_cast<float>(buffer32); |
274 | ++index; |
275 | } |
276 | } |
277 | } |
278 | |
279 | stream.PopLimit(limit); |
280 | return true; |
281 | } |
282 | |
283 | template <typename Result> |
284 | bool ParseInt64List(Result* int64_list) { |
285 | DCHECK(int64_list != nullptr); |
286 | protobuf::io::CodedInputStream stream( |
287 | reinterpret_cast<const uint8*>(serialized_.data()), serialized_.size()); |
288 | EnableAliasing(&stream); |
289 | uint32 length; |
290 | if (!stream.ReadVarint32(&length)) return false; |
291 | auto limit = stream.PushLimit(length); |
292 | |
293 | if (!stream.ExpectAtEnd()) { |
294 | uint8 peek_tag = PeekTag(&stream); |
295 | if (peek_tag != kDelimitedTag(1) && peek_tag != kVarintTag(1)) { |
296 | return false; |
297 | } |
298 | if (peek_tag == kDelimitedTag(1)) { // packed |
299 | if (!stream.ExpectTag(kDelimitedTag(1))) return false; // packed tag |
300 | uint32 packed_length; |
301 | if (!stream.ReadVarint32(&packed_length)) return false; |
302 | auto packed_limit = stream.PushLimit(packed_length); |
303 | |
304 | while (!stream.ExpectAtEnd()) { |
305 | protobuf_uint64 n; // There is no API for int64 |
306 | if (!stream.ReadVarint64(&n)) return false; |
307 | int64_list->push_back(static_cast<int64_t>(n)); |
308 | } |
309 | |
310 | stream.PopLimit(packed_limit); |
311 | } else { // non-packed |
312 | while (!stream.ExpectAtEnd()) { |
313 | if (!stream.ExpectTag(kVarintTag(1))) return false; |
314 | protobuf_uint64 n; // There is no API for int64 |
315 | if (!stream.ReadVarint64(&n)) return false; |
316 | int64_list->push_back(static_cast<int64_t>(n)); |
317 | } |
318 | } |
319 | } |
320 | stream.PopLimit(limit); |
321 | return true; |
322 | } |
323 | |
324 | StringPiece GetSerialized() const { return serialized_; } |
325 | |
326 | private: |
327 | // TODO(lew): Pair of uint8* would be more natural. |
328 | StringPiece serialized_; |
329 | }; |
330 | |
331 | using FeatureMapEntry = std::pair<StringPiece, Feature>; |
332 | using Example = std::vector<FeatureMapEntry>; |
333 | |
334 | } // namespace parsed |
335 | |
336 | inline bool (protobuf::io::CodedInputStream* stream) { |
337 | uint32 data; |
338 | protobuf_uint64 dummy; |
339 | switch (stream->ReadTag() & 0x7) { |
340 | case 0: // varint |
341 | if (!stream->ReadVarint32(&data)) return false; |
342 | return true; |
343 | case 1: // fixed64 |
344 | if (!stream->ReadLittleEndian64(&dummy)) return false; |
345 | return true; |
346 | case 2: // length delimited |
347 | if (!stream->ReadVarint32(&data)) return false; |
348 | stream->Skip(data); |
349 | return true; |
350 | case 3: // group begin |
351 | return false; // groups not supported. |
352 | case 4: // group end |
353 | return false; // groups not supported. |
354 | case 5: // fixed32 |
355 | if (!stream->ReadLittleEndian32(&data)) return false; |
356 | return true; |
357 | } |
358 | return false; // unrecognized tag type |
359 | } |
360 | |
361 | bool ParseString(protobuf::io::CodedInputStream* stream, StringPiece* result) { |
362 | DCHECK(stream != nullptr); |
363 | DCHECK(result != nullptr); |
364 | uint32 length; |
365 | if (!stream->ReadVarint32(&length)) return false; |
366 | if (length == 0) { |
367 | *result = StringPiece(nullptr, 0); |
368 | return true; |
369 | } |
370 | const void* stream_alias; |
371 | int stream_size; |
372 | if (!stream->GetDirectBufferPointer(&stream_alias, &stream_size)) { |
373 | return false; |
374 | } |
375 | if (static_cast<uint32>(stream_size) < length) return false; |
376 | *result = StringPiece(static_cast<const char*>(stream_alias), length); |
377 | stream->Skip(length); |
378 | return true; |
379 | } |
380 | |
381 | bool ParseFeatureMapEntry(protobuf::io::CodedInputStream* stream, |
382 | parsed::FeatureMapEntry* feature_map_entry) { |
383 | DCHECK(stream != nullptr); |
384 | DCHECK(feature_map_entry != nullptr); |
385 | uint32 length; |
386 | if (!stream->ReadVarint32(&length)) return false; |
387 | auto limit = stream->PushLimit(length); |
388 | |
389 | // Protobufs allow an arbitrary order for the key and value fields. |
390 | for (int n = 0; n < 2; ++n) { |
391 | const uint32_t tag = stream->ReadTag(); |
392 | switch (tag) { |
393 | case kDelimitedTag(1): |
394 | if (!ParseString(stream, &feature_map_entry->first)) return false; |
395 | break; |
396 | |
397 | case kDelimitedTag(2): { |
398 | StringPiece feature_string_piece; |
399 | if (!ParseString(stream, &feature_string_piece)) return false; |
400 | feature_map_entry->second = parsed::Feature(feature_string_piece); |
401 | break; |
402 | } |
403 | |
404 | default: |
405 | return false; |
406 | } |
407 | } |
408 | |
409 | if (!stream->ExpectAtEnd()) return false; |
410 | stream->PopLimit(limit); |
411 | return true; |
412 | } |
413 | |
414 | bool ParseFeatures(protobuf::io::CodedInputStream* stream, |
415 | parsed::Example* example) { |
416 | DCHECK(stream != nullptr); |
417 | DCHECK(example != nullptr); |
418 | uint32 length; |
419 | if (!stream->ReadVarint32(&length)) return false; |
420 | auto limit = stream->PushLimit(length); |
421 | while (!stream->ExpectAtEnd()) { |
422 | parsed::FeatureMapEntry feature_map_entry; |
423 | if (!stream->ExpectTag(kDelimitedTag(1))) return false; |
424 | if (!ParseFeatureMapEntry(stream, &feature_map_entry)) return false; |
425 | example->push_back(std::move(feature_map_entry)); |
426 | } |
427 | stream->PopLimit(limit); |
428 | return true; |
429 | } |
430 | |
431 | bool ParseExample(protobuf::io::CodedInputStream* stream, |
432 | parsed::Example* example) { |
433 | DCHECK(stream != nullptr); |
434 | DCHECK(example != nullptr); |
435 | // Loop over the input stream which may contain multiple serialized Example |
436 | // protos merged together as strings. This behavior is consistent with Proto's |
437 | // ParseFromString when string representations are concatenated. |
438 | while (!stream->ExpectAtEnd()) { |
439 | if (!stream->ExpectTag(kDelimitedTag(1))) { |
440 | if (!SkipExtraneousTag(stream)) return false; |
441 | } else { |
442 | if (!ParseFeatures(stream, example)) return false; |
443 | } |
444 | } |
445 | return true; |
446 | } |
447 | |
448 | bool ParseExample(StringPiece serialized, parsed::Example* example) { |
449 | DCHECK(example != nullptr); |
450 | protobuf::io::CodedInputStream stream( |
451 | reinterpret_cast<const uint8*>(serialized.data()), serialized.size()); |
452 | EnableAliasing(&stream); |
453 | return ParseExample(&stream, example); |
454 | } |
455 | |
456 | } // namespace |
457 | |
458 | bool TestFastParse(const string& serialized, Example* example) { |
459 | DCHECK(example != nullptr); |
460 | parsed::Example parsed_example; |
461 | if (!ParseExample(serialized, &parsed_example)) return false; |
462 | auto& features = *example->mutable_features(); |
463 | size_t parsed_example_size = parsed_example.size(); |
464 | for (size_t i = 0; i < parsed_example_size; ++i) { |
465 | // This is a logic that standard protobuf parsing is implementing. |
466 | // I.e. last entry in the map overwrites all the previous ones. |
467 | parsed::FeatureMapEntry& name_and_feature = |
468 | parsed_example[parsed_example_size - i - 1]; |
469 | string name(name_and_feature.first); |
470 | if ((*features.mutable_feature()).count(name) > 0) continue; |
471 | |
472 | auto& value = (*features.mutable_feature())[name]; |
473 | DataType dtype; |
474 | if (!name_and_feature.second.ParseDataType(&dtype).ok()) return false; |
475 | switch (dtype) { |
476 | case DT_INVALID: |
477 | break; |
478 | case DT_STRING: { |
479 | SmallVector<tstring> list; |
480 | if (!name_and_feature.second.ParseBytesList(&list)) return false; |
481 | auto* result_list = value.mutable_bytes_list(); |
482 | for (auto& bytes : list) { |
483 | result_list->add_value(bytes.data(), bytes.size()); |
484 | } |
485 | break; |
486 | } |
487 | case DT_FLOAT: { |
488 | SmallVector<float> list; |
489 | if (!name_and_feature.second.ParseFloatList(&list)) return false; |
490 | auto* result_list = value.mutable_float_list(); |
491 | for (float f : list) { |
492 | result_list->add_value(f); |
493 | } |
494 | break; |
495 | } |
496 | case DT_INT64: { |
497 | SmallVector<int64_t> list; |
498 | if (!name_and_feature.second.ParseInt64List(&list)) return false; |
499 | auto* result_list = value.mutable_int64_list(); |
500 | for (int64_t i : list) { |
501 | result_list->add_value(i); |
502 | } |
503 | break; |
504 | } |
505 | default: |
506 | LOG(FATAL) << "Should not happen." ; |
507 | } |
508 | } |
509 | return true; |
510 | } |
511 | |
512 | // ----------------------------------------------------------------------------- |
513 | |
514 | namespace { |
515 | |
516 | using Config = FastParseExampleConfig; |
517 | |
518 | void ParallelFor(const std::function<void(size_t)>& f, size_t n, |
519 | thread::ThreadPool* thread_pool) { |
520 | if (n == 0) return; |
521 | if (thread_pool == nullptr) { |
522 | for (size_t i = 0; i < n; ++i) { |
523 | f(i); |
524 | } |
525 | } else { |
526 | BlockingCounter counter(n - 1); |
527 | for (size_t i = 1; i < n; ++i) { |
528 | thread_pool->Schedule([i, &f, &counter] { |
529 | f(i); |
530 | counter.DecrementCount(); |
531 | }); |
532 | } |
533 | f(0); |
534 | counter.Wait(); |
535 | } |
536 | } |
537 | |
538 | // Enumeration for distinguishing feature types. |
539 | // Note: FastParseSequenceExample constructs a map that includes Type values, |
540 | // and relies on the fact that they are default-initialized to Dense. |
541 | enum class Type { Dense, Sparse, Ragged }; |
542 | |
543 | // Note: We use SparseBuffer for sparse, ragged, and dense_varlen features. |
544 | struct SparseBuffer { |
545 | // Features are in one of the 3 vectors below depending on config's dtype. |
546 | // Other 2 vectors remain empty. |
547 | SmallVector<tstring> bytes_list; |
548 | SmallVector<float> float_list; |
549 | SmallVector<int64_t> int64_list; |
550 | |
551 | // Features of example i are elements with indices |
552 | // from example_end_indices[i-1] to example_end_indices[i]-1 on the |
553 | // appropriate xxxxx_list |
554 | std::vector<size_t> example_end_indices; |
555 | }; |
556 | |
557 | struct SeededHasher { |
558 | uint64 operator()(StringPiece s) const { |
559 | return Hash64(s.data(), s.size(), seed); |
560 | } |
561 | uint64 seed{0xDECAFCAFFE}; |
562 | }; |
563 | |
564 | void LogDenseFeatureDataLoss(StringPiece feature_name) { |
565 | LOG(WARNING) << "Data loss! Feature '" << feature_name |
566 | << "' is present in multiple concatenated " |
567 | "tf.Examples. Ignoring all but last one." ; |
568 | static auto* duplicated_dense_feature = monitoring::Counter<0>::New( |
569 | "/tensorflow/core/util/example_proto_fast_parsing/" |
570 | "duplicated_dense_feature" , |
571 | "Dense feature appears twice in a tf.Example" ); |
572 | duplicated_dense_feature->GetCell()->IncrementBy(1); |
573 | } |
574 | |
575 | void LogSparseFeatureDataLoss(StringPiece feature_name) { |
576 | LOG(WARNING) << "Data loss! Feature '" << feature_name |
577 | << "' is present in multiple concatenated " |
578 | "tf.Examples. Ignoring all but last one." ; |
579 | static auto* duplicated_sparse_feature = monitoring::Counter<0>::New( |
580 | "/tensorflow/core/util/example_proto_fast_parsing/" |
581 | "duplicated_sparse_feature" , |
582 | "Sparse feature appears twice in a tf.Example" ); |
583 | duplicated_sparse_feature->GetCell()->IncrementBy(1); |
584 | } |
585 | |
586 | Status FastParseSerializedExample( |
587 | const tstring& serialized_example, const tstring& example_name, |
588 | const size_t example_index, const Config& config, |
589 | const PresizedCuckooMap<std::pair<size_t, Type>>& config_index, |
590 | SeededHasher hasher, std::vector<Tensor>* output_dense, |
591 | std::vector<SparseBuffer>* output_varlen_dense, |
592 | std::vector<SparseBuffer>* output_sparse, |
593 | std::vector<SparseBuffer>* output_ragged, |
594 | PerExampleFeatureStats* output_stats) { |
595 | DCHECK(output_dense != nullptr); |
596 | DCHECK(output_sparse != nullptr); |
597 | DCHECK(output_ragged != nullptr); |
598 | parsed::Example parsed_example; |
599 | if (!ParseExample(serialized_example, &parsed_example)) { |
600 | return errors::InvalidArgument("Could not parse example input, value: '" , |
601 | serialized_example, "'" ); |
602 | } |
603 | std::vector<int64_t> sparse_feature_last_example(config.sparse.size(), -1); |
604 | std::vector<int64_t> dense_feature_last_example(config.dense.size(), -1); |
605 | std::vector<int64_t> ragged_feature_last_example(config.ragged.size(), -1); |
606 | |
607 | // Handle features present in the example. |
608 | const size_t parsed_example_size = parsed_example.size(); |
609 | |
610 | if (output_stats) { |
611 | // TODO(b/111553342): This may over-count the number of features if there |
612 | // are duplicate keys in the feature map. Consider deduplicating the keys |
613 | // before computing the count. |
614 | output_stats->features_count = parsed_example_size; |
615 | } |
616 | |
617 | for (size_t i = 0; i < parsed_example_size; ++i) { |
618 | // This is a logic that standard protobuf parsing is implementing. |
619 | // I.e. last entry in the map overwrites all the previous ones. |
620 | parsed::FeatureMapEntry& name_and_feature = |
621 | parsed_example[parsed_example_size - i - 1]; |
622 | |
623 | const StringPiece feature_name = name_and_feature.first; |
624 | parsed::Feature& feature = name_and_feature.second; |
625 | |
626 | std::pair<size_t, Type> d_and_type; |
627 | uint64 h = hasher(feature_name); |
628 | if (!config_index.Find(h, &d_and_type)) continue; |
629 | |
630 | size_t d = d_and_type.first; |
631 | bool is_dense = d_and_type.second == Type::Dense; |
632 | bool is_ragged = d_and_type.second == Type::Ragged; |
633 | |
634 | { |
635 | // Testing for PresizedCuckooMap collision. |
636 | // TODO(lew): Use dense_hash_map and avoid this and hasher creation. |
637 | const tstring& config_feature_name = |
638 | is_dense ? config.dense[d].feature_name |
639 | : (is_ragged ? config.ragged[d].feature_name |
640 | : config.sparse[d].feature_name); |
641 | if (feature_name != config_feature_name) continue; |
642 | } |
643 | |
644 | auto example_error = [&](StringPiece suffix) { |
645 | return errors::InvalidArgument("Name: " , example_name, |
646 | ", Key: " , feature_name, |
647 | ", Index: " , example_index, ". " , suffix); |
648 | }; |
649 | |
650 | auto parse_error = [&] { |
651 | return example_error("Can't parse serialized Example." ); |
652 | }; |
653 | |
654 | DataType example_dtype; |
655 | TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype)); |
656 | |
657 | if (is_dense) { |
658 | if (example_dtype == DT_INVALID) continue; |
659 | |
660 | // If feature was already visited, skip. |
661 | // Compare comment at the beginning of the loop. |
662 | if (dense_feature_last_example[d] == example_index) { |
663 | LogDenseFeatureDataLoss(feature_name); |
664 | continue; |
665 | } |
666 | dense_feature_last_example[d] = example_index; |
667 | |
668 | if (example_dtype != config.dense[d].dtype) { |
669 | return example_error(strings::StrCat( |
670 | "Data types don't match. Data type: " , |
671 | DataTypeString(example_dtype), |
672 | " but expected type: " , DataTypeString(config.dense[d].dtype))); |
673 | } |
674 | if (!config.dense[d].variable_length) { |
675 | Tensor& out = (*output_dense)[d]; |
676 | |
677 | const std::size_t num_elements = config.dense[d].elements_per_stride; |
678 | if (output_stats) { |
679 | // TODO(b/111553342): If desirable, we could add support for counting |
680 | // elements in the features that aren't parsed, but this could add |
681 | // considerable runtime cost. |
682 | output_stats->feature_values_count += num_elements; |
683 | } |
684 | |
685 | const std::size_t offset = example_index * num_elements; |
686 | |
687 | auto shape_error = [&](size_t size, StringPiece type_str) { |
688 | return example_error(strings::StrCat( |
689 | "Number of " , type_str, |
690 | " values != expected. " |
691 | "Values size: " , |
692 | size, |
693 | " but output shape: " , config.dense[d].shape.DebugString())); |
694 | }; |
695 | |
696 | switch (config.dense[d].dtype) { |
697 | case DT_INT64: { |
698 | auto out_p = out.flat<int64_t>().data() + offset; |
699 | LimitedArraySlice<int64_t> slice(out_p, num_elements); |
700 | if (!feature.ParseInt64List(&slice)) return parse_error(); |
701 | if (slice.EndDistance() != 0) { |
702 | return shape_error(num_elements - slice.EndDistance(), "int64" ); |
703 | } |
704 | break; |
705 | } |
706 | case DT_FLOAT: { |
707 | auto out_p = out.flat<float>().data() + offset; |
708 | LimitedArraySlice<float> slice(out_p, num_elements); |
709 | if (!feature.ParseFloatList(&slice)) return parse_error(); |
710 | if (slice.EndDistance() != 0) { |
711 | return shape_error(num_elements - slice.EndDistance(), "float" ); |
712 | } |
713 | break; |
714 | } |
715 | case DT_STRING: { |
716 | auto out_p = out.flat<tstring>().data() + offset; |
717 | LimitedArraySlice<tstring> slice(out_p, num_elements); |
718 | if (!feature.ParseBytesList(&slice)) return parse_error(); |
719 | if (slice.EndDistance() != 0) { |
720 | return shape_error(num_elements - slice.EndDistance(), "bytes" ); |
721 | } |
722 | break; |
723 | } |
724 | default: |
725 | LOG(FATAL) << "Should not happen." ; |
726 | } |
727 | } else { // if variable length |
728 | SparseBuffer& out = (*output_varlen_dense)[d]; |
729 | |
730 | const std::size_t num_elements = config.dense[d].elements_per_stride; |
731 | |
732 | if (example_dtype != DT_INVALID && |
733 | example_dtype != config.dense[d].dtype) { |
734 | return example_error(strings::StrCat( |
735 | "Data types don't match. " , |
736 | "Expected type: " , DataTypeString(config.dense[d].dtype))); |
737 | } |
738 | |
739 | auto shape_error = [&](size_t size, StringPiece type_str) { |
740 | return example_error(strings::StrCat( |
741 | "Number of " , type_str, |
742 | " values is not a multiple of stride length. Saw " , size, |
743 | " values but output shape is: " , |
744 | config.dense[d].shape.DebugString())); |
745 | }; |
746 | |
747 | switch (config.dense[d].dtype) { |
748 | case DT_INT64: { |
749 | if (example_dtype != DT_INVALID) { |
750 | if (!feature.ParseInt64List(&out.int64_list)) { |
751 | return parse_error(); |
752 | } |
753 | if (out.int64_list.size() % num_elements != 0) { |
754 | return shape_error(out.int64_list.size(), "int64" ); |
755 | } |
756 | } |
757 | out.example_end_indices.push_back(out.int64_list.size()); |
758 | break; |
759 | } |
760 | case DT_FLOAT: { |
761 | if (example_dtype != DT_INVALID) { |
762 | if (!feature.ParseFloatList(&out.float_list)) { |
763 | return parse_error(); |
764 | } |
765 | if (out.float_list.size() % num_elements != 0) { |
766 | return shape_error(out.float_list.size(), "float" ); |
767 | } |
768 | } |
769 | out.example_end_indices.push_back(out.float_list.size()); |
770 | break; |
771 | } |
772 | case DT_STRING: { |
773 | if (example_dtype != DT_INVALID) { |
774 | if (!feature.ParseBytesList(&out.bytes_list)) { |
775 | return parse_error(); |
776 | } |
777 | if (out.bytes_list.size() % num_elements != 0) { |
778 | return shape_error(out.bytes_list.size(), "bytes" ); |
779 | } |
780 | } |
781 | out.example_end_indices.push_back(out.bytes_list.size()); |
782 | break; |
783 | } |
784 | default: |
785 | LOG(FATAL) << "Should not happen." ; |
786 | } |
787 | |
788 | if (output_stats) { |
789 | // Use `out.example_end_indices` to determine the feature-value count |
790 | // for this feature, because the preceding switch statement pushes |
791 | // the length of the appropriate feature list to that vector. |
792 | // TODO(b/111553342): If desirable, we could add support for counting |
793 | // elements in the features that aren't parsed, but this could add |
794 | // considerable runtime cost. |
795 | const size_t out_examples_count = out.example_end_indices.size(); |
796 | if (out_examples_count == 1) { |
797 | output_stats->feature_values_count += out.example_end_indices[0]; |
798 | } else { |
799 | output_stats->feature_values_count += |
800 | out.example_end_indices[out_examples_count - 1] - |
801 | out.example_end_indices[out_examples_count - 2]; |
802 | } |
803 | } |
804 | } |
805 | } else { |
806 | // Feature is sparse or ragged. |
807 | auto& last_example = |
808 | is_ragged ? ragged_feature_last_example : sparse_feature_last_example; |
809 | |
810 | // If feature was already visited, skip. |
811 | // Compare comment at the beginning of the loop. |
812 | if (last_example[d] == example_index) { |
813 | LogSparseFeatureDataLoss(feature_name); |
814 | continue; |
815 | } |
816 | last_example[d] = example_index; |
817 | |
818 | // Handle sparse features. |
819 | SparseBuffer& out = is_ragged ? (*output_ragged)[d] : (*output_sparse)[d]; |
820 | DataType feature_dtype = |
821 | is_ragged ? config.ragged[d].dtype : config.sparse[d].dtype; |
822 | if (example_dtype != DT_INVALID && example_dtype != feature_dtype) { |
823 | return example_error( |
824 | strings::StrCat("Data types don't match. " , |
825 | "Expected type: " , DataTypeString(feature_dtype), |
826 | ", Actual type: " , DataTypeString(example_dtype))); |
827 | } |
828 | |
829 | switch (feature_dtype) { |
830 | case DT_INT64: { |
831 | if (example_dtype != DT_INVALID) { |
832 | if (!feature.ParseInt64List(&out.int64_list)) { |
833 | return parse_error(); |
834 | } |
835 | } |
836 | out.example_end_indices.push_back(out.int64_list.size()); |
837 | break; |
838 | } |
839 | case DT_FLOAT: { |
840 | if (example_dtype != DT_INVALID) { |
841 | if (!feature.ParseFloatList(&out.float_list)) { |
842 | return parse_error(); |
843 | } |
844 | } |
845 | out.example_end_indices.push_back(out.float_list.size()); |
846 | break; |
847 | } |
848 | case DT_STRING: { |
849 | if (example_dtype != DT_INVALID) { |
850 | if (!feature.ParseBytesList(&out.bytes_list)) { |
851 | return parse_error(); |
852 | } |
853 | } |
854 | out.example_end_indices.push_back(out.bytes_list.size()); |
855 | break; |
856 | } |
857 | default: |
858 | LOG(FATAL) << "Should not happen." ; |
859 | } |
860 | |
861 | if (output_stats) { |
862 | // Use `out.example_end_indices` to determine the feature-value count |
863 | // for this feature, because the preceding switch statement pushes |
864 | // the length of the appropriate feature list to that vector. |
865 | // TODO(b/111553342): If desirable, we could add support for counting |
866 | // elements in the features that aren't parsed, but this could add |
867 | // considerable runtime cost. |
868 | const size_t out_examples_count = out.example_end_indices.size(); |
869 | if (out_examples_count == 1) { |
870 | output_stats->feature_values_count += out.example_end_indices[0]; |
871 | } else { |
872 | output_stats->feature_values_count += |
873 | out.example_end_indices[out_examples_count - 1] - |
874 | out.example_end_indices[out_examples_count - 2]; |
875 | } |
876 | } |
877 | } |
878 | } |
879 | |
880 | // Handle missing dense features for fixed strides. |
881 | for (size_t d = 0; d < config.dense.size(); ++d) { |
882 | if (config.dense[d].variable_length) continue; |
883 | if (dense_feature_last_example[d] == example_index) continue; |
884 | if (config.dense[d].default_value.NumElements() == 0) { |
885 | return errors::InvalidArgument( |
886 | "Name: " , example_name, ", Feature: " , config.dense[d].feature_name, |
887 | " (data type: " , DataTypeString(config.dense[d].dtype), ")" , |
888 | " is required but could not be found." ); |
889 | } |
890 | const Tensor& in = config.dense[d].default_value; |
891 | Tensor& out = (*output_dense)[d]; |
892 | const std::size_t num_elements = in.shape().num_elements(); |
893 | const std::size_t offset = example_index * num_elements; |
894 | |
895 | switch (config.dense[d].dtype) { |
896 | case DT_INT64: { |
897 | std::copy_n(in.flat<int64_t>().data(), num_elements, |
898 | out.flat<int64_t>().data() + offset); |
899 | break; |
900 | } |
901 | case DT_FLOAT: { |
902 | std::copy_n(in.flat<float>().data(), num_elements, |
903 | out.flat<float>().data() + offset); |
904 | break; |
905 | } |
906 | case DT_STRING: { |
907 | std::copy_n(in.flat<tstring>().data(), num_elements, |
908 | out.flat<tstring>().data() + offset); |
909 | break; |
910 | } |
911 | default: |
912 | LOG(FATAL) << "Should not happen." ; |
913 | } |
914 | } |
915 | |
916 | // Handle missing varlen dense features. |
917 | for (size_t d = 0; d < config.dense.size(); ++d) { |
918 | if (!config.dense[d].variable_length) continue; |
919 | if (dense_feature_last_example[d] == example_index) continue; |
920 | SparseBuffer& out = (*output_varlen_dense)[d]; |
921 | size_t prev_example_end_index = |
922 | out.example_end_indices.empty() ? 0 : out.example_end_indices.back(); |
923 | out.example_end_indices.push_back(prev_example_end_index); |
924 | } |
925 | |
926 | // Handle missing sparse features. |
927 | for (size_t d = 0; d < config.sparse.size(); ++d) { |
928 | if (sparse_feature_last_example[d] == example_index) continue; |
929 | SparseBuffer& out = (*output_sparse)[d]; |
930 | size_t prev_example_end_index = |
931 | out.example_end_indices.empty() ? 0 : out.example_end_indices.back(); |
932 | out.example_end_indices.push_back(prev_example_end_index); |
933 | } |
934 | |
935 | // Handle missing ragged features. |
936 | for (size_t d = 0; d < config.ragged.size(); ++d) { |
937 | if (ragged_feature_last_example[d] == example_index) continue; |
938 | SparseBuffer& out = (*output_ragged)[d]; |
939 | size_t prev_example_end_index = |
940 | out.example_end_indices.empty() ? 0 : out.example_end_indices.back(); |
941 | out.example_end_indices.push_back(prev_example_end_index); |
942 | } |
943 | |
944 | return OkStatus(); |
945 | } |
946 | |
947 | Status CheckConfigDataType(DataType dtype) { |
948 | switch (dtype) { |
949 | case DT_INT64: |
950 | case DT_FLOAT: |
951 | case DT_STRING: |
952 | return OkStatus(); |
953 | default: |
954 | return errors::InvalidArgument("Invalid config dtype: " , |
955 | DataTypeString(dtype)); |
956 | } |
957 | } |
958 | |
959 | // Use this in the "default" clause of switch statements when dispatching |
960 | // on a dtype variable that was checked by CheckConfigDataType(): |
961 | inline void ReportUnexpectedDataType(DataType dtype) { |
962 | DCHECK(false) |
963 | << "Encountered unexpected DataType " << DataTypeString(dtype) |
964 | << "in variable that should have been checked by CheckConfigDataType()." ; |
965 | } |
966 | |
967 | Status CheckConfigDataTypes(const Config& config) { |
968 | // Check config so we can safely CHECK(false) in switches on config.*.dtype |
969 | for (auto& c : config.sparse) { |
970 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
971 | } |
972 | for (auto& c : config.dense) { |
973 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
974 | } |
975 | for (auto& c : config.ragged) { |
976 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
977 | if (!(c.splits_dtype == DT_INT32 || c.splits_dtype == DT_INT64)) { |
978 | return errors::InvalidArgument("Invalid ragged_split_type: " , |
979 | DataTypeString(c.splits_dtype)); |
980 | } |
981 | } |
982 | return OkStatus(); |
983 | } |
984 | |
985 | template <typename T> |
986 | const SmallVector<T>& GetListFromBuffer(const SparseBuffer& buffer); |
987 | |
988 | template <> |
989 | const SmallVector<int64_t>& GetListFromBuffer<int64_t>( |
990 | const SparseBuffer& buffer) { |
991 | return buffer.int64_list; |
992 | } |
993 | template <> |
994 | const SmallVector<float>& GetListFromBuffer<float>(const SparseBuffer& buffer) { |
995 | return buffer.float_list; |
996 | } |
997 | template <> |
998 | const SmallVector<tstring>& GetListFromBuffer<tstring>( |
999 | const SparseBuffer& buffer) { |
1000 | return buffer.bytes_list; |
1001 | } |
1002 | |
1003 | template <typename T> |
1004 | void CopyOrMoveBlock(const T* b, const T* e, T* t) { |
1005 | std::copy(b, e, t); |
1006 | } |
1007 | template <> |
1008 | void CopyOrMoveBlock(const tstring* b, const tstring* e, tstring* t) { |
1009 | std::move(b, e, t); |
1010 | } |
1011 | |
1012 | template <typename T> |
1013 | void FillAndCopyVarLen( |
1014 | const int d, const size_t num_elements, |
1015 | const size_t num_elements_per_minibatch, const Config& config, |
1016 | const std::vector<std::vector<SparseBuffer>>& varlen_dense_buffers, |
1017 | Tensor* values) { |
1018 | const Tensor& default_value = config.dense[d].default_value; |
1019 | |
1020 | // Copy-fill the tensors (creating the zero/fill-padding) |
1021 | std::fill(values->flat<T>().data(), values->flat<T>().data() + num_elements, |
1022 | default_value.flat<T>()(0)); |
1023 | |
1024 | // Data is [batch_size, max_num_elements, data_stride_size] |
1025 | // and num_elements_per_minibatch = max_num_elements * data_stride_size |
1026 | auto data = values->flat<T>().data(); |
1027 | |
1028 | // Iterate over minibatch elements |
1029 | for (size_t i = 0; i < varlen_dense_buffers.size(); ++i) { |
1030 | const SparseBuffer& buffer = varlen_dense_buffers[i][d]; |
1031 | // Number of examples being stored in this buffer |
1032 | const auto& end_indices = buffer.example_end_indices; |
1033 | const size_t examples_in_buffer = end_indices.size(); |
1034 | // const size_t stride_size = config.dense[d].elements_per_stride; |
1035 | |
1036 | const auto& list = GetListFromBuffer<T>(buffer); |
1037 | auto list_ptr = list.begin(); |
1038 | |
1039 | size_t elements_tally = 0; |
1040 | // Iterate through all the examples stored in this buffer. |
1041 | for (size_t j = 0; j < examples_in_buffer; ++j) { |
1042 | // Number of elements stored for this example. |
1043 | const size_t num_elems = end_indices[j] - elements_tally; |
1044 | CopyOrMoveBlock(list_ptr, list_ptr + num_elems, data); |
1045 | // Move forward this many elements in the varlen buffer. |
1046 | list_ptr += num_elems; |
1047 | // Move forward to the next minibatch entry in the values output. |
1048 | data += num_elements_per_minibatch; |
1049 | elements_tally = end_indices[j]; |
1050 | } |
1051 | DCHECK(elements_tally == list.size()); |
1052 | } |
1053 | } |
1054 | |
1055 | // Thin vector like interface wrapper around a Tensor. This enable us to |
1056 | // directly populate a tensor during parsing instead of having to first create a |
1057 | // vactor and then copy the data over. |
1058 | template <typename T> |
1059 | class TensorVector { |
1060 | public: |
1061 | using value_type = T; |
1062 | |
1063 | const Tensor& tensor() { |
1064 | if (!tensor_.has_value()) { |
1065 | resize(0); |
1066 | } |
1067 | return *tensor_; |
1068 | } |
1069 | |
1070 | int64_t size() const { |
1071 | return tensor_.has_value() ? tensor_->NumElements() : 0; |
1072 | } |
1073 | void resize(int64_t new_size) { |
1074 | DCHECK(!tensor_.has_value()); |
1075 | tensor_ = Tensor(DataTypeToEnum<T>::v(), TensorShape({new_size})); |
1076 | data_ = tensor_->flat<T>().data(); |
1077 | } |
1078 | T* data() { return data_; } |
1079 | const T* data() const { return data_; } |
1080 | |
1081 | private: |
1082 | // Use absl::optional to avoid calling the default constructor of Tensor |
1083 | // unnecessarily. |
1084 | absl::optional<Tensor> tensor_; |
1085 | |
1086 | // Cached pointer to the raw data inside the tensor. |
1087 | T* data_ = nullptr; |
1088 | }; |
1089 | |
1090 | void CountSparseFeatures( |
1091 | const std::vector<std::vector<SparseBuffer>>& sparse_buffers, size_t d, |
1092 | size_t* total_num_features, size_t* max_num_features) { |
1093 | for (auto& sparse_values_tmp : sparse_buffers) { |
1094 | const std::vector<size_t>& end_indices = |
1095 | sparse_values_tmp[d].example_end_indices; |
1096 | *total_num_features += end_indices.back(); |
1097 | *max_num_features = std::max(*max_num_features, end_indices[0]); |
1098 | for (size_t i = 1; i < end_indices.size(); ++i) { |
1099 | size_t example_size = end_indices[i] - end_indices[i - 1]; |
1100 | *max_num_features = std::max(*max_num_features, example_size); |
1101 | } |
1102 | } |
1103 | } |
1104 | |
1105 | void CopySparseBufferToTensor(DataType dtype, size_t offset, SparseBuffer* src, |
1106 | Tensor* dst) { |
1107 | switch (dtype) { |
1108 | case DT_INT64: { |
1109 | std::copy(src->int64_list.begin(), src->int64_list.end(), |
1110 | dst->flat<int64_t>().data() + offset); |
1111 | break; |
1112 | } |
1113 | case DT_FLOAT: { |
1114 | std::copy(src->float_list.begin(), src->float_list.end(), |
1115 | dst->flat<float>().data() + offset); |
1116 | break; |
1117 | } |
1118 | case DT_STRING: { |
1119 | std::move(src->bytes_list.begin(), src->bytes_list.end(), |
1120 | dst->flat<tstring>().data() + offset); |
1121 | break; |
1122 | } |
1123 | default: |
1124 | ReportUnexpectedDataType(dtype); |
1125 | } |
1126 | } |
1127 | |
1128 | } // namespace |
1129 | |
1130 | Status FastParseExample(const Config& config, |
1131 | gtl::ArraySlice<tstring> serialized, |
1132 | gtl::ArraySlice<tstring> example_names, |
1133 | thread::ThreadPool* thread_pool, Result* result) { |
1134 | DCHECK(result != nullptr); |
1135 | // Check config so we can safely CHECK(false) in switches on config.*.dtype |
1136 | TF_RETURN_IF_ERROR(CheckConfigDataTypes(config)); |
1137 | |
1138 | if (config.collect_feature_stats) { |
1139 | result->feature_stats.resize(serialized.size()); |
1140 | } |
1141 | |
1142 | size_t config_size = |
1143 | config.dense.size() + config.sparse.size() + config.ragged.size(); |
1144 | SeededHasher hasher; |
1145 | // Build config index. |
1146 | PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size); |
1147 | bool ok = true; |
1148 | for (size_t i = 0; i < 1000; ++i) { |
1149 | for (size_t d = 0; d < config.dense.size(); ++d) { |
1150 | ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name), |
1151 | {d, Type::Dense}); |
1152 | } |
1153 | for (size_t d = 0; d < config.sparse.size(); ++d) { |
1154 | ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name), |
1155 | {d, Type::Sparse}); |
1156 | } |
1157 | for (size_t d = 0; d < config.ragged.size(); ++d) { |
1158 | ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name), |
1159 | {d, Type::Ragged}); |
1160 | } |
1161 | if (ok) break; |
1162 | LOG(WARNING) << "Collision found. This should happen only if you have " |
1163 | "around 2^32 entries in your config." ; |
1164 | hasher.seed++; |
1165 | config_index.Clear(config_size); |
1166 | ok = true; |
1167 | } |
1168 | if (!ok) { |
1169 | return errors::Internal( |
1170 | "Could not avoid collision. This should not happen." ); |
1171 | } |
1172 | |
1173 | // Allocate dense output for fixed length dense values |
1174 | // (variable-length dense and sparse and ragged have to be buffered). |
1175 | std::vector<Tensor> fixed_dense_values(config.dense.size()); |
1176 | for (size_t d = 0; d < config.dense.size(); ++d) { |
1177 | if (config.dense[d].variable_length) continue; |
1178 | TensorShape out_shape; |
1179 | out_shape.AddDim(serialized.size()); |
1180 | for (const int64_t dim : config.dense[d].shape.dim_sizes()) { |
1181 | out_shape.AddDim(dim); |
1182 | } |
1183 | fixed_dense_values[d] = Tensor(config.dense[d].dtype, out_shape); |
1184 | } |
1185 | |
1186 | // This parameter affects performance in a big and data-dependent way. |
1187 | const size_t kMiniBatchSizeBytes = 50000; |
1188 | |
1189 | // Calculate number of minibatches. |
1190 | // In main regime make each minibatch around kMiniBatchSizeBytes bytes. |
1191 | // Apply 'special logic' below for small and big regimes. |
1192 | const size_t num_minibatches = [&] { |
1193 | size_t result = 0; |
1194 | size_t minibatch_bytes = 0; |
1195 | for (size_t i = 0; i < serialized.size(); i++) { |
1196 | if (minibatch_bytes == 0) { // start minibatch |
1197 | result++; |
1198 | } |
1199 | minibatch_bytes += serialized[i].size() + 1; |
1200 | if (minibatch_bytes > kMiniBatchSizeBytes) { |
1201 | minibatch_bytes = 0; |
1202 | } |
1203 | } |
1204 | // 'special logic' |
1205 | const size_t min_minibatches = std::min<size_t>(8, serialized.size()); |
1206 | const size_t max_minibatches = 64; |
1207 | return std::max<size_t>(min_minibatches, |
1208 | std::min<size_t>(max_minibatches, result)); |
1209 | }(); |
1210 | |
1211 | auto first_example_of_minibatch = [&](size_t minibatch) -> size_t { |
1212 | return (serialized.size() * minibatch) / num_minibatches; |
1213 | }; |
1214 | |
1215 | // TODO(lew): A big performance low-hanging fruit here is to improve |
1216 | // num_minibatches calculation to take into account actual amount of work |
1217 | // needed, as the size in bytes is not perfect. Linear combination of |
1218 | // size in bytes and average number of features per example is promising. |
1219 | // Even better: measure time instead of estimating, but this is too costly |
1220 | // in small batches. |
1221 | // Maybe accept outside parameter #num_minibatches? |
1222 | |
1223 | // Do minibatches in parallel. |
1224 | std::vector<std::vector<SparseBuffer>> sparse_buffers(num_minibatches); |
1225 | std::vector<std::vector<SparseBuffer>> varlen_dense_buffers(num_minibatches); |
1226 | std::vector<std::vector<SparseBuffer>> ragged_buffers(num_minibatches); |
1227 | std::vector<Status> status_of_minibatch(num_minibatches); |
1228 | auto ProcessMiniBatch = [&](size_t minibatch) { |
1229 | sparse_buffers[minibatch].resize(config.sparse.size()); |
1230 | varlen_dense_buffers[minibatch].resize(config.dense.size()); |
1231 | ragged_buffers[minibatch].resize(config.ragged.size()); |
1232 | size_t start = first_example_of_minibatch(minibatch); |
1233 | size_t end = first_example_of_minibatch(minibatch + 1); |
1234 | for (size_t e = start; e < end; ++e) { |
1235 | PerExampleFeatureStats* stats = nullptr; |
1236 | if (config.collect_feature_stats) { |
1237 | stats = &result->feature_stats[e]; |
1238 | } |
1239 | status_of_minibatch[minibatch] = FastParseSerializedExample( |
1240 | serialized[e], |
1241 | (!example_names.empty() ? example_names[e] : "<unknown>" ), e, config, |
1242 | config_index, hasher, &fixed_dense_values, |
1243 | &varlen_dense_buffers[minibatch], &sparse_buffers[minibatch], |
1244 | &ragged_buffers[minibatch], stats); |
1245 | if (!status_of_minibatch[minibatch].ok()) break; |
1246 | } |
1247 | }; |
1248 | |
1249 | ParallelFor(ProcessMiniBatch, num_minibatches, thread_pool); |
1250 | |
1251 | for (Status& status : status_of_minibatch) { |
1252 | TF_RETURN_IF_ERROR(status); |
1253 | } |
1254 | |
1255 | result->sparse_indices.reserve(config.sparse.size()); |
1256 | result->sparse_values.reserve(config.sparse.size()); |
1257 | result->sparse_shapes.reserve(config.sparse.size()); |
1258 | result->dense_values.reserve(config.dense.size()); |
1259 | result->ragged_values.reserve(config.ragged.size()); |
1260 | result->ragged_splits.reserve(config.ragged.size()); |
1261 | |
1262 | for (size_t d = 0; d < config.dense.size(); ++d) { |
1263 | result->dense_values.push_back(std::move(fixed_dense_values[d])); |
1264 | } |
1265 | |
1266 | // Merge SparseBuffers from all minibatches for every config.sparse. |
1267 | auto MergeSparseMinibatches = [&](size_t d) { |
1268 | // Loop over minibatches |
1269 | size_t total_num_features = 0; |
1270 | size_t max_num_features = 0; |
1271 | CountSparseFeatures(sparse_buffers, d, &total_num_features, |
1272 | &max_num_features); |
1273 | |
1274 | TensorShape indices_shape; |
1275 | indices_shape.AddDim(total_num_features); |
1276 | indices_shape.AddDim(2); |
1277 | result->sparse_indices.emplace_back(DT_INT64, indices_shape); |
1278 | Tensor* indices = &result->sparse_indices.back(); |
1279 | |
1280 | TensorShape values_shape; |
1281 | values_shape.AddDim(total_num_features); |
1282 | result->sparse_values.emplace_back(config.sparse[d].dtype, values_shape); |
1283 | Tensor* values = &result->sparse_values.back(); |
1284 | |
1285 | result->sparse_shapes.emplace_back(DT_INT64, TensorShape({2})); |
1286 | auto shapes_shape_t = result->sparse_shapes.back().vec<int64_t>(); |
1287 | shapes_shape_t(0) = serialized.size(); |
1288 | shapes_shape_t(1) = max_num_features; |
1289 | |
1290 | size_t offset = 0; |
1291 | for (size_t i = 0; i < sparse_buffers.size(); ++i) { |
1292 | SparseBuffer& buffer = sparse_buffers[i][d]; |
1293 | |
1294 | // Update indices. |
1295 | size_t delta = 0; |
1296 | |
1297 | if (indices->NumElements() > 0) { |
1298 | int64* ix_p = &indices->matrix<int64_t>()(offset, 0); |
1299 | size_t example_index = first_example_of_minibatch(i); |
1300 | for (size_t example_end_index : buffer.example_end_indices) { |
1301 | size_t feature_index = 0; |
1302 | for (; delta < example_end_index; ++delta) { |
1303 | // Column 0: example index |
1304 | *ix_p = example_index; |
1305 | // Column 1: the feature index buffer example |
1306 | *(ix_p + 1) = feature_index; |
1307 | ix_p += 2; |
1308 | ++feature_index; |
1309 | } |
1310 | ++example_index; |
1311 | } |
1312 | } |
1313 | |
1314 | CopySparseBufferToTensor(config.sparse[d].dtype, offset, &buffer, values); |
1315 | offset += delta; |
1316 | } |
1317 | }; |
1318 | |
1319 | // Merge SparseBuffers from all minibatches for every config.ragged. |
1320 | auto MergeRaggedMinibatches = [&](size_t d) { |
1321 | // Loop over minibatches |
1322 | size_t total_num_features = 0; |
1323 | size_t max_num_features = 0; |
1324 | CountSparseFeatures(ragged_buffers, d, &total_num_features, |
1325 | &max_num_features); |
1326 | |
1327 | TensorShape row_splits_shape; |
1328 | row_splits_shape.AddDim(serialized.size() + 1); |
1329 | result->ragged_splits.emplace_back(config.ragged[d].splits_dtype, |
1330 | row_splits_shape); |
1331 | Tensor* row_splits = &result->ragged_splits.back(); |
1332 | if (config.ragged[d].splits_dtype == DT_INT64) { |
1333 | row_splits->flat<int64_t>()(0) = 0; |
1334 | } else { |
1335 | row_splits->flat<int32>()(0) = 0; |
1336 | } |
1337 | |
1338 | TensorShape values_shape; |
1339 | values_shape.AddDim(total_num_features); |
1340 | result->ragged_values.emplace_back(config.ragged[d].dtype, values_shape); |
1341 | Tensor* values = &result->ragged_values.back(); |
1342 | |
1343 | size_t values_offset = 0; |
1344 | size_t splits_offset = 0; |
1345 | for (size_t i = 0; i < ragged_buffers.size(); ++i) { |
1346 | SparseBuffer& buffer = ragged_buffers[i][d]; |
1347 | if (buffer.example_end_indices.empty()) continue; |
1348 | |
1349 | // Update row_splits. row_splits are formed by concatenating the example |
1350 | // end_indices (adjusting each to start after the previous one ends). |
1351 | if (config.ragged[d].splits_dtype == DT_INT64) { |
1352 | int64* row_splits_out = &row_splits->flat<int64_t>()(splits_offset); |
1353 | int64_t start = *row_splits_out; |
1354 | for (size_t example_end_index : buffer.example_end_indices) { |
1355 | *++row_splits_out = start + example_end_index; |
1356 | } |
1357 | } else { |
1358 | int32* row_splits_out = &row_splits->flat<int32>()(splits_offset); |
1359 | int32_t start = *row_splits_out; |
1360 | for (size_t example_end_index : buffer.example_end_indices) { |
1361 | *++row_splits_out = start + example_end_index; |
1362 | } |
1363 | } |
1364 | |
1365 | CopySparseBufferToTensor(config.ragged[d].dtype, values_offset, &buffer, |
1366 | values); |
1367 | values_offset += buffer.example_end_indices.back(); |
1368 | splits_offset += buffer.example_end_indices.size(); |
1369 | } |
1370 | }; |
1371 | |
1372 | // Merge SparseBuffers from all minibatches for every config.dense having |
1373 | // variable_length. |
1374 | auto MergeDenseVarLenMinibatches = [&](size_t d) { |
1375 | if (!config.dense[d].variable_length) return; |
1376 | |
1377 | // Loop over minibatches |
1378 | size_t max_num_features = 0; |
1379 | for (auto& dense_values_tmp : varlen_dense_buffers) { |
1380 | std::vector<size_t>& end_indices = |
1381 | dense_values_tmp[d].example_end_indices; |
1382 | max_num_features = std::max(max_num_features, end_indices[0]); |
1383 | for (size_t i = 1; i < end_indices.size(); ++i) { |
1384 | size_t example_size = end_indices[i] - end_indices[i - 1]; |
1385 | max_num_features = std::max(max_num_features, example_size); |
1386 | } |
1387 | } |
1388 | |
1389 | const size_t stride_size = config.dense[d].elements_per_stride; |
1390 | const size_t max_num_elements = max_num_features / stride_size; |
1391 | TensorShape values_shape; |
1392 | DCHECK_EQ(max_num_features % config.dense[d].elements_per_stride, 0); |
1393 | const size_t batch_size = serialized.size(); |
1394 | values_shape.AddDim(batch_size); |
1395 | values_shape.AddDim(max_num_elements); |
1396 | for (int i = 1; i < config.dense[d].shape.dims(); ++i) { |
1397 | values_shape.AddDim(config.dense[d].shape.dim_size(i)); |
1398 | } |
1399 | Tensor values(config.dense[d].dtype, values_shape); |
1400 | result->dense_values[d] = values; |
1401 | const size_t num_elements = values.NumElements(); |
1402 | |
1403 | // Nothing to write, exit early. |
1404 | if (num_elements == 0) return; |
1405 | |
1406 | const size_t num_elements_per_minibatch = num_elements / batch_size; |
1407 | |
1408 | switch (config.dense[d].dtype) { |
1409 | case DT_INT64: { |
1410 | FillAndCopyVarLen<int64_t>(d, num_elements, num_elements_per_minibatch, |
1411 | config, varlen_dense_buffers, &values); |
1412 | break; |
1413 | } |
1414 | case DT_FLOAT: { |
1415 | FillAndCopyVarLen<float>(d, num_elements, num_elements_per_minibatch, |
1416 | config, varlen_dense_buffers, &values); |
1417 | break; |
1418 | } |
1419 | case DT_STRING: { |
1420 | FillAndCopyVarLen<tstring>(d, num_elements, num_elements_per_minibatch, |
1421 | config, varlen_dense_buffers, &values); |
1422 | break; |
1423 | } |
1424 | default: |
1425 | ReportUnexpectedDataType(config.dense[d].dtype); |
1426 | } |
1427 | }; |
1428 | |
1429 | for (size_t d = 0; d < config.dense.size(); ++d) { |
1430 | MergeDenseVarLenMinibatches(d); |
1431 | } |
1432 | |
1433 | for (size_t d = 0; d < config.sparse.size(); ++d) { |
1434 | MergeSparseMinibatches(d); |
1435 | } |
1436 | |
1437 | for (size_t d = 0; d < config.ragged.size(); ++d) { |
1438 | MergeRaggedMinibatches(d); |
1439 | } |
1440 | |
1441 | return OkStatus(); |
1442 | } |
1443 | |
1444 | Status FastParseSingleExample(const Config& config, StringPiece serialized, |
1445 | Result* result) { |
1446 | DCHECK(result != nullptr); |
1447 | // Check config so we can safely CHECK(false) in switches on config.*.dtype |
1448 | TF_RETURN_IF_ERROR(CheckConfigDataTypes(config)); |
1449 | |
1450 | PerExampleFeatureStats* stats = nullptr; |
1451 | if (config.collect_feature_stats) { |
1452 | result->feature_stats.emplace_back(); |
1453 | stats = &result->feature_stats.back(); |
1454 | } |
1455 | |
1456 | // TODO(mrry): Cache the construction of this map at Op construction time. |
1457 | size_t config_size = |
1458 | config.dense.size() + config.sparse.size() + config.ragged.size(); |
1459 | SeededHasher hasher; |
1460 | // Build config index. |
1461 | PresizedCuckooMap<std::pair<size_t, Type>> config_index(config_size); |
1462 | bool ok = true; |
1463 | for (size_t i = 0; i < 1000; ++i) { |
1464 | for (size_t d = 0; d < config.dense.size(); ++d) { |
1465 | ok &= config_index.InsertUnique(hasher(config.dense[d].feature_name), |
1466 | {d, Type::Dense}); |
1467 | } |
1468 | for (size_t d = 0; d < config.sparse.size(); ++d) { |
1469 | ok &= config_index.InsertUnique(hasher(config.sparse[d].feature_name), |
1470 | {d, Type::Sparse}); |
1471 | } |
1472 | for (size_t d = 0; d < config.ragged.size(); ++d) { |
1473 | ok &= config_index.InsertUnique(hasher(config.ragged[d].feature_name), |
1474 | {d, Type::Ragged}); |
1475 | } |
1476 | if (ok) break; |
1477 | LOG(WARNING) << "Collision found. This should happen only if you have " |
1478 | "around 2^32 entries in your config." ; |
1479 | hasher.seed++; |
1480 | config_index.Clear(config_size); |
1481 | ok = true; |
1482 | } |
1483 | if (!ok) { |
1484 | return errors::Internal( |
1485 | "Could not avoid collision. This should not happen." ); |
1486 | } |
1487 | |
1488 | result->sparse_indices.reserve(config.sparse.size()); |
1489 | result->sparse_values.reserve(config.sparse.size()); |
1490 | result->sparse_shapes.reserve(config.sparse.size()); |
1491 | result->dense_values.reserve(config.dense.size()); |
1492 | result->ragged_values.reserve(config.ragged.size()); |
1493 | result->ragged_splits.reserve(config.ragged.size()); |
1494 | |
1495 | // Allocate dense output tensors. |
1496 | for (size_t d = 0; d < config.dense.size(); ++d) { |
1497 | if (!config.dense[d].variable_length) { |
1498 | TensorShape values_shape; |
1499 | if (!config.dense[d].shape.AsTensorShape(&values_shape)) { |
1500 | return errors::Internal( |
1501 | "Fixed-length shape was not a statically defined shape." ); |
1502 | } |
1503 | result->dense_values.emplace_back(config.dense[d].dtype, values_shape); |
1504 | } else { |
1505 | // Variable-length tensor will be allocated later. |
1506 | result->dense_values.emplace_back(); |
1507 | } |
1508 | } |
1509 | |
1510 | // Allocate sparse output tensors. |
1511 | for (size_t d = 0; d < config.sparse.size(); ++d) { |
1512 | // The dense_shape is always a vector of length 1. |
1513 | result->sparse_shapes.emplace_back(DT_INT64, TensorShape({1})); |
1514 | // Variable-length tensors will be allocated later. |
1515 | result->sparse_indices.emplace_back(); |
1516 | result->sparse_values.emplace_back(); |
1517 | } |
1518 | |
1519 | // Allocate ragged output tensors. |
1520 | for (size_t d = 0; d < config.ragged.size(); ++d) { |
1521 | // Variable-length values tensors will be allocated later. |
1522 | result->ragged_values.emplace_back(); |
1523 | // Splits tensors are empty (unused) for single (scalar) inputs. |
1524 | const auto splits_dtype = config.ragged[d].splits_dtype; |
1525 | result->ragged_splits.emplace_back(splits_dtype, TensorShape({0})); |
1526 | } |
1527 | |
1528 | parsed::Example parsed_example; |
1529 | if (!ParseExample(serialized, &parsed_example)) { |
1530 | return errors::InvalidArgument("Could not parse example input, value: '" , |
1531 | serialized, "'" ); |
1532 | } |
1533 | std::vector<bool> sparse_feature_already_seen(config.sparse.size(), false); |
1534 | std::vector<bool> dense_feature_already_seen(config.dense.size(), false); |
1535 | std::vector<bool> ragged_feature_already_seen(config.ragged.size(), false); |
1536 | |
1537 | if (stats) { |
1538 | // TODO(b/111553342): This may over-count the number of features if there |
1539 | // are duplicate keys in the feature map. Consider deduplicating the keys |
1540 | // before computing the count. |
1541 | stats->features_count = parsed_example.size(); |
1542 | } |
1543 | |
1544 | // Handle features present in the example. |
1545 | const size_t parsed_example_size = parsed_example.size(); |
1546 | for (size_t i = 0; i < parsed_example_size; ++i) { |
1547 | // This is a logic that standard protobuf parsing is implementing. |
1548 | // I.e. last entry in the map overwrites all the previous ones. |
1549 | parsed::FeatureMapEntry& name_and_feature = |
1550 | parsed_example[parsed_example_size - i - 1]; |
1551 | |
1552 | const StringPiece feature_name = name_and_feature.first; |
1553 | parsed::Feature& feature = name_and_feature.second; |
1554 | |
1555 | std::pair<size_t, Type> d_and_type; |
1556 | uint64 h = hasher(feature_name); |
1557 | if (!config_index.Find(h, &d_and_type)) continue; |
1558 | |
1559 | size_t d = d_and_type.first; |
1560 | bool is_dense = d_and_type.second == Type::Dense; |
1561 | bool is_sparse = d_and_type.second == Type::Sparse; |
1562 | |
1563 | { |
1564 | // Testing for PresizedCuckooMap collision. |
1565 | // TODO(lew): Use dense_hash_map and avoid this and hasher creation. |
1566 | const tstring& config_feature_name = |
1567 | is_dense ? config.dense[d].feature_name |
1568 | : (is_sparse ? config.sparse[d].feature_name |
1569 | : config.ragged[d].feature_name); |
1570 | if (feature_name != config_feature_name) continue; |
1571 | } |
1572 | |
1573 | auto example_error = [feature_name](StringPiece suffix) { |
1574 | return errors::InvalidArgument("Key: " , feature_name, ". " , suffix); |
1575 | }; |
1576 | |
1577 | auto parse_error = [feature_name] { |
1578 | return errors::InvalidArgument("Key: " , feature_name, |
1579 | ". Can't parse serialized Example." ); |
1580 | }; |
1581 | |
1582 | DataType example_dtype; |
1583 | TF_RETURN_IF_ERROR(feature.ParseDataType(&example_dtype)); |
1584 | if (example_dtype == DT_INVALID) continue; |
1585 | |
1586 | if (is_dense && !config.dense[d].variable_length) { |
1587 | // If feature was already visited, skip. |
1588 | // Compare comment at the beginning of the loop. |
1589 | if (dense_feature_already_seen[d]) { |
1590 | LogDenseFeatureDataLoss(feature_name); |
1591 | continue; |
1592 | } |
1593 | dense_feature_already_seen[d] = true; |
1594 | |
1595 | if (example_dtype != config.dense[d].dtype) { |
1596 | return example_error(strings::StrCat( |
1597 | "Data types don't match. Data type: " , |
1598 | DataTypeString(example_dtype), |
1599 | " but expected type: " , DataTypeString(config.dense[d].dtype))); |
1600 | } |
1601 | |
1602 | Tensor* out = &result->dense_values[d]; |
1603 | const std::size_t num_elements = config.dense[d].elements_per_stride; |
1604 | if (stats) { |
1605 | // TODO(b/111553342): If desirable, we could add support for counting |
1606 | // elements in the features that aren't parsed, but this could add |
1607 | // considerable runtime cost. |
1608 | stats->feature_values_count += num_elements; |
1609 | } |
1610 | switch (example_dtype) { |
1611 | case DT_INT64: { |
1612 | auto out_p = out->flat<int64_t>().data(); |
1613 | LimitedArraySlice<int64_t> slice(out_p, num_elements); |
1614 | if (!feature.ParseInt64List(&slice)) return parse_error(); |
1615 | if (slice.EndDistance() != 0) { |
1616 | return parse_error(); |
1617 | } |
1618 | break; |
1619 | } |
1620 | case DT_FLOAT: { |
1621 | auto out_p = out->flat<float>().data(); |
1622 | LimitedArraySlice<float> slice(out_p, num_elements); |
1623 | if (!feature.ParseFloatList(&slice)) return parse_error(); |
1624 | if (slice.EndDistance() != 0) { |
1625 | return parse_error(); |
1626 | } |
1627 | break; |
1628 | } |
1629 | case DT_STRING: { |
1630 | auto out_p = out->flat<tstring>().data(); |
1631 | LimitedArraySlice<tstring> slice(out_p, num_elements); |
1632 | if (!feature.ParseBytesList(&slice)) return parse_error(); |
1633 | if (slice.EndDistance() != 0) { |
1634 | return parse_error(); |
1635 | } |
1636 | break; |
1637 | } |
1638 | default: |
1639 | ReportUnexpectedDataType(example_dtype); |
1640 | } |
1641 | |
1642 | } else { // if variable length |
1643 | SmallVector<tstring> bytes_list; |
1644 | TensorVector<float> float_list; |
1645 | SmallVector<int64_t> int64_list; |
1646 | |
1647 | const size_t num_elements_divisor = |
1648 | is_dense ? config.dense[d].elements_per_stride : 1; |
1649 | size_t num_elements; |
1650 | |
1651 | if (is_dense) { |
1652 | // If feature was already visited, skip. |
1653 | // Compare comment at the beginning of the loop. |
1654 | if (dense_feature_already_seen[d]) { |
1655 | LogDenseFeatureDataLoss(feature_name); |
1656 | continue; |
1657 | } |
1658 | dense_feature_already_seen[d] = true; |
1659 | if (example_dtype != config.dense[d].dtype) { |
1660 | return example_error(strings::StrCat( |
1661 | "Data types don't match. Data type: " , |
1662 | DataTypeString(example_dtype), |
1663 | " but expected type: " , DataTypeString(config.dense[d].dtype))); |
1664 | } |
1665 | } else { |
1666 | // Feature is sparse or ragged. |
1667 | auto& feature_already_seen = is_sparse ? sparse_feature_already_seen |
1668 | : ragged_feature_already_seen; |
1669 | auto& feature_dtype = |
1670 | is_sparse ? config.sparse[d].dtype : config.ragged[d].dtype; |
1671 | // If feature was already visited, skip. |
1672 | // Compare comment at the beginning of the loop. |
1673 | if (feature_already_seen[d]) { |
1674 | LogSparseFeatureDataLoss(feature_name); |
1675 | continue; |
1676 | } |
1677 | feature_already_seen[d] = true; |
1678 | |
1679 | // Handle sparse features. |
1680 | if (example_dtype != DT_INVALID && example_dtype != feature_dtype) { |
1681 | return example_error(strings::StrCat( |
1682 | "Data types don't match. " , |
1683 | "Expected type: " , DataTypeString(feature_dtype), |
1684 | ", Actual type: " , DataTypeString(example_dtype))); |
1685 | } |
1686 | } |
1687 | |
1688 | switch (example_dtype) { |
1689 | case DT_INT64: { |
1690 | // TODO(mrry): Use the fact that the `int64_list` is packed to read |
1691 | // out the length and pre-allocate the output tensor. |
1692 | if (!feature.ParseInt64List(&int64_list)) return parse_error(); |
1693 | num_elements = int64_list.size(); |
1694 | break; |
1695 | } |
1696 | case DT_FLOAT: { |
1697 | if (!feature.ParseFloatList(&float_list)) return parse_error(); |
1698 | num_elements = float_list.size(); |
1699 | break; |
1700 | } |
1701 | case DT_STRING: { |
1702 | int actual_num_elements = 0; |
1703 | if (!feature.GetNumElementsInBytesList(&actual_num_elements)) { |
1704 | return parse_error(); |
1705 | } |
1706 | bytes_list.reserve(actual_num_elements); |
1707 | if (!feature.ParseBytesList(&bytes_list)) return parse_error(); |
1708 | num_elements = bytes_list.size(); |
1709 | break; |
1710 | } |
1711 | default: |
1712 | num_elements = 0; |
1713 | ReportUnexpectedDataType(example_dtype); |
1714 | } |
1715 | |
1716 | if (num_elements % num_elements_divisor != 0) { |
1717 | return parse_error(); |
1718 | } |
1719 | |
1720 | if (stats) { |
1721 | stats->feature_values_count += num_elements; |
1722 | } |
1723 | |
1724 | Tensor* out; |
1725 | DataType out_dtype; |
1726 | TensorShape out_shape; |
1727 | if (is_dense) { |
1728 | out_shape.AddDim(num_elements / num_elements_divisor); |
1729 | for (int i = 1; i < config.dense[d].shape.dims(); ++i) { |
1730 | out_shape.AddDim(config.dense[d].shape.dim_size(i)); |
1731 | } |
1732 | |
1733 | out = &result->dense_values[d]; |
1734 | out_dtype = config.dense[d].dtype; |
1735 | } else if (is_sparse) { |
1736 | Tensor* out_indices = &result->sparse_indices[d]; |
1737 | Tensor* out_dense_shape = &result->sparse_shapes[d]; |
1738 | |
1739 | // TODO(mrry): Investigate the possibility of not materializing |
1740 | // the indices (and perhaps dense_shape) until they are needed. |
1741 | *out_indices = Tensor( |
1742 | DT_INT64, TensorShape({static_cast<int64_t>(num_elements), 1})); |
1743 | auto indices_flat = out_indices->flat<int64_t>(); |
1744 | for (size_t i = 0; i < num_elements; ++i) { |
1745 | indices_flat(i) = static_cast<int64_t>(i); |
1746 | } |
1747 | |
1748 | *out_dense_shape = Tensor(DT_INT64, TensorShape({1})); |
1749 | auto shapes_shape_t = out_dense_shape->vec<int64_t>(); |
1750 | shapes_shape_t(0) = num_elements; |
1751 | |
1752 | out = &result->sparse_values[d]; |
1753 | out_dtype = config.sparse[d].dtype; |
1754 | out_shape.AddDim(num_elements); |
1755 | } else { |
1756 | out = &result->ragged_values[d]; |
1757 | out_dtype = config.ragged[d].dtype; |
1758 | out_shape.AddDim(num_elements); |
1759 | } |
1760 | |
1761 | switch (example_dtype) { |
1762 | case DT_INT64: { |
1763 | *out = Tensor(out_dtype, out_shape); |
1764 | CopyOrMoveBlock(int64_list.begin(), int64_list.end(), |
1765 | out->flat<int64_t>().data()); |
1766 | break; |
1767 | } |
1768 | case DT_FLOAT: { |
1769 | if (!out->CopyFrom(float_list.tensor(), out_shape)) { |
1770 | return parse_error(); |
1771 | } |
1772 | break; |
1773 | } |
1774 | case DT_STRING: { |
1775 | *out = Tensor(out_dtype, out_shape); |
1776 | CopyOrMoveBlock(bytes_list.begin(), bytes_list.end(), |
1777 | out->flat<tstring>().data()); |
1778 | break; |
1779 | } |
1780 | default: |
1781 | ReportUnexpectedDataType(example_dtype); |
1782 | } |
1783 | } |
1784 | } |
1785 | |
1786 | // Handle missing dense features. |
1787 | for (size_t d = 0; d < config.dense.size(); ++d) { |
1788 | if (!dense_feature_already_seen[d]) { |
1789 | if (!config.dense[d].variable_length) { |
1790 | // Handle missing fixed-length dense feature. |
1791 | if (config.dense[d].default_value.NumElements() == 0) { |
1792 | return errors::InvalidArgument( |
1793 | "Feature: " , config.dense[d].feature_name, |
1794 | " (data type: " , DataTypeString(config.dense[d].dtype), ")" , |
1795 | " is required but could not be found." ); |
1796 | } |
1797 | result->dense_values[d] = config.dense[d].default_value; |
1798 | } else { |
1799 | // Handle missing varlen dense feature. |
1800 | TensorShape empty_shape; |
1801 | empty_shape.AddDim(0); |
1802 | for (int i = 1; i < config.dense[d].shape.dims(); ++i) { |
1803 | empty_shape.AddDim(config.dense[d].shape.dim_size(i)); |
1804 | } |
1805 | result->dense_values[d] = Tensor(config.dense[d].dtype, empty_shape); |
1806 | } |
1807 | } |
1808 | } |
1809 | |
1810 | // Handle missing sparse features. |
1811 | for (size_t d = 0; d < config.sparse.size(); ++d) { |
1812 | if (!sparse_feature_already_seen[d]) { |
1813 | result->sparse_indices[d] = Tensor(DT_INT64, TensorShape({0, 1})); |
1814 | result->sparse_values[d] = |
1815 | Tensor(config.sparse[d].dtype, TensorShape({0})); |
1816 | result->sparse_shapes[d].vec<int64_t>()(0) = 0; |
1817 | } |
1818 | } |
1819 | |
1820 | // Handle missing ragged features. |
1821 | for (size_t d = 0; d < config.ragged.size(); ++d) { |
1822 | if (!ragged_feature_already_seen[d]) { |
1823 | result->ragged_values[d] = |
1824 | Tensor(config.ragged[d].dtype, TensorShape({0})); |
1825 | } |
1826 | } |
1827 | |
1828 | return OkStatus(); |
1829 | } |
1830 | |
1831 | // Private helper functions for FastParseSequenceExample. |
1832 | namespace { |
1833 | |
1834 | // A struct used by FastParseSequenceExample to hold the serialized proto |
1835 | // substrings for a single feature, plus some auxiliary information derived |
1836 | // from those protos (such as the total value length). |
1837 | struct FeatureProtos { |
1838 | // Proto substrings from each serialized SequenceExample that correspond |
1839 | // with this feature. `protos_present` records whether the proto had a |
1840 | // value defined (even if that value is empty). |
1841 | std::vector<StringPiece> protos; |
1842 | std::vector<bool> protos_present; |
1843 | |
1844 | // Information derived from protos: |
1845 | size_t length; // total length for ragged/sparse, max row length for dense. |
1846 | size_t num_rows; // only populated for ragged sequence features. |
1847 | |
1848 | // Information from the config: |
1849 | Type type; // Whether this feature is sparse, ragged, or dense. |
1850 | DataType dtype; |
1851 | }; |
1852 | |
1853 | // Map from feature name to FeatureProtos for that feature. |
1854 | using FeatureProtosMap = absl::flat_hash_map<StringPiece, FeatureProtos>; |
1855 | |
1856 | string ExampleName(const gtl::ArraySlice<tstring> example_names, int n) { |
1857 | return example_names.empty() ? "<unknown>" : example_names[n]; |
1858 | } |
1859 | |
1860 | // Return the number of bytes elements parsed, or -1 on error. If out is null, |
1861 | // this method simply counts the number of elements without any copying. |
1862 | inline int ParseBytesFeature(protobuf::io::CodedInputStream* stream, |
1863 | tstring* out) { |
1864 | int num_elements = 0; |
1865 | uint32 length; |
1866 | if (!stream->ExpectTag(kDelimitedTag(1)) || !stream->ReadVarint32(&length)) { |
1867 | return -1; |
1868 | } |
1869 | if (length > 0) { |
1870 | auto limit = stream->PushLimit(length); |
1871 | while (!stream->ExpectAtEnd()) { |
1872 | uint32 bytes_length; |
1873 | if (!stream->ExpectTag(kDelimitedTag(1)) || |
1874 | !stream->ReadVarint32(&bytes_length)) { |
1875 | return -1; |
1876 | } |
1877 | if (out == nullptr) { |
1878 | stream->Skip(bytes_length); |
1879 | } else { |
1880 | out->resize_uninitialized(bytes_length); |
1881 | if (!stream->ReadRaw(out->data(), bytes_length)) { |
1882 | return -1; |
1883 | } |
1884 | out++; |
1885 | } |
1886 | num_elements++; |
1887 | } |
1888 | stream->PopLimit(limit); |
1889 | } |
1890 | return num_elements; |
1891 | } |
1892 | |
1893 | inline void PadFloatFeature(int num_to_pad, float* out) { |
1894 | for (int i = 0; i < num_to_pad; i++) { |
1895 | *out++ = 0.0; |
1896 | } |
1897 | } |
1898 | |
1899 | inline void PadInt64Feature(int num_to_pad, int64_t* out) { |
1900 | for (int i = 0; i < num_to_pad; i++) { |
1901 | *out++ = 0; |
1902 | } |
1903 | } |
1904 | |
1905 | // Return the number of float elements parsed, or -1 on error. If out is null, |
1906 | // this method simply counts the number of elements without any copying. |
1907 | inline int ParseFloatFeature(protobuf::io::CodedInputStream* stream, |
1908 | float* out) { |
1909 | int num_elements = 0; |
1910 | uint32 length; |
1911 | if (!stream->ExpectTag(kDelimitedTag(2)) || !stream->ReadVarint32(&length)) { |
1912 | return -1; |
1913 | } |
1914 | if (length > 0) { |
1915 | auto limit = stream->PushLimit(length); |
1916 | uint8 peek_tag = PeekTag(stream); |
1917 | if (peek_tag == kDelimitedTag(1)) { // packed |
1918 | uint32 packed_length; |
1919 | if (!stream->ExpectTag(kDelimitedTag(1)) || |
1920 | !stream->ReadVarint32(&packed_length)) { |
1921 | return -1; |
1922 | } |
1923 | auto packed_limit = stream->PushLimit(packed_length); |
1924 | while (!stream->ExpectAtEnd()) { |
1925 | uint32 buffer32; |
1926 | if (!stream->ReadLittleEndian32(&buffer32)) { |
1927 | return -1; |
1928 | } |
1929 | if (out != nullptr) { |
1930 | *out++ = absl::bit_cast<float>(buffer32); |
1931 | } |
1932 | num_elements++; |
1933 | } |
1934 | stream->PopLimit(packed_limit); |
1935 | } else if (peek_tag == kFixed32Tag(1)) { |
1936 | while (!stream->ExpectAtEnd()) { |
1937 | uint32 buffer32; |
1938 | if (!stream->ExpectTag(kFixed32Tag(1)) || |
1939 | !stream->ReadLittleEndian32(&buffer32)) { |
1940 | return -1; |
1941 | } |
1942 | if (out != nullptr) { |
1943 | *out++ = absl::bit_cast<float>(buffer32); |
1944 | } |
1945 | num_elements++; |
1946 | } |
1947 | } else { |
1948 | // Unknown tag. |
1949 | return -1; |
1950 | } |
1951 | stream->PopLimit(limit); |
1952 | } |
1953 | return num_elements; |
1954 | } |
1955 | |
1956 | // Return the number of int64 elements parsed, or -1 on error. If out is null, |
1957 | // this method simply counts the number of elements without any copying. |
1958 | inline int ParseInt64Feature(protobuf::io::CodedInputStream* stream, |
1959 | int64_t* out) { |
1960 | int num_elements = 0; |
1961 | uint32 length; |
1962 | if (!stream->ExpectTag(kDelimitedTag(3)) || !stream->ReadVarint32(&length)) { |
1963 | return -1; |
1964 | } |
1965 | if (length > 0) { |
1966 | auto limit = stream->PushLimit(length); |
1967 | uint8 peek_tag = PeekTag(stream); |
1968 | if (peek_tag == kDelimitedTag(1)) { // packed |
1969 | uint32 packed_length; |
1970 | if (!stream->ExpectTag(kDelimitedTag(1)) || |
1971 | !stream->ReadVarint32(&packed_length)) { |
1972 | return -1; |
1973 | } |
1974 | auto packed_limit = stream->PushLimit(packed_length); |
1975 | while (!stream->ExpectAtEnd()) { |
1976 | protobuf_uint64 n; // There is no API for int64 |
1977 | if (!stream->ReadVarint64(&n)) { |
1978 | return -1; |
1979 | } |
1980 | if (out != nullptr) { |
1981 | *out++ = n; |
1982 | } |
1983 | num_elements++; |
1984 | } |
1985 | stream->PopLimit(packed_limit); |
1986 | } else if (peek_tag == kVarintTag(1)) { |
1987 | while (!stream->ExpectAtEnd()) { |
1988 | protobuf_uint64 n; // There is no API for int64 |
1989 | if (!stream->ExpectTag(kVarintTag(1)) || !stream->ReadVarint64(&n)) { |
1990 | return -1; |
1991 | } |
1992 | if (out != nullptr) { |
1993 | *out++ = n; |
1994 | } |
1995 | num_elements++; |
1996 | } |
1997 | } else { |
1998 | // Unknown tag. |
1999 | return -1; |
2000 | } |
2001 | stream->PopLimit(limit); |
2002 | } |
2003 | return num_elements; |
2004 | } |
2005 | |
2006 | // Parses the next feature on `stream` into `out` starting at `out_offset`. |
2007 | // Updates `out_offset`, and returns the number of values added. |
2008 | // Returns -1 if the next feature on `stream` doesn't match `dtype`. |
2009 | inline int ParseFeature(DataType dtype, protobuf::io::CodedInputStream* stream, |
2010 | Tensor* out, size_t* out_offset) { |
2011 | int delta; |
2012 | switch (dtype) { |
2013 | case DT_STRING: |
2014 | delta = |
2015 | ParseBytesFeature(stream, out->flat<tstring>().data() + *out_offset); |
2016 | break; |
2017 | case DT_FLOAT: |
2018 | delta = |
2019 | ParseFloatFeature(stream, out->flat<float>().data() + *out_offset); |
2020 | break; |
2021 | case DT_INT64: |
2022 | delta = |
2023 | ParseInt64Feature(stream, out->flat<int64_t>().data() + *out_offset); |
2024 | break; |
2025 | default: |
2026 | ReportUnexpectedDataType(dtype); |
2027 | delta = 0; |
2028 | } |
2029 | if (delta > 0) { |
2030 | *out_offset += delta; |
2031 | } |
2032 | return delta; |
2033 | } |
2034 | |
2035 | // Returns the length of the next feature on `stream`. |
2036 | // Returns -1 if the next feature on `stream` doesn't match `dtype`. |
2037 | inline int GetFeatureLength(DataType dtype, |
2038 | protobuf::io::CodedInputStream* stream) { |
2039 | switch (dtype) { |
2040 | case DT_STRING: |
2041 | return ParseBytesFeature(stream, nullptr); |
2042 | case DT_FLOAT: |
2043 | return ParseFloatFeature(stream, nullptr); |
2044 | case DT_INT64: |
2045 | return ParseInt64Feature(stream, nullptr); |
2046 | default: |
2047 | ReportUnexpectedDataType(dtype); |
2048 | return -1; |
2049 | } |
2050 | } |
2051 | |
2052 | inline DataType ParseDataType(protobuf::io::CodedInputStream* stream) { |
2053 | uint8 peek_tag = PeekTag(stream); |
2054 | switch (peek_tag) { |
2055 | case kDelimitedTag(1): |
2056 | return DT_STRING; |
2057 | case kDelimitedTag(2): |
2058 | return DT_FLOAT; |
2059 | case kDelimitedTag(3): |
2060 | return DT_INT64; |
2061 | default: |
2062 | return DT_INVALID; |
2063 | } |
2064 | } |
2065 | |
2066 | inline bool SkipEmptyFeature(protobuf::io::CodedInputStream* stream, |
2067 | DataType dtype) { |
2068 | switch (dtype) { |
2069 | case DT_STRING: |
2070 | if (!stream->ExpectTag(kDelimitedTag(1))) { |
2071 | return false; |
2072 | } |
2073 | break; |
2074 | case DT_FLOAT: |
2075 | if (!stream->ExpectTag(kDelimitedTag(2))) { |
2076 | return false; |
2077 | } |
2078 | break; |
2079 | case DT_INT64: |
2080 | if (!stream->ExpectTag(kDelimitedTag(3))) { |
2081 | return false; |
2082 | } |
2083 | break; |
2084 | default: |
2085 | return false; |
2086 | } |
2087 | uint32 length; |
2088 | return stream->ReadVarint32(&length) && length == 0; |
2089 | } |
2090 | |
2091 | // Reads an example proto, and extracts a StringPiece pointer to each feature. |
2092 | Status ( |
2093 | const gtl::ArraySlice<tstring> examples, |
2094 | const gtl::ArraySlice<tstring> example_names, |
2095 | FeatureProtosMap* context_features, FeatureProtosMap* sequence_features) { |
2096 | for (int d = 0; d < examples.size(); d++) { |
2097 | const tstring& example = examples[d]; |
2098 | protobuf::io::CodedInputStream stream( |
2099 | reinterpret_cast<const uint8*>(example.data()), example.size()); |
2100 | // Not clear what this does. Why not stream.EnableAliasing()? |
2101 | EnableAliasing(&stream); |
2102 | |
2103 | // Extract pointers to all features within this serialized example. |
2104 | while (!stream.ExpectAtEnd()) { |
2105 | FeatureProtosMap* features = nullptr; |
2106 | if (stream.ExpectTag(kDelimitedTag(1))) { |
2107 | // Context |
2108 | features = context_features; |
2109 | } else if (stream.ExpectTag(kDelimitedTag(2))) { |
2110 | // Sequence |
2111 | features = sequence_features; |
2112 | } else if (!SkipExtraneousTag(&stream)) { |
2113 | return errors::InvalidArgument( |
2114 | "Invalid protocol message input, example id: " , |
2115 | ExampleName(example_names, d)); |
2116 | } |
2117 | if (features != nullptr) { |
2118 | uint32 length; |
2119 | if (!stream.ReadVarint32(&length)) { |
2120 | return errors::InvalidArgument( |
2121 | "Invalid protocol message input, example id: " , |
2122 | ExampleName(example_names, d)); |
2123 | } |
2124 | auto limit = stream.PushLimit(length); |
2125 | while (!stream.ExpectAtEnd()) { |
2126 | StringPiece key, value; |
2127 | uint32 length; |
2128 | if (!stream.ExpectTag(kDelimitedTag(1)) || |
2129 | !stream.ReadVarint32(&length)) { |
2130 | return errors::InvalidArgument( |
2131 | "Invalid protocol message input, example id: " , |
2132 | ExampleName(example_names, d)); |
2133 | } |
2134 | auto limit = stream.PushLimit(length); |
2135 | if (!stream.ExpectTag(kDelimitedTag(1)) || |
2136 | !ParseString(&stream, &key) || |
2137 | !stream.ExpectTag(kDelimitedTag(2)) || |
2138 | !ParseString(&stream, &value) || !stream.ExpectAtEnd()) { |
2139 | return errors::InvalidArgument( |
2140 | "Invalid protocol message input, example id: " , |
2141 | ExampleName(example_names, d)); |
2142 | } |
2143 | stream.PopLimit(limit); |
2144 | // Only save if this feature was requested. |
2145 | auto feature_iter = features->find(key); |
2146 | if (feature_iter != features->end()) { |
2147 | auto& feature = feature_iter->second; |
2148 | feature.protos[d] = value; |
2149 | feature.protos_present[d] = true; |
2150 | } |
2151 | } |
2152 | stream.PopLimit(limit); |
2153 | } |
2154 | } |
2155 | } |
2156 | return OkStatus(); |
2157 | } |
2158 | |
2159 | // Populates context_features[k].length based on context_features[k].protos |
2160 | // (for all k). |
2161 | Status GetContextFeatureLengths(const gtl::ArraySlice<tstring> example_names, |
2162 | FeatureProtosMap* context_features) { |
2163 | for (auto& c : *context_features) { |
2164 | FeatureProtos& feature = c.second; |
2165 | for (int d = 0; d < feature.protos.size(); ++d) { |
2166 | const auto& proto = feature.protos[d]; |
2167 | if (proto.empty()) continue; |
2168 | protobuf::io::CodedInputStream stream( |
2169 | reinterpret_cast<const uint8*>(proto.data()), proto.size()); |
2170 | EnableAliasing(&stream); |
2171 | int num_elements = GetFeatureLength(feature.dtype, &stream); |
2172 | if (num_elements < 0) { |
2173 | return errors::InvalidArgument( |
2174 | "Name: " , ExampleName(example_names, d), |
2175 | ", Context feature: " , c.first, |
2176 | ". Data types don't match. Expected type: " , |
2177 | DataTypeString(feature.dtype)); |
2178 | } |
2179 | switch (feature.type) { |
2180 | case Type::Sparse: // intentional fall-through |
2181 | case Type::Ragged: |
2182 | feature.length += num_elements; |
2183 | break; |
2184 | case Type::Dense: |
2185 | feature.length = |
2186 | std::max(feature.length, static_cast<size_t>(num_elements)); |
2187 | break; |
2188 | } |
2189 | } |
2190 | } |
2191 | return OkStatus(); |
2192 | } |
2193 | |
2194 | // Populates sequence_features[k].length and sequence_features[k].num_rows based |
2195 | // on sequence_features[k].protos (for all k). |
2196 | Status GetSequenceFeatureLengths(const gtl::ArraySlice<tstring> example_names, |
2197 | FeatureProtosMap* sequence_features) { |
2198 | for (auto& c : *sequence_features) { |
2199 | FeatureProtos& feature = c.second; |
2200 | for (int d = 0; d < feature.protos.size(); ++d) { |
2201 | const auto& proto = feature.protos[d]; |
2202 | if (proto.empty()) continue; |
2203 | |
2204 | size_t num_rows = 0; |
2205 | size_t num_elements = 0; |
2206 | protobuf::io::CodedInputStream stream( |
2207 | reinterpret_cast<const uint8*>(proto.data()), proto.size()); |
2208 | EnableAliasing(&stream); |
2209 | while (!stream.ExpectAtEnd()) { |
2210 | uint32 feature_bytes; |
2211 | if (!stream.ExpectTag(kDelimitedTag(1)) || |
2212 | !stream.ReadVarint32(&feature_bytes)) { |
2213 | return errors::InvalidArgument("Error in sequence feature " , c.first, |
2214 | " in example " , |
2215 | ExampleName(example_names, d)); |
2216 | } |
2217 | if (feature_bytes > 2) { |
2218 | auto limit = stream.PushLimit(feature_bytes); |
2219 | int delta = GetFeatureLength(feature.dtype, &stream); |
2220 | if (delta < 0) { |
2221 | return errors::InvalidArgument( |
2222 | "Name: " , ExampleName(example_names, d), |
2223 | ", Feature list: " , c.first, ", Index: " , num_rows, |
2224 | ". Data types don't match. Expected type: " , |
2225 | DataTypeString(feature.dtype)); |
2226 | } |
2227 | num_elements += delta; |
2228 | stream.PopLimit(limit); |
2229 | } else if (feature_bytes == 2) { |
2230 | if (!SkipEmptyFeature(&stream, feature.dtype)) { |
2231 | return errors::InvalidArgument( |
2232 | "Name: " , ExampleName(example_names, d), |
2233 | ", Feature list: " , c.first, ", Index: " , num_rows, |
2234 | ". Data types don't match. Expected type: " , |
2235 | DataTypeString(feature.dtype)); |
2236 | } |
2237 | } else if (feature_bytes != 0) { |
2238 | return errors::InvalidArgument("Error in sequence feature " , c.first, |
2239 | " in example " , |
2240 | ExampleName(example_names, d)); |
2241 | } |
2242 | ++num_rows; |
2243 | } |
2244 | switch (feature.type) { |
2245 | case Type::Sparse: |
2246 | feature.length += num_elements; |
2247 | break; |
2248 | case Type::Ragged: |
2249 | feature.length += num_elements; |
2250 | feature.num_rows += num_rows; |
2251 | break; |
2252 | case Type::Dense: |
2253 | feature.length = std::max(feature.length, num_elements); |
2254 | break; |
2255 | } |
2256 | } |
2257 | } |
2258 | return OkStatus(); |
2259 | } |
2260 | |
2261 | // Copies src into dst[dst_offset:dst_offset+src.size], and then increments |
2262 | // dst_offset by src.size. |
2263 | void CopyTensorIntoTensor(DataType dtype, const Tensor& src, Tensor* dst, |
2264 | size_t* dst_offset) { |
2265 | size_t src_size = src.NumElements(); |
2266 | switch (dtype) { |
2267 | case DT_INT64: { |
2268 | auto src_t = src.flat<int64_t>().data(); |
2269 | std::copy(src_t, src_t + src_size, |
2270 | dst->flat<int64_t>().data() + *dst_offset); |
2271 | break; |
2272 | } |
2273 | case DT_FLOAT: { |
2274 | auto src_t = src.flat<float>().data(); |
2275 | std::copy(src_t, src_t + src_size, |
2276 | dst->flat<float>().data() + *dst_offset); |
2277 | break; |
2278 | } |
2279 | case DT_STRING: { |
2280 | auto src_t = src.flat<tstring>().data(); |
2281 | std::copy(src_t, src_t + src_size, |
2282 | dst->flat<tstring>().data() + *dst_offset); |
2283 | break; |
2284 | } |
2285 | default: |
2286 | ReportUnexpectedDataType(dtype); |
2287 | } |
2288 | *dst_offset += src_size; |
2289 | } |
2290 | |
2291 | // Parses dense features in `context_features`, and writes their parsed |
2292 | // values to `context_results`. |
2293 | Status ParseContextDenseFeatures(const FeatureProtosMap& context_features, |
2294 | const FastParseExampleConfig& context_config, |
2295 | gtl::ArraySlice<tstring> example_names, |
2296 | bool is_batch, int num_examples, |
2297 | Allocator* allocator, Result* context_result) { |
2298 | for (int t = 0; t < context_config.dense.size(); ++t) { |
2299 | const auto& c = context_config.dense[t]; |
2300 | const FeatureProtos& feature = |
2301 | context_features.find(c.feature_name)->second; |
2302 | TensorShape dense_shape, example_shape; |
2303 | DataType dtype = c.dtype; |
2304 | const size_t data_max_elements = feature.length; |
2305 | if (!c.shape.AsTensorShape(&example_shape) || |
2306 | data_max_elements != example_shape.num_elements()) { |
2307 | return errors::InvalidArgument( |
2308 | "Inconsistent max number of elements for feature " , c.feature_name, |
2309 | ": expected " , example_shape.num_elements(), ", but found " , |
2310 | data_max_elements); |
2311 | } |
2312 | if (is_batch) { |
2313 | dense_shape.AddDim(num_examples); |
2314 | } |
2315 | for (const int dim : c.shape.dim_sizes()) { |
2316 | dense_shape.AddDim(dim); |
2317 | } |
2318 | context_result->dense_values[t] = Tensor(allocator, dtype, dense_shape); |
2319 | |
2320 | Tensor& out = context_result->dense_values[t]; |
2321 | size_t out_offset = 0; |
2322 | |
2323 | // Fill in the values. |
2324 | for (int e = 0; e < num_examples; e++) { |
2325 | size_t num_elements = 0; |
2326 | const auto& feature_proto = feature.protos[e]; |
2327 | if (!feature.protos_present[e]) { |
2328 | // Copy the default value, if present. If not, return an error. |
2329 | if (c.default_value.NumElements() == 0) { |
2330 | return errors::InvalidArgument( |
2331 | "Feature: " , c.feature_name, |
2332 | " (data type: " , DataTypeString(c.dtype), ")" , |
2333 | " is required but could not be found." ); |
2334 | } |
2335 | CopyTensorIntoTensor(dtype, c.default_value, &out, &out_offset); |
2336 | num_elements += c.default_value.NumElements(); |
2337 | } else if (!feature_proto.empty()) { |
2338 | protobuf::io::CodedInputStream stream( |
2339 | reinterpret_cast<const uint8*>(feature_proto.data()), |
2340 | feature_proto.size()); |
2341 | EnableAliasing(&stream); |
2342 | num_elements += ParseFeature(dtype, &stream, &out, &out_offset); |
2343 | } |
2344 | if (num_elements != data_max_elements) { |
2345 | return errors::InvalidArgument( |
2346 | "Unexpected number of elements in example " , |
2347 | ExampleName(example_names, e)); |
2348 | } |
2349 | } |
2350 | } |
2351 | return OkStatus(); |
2352 | } |
2353 | |
2354 | // Parses sparse features in `context_features`, and writes their parsed |
2355 | // values to `context_results`. |
2356 | Status ParseContextSparseFeatures(const FeatureProtosMap& context_features, |
2357 | const FastParseExampleConfig& context_config, |
2358 | gtl::ArraySlice<tstring> example_names, |
2359 | bool is_batch, int num_examples, |
2360 | Allocator* allocator, |
2361 | Result* context_result) { |
2362 | for (int t = 0; t < context_config.sparse.size(); ++t) { |
2363 | const auto& c = context_config.sparse[t]; |
2364 | const FeatureProtos& feature = |
2365 | context_features.find(c.feature_name)->second; |
2366 | TensorShape indices_shape, values_shape; |
2367 | DataType dtype = c.dtype; |
2368 | size_t expected_num_elements = feature.length; |
2369 | indices_shape.AddDim(expected_num_elements); |
2370 | indices_shape.AddDim(is_batch ? 2 : 1); |
2371 | values_shape.AddDim(expected_num_elements); |
2372 | context_result->sparse_indices[t] = |
2373 | Tensor(allocator, DT_INT64, indices_shape); |
2374 | context_result->sparse_values[t] = Tensor(allocator, dtype, values_shape); |
2375 | context_result->sparse_shapes[t] = |
2376 | Tensor(allocator, DT_INT64, TensorShape({is_batch ? 2 : 1})); |
2377 | Tensor& out_values = context_result->sparse_values[t]; |
2378 | size_t out_values_offset = 0; |
2379 | int64_t* out_indices = |
2380 | context_result->sparse_indices[t].flat<int64_t>().data(); |
2381 | auto out_shape = context_result->sparse_shapes[t].vec<int64_t>(); |
2382 | |
2383 | // Fill in the values. |
2384 | size_t num_elements = 0; |
2385 | size_t max_num_cols = 0; |
2386 | for (int e = 0; e < num_examples; e++) { |
2387 | const auto& feature_proto = feature.protos[e]; |
2388 | if (feature_proto.empty()) continue; |
2389 | protobuf::io::CodedInputStream stream( |
2390 | reinterpret_cast<const uint8*>(feature_proto.data()), |
2391 | feature_proto.size()); |
2392 | EnableAliasing(&stream); |
2393 | size_t num_added = |
2394 | ParseFeature(dtype, &stream, &out_values, &out_values_offset); |
2395 | num_elements += num_added; |
2396 | max_num_cols = std::max(max_num_cols, num_added); |
2397 | for (int i = 0; i < num_added; i++) { |
2398 | if (is_batch) *out_indices++ = e; |
2399 | *out_indices++ = i; |
2400 | } |
2401 | } |
2402 | if (num_elements != expected_num_elements) { |
2403 | return errors::InvalidArgument( |
2404 | "Unexpected total number of elements in feature " , c.feature_name); |
2405 | } |
2406 | if (is_batch) { |
2407 | out_shape(0) = num_examples; |
2408 | out_shape(1) = max_num_cols; |
2409 | } else { |
2410 | out_shape(0) = max_num_cols; |
2411 | } |
2412 | } |
2413 | return OkStatus(); |
2414 | } |
2415 | |
2416 | // Parses ragged features in `context_features`, and writes their parsed |
2417 | // values to `context_results`. |
2418 | Status (const FeatureProtosMap& context_features, |
2419 | const FastParseExampleConfig& context_config, |
2420 | gtl::ArraySlice<tstring> example_names, |
2421 | bool is_batch, int num_examples, |
2422 | Allocator* allocator, |
2423 | Result* context_result) { |
2424 | for (int t = 0; t < context_config.ragged.size(); ++t) { |
2425 | const auto& c = context_config.ragged[t]; |
2426 | const FeatureProtos& feature = |
2427 | context_features.find(c.feature_name)->second; |
2428 | TensorShape values_shape, splits_shape; |
2429 | DataType dtype = c.dtype; |
2430 | DataType splits_dtype = c.splits_dtype; |
2431 | size_t expected_num_elements = feature.length; |
2432 | values_shape.AddDim(expected_num_elements); |
2433 | if (is_batch) { |
2434 | splits_shape.AddDim(num_examples + 1); |
2435 | } |
2436 | context_result->ragged_values[t] = Tensor(allocator, dtype, values_shape); |
2437 | context_result->ragged_splits[t] = |
2438 | Tensor(allocator, splits_dtype, splits_shape); |
2439 | Tensor& out_values = context_result->ragged_values[t]; |
2440 | size_t out_values_offset = 0; |
2441 | int32* int32_splits = |
2442 | is_batch && splits_dtype == DT_INT32 |
2443 | ? context_result->ragged_splits[t].vec<int32>().data() |
2444 | : nullptr; |
2445 | int64_t* int64_splits = |
2446 | is_batch && splits_dtype == DT_INT64 |
2447 | ? context_result->ragged_splits[t].vec<int64_t>().data() |
2448 | : nullptr; |
2449 | if (int32_splits) { |
2450 | *int32_splits++ = 0; |
2451 | } else if (int64_splits) { |
2452 | *int64_splits++ = 0; |
2453 | } |
2454 | |
2455 | // Fill in the values. |
2456 | size_t split = 0; // = total number of elements we've seen so far |
2457 | for (int e = 0; e < num_examples; e++) { |
2458 | const auto& feature_proto = feature.protos[e]; |
2459 | if (!feature_proto.empty()) { |
2460 | protobuf::io::CodedInputStream stream( |
2461 | reinterpret_cast<const uint8*>(feature_proto.data()), |
2462 | feature_proto.size()); |
2463 | EnableAliasing(&stream); |
2464 | size_t num_added = |
2465 | ParseFeature(dtype, &stream, &out_values, &out_values_offset); |
2466 | split += num_added; |
2467 | } |
2468 | if (int32_splits) { |
2469 | *int32_splits++ = split; |
2470 | } else if (int64_splits) { |
2471 | *int64_splits++ = split; |
2472 | } |
2473 | } |
2474 | if (split != expected_num_elements) { |
2475 | return errors::InvalidArgument( |
2476 | "Unexpected total number of elements in feature " , c.feature_name); |
2477 | } |
2478 | if (int32_splits || int64_splits) { |
2479 | int actual_splits = |
2480 | int32_splits |
2481 | ? int32_splits - |
2482 | context_result->ragged_splits[t].vec<int32>().data() |
2483 | : int64_splits - |
2484 | context_result->ragged_splits[t].vec<int64_t>().data(); |
2485 | if (actual_splits != num_examples + 1) { |
2486 | return errors::InvalidArgument( |
2487 | "Unexpected number of examples for feature " , c.feature_name); |
2488 | } |
2489 | } |
2490 | } |
2491 | return OkStatus(); |
2492 | } |
2493 | |
2494 | // Parses dense features in `sequence_features`, and writes their parsed |
2495 | // values to `sequence_result`. |
2496 | Status ParseSequenceDenseFeatures(const FeatureProtosMap& sequence_features, |
2497 | const FastParseExampleConfig& sequence_config, |
2498 | gtl::ArraySlice<tstring> example_names, |
2499 | bool is_batch, int num_examples, |
2500 | Allocator* allocator, Result* sequence_result, |
2501 | std::vector<Tensor>* dense_feature_lengths) { |
2502 | TensorShape dense_length_shape; |
2503 | if (is_batch) { |
2504 | dense_length_shape.AddDim(num_examples); |
2505 | } |
2506 | for (int t = 0; t < sequence_config.dense.size(); ++t) { |
2507 | const auto& c = sequence_config.dense[t]; |
2508 | const FeatureProtos& feature = |
2509 | sequence_features.find(c.feature_name)->second; |
2510 | TensorShape dense_shape, row_shape; |
2511 | DataType dtype = c.dtype; |
2512 | const size_t expected_max_elements = feature.length; |
2513 | if (!c.shape.AsTensorShape(&row_shape) || |
2514 | expected_max_elements != |
2515 | (expected_max_elements / row_shape.num_elements()) * |
2516 | row_shape.num_elements()) { |
2517 | PartialTensorShape total_shape = row_shape; |
2518 | total_shape.InsertDim(0, -1); |
2519 | return errors::InvalidArgument( |
2520 | "Feature list '" , c.feature_name, |
2521 | "' has an unexpected number of values. Total values size: " , |
2522 | expected_max_elements, |
2523 | " is not consistent with output shape: " , total_shape.DebugString()); |
2524 | } |
2525 | int64_t expected_max_rows = |
2526 | expected_max_elements / row_shape.num_elements(); |
2527 | if (is_batch) { |
2528 | dense_shape.AddDim(num_examples); |
2529 | } |
2530 | dense_shape.AddDim(expected_max_rows); |
2531 | for (const int dim : sequence_config.dense[t].shape.dim_sizes()) { |
2532 | dense_shape.AddDim(dim); |
2533 | } |
2534 | sequence_result->dense_values[t] = Tensor(allocator, dtype, dense_shape); |
2535 | (*dense_feature_lengths)[t] = |
2536 | Tensor(allocator, DT_INT64, dense_length_shape); |
2537 | int64_t* out_lengths = (*dense_feature_lengths)[t].flat<int64_t>().data(); |
2538 | |
2539 | tstring* out_bytes = nullptr; |
2540 | float* out_float = nullptr; |
2541 | int64_t* out_int64 = nullptr; |
2542 | switch (dtype) { |
2543 | case DT_STRING: |
2544 | out_bytes = sequence_result->dense_values[t].flat<tstring>().data(); |
2545 | break; |
2546 | case DT_FLOAT: |
2547 | out_float = sequence_result->dense_values[t].flat<float>().data(); |
2548 | break; |
2549 | case DT_INT64: |
2550 | out_int64 = sequence_result->dense_values[t].flat<int64_t>().data(); |
2551 | break; |
2552 | default: |
2553 | ReportUnexpectedDataType(dtype); |
2554 | } |
2555 | |
2556 | // Fill in the values. |
2557 | for (int e = 0; e < num_examples; e++) { |
2558 | size_t num_elements = 0, num_rows = 0; |
2559 | const auto& feature_proto = feature.protos[e]; |
2560 | if (!feature.protos_present[e]) { |
2561 | // Return an error if this feature was not allowed to be missing. |
2562 | // Otherwise, we'll pad as needed below. |
2563 | if (!c.variable_length) { |
2564 | return errors::InvalidArgument( |
2565 | "Name: " , ExampleName(example_names, e), ", Feature list '" , |
2566 | c.feature_name, |
2567 | "' is required but could not be found. " |
2568 | "Did you mean to include it in " |
2569 | "feature_list_dense_missing_assumed_empty or " |
2570 | "feature_list_dense_defaults?" ); |
2571 | } |
2572 | } else if (!feature_proto.empty()) { |
2573 | protobuf::io::CodedInputStream stream( |
2574 | reinterpret_cast<const uint8*>(feature_proto.data()), |
2575 | feature_proto.size()); |
2576 | EnableAliasing(&stream); |
2577 | while (!stream.ExpectAtEnd()) { |
2578 | uint32 feature_length; |
2579 | if (!stream.ExpectTag(kDelimitedTag(1)) || |
2580 | !stream.ReadVarint32(&feature_length)) { |
2581 | return errors::InvalidArgument("Error in sequence feature " , |
2582 | c.feature_name, " in example " , |
2583 | ExampleName(example_names, e)); |
2584 | } |
2585 | auto limit = stream.PushLimit(feature_length); |
2586 | int num_added = 0; |
2587 | if (feature_length > 2) { |
2588 | switch (dtype) { |
2589 | case DT_STRING: |
2590 | num_added = ParseBytesFeature(&stream, out_bytes); |
2591 | out_bytes += num_added; |
2592 | break; |
2593 | case DT_FLOAT: |
2594 | num_added = ParseFloatFeature(&stream, out_float); |
2595 | out_float += num_added; |
2596 | break; |
2597 | case DT_INT64: |
2598 | num_added = ParseInt64Feature(&stream, out_int64); |
2599 | out_int64 += num_added; |
2600 | break; |
2601 | default: |
2602 | ReportUnexpectedDataType(dtype); |
2603 | num_added = 0; |
2604 | } |
2605 | if (num_added < 0) { |
2606 | // This should be unreachable -- we already scanned the feature in |
2607 | // GetSequenceFeatureLengths, and it hasn't changed since then. |
2608 | return errors::InvalidArgument("Error in sequence feature " , |
2609 | c.feature_name, " in example " , |
2610 | ExampleName(example_names, e)); |
2611 | } |
2612 | } |
2613 | if (num_added != row_shape.num_elements()) { |
2614 | return errors::InvalidArgument( |
2615 | "Name: " , ExampleName(example_names, e), |
2616 | ", Key: " , c.feature_name, ", Index: " , num_rows, |
2617 | ". Number of values != expected. values size: " , num_added, |
2618 | " but output shape: " , row_shape.DebugString()); |
2619 | } |
2620 | num_elements += num_added; |
2621 | num_rows++; |
2622 | stream.PopLimit(limit); |
2623 | } |
2624 | } |
2625 | *out_lengths++ = num_rows; |
2626 | // Pad as necessary. |
2627 | int num_to_pad = expected_max_elements - num_elements; |
2628 | switch (dtype) { |
2629 | case DT_STRING: |
2630 | out_bytes += num_to_pad; |
2631 | break; |
2632 | case DT_FLOAT: |
2633 | PadFloatFeature(num_to_pad, out_float); |
2634 | out_float += num_to_pad; |
2635 | break; |
2636 | case DT_INT64: |
2637 | PadInt64Feature(num_to_pad, out_int64); |
2638 | out_int64 += num_to_pad; |
2639 | break; |
2640 | default: |
2641 | ReportUnexpectedDataType(dtype); |
2642 | } |
2643 | } |
2644 | } |
2645 | return OkStatus(); |
2646 | } |
2647 | |
2648 | // Parses sparse features in `sequence_features`, and writes their parsed |
2649 | // values to `sequence_result`. |
2650 | Status ParseSequenceSparseFeatures( |
2651 | const FeatureProtosMap& sequence_features, |
2652 | const FastParseExampleConfig& sequence_config, |
2653 | gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples, |
2654 | Allocator* allocator, Result* sequence_result) { |
2655 | for (int t = 0; t < sequence_config.sparse.size(); ++t) { |
2656 | const auto& c = sequence_config.sparse[t]; |
2657 | const FeatureProtos& feature = |
2658 | sequence_features.find(c.feature_name)->second; |
2659 | TensorShape indices_shape, values_shape; |
2660 | DataType dtype = c.dtype; |
2661 | size_t expected_num_elements = feature.length; |
2662 | indices_shape.AddDim(expected_num_elements); |
2663 | indices_shape.AddDim(is_batch ? 3 : 2); |
2664 | values_shape.AddDim(expected_num_elements); |
2665 | sequence_result->sparse_indices[t] = |
2666 | Tensor(allocator, DT_INT64, indices_shape); |
2667 | sequence_result->sparse_values[t] = Tensor(allocator, dtype, values_shape); |
2668 | sequence_result->sparse_shapes[t] = |
2669 | Tensor(allocator, DT_INT64, TensorShape({is_batch ? 3 : 2})); |
2670 | |
2671 | tstring* out_bytes = nullptr; |
2672 | float* out_float = nullptr; |
2673 | int64_t* out_int64 = nullptr; |
2674 | switch (dtype) { |
2675 | case DT_STRING: |
2676 | out_bytes = sequence_result->sparse_values[t].flat<tstring>().data(); |
2677 | break; |
2678 | case DT_FLOAT: |
2679 | out_float = sequence_result->sparse_values[t].flat<float>().data(); |
2680 | break; |
2681 | case DT_INT64: |
2682 | out_int64 = sequence_result->sparse_values[t].flat<int64_t>().data(); |
2683 | break; |
2684 | default: |
2685 | ReportUnexpectedDataType(dtype); |
2686 | } |
2687 | int64_t* out_indices = |
2688 | sequence_result->sparse_indices[t].flat<int64_t>().data(); |
2689 | auto out_shape = sequence_result->sparse_shapes[t].vec<int64_t>(); |
2690 | |
2691 | // Fill in the values. |
2692 | size_t num_elements = 0; |
2693 | size_t max_num_rows = 0; |
2694 | size_t max_num_cols = 0; |
2695 | for (int e = 0; e < num_examples; e++) { |
2696 | const auto& feature_proto = feature.protos[e]; |
2697 | if (feature_proto.empty()) continue; |
2698 | protobuf::io::CodedInputStream stream( |
2699 | reinterpret_cast<const uint8*>(feature_proto.data()), |
2700 | feature_proto.size()); |
2701 | EnableAliasing(&stream); |
2702 | size_t num_rows = 0; |
2703 | while (!stream.ExpectAtEnd()) { |
2704 | uint32 feature_length; |
2705 | if (!stream.ExpectTag(kDelimitedTag(1)) || |
2706 | !stream.ReadVarint32(&feature_length)) { |
2707 | // This should be unreachable -- we already scanned the feature in |
2708 | // GetSequenceFeatureLengths, and it hasn't changed since then. |
2709 | return errors::InvalidArgument("Error in sequence feature " , |
2710 | c.feature_name, " in example " , |
2711 | ExampleName(example_names, e)); |
2712 | } |
2713 | if (feature_length > 2) { |
2714 | auto limit = stream.PushLimit(feature_length); |
2715 | size_t num_added; |
2716 | switch (dtype) { |
2717 | case DT_STRING: |
2718 | num_added = ParseBytesFeature(&stream, out_bytes); |
2719 | out_bytes += num_added; |
2720 | break; |
2721 | case DT_FLOAT: |
2722 | num_added = ParseFloatFeature(&stream, out_float); |
2723 | out_float += num_added; |
2724 | break; |
2725 | case DT_INT64: |
2726 | num_added = ParseInt64Feature(&stream, out_int64); |
2727 | out_int64 += num_added; |
2728 | break; |
2729 | default: |
2730 | ReportUnexpectedDataType(dtype); |
2731 | num_added = 0; |
2732 | } |
2733 | num_elements += num_added; |
2734 | max_num_cols = std::max(max_num_cols, num_added); |
2735 | for (int i = 0; i < num_added; i++) { |
2736 | if (is_batch) *out_indices++ = e; |
2737 | *out_indices++ = num_rows; |
2738 | *out_indices++ = i; |
2739 | } |
2740 | stream.PopLimit(limit); |
2741 | } else if (feature_length == 2) { |
2742 | if (!SkipEmptyFeature(&stream, dtype)) { |
2743 | // This should be unreachable -- we already scanned the feature in |
2744 | // GetSequenceFeatureLengths, and it hasn't changed since then. |
2745 | return errors::InvalidArgument("Error in sequence feature " , |
2746 | c.feature_name, " in example " , |
2747 | ExampleName(example_names, e)); |
2748 | } |
2749 | } else if (feature_length != 0) { |
2750 | // This should be unreachable -- we already scanned the feature in |
2751 | // GetSequenceFeatureLengths, and it hasn't changed since then. |
2752 | return errors::InvalidArgument("Error in sequence feature " , |
2753 | c.feature_name, " in example " , |
2754 | ExampleName(example_names, e)); |
2755 | } |
2756 | num_rows++; |
2757 | } |
2758 | max_num_rows = std::max(max_num_rows, num_rows); |
2759 | } |
2760 | if (num_elements != expected_num_elements) { |
2761 | return errors::InvalidArgument( |
2762 | "Unexpected number of elements in feature " , c.feature_name); |
2763 | } |
2764 | if (is_batch) { |
2765 | out_shape(0) = num_examples; |
2766 | out_shape(1) = max_num_rows; |
2767 | out_shape(2) = max_num_cols; |
2768 | } else { |
2769 | out_shape(0) = max_num_rows; |
2770 | out_shape(1) = max_num_cols; |
2771 | } |
2772 | } |
2773 | return OkStatus(); |
2774 | } |
2775 | |
2776 | // Parses ragged features in `sequence_features`, and writes their parsed |
2777 | // values to `sequence_result`. |
2778 | Status ParseSequenceRaggedFeatures( |
2779 | const FeatureProtosMap& sequence_features, |
2780 | const FastParseExampleConfig& sequence_config, |
2781 | gtl::ArraySlice<tstring> example_names, bool is_batch, int num_examples, |
2782 | Allocator* allocator, Result* sequence_result) { |
2783 | for (int t = 0; t < sequence_config.ragged.size(); ++t) { |
2784 | const auto& c = sequence_config.ragged[t]; |
2785 | const FeatureProtos& feature = |
2786 | sequence_features.find(c.feature_name)->second; |
2787 | TensorShape values_shape, inner_splits_shape, outer_splits_shape; |
2788 | DataType dtype = c.dtype; |
2789 | DataType splits_dtype = c.splits_dtype; |
2790 | size_t expected_num_elements = feature.length; |
2791 | size_t expected_num_rows = feature.num_rows; |
2792 | values_shape.AddDim(expected_num_elements); |
2793 | inner_splits_shape.AddDim(expected_num_rows + 1); |
2794 | if (is_batch) { |
2795 | outer_splits_shape.AddDim(num_examples + 1); |
2796 | } |
2797 | sequence_result->ragged_values[t] = Tensor(allocator, dtype, values_shape); |
2798 | sequence_result->ragged_splits[t] = |
2799 | Tensor(allocator, splits_dtype, inner_splits_shape); |
2800 | sequence_result->ragged_outer_splits[t] = |
2801 | Tensor(allocator, splits_dtype, outer_splits_shape); |
2802 | Tensor& out_values = sequence_result->ragged_values[t]; |
2803 | size_t out_values_offset = 0; |
2804 | int32* int32_inner_splits = |
2805 | splits_dtype == DT_INT32 |
2806 | ? sequence_result->ragged_splits[t].vec<int32>().data() |
2807 | : nullptr; |
2808 | int64_t* int64_inner_splits = |
2809 | splits_dtype == DT_INT64 |
2810 | ? sequence_result->ragged_splits[t].vec<int64_t>().data() |
2811 | : nullptr; |
2812 | int32* int32_outer_splits = |
2813 | is_batch && splits_dtype == DT_INT32 |
2814 | ? sequence_result->ragged_outer_splits[t].vec<int32>().data() |
2815 | : nullptr; |
2816 | int64_t* int64_outer_splits = |
2817 | is_batch && splits_dtype == DT_INT64 |
2818 | ? sequence_result->ragged_outer_splits[t].vec<int64_t>().data() |
2819 | : nullptr; |
2820 | if (int32_inner_splits) { |
2821 | *int32_inner_splits++ = 0; |
2822 | } else if (int64_inner_splits) { |
2823 | *int64_inner_splits++ = 0; |
2824 | } |
2825 | if (int32_outer_splits) { |
2826 | *int32_outer_splits++ = 0; |
2827 | } else if (int64_outer_splits) { |
2828 | *int64_outer_splits++ = 0; |
2829 | } |
2830 | |
2831 | // Fill in the values. |
2832 | size_t inner_split = 0; // total number of elements we've seen so far |
2833 | size_t outer_split = 0; // total number of rows we've seen so far |
2834 | for (int e = 0; e < num_examples; e++) { |
2835 | const auto& feature_proto = feature.protos[e]; |
2836 | if (!feature_proto.empty()) { |
2837 | protobuf::io::CodedInputStream stream( |
2838 | reinterpret_cast<const uint8*>(feature_proto.data()), |
2839 | feature_proto.size()); |
2840 | EnableAliasing(&stream); |
2841 | while (!stream.ExpectAtEnd()) { |
2842 | uint32 feature_length; |
2843 | if (!stream.ExpectTag(kDelimitedTag(1)) || |
2844 | !stream.ReadVarint32(&feature_length)) { |
2845 | // This should be unreachable -- we already scanned the feature in |
2846 | // GetSequenceFeatureLengths, and it hasn't changed since then. |
2847 | return errors::InvalidArgument("Error in sequence feature " , |
2848 | c.feature_name, " in example " , |
2849 | ExampleName(example_names, e)); |
2850 | } |
2851 | if (feature_length > 2) { |
2852 | auto limit = stream.PushLimit(feature_length); |
2853 | size_t num_added = |
2854 | ParseFeature(dtype, &stream, &out_values, &out_values_offset); |
2855 | inner_split += num_added; |
2856 | stream.PopLimit(limit); |
2857 | } else if (feature_length == 2) { |
2858 | if (!SkipEmptyFeature(&stream, dtype)) { |
2859 | // This should be unreachable -- we already scanned the feature in |
2860 | // GetSequenceFeatureLengths, and it hasn't changed since then. |
2861 | return errors::InvalidArgument("Error in sequence feature " , |
2862 | c.feature_name, " in example " , |
2863 | ExampleName(example_names, e)); |
2864 | } |
2865 | } else if (feature_length != 0) { |
2866 | // This should be unreachable -- we already scanned the feature in |
2867 | // GetSequenceFeatureLengths, and it hasn't changed since then. |
2868 | return errors::InvalidArgument("Error in sequence feature " , |
2869 | c.feature_name, " in example " , |
2870 | ExampleName(example_names, e)); |
2871 | } |
2872 | if (int32_inner_splits) { |
2873 | *int32_inner_splits++ = inner_split; |
2874 | } else if (int64_inner_splits) { |
2875 | *int64_inner_splits++ = inner_split; |
2876 | } |
2877 | outer_split++; |
2878 | } |
2879 | } |
2880 | if (int32_outer_splits) { |
2881 | *int32_outer_splits++ = outer_split; |
2882 | } else if (int64_outer_splits) { |
2883 | *int64_outer_splits++ = outer_split; |
2884 | } |
2885 | } |
2886 | if (outer_split != expected_num_rows) { |
2887 | return errors::InvalidArgument("Unexpected number of rows for feature " , |
2888 | c.feature_name); |
2889 | } |
2890 | if (inner_split != expected_num_elements) { |
2891 | return errors::InvalidArgument( |
2892 | "Unexpected number of elements for feature " , c.feature_name); |
2893 | } |
2894 | |
2895 | if (int32_inner_splits || int64_inner_splits) { |
2896 | const auto& inner_splits = sequence_result->ragged_splits[t]; |
2897 | int num_inner_splits = |
2898 | int32_inner_splits |
2899 | ? int32_inner_splits - inner_splits.vec<int32>().data() |
2900 | : int64_inner_splits - inner_splits.vec<int64_t>().data(); |
2901 | if (num_inner_splits != expected_num_rows + 1) { |
2902 | return errors::InvalidArgument("Unexpected number of rows for feature " , |
2903 | c.feature_name); |
2904 | } |
2905 | } |
2906 | if (int32_outer_splits || int64_outer_splits) { |
2907 | const auto& outer_splits = sequence_result->ragged_outer_splits[t]; |
2908 | int num_outer_splits = |
2909 | int32_outer_splits |
2910 | ? int32_outer_splits - outer_splits.vec<int32>().data() |
2911 | : int64_outer_splits - outer_splits.vec<int64_t>().data(); |
2912 | if (num_outer_splits != num_examples + 1) { |
2913 | return errors::InvalidArgument( |
2914 | "Unexpected number of examples for feature " , c.feature_name); |
2915 | } |
2916 | } |
2917 | } |
2918 | return OkStatus(); |
2919 | } |
2920 | |
2921 | } // namespace |
2922 | |
2923 | // TODO(sundberg): Use the threadpool to parallelize example parsing. |
2924 | // TODO(b/111553342): Support extracting feature statistics from the examples. |
2925 | Status FastParseSequenceExample(const FastParseExampleConfig& context_config, |
2926 | const FastParseExampleConfig& sequence_config, |
2927 | gtl::ArraySlice<tstring> serialized, |
2928 | gtl::ArraySlice<tstring> example_names, |
2929 | thread::ThreadPool* thread_pool, |
2930 | Result* context_result, Result* sequence_result, |
2931 | std::vector<Tensor>* dense_feature_lengths, |
2932 | bool is_batch) { |
2933 | int num_examples = serialized.size(); |
2934 | DCHECK(context_result != nullptr); |
2935 | DCHECK(sequence_result != nullptr); |
2936 | DCHECK(dense_feature_lengths != nullptr); |
2937 | size_t num_context_features = context_config.sparse.size() + |
2938 | context_config.dense.size() + |
2939 | context_config.ragged.size(); |
2940 | FeatureProtosMap context_features; |
2941 | context_features.reserve(num_context_features); |
2942 | |
2943 | if (!example_names.empty() && example_names.size() != num_examples) { |
2944 | return errors::InvalidArgument( |
2945 | "example_names must be empty or have the correct number of elements" ); |
2946 | } |
2947 | for (auto& c : context_config.sparse) { |
2948 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
2949 | FeatureProtos& feature = context_features[c.feature_name]; |
2950 | feature.dtype = c.dtype; |
2951 | feature.length = 0; |
2952 | feature.type = Type::Sparse; |
2953 | feature.protos.resize(num_examples); |
2954 | feature.protos_present.resize(num_examples); |
2955 | } |
2956 | for (auto& c : context_config.ragged) { |
2957 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
2958 | FeatureProtos& feature = context_features[c.feature_name]; |
2959 | if (feature.type == Type::Sparse) { |
2960 | return errors::InvalidArgument("Context feature " + c.feature_name + |
2961 | " cannot be both ragged and sparse" ); |
2962 | } |
2963 | feature.dtype = c.dtype; |
2964 | feature.length = 0; |
2965 | feature.type = Type::Ragged; |
2966 | feature.protos.resize(num_examples); |
2967 | feature.protos_present.resize(num_examples); |
2968 | } |
2969 | for (auto& c : context_config.dense) { |
2970 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
2971 | FeatureProtos& feature = context_features[c.feature_name]; |
2972 | if (feature.type != Type::Dense) { |
2973 | return errors::InvalidArgument("Context feature " + c.feature_name + |
2974 | " cannot be both dense and sparse" ); |
2975 | } |
2976 | if (c.default_value.NumElements() > 0) { |
2977 | if (!c.shape.IsCompatibleWith(c.default_value.shape())) { |
2978 | return errors::InvalidArgument("Default value for context feature " , |
2979 | c.feature_name, |
2980 | " has an incorrect shape: saw " , |
2981 | c.default_value.shape().DebugString(), |
2982 | " but expected " , c.shape.DebugString()); |
2983 | } |
2984 | } |
2985 | feature.dtype = c.dtype; |
2986 | feature.length = c.default_value.NumElements(); |
2987 | feature.protos.resize(num_examples); |
2988 | feature.protos_present.resize(num_examples); |
2989 | } |
2990 | size_t num_sequence_features = sequence_config.sparse.size() + |
2991 | sequence_config.dense.size() + |
2992 | sequence_config.ragged.size(); |
2993 | FeatureProtosMap sequence_features; |
2994 | sequence_features.reserve(num_sequence_features); |
2995 | for (auto& c : sequence_config.sparse) { |
2996 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
2997 | FeatureProtos& feature = sequence_features[c.feature_name]; |
2998 | feature.dtype = c.dtype; |
2999 | feature.length = 0; |
3000 | feature.type = Type::Sparse; |
3001 | feature.protos.resize(num_examples); |
3002 | feature.protos_present.resize(num_examples); |
3003 | } |
3004 | for (auto& c : sequence_config.ragged) { |
3005 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
3006 | FeatureProtos& feature = sequence_features[c.feature_name]; |
3007 | if (feature.type == Type::Sparse) { |
3008 | return errors::InvalidArgument("Sequence feature " + c.feature_name + |
3009 | " cannot be both ragged and sparse" ); |
3010 | } |
3011 | feature.dtype = c.dtype; |
3012 | feature.length = 0; |
3013 | feature.type = Type::Ragged; |
3014 | feature.protos.resize(num_examples); |
3015 | feature.protos_present.resize(num_examples); |
3016 | } |
3017 | for (auto& c : sequence_config.dense) { |
3018 | TF_RETURN_IF_ERROR(CheckConfigDataType(c.dtype)); |
3019 | FeatureProtos& feature = sequence_features[c.feature_name]; |
3020 | if (feature.type != Type::Dense) { |
3021 | return errors::InvalidArgument("Sequence feature " + c.feature_name + |
3022 | " cannot be both dense and sparse" ); |
3023 | } |
3024 | feature.dtype = c.dtype; |
3025 | feature.length = 0; |
3026 | feature.protos.resize(num_examples); |
3027 | feature.protos_present.resize(num_examples); |
3028 | } |
3029 | |
3030 | // Find the serialized proto substrings for each feature. |
3031 | TF_RETURN_IF_ERROR(ExtractFeaturesFromSequenceExamples( |
3032 | serialized, example_names, &context_features, &sequence_features)); |
3033 | |
3034 | // Scan through the protos to determine how much memory we need to allocate. |
3035 | TF_RETURN_IF_ERROR( |
3036 | GetContextFeatureLengths(example_names, &context_features)); |
3037 | TF_RETURN_IF_ERROR( |
3038 | GetSequenceFeatureLengths(example_names, &sequence_features)); |
3039 | |
3040 | // Allocate memory. |
3041 | context_result->sparse_values.resize(context_config.sparse.size()); |
3042 | context_result->sparse_indices.resize(context_config.sparse.size()); |
3043 | context_result->sparse_shapes.resize(context_config.sparse.size()); |
3044 | context_result->dense_values.resize(context_config.dense.size()); |
3045 | context_result->ragged_values.resize(context_config.ragged.size()); |
3046 | context_result->ragged_splits.resize(context_config.ragged.size()); |
3047 | context_result->ragged_outer_splits.resize(context_config.ragged.size()); |
3048 | sequence_result->sparse_values.resize(sequence_config.sparse.size()); |
3049 | sequence_result->sparse_indices.resize(sequence_config.sparse.size()); |
3050 | sequence_result->sparse_shapes.resize(sequence_config.sparse.size()); |
3051 | sequence_result->dense_values.resize(sequence_config.dense.size()); |
3052 | sequence_result->ragged_values.resize(sequence_config.ragged.size()); |
3053 | sequence_result->ragged_splits.resize(sequence_config.ragged.size()); |
3054 | sequence_result->ragged_outer_splits.resize(sequence_config.ragged.size()); |
3055 | dense_feature_lengths->resize(sequence_config.dense.size()); |
3056 | |
3057 | // NOTE(mrry): Cache the CPU allocator here and use it in Tensor construction, |
3058 | // to avoid lock contention in `tensorflow::cpu_allocator()`. |
3059 | Allocator* allocator = tensorflow::cpu_allocator(); |
3060 | |
3061 | TF_RETURN_IF_ERROR(ParseContextDenseFeatures( |
3062 | context_features, context_config, example_names, is_batch, num_examples, |
3063 | allocator, context_result)); |
3064 | TF_RETURN_IF_ERROR(ParseContextSparseFeatures( |
3065 | context_features, context_config, example_names, is_batch, num_examples, |
3066 | allocator, context_result)); |
3067 | TF_RETURN_IF_ERROR(ParseContextRaggedFeatures( |
3068 | context_features, context_config, example_names, is_batch, num_examples, |
3069 | allocator, context_result)); |
3070 | TF_RETURN_IF_ERROR(ParseSequenceDenseFeatures( |
3071 | sequence_features, sequence_config, example_names, is_batch, num_examples, |
3072 | allocator, sequence_result, dense_feature_lengths)); |
3073 | TF_RETURN_IF_ERROR(ParseSequenceSparseFeatures( |
3074 | sequence_features, sequence_config, example_names, is_batch, num_examples, |
3075 | allocator, sequence_result)); |
3076 | TF_RETURN_IF_ERROR(ParseSequenceRaggedFeatures( |
3077 | sequence_features, sequence_config, example_names, is_batch, num_examples, |
3078 | allocator, sequence_result)); |
3079 | |
3080 | return OkStatus(); |
3081 | } |
3082 | |
3083 | } // namespace example |
3084 | } // namespace tensorflow |
3085 | |