1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // DecodeProto is a TensorFlow op which extracts arbitrary fields from protos |
17 | // serialized as strings. |
18 | // |
19 | // See docs in ../ops/decode_proto_op.cc. |
20 | // |
21 | // This implementation reads the serialized format using a handful of calls from |
22 | // the WireFormatLite API used by generated proto code. WireFormatLite is marked |
23 | // as an "internal" proto API but is widely used in practice and highly unlikely |
24 | // to change. This will be much faster than the previous implementation based on |
25 | // constructing a temporary dynamic message in memory and using the proto |
26 | // reflection api to read it. It can be used with any proto whose descriptors |
27 | // are available at runtime but should be competitive in speed with approaches |
28 | // that compile in the proto definitions. |
29 | |
30 | #include <memory> |
31 | #include <string> |
32 | #include <vector> |
33 | |
34 | #include "absl/container/flat_hash_map.h" |
35 | #include "absl/types/span.h" |
36 | #include "third_party/eigen3/Eigen/Core" |
37 | #include "tensorflow/core/framework/op_kernel.h" |
38 | #include "tensorflow/core/framework/tensor_types.h" |
39 | #include "tensorflow/core/framework/types.h" |
40 | #include "tensorflow/core/lib/core/errors.h" |
41 | #include "tensorflow/core/platform/logging.h" |
42 | #include "tensorflow/core/platform/protobuf.h" |
43 | #include "tensorflow/core/util/proto/decode.h" |
44 | #include "tensorflow/core/util/proto/descriptors.h" |
45 | #include "tensorflow/core/util/proto/proto_utils.h" |
46 | #include "tensorflow/core/util/ptr_util.h" |
47 | |
48 | namespace tensorflow { |
49 | namespace { |
50 | |
51 | using ::tensorflow::MakeUnique; |
52 | using ::tensorflow::protobuf::Descriptor; |
53 | using ::tensorflow::protobuf::DescriptorPool; |
54 | using ::tensorflow::protobuf::DynamicMessageFactory; |
55 | using ::tensorflow::protobuf::FieldDescriptor; |
56 | using ::tensorflow::protobuf::Message; |
57 | using ::tensorflow::protobuf::TextFormat; |
58 | using ::tensorflow::protobuf::internal::WireFormatLite; |
59 | using ::tensorflow::protobuf::io::CodedInputStream; |
60 | |
61 | const bool kFailOnDecodeError = true; |
62 | |
63 | // Used to store the default value of a protocol message field, casted to the |
64 | // type of the output tensor. |
65 | // |
66 | // TODO(paskin): Use absl::variant once TensorFlow gets absl dependencies. |
67 | struct DefaultValue { |
68 | DataType dtype = DataType::DT_INVALID; |
69 | union Value { |
70 | bool v_bool; // DT_BOOL |
71 | double v_double; // DT_DOUBLE |
72 | float v_float; // DT_FLOAT |
73 | int8 v_int8; // DT_INT8 |
74 | int32 v_int32; // DT_INT32 |
75 | int64_t v_int64; // DT_INT64 |
76 | const char* v_string; // DT_STRING |
77 | uint8 v_uint8; // DT_UINT8 |
78 | uint8 v_uint32; // DT_UINT32 |
79 | uint8 v_uint64; // DT_UINT64 |
80 | }; |
81 | Value value; |
82 | }; |
83 | |
84 | // Initializes a DefaultValue object. This generic template handles numeric |
85 | // types and strings are handled by a template specialization below. |
86 | // |
87 | // Args: |
88 | // dtype: the type of the output tensor |
89 | // value: the default value as obtained from the FieldDescriptor |
90 | // result: the object to initialize |
91 | template <typename T> |
92 | Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) { |
93 | result->dtype = dtype; |
94 | switch (dtype) { |
95 | case DT_BOOL: |
96 | result->value.v_bool = static_cast<bool>(value); |
97 | break; |
98 | case DT_DOUBLE: |
99 | result->value.v_double = static_cast<double>(value); |
100 | break; |
101 | case DT_FLOAT: |
102 | result->value.v_float = static_cast<float>(value); |
103 | break; |
104 | case DT_INT8: |
105 | result->value.v_int8 = static_cast<int8>(value); |
106 | break; |
107 | case DT_INT32: |
108 | result->value.v_int32 = static_cast<int32>(value); |
109 | break; |
110 | case DT_INT64: |
111 | result->value.v_int64 = static_cast<int64_t>(value); |
112 | break; |
113 | case DT_UINT8: |
114 | result->value.v_uint8 = static_cast<uint8>(value); |
115 | break; |
116 | case DT_UINT32: |
117 | result->value.v_uint32 = static_cast<uint32>(value); |
118 | break; |
119 | case DT_UINT64: |
120 | result->value.v_uint64 = static_cast<uint64>(value); |
121 | break; |
122 | default: |
123 | // We should never get here, given the type checking that occurs earlier. |
124 | return errors::Internal( |
125 | "Cannot initialize default value for unsupported type: " , |
126 | DataTypeString(dtype)); |
127 | } |
128 | return OkStatus(); |
129 | } |
130 | |
131 | template <> |
132 | Status InitDefaultValue(DataType dtype, const char* value, |
133 | DefaultValue* result) { |
134 | // These are sanity checks that should never trigger given the code that |
135 | // leads here. |
136 | if (TF_PREDICT_FALSE(dtype != DT_STRING)) { |
137 | return errors::InvalidArgument( |
138 | "Cannot cast field to anything but DT_STRING" ); |
139 | } |
140 | if (TF_PREDICT_FALSE(value == nullptr)) { |
141 | return errors::InvalidArgument("Null default string value." ); |
142 | } |
143 | result->dtype = DT_STRING; |
144 | result->value.v_string = value; |
145 | return OkStatus(); |
146 | } |
147 | |
148 | // Initializes a default value from the output data type and the field |
149 | // descriptor. |
150 | Status InitDefaultValueFromFieldDescriptor(DataType dtype, |
151 | const FieldDescriptor* field_desc, |
152 | DefaultValue* result) { |
153 | switch (field_desc->type()) { |
154 | case WireFormatLite::TYPE_DOUBLE: |
155 | return InitDefaultValue(dtype, field_desc->default_value_double(), |
156 | result); |
157 | case WireFormatLite::TYPE_FLOAT: |
158 | return InitDefaultValue(dtype, field_desc->default_value_float(), result); |
159 | case WireFormatLite::TYPE_INT64: |
160 | case WireFormatLite::TYPE_SINT64: |
161 | case WireFormatLite::TYPE_SFIXED64: |
162 | return InitDefaultValue(dtype, field_desc->default_value_int64(), result); |
163 | case WireFormatLite::TYPE_FIXED64: |
164 | case WireFormatLite::TYPE_UINT64: |
165 | return InitDefaultValue(dtype, field_desc->default_value_uint64(), |
166 | result); |
167 | case WireFormatLite::TYPE_INT32: |
168 | case WireFormatLite::TYPE_SINT32: |
169 | case WireFormatLite::TYPE_SFIXED32: |
170 | return InitDefaultValue(dtype, field_desc->default_value_int32(), result); |
171 | case WireFormatLite::TYPE_FIXED32: |
172 | case WireFormatLite::TYPE_UINT32: |
173 | return InitDefaultValue(dtype, field_desc->default_value_uint32(), |
174 | result); |
175 | case WireFormatLite::TYPE_BOOL: |
176 | return InitDefaultValue(dtype, field_desc->default_value_bool(), result); |
177 | case WireFormatLite::TYPE_ENUM: |
178 | return InitDefaultValue(dtype, field_desc->default_value_enum()->number(), |
179 | result); |
180 | case WireFormatLite::TYPE_BYTES: |
181 | case WireFormatLite::TYPE_STRING: |
182 | // Manipulating default string values as C-style pointers should be OK |
183 | // for typical code-generated protocol messages. It is possible in |
184 | // principle to register a message descriptor on the fly, and these |
185 | // pointers may not be stable if that descriptor has a weird |
186 | // implementation. (But the return type of default_value_string() is |
187 | // const string&, so it'd have to be very weird.) |
188 | return InitDefaultValue(dtype, field_desc->default_value_string().c_str(), |
189 | result); |
190 | case WireFormatLite::TYPE_GROUP: |
191 | case WireFormatLite::TYPE_MESSAGE: |
192 | return InitDefaultValue(dtype, "" , result); |
193 | // default: intentionally omitted in order to enable static checking. |
194 | } |
195 | return OkStatus(); |
196 | } |
197 | |
198 | // A FieldInfo holds a handful of information from the FieldDescriptor |
199 | // and user attributes. |
200 | struct FieldInfo { |
201 | FieldInfo(const FieldDescriptor* field_desc, int user_index, |
202 | DefaultValue def_value) |
203 | : output_index(user_index), default_value(def_value) { |
204 | // Without this intermediate data structure, the profile had hotspots |
205 | // calling methods of FieldDescriptor. |
206 | number = field_desc->number(); |
207 | |
208 | // The wire format library defines the same constants used in |
209 | // descriptor.proto. This static_cast is safe because they are guaranteed to |
210 | // stay in sync. We need the field type from the FieldDescriptor here |
211 | // because the wire format doesn't tell us anything about what happens |
212 | // inside a packed repeated field: there is enough information in the wire |
213 | // format to skip the whole field but not enough to know how to parse what's |
214 | // inside. For that we go to the schema. |
215 | type = static_cast<WireFormatLite::FieldType>(field_desc->type()); |
216 | is_repeated = field_desc->is_repeated(); |
217 | } |
218 | |
219 | // Disable copy and move. |
220 | FieldInfo(const FieldInfo&) = delete; |
221 | FieldInfo& operator=(const FieldInfo&) = delete; |
222 | |
223 | // Internally we sort field descriptors by wire number for fast lookup. In |
224 | // general this is different from the order given by the user. Output_index |
225 | // gives the index into the field_names and output_types attributes and into |
226 | // the output tensor list. |
227 | int output_index = -1; |
228 | |
229 | // This is a cache of the relevant fields from `FieldDescriptorProto`. This |
230 | // was added after noticing that FieldDescriptor->type() was using 6% of the |
231 | // cpu profile. |
232 | WireFormatLite::FieldType type; |
233 | int number; |
234 | bool is_repeated; |
235 | DefaultValue default_value; |
236 | }; |
237 | |
238 | // A CountCollector counts sizes of repeated and optional fields in a proto. |
239 | // |
240 | // Each field is tracked by a single CountCollector instance. The instance |
241 | // manages a single count, which is stored as a pointer (it is intended to be a |
242 | // reference to the `sizes` output which is being filled in). The pointer is |
243 | // passed in at initialization. |
244 | // |
245 | // Counting is done as a separate pass in order to allocate output tensors all |
246 | // at once. This allows the TensorFlow runtime to optimize allocation for the |
247 | // consumer, while removing the need for copying inside this op. After this |
248 | // pass, the DenseCollector class (below) gathers the data: it is more complex |
249 | // and provides better motivation for the API here. |
250 | class CountCollector { |
251 | public: |
252 | CountCollector() = delete; |
253 | |
254 | // The count may be stored inside an Eigen Tensor to eliminate copying. |
255 | explicit CountCollector(int32* count) : count_ptr_(count) {} |
256 | |
257 | // Reads (in this case counts) a single value. |
258 | Status ReadValue(CodedInputStream* input, const FieldInfo& field) { |
259 | // Only repeated fields can have count > 1. |
260 | if (*count_ptr_ == 0 || field.is_repeated) { |
261 | (*count_ptr_)++; |
262 | } |
263 | // We expect a wire type based on the schema field_type, to allow a little |
264 | // more checking. |
265 | if (!SkipValue(input, field)) { |
266 | return errors::DataLoss("ReadValue: Failed skipping field when counting" ); |
267 | } |
268 | return OkStatus(); |
269 | } |
270 | |
271 | // Reads (in this case counts) a length-delimited list of values. |
272 | Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, |
273 | size_t buf_size) { |
274 | if (buf_size == 0) { |
275 | return OkStatus(); |
276 | } |
277 | |
278 | const void* tmpbuf; |
279 | int unused_max_buf_size; |
280 | |
281 | input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size); |
282 | // This is safe because the underlying storage for the CodedInputStream is |
283 | // owned by the input tensor. If it were a Cord or file-backed stream this |
284 | // pointer would go stale after the bytes were skipped. |
285 | const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf); |
286 | |
287 | // Important: we skipped the input->{Push,Pop}Limit() calls for speed, |
288 | // so the bounds check on buf_size inside Skip() is critical, and |
289 | // must be done before scanning the contents. |
290 | if (!input->Skip(buf_size)) { |
291 | return errors::DataLoss("ReadPackedValues: Skipping packed field failed" ); |
292 | } |
293 | |
294 | // Dispatch to the appropriately typed field reader based on the schema |
295 | // type. |
296 | Status st; |
297 | switch (field.type) { |
298 | case WireFormatLite::TYPE_DOUBLE: |
299 | st = CountPackedFixed<double>(buf, buf_size); |
300 | break; |
301 | case WireFormatLite::TYPE_FLOAT: |
302 | st = CountPackedFixed<float>(buf, buf_size); |
303 | break; |
304 | case WireFormatLite::TYPE_INT64: |
305 | st = CountPackedVarint(buf, buf_size); |
306 | break; |
307 | case WireFormatLite::TYPE_UINT64: |
308 | st = CountPackedVarint(buf, buf_size); |
309 | break; |
310 | case WireFormatLite::TYPE_INT32: |
311 | st = CountPackedVarint(buf, buf_size); |
312 | break; |
313 | case WireFormatLite::TYPE_FIXED64: |
314 | st = CountPackedFixed<uint64>(buf, buf_size); |
315 | break; |
316 | case WireFormatLite::TYPE_FIXED32: |
317 | st = CountPackedFixed<uint32>(buf, buf_size); |
318 | break; |
319 | case WireFormatLite::TYPE_BOOL: |
320 | st = CountPackedVarint(buf, buf_size); |
321 | break; |
322 | case WireFormatLite::TYPE_STRING: |
323 | st = errors::DataLoss("TYPE_STRING encountered as packed" ); |
324 | break; |
325 | case WireFormatLite::TYPE_GROUP: |
326 | st = errors::DataLoss("TYPE_GROUP encountered as packed" ); |
327 | break; |
328 | case WireFormatLite::TYPE_MESSAGE: |
329 | st = errors::DataLoss("TYPE_MESSAGE encountered as packed" ); |
330 | break; |
331 | case WireFormatLite::TYPE_BYTES: |
332 | st = errors::DataLoss("TYPE_BYTES encountered as packed" ); |
333 | break; |
334 | case WireFormatLite::TYPE_UINT32: |
335 | st = CountPackedVarint(buf, buf_size); |
336 | break; |
337 | case WireFormatLite::TYPE_ENUM: |
338 | st = CountPackedVarint(buf, buf_size); |
339 | break; |
340 | case WireFormatLite::TYPE_SFIXED32: |
341 | st = CountPackedFixed<int32>(buf, buf_size); |
342 | break; |
343 | case WireFormatLite::TYPE_SFIXED64: |
344 | st = CountPackedFixed<int64_t>(buf, buf_size); |
345 | break; |
346 | case WireFormatLite::TYPE_SINT32: |
347 | st = CountPackedVarint(buf, buf_size); |
348 | break; |
349 | case WireFormatLite::TYPE_SINT64: |
350 | st = CountPackedVarint(buf, buf_size); |
351 | break; |
352 | // default: intentionally omitted in order to enable static checking. |
353 | } |
354 | if (!st.ok()) { |
355 | return st; |
356 | } |
357 | |
358 | if (!field.is_repeated && *count_ptr_ > 1) { |
359 | *count_ptr_ = 1; |
360 | } |
361 | return OkStatus(); |
362 | } |
363 | |
364 | private: |
365 | // Skips a length-delimited value. |
366 | static bool SkipBytes(CodedInputStream* input) { |
367 | uint32 length; |
368 | if (!input->ReadVarint32(&length)) { |
369 | return false; |
370 | } |
371 | return input->Skip(length); |
372 | } |
373 | |
374 | // Counts the number of packed varints in an array. The end of a varint is |
375 | // signaled by a value < 0x80, so counting them requires parsing the |
376 | // bytestream. It is the caller's responsibility to ensure that len > 0. |
377 | Status CountPackedVarint(const uint8* buf, size_t len) { |
378 | const uint8* bound = buf + len; |
379 | int count; |
380 | |
381 | // The last byte in a valid encoded varint is guaranteed to have the high |
382 | // bit unset. We rely on this property to prevent ReadVarint64FromArray from |
383 | // going out of bounds, so validate the end of the buf before scanning |
384 | // anything. |
385 | if (bound[-1] & 0x80) { |
386 | return errors::DataLoss("Corrupt packed varint" ); |
387 | } |
388 | |
389 | // Now we can trust ReadVarint64FromArray to stay in bounds. |
390 | for (count = 0; buf < bound; ++count) { |
391 | uint64 temp; |
392 | bool ok; |
393 | buf = internal::ReadVarint64FromArray(buf, &ok, &temp); |
394 | if (!ok) { |
395 | return errors::DataLoss("Corrupt packed varint" ); |
396 | } |
397 | } |
398 | |
399 | *count_ptr_ += count; |
400 | return OkStatus(); |
401 | } |
402 | |
403 | // Counts the number of fixed-size values in a packed field. This can be done |
404 | // without actually parsing anything. |
405 | template <typename T> |
406 | Status CountPackedFixed(const uint8* unused_buf, size_t len) { |
407 | int count = len / sizeof(T); |
408 | if (count * sizeof(T) != len) { |
409 | return errors::DataLoss( |
410 | "Illegal data length for packed fixed-size type: " , len); |
411 | } |
412 | *count_ptr_ += len / sizeof(T); |
413 | return OkStatus(); |
414 | } |
415 | |
416 | // Skips a single value in the input stream. Dispatches to the appropriately |
417 | // typed field skipper based on the schema type tag. This is not as permissive |
418 | // as just handling the wire type. |
419 | static bool SkipValue(CodedInputStream* input, const FieldInfo& field) { |
420 | uint32 tmp32; |
421 | protobuf_uint64 tmp64; |
422 | switch (field.type) { |
423 | case WireFormatLite::TYPE_DOUBLE: |
424 | return input->ReadLittleEndian64(&tmp64); |
425 | case WireFormatLite::TYPE_FLOAT: |
426 | return input->ReadLittleEndian32(&tmp32); |
427 | case WireFormatLite::TYPE_INT64: |
428 | return input->ReadVarint64(&tmp64); |
429 | case WireFormatLite::TYPE_UINT64: |
430 | return input->ReadVarint64(&tmp64); |
431 | case WireFormatLite::TYPE_INT32: |
432 | return input->ReadVarint32(&tmp32); |
433 | case WireFormatLite::TYPE_FIXED64: |
434 | return input->ReadLittleEndian64(&tmp64); |
435 | case WireFormatLite::TYPE_FIXED32: |
436 | return input->ReadLittleEndian32(&tmp32); |
437 | case WireFormatLite::TYPE_BOOL: |
438 | return input->ReadVarint32(&tmp32); |
439 | case WireFormatLite::TYPE_STRING: |
440 | return SkipBytes(input); |
441 | case WireFormatLite::TYPE_GROUP: |
442 | return WireFormatLite::SkipField( |
443 | input, WireFormatLite::MakeTag( |
444 | field.number, WireFormatLite::WIRETYPE_START_GROUP)); |
445 | case WireFormatLite::TYPE_MESSAGE: |
446 | return SkipBytes(input); |
447 | case WireFormatLite::TYPE_BYTES: |
448 | return SkipBytes(input); |
449 | case WireFormatLite::TYPE_UINT32: |
450 | return input->ReadVarint32(&tmp32); |
451 | case WireFormatLite::TYPE_ENUM: |
452 | return input->ReadVarint32(&tmp32); |
453 | case WireFormatLite::TYPE_SFIXED32: |
454 | return input->ReadLittleEndian32(&tmp32); |
455 | case WireFormatLite::TYPE_SFIXED64: |
456 | return input->ReadLittleEndian64(&tmp64); |
457 | case WireFormatLite::TYPE_SINT32: |
458 | return input->ReadVarint32(&tmp32); |
459 | case WireFormatLite::TYPE_SINT64: |
460 | return input->ReadVarint64(&tmp64); |
461 | // default: intentionally omitted in order to enable static checking. |
462 | } |
463 | } |
464 | |
465 | int32* count_ptr_ = nullptr; |
466 | }; |
467 | |
468 | // A DenseCollector accumulates values from a proto into a tensor. |
469 | // |
470 | // There is an instance of DenseCollector for each field of each proto. The |
471 | // DenseCollector deserializes the value from the wire directly into the |
472 | // preallocated output Tensor. |
473 | // |
474 | // This class is named DenseCollector because in the future there should be a |
475 | // SparseCollector that accumulates field data into sparse tensors if the user |
476 | // requests it. |
477 | class DenseCollector { |
478 | public: |
479 | DenseCollector() = delete; |
480 | |
481 | // A DenseCollector applies to one field of a serialized message. |
482 | // Note that default_value.dtype is the type of the output tensor. |
483 | DenseCollector(uint8* datap, DefaultValue default_value, int max_repeat_count) |
484 | : datap_(datap), |
485 | default_value_(default_value), |
486 | max_repeat_count_(max_repeat_count) {} |
487 | |
488 | // Reads a value from the input stream and stores it. |
489 | // |
490 | // Always inlining gave a ~50% speedup on microbenchmarks at one point. |
491 | // TODO(nix): try removing it to see if that still holds. |
492 | // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE |
493 | Status ReadValue(CodedInputStream* input, const FieldInfo& field) { |
494 | // For required and optional fields, we overwrite values[0] with |
495 | // the latest one in the wire stream. |
496 | // See https://developers.google.com/protocol-buffers/docs/encoding#optional |
497 | // Only for repeated fields do we advance the next_repeat_index_ past 1. |
498 | // TODO(nix): to handle oneof we must also zero out any previous values |
499 | // seen on the wire. |
500 | int32_t index = 0; |
501 | if (field.is_repeated) { |
502 | index = next_repeat_index_; |
503 | } |
504 | next_repeat_index_ = index + 1; |
505 | |
506 | return internal::ReadValue(input, field.type, field.number, |
507 | default_value_.dtype, index, datap_); |
508 | } |
509 | |
510 | // Reads and stores a length-delimited list of values. |
511 | Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field, |
512 | const size_t buf_size) { |
513 | const void* buf; |
514 | int unused_max_buf_size; |
515 | input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size); |
516 | // This is safe because the underlying storage for the CodedInputStream is |
517 | // owned by the input tensor. If it were a Cord or file-backed stream this |
518 | // pointer would go stale after the bytes were skipped. |
519 | if (!input->Skip(buf_size)) { |
520 | return errors::DataLoss( |
521 | "ReadPackedValues: Skipping packed field failed. Field tag: " , |
522 | field.number); |
523 | } |
524 | |
525 | // Setting stride=0 causes new values to overwrite old ones for |
526 | // non-repeated fields. |
527 | const int stride = field.is_repeated ? 1 : 0; |
528 | |
529 | if (next_repeat_index_ >= max_repeat_count_) { |
530 | return errors::DataLoss( |
531 | "ReadPackedValues: Tried to write more entries than allowed. " |
532 | "Field tag: " , |
533 | field.number, ", Max entries allowed: " , max_repeat_count_); |
534 | } else { |
535 | return internal::ReadPackedFromArray(buf, buf_size, field.type, |
536 | field.number, default_value_.dtype, |
537 | stride, &next_repeat_index_, datap_); |
538 | } |
539 | } |
540 | |
541 | // Fills in any missing values in the output array with defaults. Dispatches |
542 | // to the appropriately typed field default based on the runtime type tag. |
543 | Status FillWithDefaults() { |
544 | switch (default_value_.dtype) { |
545 | case DataType::DT_BOOL: |
546 | return FillDefault<bool>(default_value_.value.v_bool); |
547 | case DataType::DT_FLOAT: |
548 | return FillDefault<float>(default_value_.value.v_float); |
549 | case DataType::DT_DOUBLE: |
550 | return FillDefault<double>(default_value_.value.v_double); |
551 | case DataType::DT_INT8: |
552 | return FillDefault<int8>(default_value_.value.v_int8); |
553 | case DataType::DT_INT32: |
554 | return FillDefault<int32>(default_value_.value.v_int32); |
555 | case DataType::DT_INT64: |
556 | return FillDefault<int64_t>(default_value_.value.v_int64); |
557 | case DataType::DT_STRING: |
558 | return FillDefault<tstring>(default_value_.value.v_string); |
559 | case DataType::DT_UINT8: |
560 | return FillDefault<uint8>(default_value_.value.v_uint8); |
561 | case DataType::DT_UINT32: |
562 | return FillDefault<uint32>(default_value_.value.v_uint32); |
563 | case DataType::DT_UINT64: |
564 | return FillDefault<uint64>(default_value_.value.v_uint64); |
565 | default: |
566 | // There are many tensorflow dtypes not handled here, but they |
567 | // should not come up unless type casting is added to the Op. |
568 | // Chaining with tf.cast() should do the right thing until then. |
569 | return errors::DataLoss("Failed filling defaults for " , |
570 | DataTypeString(default_value_.dtype)); |
571 | } |
572 | } |
573 | |
574 | private: |
575 | // Fills empty values in the dense representation with a default value. This |
576 | // uses next_repeat_index_ which counts the number of parsed values for the |
577 | // field. |
578 | template <class T> |
579 | Status FillDefault(const T& default_value) { |
580 | for (int i = next_repeat_index_; i < max_repeat_count_; i++) { |
581 | reinterpret_cast<T*>(datap_)[i] = default_value; |
582 | } |
583 | return OkStatus(); |
584 | } |
585 | |
586 | int32 next_repeat_index_ = 0; |
587 | |
588 | // This is a pointer to data_[message_index_]. There is no bounds checking at |
589 | // this level: we computed the max repeat size for each field in |
590 | // CountCollector and use the same code to traverse it here, so we are |
591 | // guaranteed not to be called for more items than we have allocated space. |
592 | void* const datap_ = nullptr; |
593 | |
594 | const DefaultValue default_value_; |
595 | const int max_repeat_count_ = 0; |
596 | }; |
597 | |
598 | class DecodeProtoOp : public OpKernel { |
599 | public: |
600 | explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) { |
601 | string descriptor_source; |
602 | OP_REQUIRES_OK(context, |
603 | context->GetAttr("descriptor_source" , &descriptor_source)); |
604 | |
605 | // We always get back a desc_pool, but we may not own it. If we own it, |
606 | // owned_desc_pool_ will be filled in. |
607 | DescriptorPool const* desc_pool; |
608 | OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source, |
609 | &desc_pool, &owned_desc_pool_)); |
610 | |
611 | string message_type; |
612 | OP_REQUIRES_OK(context, context->GetAttr("message_type" , &message_type)); |
613 | |
614 | const Descriptor* message_desc = |
615 | desc_pool->FindMessageTypeByName(message_type); |
616 | OP_REQUIRES(context, message_desc != nullptr, |
617 | errors::InvalidArgument("No descriptor found for message type " , |
618 | message_type)); |
619 | |
620 | std::vector<string> field_names; |
621 | OP_REQUIRES_OK(context, context->GetAttr("field_names" , &field_names)); |
622 | std::vector<DataType> output_types; |
623 | OP_REQUIRES_OK(context, context->GetAttr("output_types" , &output_types)); |
624 | OP_REQUIRES( |
625 | context, field_names.size() == output_types.size(), |
626 | errors::InvalidArgument("field_names and output_types attributes must " |
627 | "have the same length" )); |
628 | |
629 | // Gather the field descriptors and check that requested output types match. |
630 | int field_index = 0; |
631 | std::vector<const FieldDescriptor*> field_descs; |
632 | std::vector<const FieldDescriptor*> exts; |
633 | absl::flat_hash_map<string, const FieldDescriptor*> ext_name_to_field; |
634 | std::vector<const FieldDescriptor*>::iterator ext_it = exts.begin(); |
635 | for (const string& name : field_names) { |
636 | auto fd = message_desc->FindFieldByName(name); |
637 | if (fd == nullptr) { |
638 | // If field can't be found in original message, try to find a matching |
639 | // extension (by its full_name). First check a hashmap for a matching |
640 | // extension, and if not found, then iterate through available |
641 | // extensions to find a match (updating the hashmap while iterating.) |
642 | auto lookup_result = ext_name_to_field.find(name); |
643 | if (lookup_result != ext_name_to_field.end()) { |
644 | fd = lookup_result->second; |
645 | } else { |
646 | if (ext_it == exts.begin()) { |
647 | desc_pool->FindAllExtensions(message_desc, &exts); |
648 | ext_it = exts.begin(); |
649 | } |
650 | while (ext_it != exts.end()) { |
651 | auto ext_name = (*ext_it)->full_name(); |
652 | auto ext_field = *ext_it; |
653 | ++ext_it; |
654 | |
655 | ext_name_to_field.insert({ext_name, ext_field}); |
656 | if (ext_name == name) { |
657 | fd = ext_field; |
658 | break; |
659 | } |
660 | } |
661 | } |
662 | } |
663 | OP_REQUIRES(context, fd != nullptr, |
664 | errors::InvalidArgument("Unknown field: " , name, |
665 | " in message type " , message_type)); |
666 | OP_REQUIRES( |
667 | context, |
668 | proto_utils::IsCompatibleType(fd->type(), output_types[field_index]), |
669 | // Many TensorFlow types don't have corresponding proto types and the |
670 | // user will get an error if they are requested. It would be nice to |
671 | // allow conversions here, but tf.cast already exists so we don't |
672 | // duplicate the functionality. |
673 | errors::InvalidArgument("Unexpected output type for " , |
674 | fd->full_name(), ": " , fd->cpp_type(), " to " , |
675 | output_types[field_index])); |
676 | |
677 | field_index++; |
678 | field_descs.push_back(fd); |
679 | } |
680 | |
681 | // Internally we want the field_descs sorted by their number on the wire. |
682 | // But the output tensors are allocated in the order given by the caller. |
683 | // Build a mapping i->j, where field_descs[i] corresponds to outputs[j]. |
684 | std::vector<int> output_indices; |
685 | output_indices.reserve(field_names.size()); |
686 | for (int i = 0; i < field_names.size(); i++) { |
687 | output_indices.push_back(i); |
688 | } |
689 | std::sort(output_indices.begin(), output_indices.end(), |
690 | [field_descs](int a, int b) { |
691 | return field_descs[a]->number() < field_descs[b]->number(); |
692 | }); |
693 | |
694 | // Now store the fields in sorted order. |
695 | for (int i = 0; i < field_names.size(); i++) { |
696 | const int output_index = output_indices[i]; |
697 | const DataType dtype = output_types[output_index]; |
698 | const FieldDescriptor* field_descriptor = field_descs[output_index]; |
699 | DefaultValue default_value; |
700 | OP_REQUIRES_OK(context, InitDefaultValueFromFieldDescriptor( |
701 | dtype, field_descriptor, &default_value)); |
702 | fields_.push_back( |
703 | MakeUnique<FieldInfo>(field_descriptor, output_index, default_value)); |
704 | } |
705 | |
706 | message_prototype_ = message_factory_.GetPrototype(message_desc); |
707 | OP_REQUIRES(context, message_prototype_ != nullptr, |
708 | errors::InvalidArgument("Couldn't get prototype message: " , |
709 | message_desc->full_name())); |
710 | string format; |
711 | OP_REQUIRES_OK(context, context->GetAttr("message_format" , &format)); |
712 | OP_REQUIRES( |
713 | context, format == "binary" || format == "text" , |
714 | errors::InvalidArgument("format must be one of binary or text" )); |
715 | is_binary_ = format == "binary" ; |
716 | |
717 | // Enable the initial protobuf sanitizer, which is much more expensive than |
718 | // the decoder. |
719 | // TODO(nix): Remove this once the fast decoder has passed security review. |
720 | OP_REQUIRES_OK(context, context->GetAttr("sanitize" , &sanitize_)); |
721 | } |
722 | |
723 | void Compute(OpKernelContext* ctx) override { |
724 | const Tensor& buf_tensor = ctx->input(0); |
725 | int message_count = buf_tensor.NumElements(); |
726 | OP_REQUIRES(ctx, message_count >= 1, |
727 | errors::InvalidArgument( |
728 | "Bufs argument must contain at least one value" )); |
729 | |
730 | int field_count = fields_.size(); |
731 | |
732 | // Save the argument shape for later, then flatten the input Tensor since we |
733 | // are working componentwise. We will restore the same shape in the returned |
734 | // Tensor. |
735 | const TensorShape& shape_prefix = buf_tensor.shape(); |
736 | |
737 | TensorShape sizes_shape = shape_prefix; |
738 | sizes_shape.AddDim(field_count); |
739 | Tensor* sizes_tensor = nullptr; |
740 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor)); |
741 | |
742 | // This is used to allocate binary bufs if used. It serves only to define |
743 | // memory ownership. |
744 | std::vector<tstring> tmp_binary_bufs(message_count); |
745 | |
746 | // These are the actual buffers to use, which may be in tmp_binary_bufs |
747 | // or may be pointers into the buf_tensor. Either way they are not owned |
748 | // here. |
749 | std::vector<const tstring*> bufs; |
750 | |
751 | if (is_binary_ && !sanitize_) { |
752 | // Fast path. |
753 | for (int mi = 0; mi < message_count; ++mi) { |
754 | const tstring* buf = &buf_tensor.flat<tstring>()(mi); |
755 | bufs.push_back(buf); |
756 | } |
757 | } else { |
758 | // We will have to allocate a copy, either to convert from text to binary |
759 | // or to sanitize a binary proto. |
760 | for (int mi = 0; mi < message_count; ++mi) { |
761 | ReserializeMessage(ctx, buf_tensor.flat<tstring>()(mi), |
762 | &tmp_binary_bufs[mi]); |
763 | if (!ctx->status().ok()) { |
764 | return; |
765 | } |
766 | bufs.push_back(&tmp_binary_bufs[mi]); |
767 | } |
768 | } |
769 | |
770 | // Walk through all the strings in the input tensor, counting the number of |
771 | // fields in each. We can't allocate our actual output Tensor until we know |
772 | // the maximum repeat count, so we do a first pass through the serialized |
773 | // proto just counting fields. We always allocate at least one value so that |
774 | // optional fields are populated with default values - this avoids a TF |
775 | // conditional when handling the output data. The caller can distinguish |
776 | // between real data and defaults using the repeat count matrix that is |
777 | // returned by decode_proto. |
778 | std::vector<int32> max_sizes(field_count, 1); |
779 | for (int mi = 0; mi < message_count; ++mi) { |
780 | CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes); |
781 | if (!ctx->status().ok()) { |
782 | return; |
783 | } |
784 | } |
785 | |
786 | // Allocate the output tensors now that we've seen the max size. |
787 | // TODO(nix): Use allocate_output_or_forward_input for the largest |
788 | // output tensor. This can avoid one large allocation by re-using |
789 | // the memory of the input tensor. |
790 | std::vector<Tensor*> outputs(field_count); |
791 | for (int fi = 0; fi < field_count; ++fi) { |
792 | TensorShape flat_shape = {static_cast<int64_t>(message_count), |
793 | max_sizes[fi]}; |
794 | TensorShape out_shape = shape_prefix; |
795 | out_shape.AddDim(max_sizes[fi]); |
796 | |
797 | // Surprisingly we don't specify the types from the output_types |
798 | // attribute: that is done for us based on the Op declaration: |
799 | // REGISTER_OP(...) |
800 | // .Attr("output_types: list(type) >= 0") |
801 | // .Output("values: output_types") |
802 | OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1, |
803 | out_shape, &outputs[fi])); |
804 | } |
805 | |
806 | // Make the second pass through the serialized proto, decoding into |
807 | // preallocated tensors. |
808 | AccumulateFields(ctx, bufs, outputs); |
809 | } |
810 | |
811 | private: |
812 | // Copy a serialized message to binary, e.g. to handle text proto inputs. |
813 | void ReserializeMessage(OpKernelContext* ctx, const tstring& buf, |
814 | tstring* binary_buf) { |
815 | // Handle text protos by translating them to binary. |
816 | std::unique_ptr<Message> message(message_prototype_->New()); |
817 | OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed" )); |
818 | |
819 | if (is_binary_) { |
820 | // If we get here we are sanitizing the input protobuf by parsing |
821 | // and reserializing it with a trusted (but very slow) library. |
822 | OP_REQUIRES(ctx, message->ParseFromString(buf), |
823 | errors::DataLoss("Unable to parse binary protobuf" )); |
824 | } else { |
825 | OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()), |
826 | errors::DataLoss("Unable to parse text protobuf" )); |
827 | } |
828 | |
829 | OP_REQUIRES(ctx, SerializeToTString(*message, binary_buf), |
830 | errors::DataLoss("Unable to reserialize text proto as binary" )); |
831 | } |
832 | |
833 | // Count the number of occurrences of each requested field in a message batch. |
834 | void CountFields(OpKernelContext* ctx, int message_index, const tstring& buf, |
835 | Tensor* sizes_tensor, std::vector<int32>* max_sizes) { |
836 | int field_count = fields_.size(); |
837 | |
838 | CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()), |
839 | buf.size()); |
840 | |
841 | std::vector<int32> field_sizes(field_count, 0); |
842 | std::vector<CountCollector> counters; |
843 | counters.reserve(field_count); |
844 | for (int i = 0; i < field_count; i++) { |
845 | counters.emplace_back(&field_sizes[i]); |
846 | } |
847 | |
848 | Status st = Collect(&input, absl::MakeSpan(counters)); |
849 | if (st.ok() && !input.ConsumedEntireMessage()) { |
850 | st = errors::DataLoss("CountFields: Failed to consume entire buffer" ); |
851 | } |
852 | if (kFailOnDecodeError) { |
853 | OP_REQUIRES_OK(ctx, st); // NOLINT |
854 | } |
855 | if (!st.ok()) { |
856 | // This code suppresses the corrupt proto, treating it as empty |
857 | // to avoid crashing the process. |
858 | LOG(WARNING) << "Proto counting error for message type " << message_type_ |
859 | << ": " << st; |
860 | |
861 | for (int fi = 0; fi < field_count; fi++) { |
862 | field_sizes[fi] = 0; |
863 | } |
864 | // Finished decoding this message. |
865 | return; |
866 | } |
867 | |
868 | // Update the size tensor and max repeat size for each field. |
869 | auto sizes = sizes_tensor->flat_inner_dims<int32>(); |
870 | for (int fi = 0; fi < field_count; fi++) { |
871 | int32_t size = field_sizes[fi]; |
872 | sizes(message_index, fields_[fi]->output_index) = size; |
873 | if ((*max_sizes)[fi] < size) { |
874 | (*max_sizes)[fi] = size; |
875 | } |
876 | } |
877 | } |
878 | |
879 | // Parse fields from a serialized message into preallocated tensors. |
880 | void AccumulateFields(OpKernelContext* ctx, |
881 | const std::vector<const tstring*>& bufs, |
882 | std::vector<Tensor*> outputs) { |
883 | struct TensorInfo { |
884 | explicit TensorInfo(Tensor* tensor) { |
885 | // Note that we can decode only max_repeat_count values before overflow. |
886 | // No other bounds checking is done for repeated fields. For |
887 | // optional fields there is a check to make sure that only the last |
888 | // value on the wire appears in the output tensor. |
889 | dtype = tensor->dtype(); |
890 | last_dim_size = tensor->dim_size(tensor->dims() - 1); |
891 | |
892 | if (dtype != DT_STRING) { |
893 | const int element_size = DataTypeSize(dtype); |
894 | CHECK_GT(element_size, 0); |
895 | stride = last_dim_size * element_size; |
896 | |
897 | const int64_t flatshape[1] = {tensor->NumElements() * element_size}; |
898 | data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data(); |
899 | } else { |
900 | // DataTypeSize() returns 0 for string types. |
901 | stride = last_dim_size * sizeof(tstring); |
902 | data = reinterpret_cast<uint8*>(tensor->flat<tstring>().data()); |
903 | } |
904 | } |
905 | |
906 | DataType dtype; |
907 | int last_dim_size; |
908 | int stride; |
909 | uint8* data; |
910 | }; |
911 | |
912 | int field_count = fields_.size(); |
913 | |
914 | std::vector<TensorInfo> tensors; |
915 | tensors.reserve(field_count); |
916 | for (int fi = 0; fi < field_count; fi++) { |
917 | tensors.emplace_back(outputs[fi]); |
918 | } |
919 | |
920 | for (int message_index = 0; message_index < bufs.size(); ++message_index) { |
921 | const tstring& buf = *bufs[message_index]; |
922 | |
923 | std::vector<DenseCollector> collectors; |
924 | collectors.reserve(field_count); |
925 | for (int output_index = 0; output_index < field_count; ++output_index) { |
926 | const TensorInfo& info = tensors[output_index]; |
927 | const FieldInfo* field_info = fields_[output_index].get(); |
928 | DCHECK(field_info != nullptr); |
929 | const DefaultValue default_value = field_info->default_value; |
930 | collectors.emplace_back(info.data + message_index * info.stride, |
931 | default_value, info.last_dim_size); |
932 | } |
933 | |
934 | // Fill in output tensors from the wire. |
935 | CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()), |
936 | buf.size()); |
937 | Status st = Collect(&input, absl::MakeSpan(collectors)); |
938 | if (st.ok() && !input.ConsumedEntireMessage()) { |
939 | st = errors::DataLoss( |
940 | "AccumulateFields: Failed to consume entire buffer" ); |
941 | } |
942 | if (kFailOnDecodeError) { |
943 | OP_REQUIRES_OK(ctx, st); // NOLINT |
944 | } |
945 | if (!st.ok()) { |
946 | // This code suppresses the corrupt proto, treating it as empty |
947 | // to avoid crashing training. |
948 | LOG(WARNING) << "Proto counting error for message type " |
949 | << message_type_ << ": " << st; |
950 | } |
951 | |
952 | // Fill the remainder of the dense outputs with default values. |
953 | for (auto& collector : collectors) { |
954 | OP_REQUIRES_OK(ctx, collector.FillWithDefaults()); |
955 | } |
956 | } |
957 | } |
958 | |
959 | // Traverses a serialized protobuf, dispatching values to the collectors. |
960 | template <class CollectorClass> |
961 | Status Collect(CodedInputStream* input, |
962 | absl::Span<CollectorClass> collectors) { |
963 | // At the beginning of each loop, the last field number that was seen, |
964 | // regardless of whether it was collected or not, or -1 if no field has |
965 | // been seen before. |
966 | int last_seen_field_number = -1; |
967 | // The FieldInfo that is expected to be used next. |
968 | // It was either used to collect the last seen field number, or if the |
969 | // last seen field number was not in fields_, it is the next FieldInfo after |
970 | // the last seen field number. At the beginning it is the first FieldInfo. |
971 | auto expected_field_info_iter = fields_.begin(); |
972 | |
973 | // The 'tag' variable should always be treated as tainted. |
974 | for (uint32 tag = input->ReadTag(); |
975 | tag != 0 && WireFormatLite::GetTagWireType(tag) != |
976 | WireFormatLite::WIRETYPE_END_GROUP; |
977 | tag = input->ReadTag()) { |
978 | DCHECK(expected_field_info_iter == fields_.begin() || |
979 | last_seen_field_number > |
980 | (*(expected_field_info_iter - 1))->number); |
981 | DCHECK(expected_field_info_iter == fields_.end() || |
982 | last_seen_field_number <= (*expected_field_info_iter)->number); |
983 | |
984 | // The field wire number. |
985 | const int field_number = WireFormatLite::GetTagFieldNumber(tag); |
986 | // The field info associated with the field wire number. |
987 | const FieldInfo* field_info = nullptr; |
988 | |
989 | // fields_ are ordered by their field numbers. If the field numbers |
990 | // on wire are also ordered (which is a convention), then we can |
991 | // monotonically increment `expected_field_info_iter` as the field |
992 | // numbers on wire get larger. If we detect any out-of-order |
993 | // field number, we reset `expected_field_info_iter`, and expect that |
994 | // future wire numbers are ordered. This algorithm is quadratic in the |
995 | // worst case where field numbers on wire are in descending order, however |
996 | // it works well in the case where two serialized protobufs are |
997 | // concatenated together. |
998 | if (field_number < last_seen_field_number) { |
999 | expected_field_info_iter = fields_.begin(); |
1000 | } |
1001 | |
1002 | // Advance expected_field_info_iter until |
1003 | // field_number <= expected_field_number. |
1004 | for (; expected_field_info_iter != fields_.end(); |
1005 | ++expected_field_info_iter) { |
1006 | DCHECK(expected_field_info_iter == fields_.begin() || |
1007 | field_number > (*(expected_field_info_iter - 1))->number); |
1008 | const FieldInfo* expected_field_info = expected_field_info_iter->get(); |
1009 | if (field_number <= expected_field_info->number) { |
1010 | if (field_number == expected_field_info->number) { |
1011 | field_info = expected_field_info; |
1012 | } |
1013 | break; |
1014 | } |
1015 | } |
1016 | last_seen_field_number = field_number; |
1017 | if (!field_info) { |
1018 | // This DCHECK verifies that if we skip a field, we didn't want it. |
1019 | // In particular, field_builders is empty or the field_number is either: |
1020 | // before fields_.begin().number or after (fields_.end() - 1).number or |
1021 | // in-between expected_field_info_iter and expected_field_info_iter - 1. |
1022 | DCHECK(fields_.empty() || (field_number < (*fields_.begin())->number) || |
1023 | (field_number > (*(fields_.end() - 1))->number) || |
1024 | (((*(expected_field_info_iter - 1))->number < field_number) && |
1025 | (field_number < (*(expected_field_info_iter))->number))); |
1026 | // Unknown and unrequested fields are skipped. |
1027 | if (!WireFormatLite::SkipField(input, tag)) { |
1028 | return errors::DataLoss("Failed skipping unrequested field" ); |
1029 | } |
1030 | continue; |
1031 | } |
1032 | |
1033 | TF_RETURN_IF_ERROR(CollectField( |
1034 | *field_info, WireFormatLite::GetTagWireType(tag), input, |
1035 | &collectors[expected_field_info_iter - fields_.begin()])); |
1036 | } |
1037 | return OkStatus(); |
1038 | } |
1039 | |
1040 | // Collects values for a single field. |
1041 | template <class CollectorClass> |
1042 | Status CollectField(const FieldInfo& field, |
1043 | WireFormatLite::WireType wire_type, |
1044 | CodedInputStream* input, CollectorClass* collector) { |
1045 | // The wire format library defines the same constants used in |
1046 | // descriptor.proto. This static_cast is safe because they are guaranteed to |
1047 | // stay in sync. |
1048 | // |
1049 | // We need the field type from the FieldDescriptor here because the wire |
1050 | // format doesn't tell us anything about what happens inside a packed |
1051 | // repeated field: there is enough information in the wire format to skip |
1052 | // the whole field but not enough to know how to parse what's inside. For |
1053 | // that we go to the schema. |
1054 | WireFormatLite::WireType schema_wire_type = |
1055 | WireFormatLite::WireTypeForFieldType(field.type); |
1056 | |
1057 | // Handle packed repeated fields. SkipField would skip the whole |
1058 | // length-delimited blob without letting us count the values, so we have to |
1059 | // scan them ourselves. |
1060 | if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED && |
1061 | schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) { |
1062 | // Handle packed repeated primitives. |
1063 | int length; |
1064 | if (!input->ReadVarintSizeAsInt(&length)) { |
1065 | return errors::DataLoss("CollectField: Failed reading packed size" ); |
1066 | } |
1067 | return collector->ReadPackedValues(input, field, length); |
1068 | } |
1069 | |
1070 | // Read ordinary values, including strings, bytes, and messages. |
1071 | if (wire_type != schema_wire_type) { |
1072 | if (!WireFormatLite::SkipField( |
1073 | input, WireFormatLite::MakeTag(field.number, wire_type))) { |
1074 | return errors::DataLoss( |
1075 | "CollectField: Failed skipping malformed field" ); |
1076 | } |
1077 | return OkStatus(); |
1078 | } |
1079 | return collector->ReadValue(input, field); |
1080 | } |
1081 | |
1082 | string message_type_; |
1083 | // Note that fields are sorted by increasing field number, which is not in |
1084 | // general the order given by the user-specified field_names and output_types |
1085 | // Op attributes. |
1086 | std::vector<std::unique_ptr<const FieldInfo>> fields_; |
1087 | |
1088 | // Owned_desc_pool_ is null when using descriptor_source=local. |
1089 | std::unique_ptr<DescriptorPool> owned_desc_pool_; |
1090 | DynamicMessageFactory message_factory_; |
1091 | const Message* message_prototype_; |
1092 | |
1093 | // True if decoding binary format, false if decoding text format. |
1094 | bool is_binary_; |
1095 | |
1096 | // True if the protos should be sanitized before parsing. Enables the initial |
1097 | // protobuf sanitizer, which is much more expensive than the decoder. The flag |
1098 | // defaults to true but can be set to false for trusted sources. |
1099 | // |
1100 | // TODO(nix): Flip the default to false when the fast decoder has passed |
1101 | // security review. |
1102 | bool sanitize_; |
1103 | |
1104 | TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp); |
1105 | }; |
1106 | |
1107 | REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2" ).Device(DEVICE_CPU), |
1108 | DecodeProtoOp); |
1109 | |
1110 | } // namespace |
1111 | } // namespace tensorflow |
1112 | |