1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// DecodeProto is a TensorFlow op which extracts arbitrary fields from protos
17// serialized as strings.
18//
19// See docs in ../ops/decode_proto_op.cc.
20//
21// This implementation reads the serialized format using a handful of calls from
22// the WireFormatLite API used by generated proto code. WireFormatLite is marked
23// as an "internal" proto API but is widely used in practice and highly unlikely
24// to change. This will be much faster than the previous implementation based on
25// constructing a temporary dynamic message in memory and using the proto
26// reflection api to read it. It can be used with any proto whose descriptors
27// are available at runtime but should be competitive in speed with approaches
28// that compile in the proto definitions.
29
30#include <memory>
31#include <string>
32#include <vector>
33
34#include "absl/container/flat_hash_map.h"
35#include "absl/types/span.h"
36#include "third_party/eigen3/Eigen/Core"
37#include "tensorflow/core/framework/op_kernel.h"
38#include "tensorflow/core/framework/tensor_types.h"
39#include "tensorflow/core/framework/types.h"
40#include "tensorflow/core/lib/core/errors.h"
41#include "tensorflow/core/platform/logging.h"
42#include "tensorflow/core/platform/protobuf.h"
43#include "tensorflow/core/util/proto/decode.h"
44#include "tensorflow/core/util/proto/descriptors.h"
45#include "tensorflow/core/util/proto/proto_utils.h"
46#include "tensorflow/core/util/ptr_util.h"
47
48namespace tensorflow {
49namespace {
50
51using ::tensorflow::MakeUnique;
52using ::tensorflow::protobuf::Descriptor;
53using ::tensorflow::protobuf::DescriptorPool;
54using ::tensorflow::protobuf::DynamicMessageFactory;
55using ::tensorflow::protobuf::FieldDescriptor;
56using ::tensorflow::protobuf::Message;
57using ::tensorflow::protobuf::TextFormat;
58using ::tensorflow::protobuf::internal::WireFormatLite;
59using ::tensorflow::protobuf::io::CodedInputStream;
60
61const bool kFailOnDecodeError = true;
62
63// Used to store the default value of a protocol message field, casted to the
64// type of the output tensor.
65//
66// TODO(paskin): Use absl::variant once TensorFlow gets absl dependencies.
67struct DefaultValue {
68 DataType dtype = DataType::DT_INVALID;
69 union Value {
70 bool v_bool; // DT_BOOL
71 double v_double; // DT_DOUBLE
72 float v_float; // DT_FLOAT
73 int8 v_int8; // DT_INT8
74 int32 v_int32; // DT_INT32
75 int64_t v_int64; // DT_INT64
76 const char* v_string; // DT_STRING
77 uint8 v_uint8; // DT_UINT8
78 uint8 v_uint32; // DT_UINT32
79 uint8 v_uint64; // DT_UINT64
80 };
81 Value value;
82};
83
84// Initializes a DefaultValue object. This generic template handles numeric
85// types and strings are handled by a template specialization below.
86//
87// Args:
88// dtype: the type of the output tensor
89// value: the default value as obtained from the FieldDescriptor
90// result: the object to initialize
91template <typename T>
92Status InitDefaultValue(DataType dtype, const T value, DefaultValue* result) {
93 result->dtype = dtype;
94 switch (dtype) {
95 case DT_BOOL:
96 result->value.v_bool = static_cast<bool>(value);
97 break;
98 case DT_DOUBLE:
99 result->value.v_double = static_cast<double>(value);
100 break;
101 case DT_FLOAT:
102 result->value.v_float = static_cast<float>(value);
103 break;
104 case DT_INT8:
105 result->value.v_int8 = static_cast<int8>(value);
106 break;
107 case DT_INT32:
108 result->value.v_int32 = static_cast<int32>(value);
109 break;
110 case DT_INT64:
111 result->value.v_int64 = static_cast<int64_t>(value);
112 break;
113 case DT_UINT8:
114 result->value.v_uint8 = static_cast<uint8>(value);
115 break;
116 case DT_UINT32:
117 result->value.v_uint32 = static_cast<uint32>(value);
118 break;
119 case DT_UINT64:
120 result->value.v_uint64 = static_cast<uint64>(value);
121 break;
122 default:
123 // We should never get here, given the type checking that occurs earlier.
124 return errors::Internal(
125 "Cannot initialize default value for unsupported type: ",
126 DataTypeString(dtype));
127 }
128 return OkStatus();
129}
130
131template <>
132Status InitDefaultValue(DataType dtype, const char* value,
133 DefaultValue* result) {
134 // These are sanity checks that should never trigger given the code that
135 // leads here.
136 if (TF_PREDICT_FALSE(dtype != DT_STRING)) {
137 return errors::InvalidArgument(
138 "Cannot cast field to anything but DT_STRING");
139 }
140 if (TF_PREDICT_FALSE(value == nullptr)) {
141 return errors::InvalidArgument("Null default string value.");
142 }
143 result->dtype = DT_STRING;
144 result->value.v_string = value;
145 return OkStatus();
146}
147
148// Initializes a default value from the output data type and the field
149// descriptor.
150Status InitDefaultValueFromFieldDescriptor(DataType dtype,
151 const FieldDescriptor* field_desc,
152 DefaultValue* result) {
153 switch (field_desc->type()) {
154 case WireFormatLite::TYPE_DOUBLE:
155 return InitDefaultValue(dtype, field_desc->default_value_double(),
156 result);
157 case WireFormatLite::TYPE_FLOAT:
158 return InitDefaultValue(dtype, field_desc->default_value_float(), result);
159 case WireFormatLite::TYPE_INT64:
160 case WireFormatLite::TYPE_SINT64:
161 case WireFormatLite::TYPE_SFIXED64:
162 return InitDefaultValue(dtype, field_desc->default_value_int64(), result);
163 case WireFormatLite::TYPE_FIXED64:
164 case WireFormatLite::TYPE_UINT64:
165 return InitDefaultValue(dtype, field_desc->default_value_uint64(),
166 result);
167 case WireFormatLite::TYPE_INT32:
168 case WireFormatLite::TYPE_SINT32:
169 case WireFormatLite::TYPE_SFIXED32:
170 return InitDefaultValue(dtype, field_desc->default_value_int32(), result);
171 case WireFormatLite::TYPE_FIXED32:
172 case WireFormatLite::TYPE_UINT32:
173 return InitDefaultValue(dtype, field_desc->default_value_uint32(),
174 result);
175 case WireFormatLite::TYPE_BOOL:
176 return InitDefaultValue(dtype, field_desc->default_value_bool(), result);
177 case WireFormatLite::TYPE_ENUM:
178 return InitDefaultValue(dtype, field_desc->default_value_enum()->number(),
179 result);
180 case WireFormatLite::TYPE_BYTES:
181 case WireFormatLite::TYPE_STRING:
182 // Manipulating default string values as C-style pointers should be OK
183 // for typical code-generated protocol messages. It is possible in
184 // principle to register a message descriptor on the fly, and these
185 // pointers may not be stable if that descriptor has a weird
186 // implementation. (But the return type of default_value_string() is
187 // const string&, so it'd have to be very weird.)
188 return InitDefaultValue(dtype, field_desc->default_value_string().c_str(),
189 result);
190 case WireFormatLite::TYPE_GROUP:
191 case WireFormatLite::TYPE_MESSAGE:
192 return InitDefaultValue(dtype, "", result);
193 // default: intentionally omitted in order to enable static checking.
194 }
195 return OkStatus();
196}
197
198// A FieldInfo holds a handful of information from the FieldDescriptor
199// and user attributes.
200struct FieldInfo {
201 FieldInfo(const FieldDescriptor* field_desc, int user_index,
202 DefaultValue def_value)
203 : output_index(user_index), default_value(def_value) {
204 // Without this intermediate data structure, the profile had hotspots
205 // calling methods of FieldDescriptor.
206 number = field_desc->number();
207
208 // The wire format library defines the same constants used in
209 // descriptor.proto. This static_cast is safe because they are guaranteed to
210 // stay in sync. We need the field type from the FieldDescriptor here
211 // because the wire format doesn't tell us anything about what happens
212 // inside a packed repeated field: there is enough information in the wire
213 // format to skip the whole field but not enough to know how to parse what's
214 // inside. For that we go to the schema.
215 type = static_cast<WireFormatLite::FieldType>(field_desc->type());
216 is_repeated = field_desc->is_repeated();
217 }
218
219 // Disable copy and move.
220 FieldInfo(const FieldInfo&) = delete;
221 FieldInfo& operator=(const FieldInfo&) = delete;
222
223 // Internally we sort field descriptors by wire number for fast lookup. In
224 // general this is different from the order given by the user. Output_index
225 // gives the index into the field_names and output_types attributes and into
226 // the output tensor list.
227 int output_index = -1;
228
229 // This is a cache of the relevant fields from `FieldDescriptorProto`. This
230 // was added after noticing that FieldDescriptor->type() was using 6% of the
231 // cpu profile.
232 WireFormatLite::FieldType type;
233 int number;
234 bool is_repeated;
235 DefaultValue default_value;
236};
237
238// A CountCollector counts sizes of repeated and optional fields in a proto.
239//
240// Each field is tracked by a single CountCollector instance. The instance
241// manages a single count, which is stored as a pointer (it is intended to be a
242// reference to the `sizes` output which is being filled in). The pointer is
243// passed in at initialization.
244//
245// Counting is done as a separate pass in order to allocate output tensors all
246// at once. This allows the TensorFlow runtime to optimize allocation for the
247// consumer, while removing the need for copying inside this op. After this
248// pass, the DenseCollector class (below) gathers the data: it is more complex
249// and provides better motivation for the API here.
250class CountCollector {
251 public:
252 CountCollector() = delete;
253
254 // The count may be stored inside an Eigen Tensor to eliminate copying.
255 explicit CountCollector(int32* count) : count_ptr_(count) {}
256
257 // Reads (in this case counts) a single value.
258 Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
259 // Only repeated fields can have count > 1.
260 if (*count_ptr_ == 0 || field.is_repeated) {
261 (*count_ptr_)++;
262 }
263 // We expect a wire type based on the schema field_type, to allow a little
264 // more checking.
265 if (!SkipValue(input, field)) {
266 return errors::DataLoss("ReadValue: Failed skipping field when counting");
267 }
268 return OkStatus();
269 }
270
271 // Reads (in this case counts) a length-delimited list of values.
272 Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
273 size_t buf_size) {
274 if (buf_size == 0) {
275 return OkStatus();
276 }
277
278 const void* tmpbuf;
279 int unused_max_buf_size;
280
281 input->GetDirectBufferPointerInline(&tmpbuf, &unused_max_buf_size);
282 // This is safe because the underlying storage for the CodedInputStream is
283 // owned by the input tensor. If it were a Cord or file-backed stream this
284 // pointer would go stale after the bytes were skipped.
285 const uint8* buf = reinterpret_cast<const uint8*>(tmpbuf);
286
287 // Important: we skipped the input->{Push,Pop}Limit() calls for speed,
288 // so the bounds check on buf_size inside Skip() is critical, and
289 // must be done before scanning the contents.
290 if (!input->Skip(buf_size)) {
291 return errors::DataLoss("ReadPackedValues: Skipping packed field failed");
292 }
293
294 // Dispatch to the appropriately typed field reader based on the schema
295 // type.
296 Status st;
297 switch (field.type) {
298 case WireFormatLite::TYPE_DOUBLE:
299 st = CountPackedFixed<double>(buf, buf_size);
300 break;
301 case WireFormatLite::TYPE_FLOAT:
302 st = CountPackedFixed<float>(buf, buf_size);
303 break;
304 case WireFormatLite::TYPE_INT64:
305 st = CountPackedVarint(buf, buf_size);
306 break;
307 case WireFormatLite::TYPE_UINT64:
308 st = CountPackedVarint(buf, buf_size);
309 break;
310 case WireFormatLite::TYPE_INT32:
311 st = CountPackedVarint(buf, buf_size);
312 break;
313 case WireFormatLite::TYPE_FIXED64:
314 st = CountPackedFixed<uint64>(buf, buf_size);
315 break;
316 case WireFormatLite::TYPE_FIXED32:
317 st = CountPackedFixed<uint32>(buf, buf_size);
318 break;
319 case WireFormatLite::TYPE_BOOL:
320 st = CountPackedVarint(buf, buf_size);
321 break;
322 case WireFormatLite::TYPE_STRING:
323 st = errors::DataLoss("TYPE_STRING encountered as packed");
324 break;
325 case WireFormatLite::TYPE_GROUP:
326 st = errors::DataLoss("TYPE_GROUP encountered as packed");
327 break;
328 case WireFormatLite::TYPE_MESSAGE:
329 st = errors::DataLoss("TYPE_MESSAGE encountered as packed");
330 break;
331 case WireFormatLite::TYPE_BYTES:
332 st = errors::DataLoss("TYPE_BYTES encountered as packed");
333 break;
334 case WireFormatLite::TYPE_UINT32:
335 st = CountPackedVarint(buf, buf_size);
336 break;
337 case WireFormatLite::TYPE_ENUM:
338 st = CountPackedVarint(buf, buf_size);
339 break;
340 case WireFormatLite::TYPE_SFIXED32:
341 st = CountPackedFixed<int32>(buf, buf_size);
342 break;
343 case WireFormatLite::TYPE_SFIXED64:
344 st = CountPackedFixed<int64_t>(buf, buf_size);
345 break;
346 case WireFormatLite::TYPE_SINT32:
347 st = CountPackedVarint(buf, buf_size);
348 break;
349 case WireFormatLite::TYPE_SINT64:
350 st = CountPackedVarint(buf, buf_size);
351 break;
352 // default: intentionally omitted in order to enable static checking.
353 }
354 if (!st.ok()) {
355 return st;
356 }
357
358 if (!field.is_repeated && *count_ptr_ > 1) {
359 *count_ptr_ = 1;
360 }
361 return OkStatus();
362 }
363
364 private:
365 // Skips a length-delimited value.
366 static bool SkipBytes(CodedInputStream* input) {
367 uint32 length;
368 if (!input->ReadVarint32(&length)) {
369 return false;
370 }
371 return input->Skip(length);
372 }
373
374 // Counts the number of packed varints in an array. The end of a varint is
375 // signaled by a value < 0x80, so counting them requires parsing the
376 // bytestream. It is the caller's responsibility to ensure that len > 0.
377 Status CountPackedVarint(const uint8* buf, size_t len) {
378 const uint8* bound = buf + len;
379 int count;
380
381 // The last byte in a valid encoded varint is guaranteed to have the high
382 // bit unset. We rely on this property to prevent ReadVarint64FromArray from
383 // going out of bounds, so validate the end of the buf before scanning
384 // anything.
385 if (bound[-1] & 0x80) {
386 return errors::DataLoss("Corrupt packed varint");
387 }
388
389 // Now we can trust ReadVarint64FromArray to stay in bounds.
390 for (count = 0; buf < bound; ++count) {
391 uint64 temp;
392 bool ok;
393 buf = internal::ReadVarint64FromArray(buf, &ok, &temp);
394 if (!ok) {
395 return errors::DataLoss("Corrupt packed varint");
396 }
397 }
398
399 *count_ptr_ += count;
400 return OkStatus();
401 }
402
403 // Counts the number of fixed-size values in a packed field. This can be done
404 // without actually parsing anything.
405 template <typename T>
406 Status CountPackedFixed(const uint8* unused_buf, size_t len) {
407 int count = len / sizeof(T);
408 if (count * sizeof(T) != len) {
409 return errors::DataLoss(
410 "Illegal data length for packed fixed-size type: ", len);
411 }
412 *count_ptr_ += len / sizeof(T);
413 return OkStatus();
414 }
415
416 // Skips a single value in the input stream. Dispatches to the appropriately
417 // typed field skipper based on the schema type tag. This is not as permissive
418 // as just handling the wire type.
419 static bool SkipValue(CodedInputStream* input, const FieldInfo& field) {
420 uint32 tmp32;
421 protobuf_uint64 tmp64;
422 switch (field.type) {
423 case WireFormatLite::TYPE_DOUBLE:
424 return input->ReadLittleEndian64(&tmp64);
425 case WireFormatLite::TYPE_FLOAT:
426 return input->ReadLittleEndian32(&tmp32);
427 case WireFormatLite::TYPE_INT64:
428 return input->ReadVarint64(&tmp64);
429 case WireFormatLite::TYPE_UINT64:
430 return input->ReadVarint64(&tmp64);
431 case WireFormatLite::TYPE_INT32:
432 return input->ReadVarint32(&tmp32);
433 case WireFormatLite::TYPE_FIXED64:
434 return input->ReadLittleEndian64(&tmp64);
435 case WireFormatLite::TYPE_FIXED32:
436 return input->ReadLittleEndian32(&tmp32);
437 case WireFormatLite::TYPE_BOOL:
438 return input->ReadVarint32(&tmp32);
439 case WireFormatLite::TYPE_STRING:
440 return SkipBytes(input);
441 case WireFormatLite::TYPE_GROUP:
442 return WireFormatLite::SkipField(
443 input, WireFormatLite::MakeTag(
444 field.number, WireFormatLite::WIRETYPE_START_GROUP));
445 case WireFormatLite::TYPE_MESSAGE:
446 return SkipBytes(input);
447 case WireFormatLite::TYPE_BYTES:
448 return SkipBytes(input);
449 case WireFormatLite::TYPE_UINT32:
450 return input->ReadVarint32(&tmp32);
451 case WireFormatLite::TYPE_ENUM:
452 return input->ReadVarint32(&tmp32);
453 case WireFormatLite::TYPE_SFIXED32:
454 return input->ReadLittleEndian32(&tmp32);
455 case WireFormatLite::TYPE_SFIXED64:
456 return input->ReadLittleEndian64(&tmp64);
457 case WireFormatLite::TYPE_SINT32:
458 return input->ReadVarint32(&tmp32);
459 case WireFormatLite::TYPE_SINT64:
460 return input->ReadVarint64(&tmp64);
461 // default: intentionally omitted in order to enable static checking.
462 }
463 }
464
465 int32* count_ptr_ = nullptr;
466};
467
468// A DenseCollector accumulates values from a proto into a tensor.
469//
470// There is an instance of DenseCollector for each field of each proto. The
471// DenseCollector deserializes the value from the wire directly into the
472// preallocated output Tensor.
473//
474// This class is named DenseCollector because in the future there should be a
475// SparseCollector that accumulates field data into sparse tensors if the user
476// requests it.
477class DenseCollector {
478 public:
479 DenseCollector() = delete;
480
481 // A DenseCollector applies to one field of a serialized message.
482 // Note that default_value.dtype is the type of the output tensor.
483 DenseCollector(uint8* datap, DefaultValue default_value, int max_repeat_count)
484 : datap_(datap),
485 default_value_(default_value),
486 max_repeat_count_(max_repeat_count) {}
487
488 // Reads a value from the input stream and stores it.
489 //
490 // Always inlining gave a ~50% speedup on microbenchmarks at one point.
491 // TODO(nix): try removing it to see if that still holds.
492 // TODO(jsimsa): ABSL_ATTRIBUTE_ALWAYS_INLINE
493 Status ReadValue(CodedInputStream* input, const FieldInfo& field) {
494 // For required and optional fields, we overwrite values[0] with
495 // the latest one in the wire stream.
496 // See https://developers.google.com/protocol-buffers/docs/encoding#optional
497 // Only for repeated fields do we advance the next_repeat_index_ past 1.
498 // TODO(nix): to handle oneof we must also zero out any previous values
499 // seen on the wire.
500 int32_t index = 0;
501 if (field.is_repeated) {
502 index = next_repeat_index_;
503 }
504 next_repeat_index_ = index + 1;
505
506 return internal::ReadValue(input, field.type, field.number,
507 default_value_.dtype, index, datap_);
508 }
509
510 // Reads and stores a length-delimited list of values.
511 Status ReadPackedValues(CodedInputStream* input, const FieldInfo& field,
512 const size_t buf_size) {
513 const void* buf;
514 int unused_max_buf_size;
515 input->GetDirectBufferPointerInline(&buf, &unused_max_buf_size);
516 // This is safe because the underlying storage for the CodedInputStream is
517 // owned by the input tensor. If it were a Cord or file-backed stream this
518 // pointer would go stale after the bytes were skipped.
519 if (!input->Skip(buf_size)) {
520 return errors::DataLoss(
521 "ReadPackedValues: Skipping packed field failed. Field tag: ",
522 field.number);
523 }
524
525 // Setting stride=0 causes new values to overwrite old ones for
526 // non-repeated fields.
527 const int stride = field.is_repeated ? 1 : 0;
528
529 if (next_repeat_index_ >= max_repeat_count_) {
530 return errors::DataLoss(
531 "ReadPackedValues: Tried to write more entries than allowed. "
532 "Field tag: ",
533 field.number, ", Max entries allowed: ", max_repeat_count_);
534 } else {
535 return internal::ReadPackedFromArray(buf, buf_size, field.type,
536 field.number, default_value_.dtype,
537 stride, &next_repeat_index_, datap_);
538 }
539 }
540
541 // Fills in any missing values in the output array with defaults. Dispatches
542 // to the appropriately typed field default based on the runtime type tag.
543 Status FillWithDefaults() {
544 switch (default_value_.dtype) {
545 case DataType::DT_BOOL:
546 return FillDefault<bool>(default_value_.value.v_bool);
547 case DataType::DT_FLOAT:
548 return FillDefault<float>(default_value_.value.v_float);
549 case DataType::DT_DOUBLE:
550 return FillDefault<double>(default_value_.value.v_double);
551 case DataType::DT_INT8:
552 return FillDefault<int8>(default_value_.value.v_int8);
553 case DataType::DT_INT32:
554 return FillDefault<int32>(default_value_.value.v_int32);
555 case DataType::DT_INT64:
556 return FillDefault<int64_t>(default_value_.value.v_int64);
557 case DataType::DT_STRING:
558 return FillDefault<tstring>(default_value_.value.v_string);
559 case DataType::DT_UINT8:
560 return FillDefault<uint8>(default_value_.value.v_uint8);
561 case DataType::DT_UINT32:
562 return FillDefault<uint32>(default_value_.value.v_uint32);
563 case DataType::DT_UINT64:
564 return FillDefault<uint64>(default_value_.value.v_uint64);
565 default:
566 // There are many tensorflow dtypes not handled here, but they
567 // should not come up unless type casting is added to the Op.
568 // Chaining with tf.cast() should do the right thing until then.
569 return errors::DataLoss("Failed filling defaults for ",
570 DataTypeString(default_value_.dtype));
571 }
572 }
573
574 private:
575 // Fills empty values in the dense representation with a default value. This
576 // uses next_repeat_index_ which counts the number of parsed values for the
577 // field.
578 template <class T>
579 Status FillDefault(const T& default_value) {
580 for (int i = next_repeat_index_; i < max_repeat_count_; i++) {
581 reinterpret_cast<T*>(datap_)[i] = default_value;
582 }
583 return OkStatus();
584 }
585
586 int32 next_repeat_index_ = 0;
587
588 // This is a pointer to data_[message_index_]. There is no bounds checking at
589 // this level: we computed the max repeat size for each field in
590 // CountCollector and use the same code to traverse it here, so we are
591 // guaranteed not to be called for more items than we have allocated space.
592 void* const datap_ = nullptr;
593
594 const DefaultValue default_value_;
595 const int max_repeat_count_ = 0;
596};
597
598class DecodeProtoOp : public OpKernel {
599 public:
600 explicit DecodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
601 string descriptor_source;
602 OP_REQUIRES_OK(context,
603 context->GetAttr("descriptor_source", &descriptor_source));
604
605 // We always get back a desc_pool, but we may not own it. If we own it,
606 // owned_desc_pool_ will be filled in.
607 DescriptorPool const* desc_pool;
608 OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
609 &desc_pool, &owned_desc_pool_));
610
611 string message_type;
612 OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
613
614 const Descriptor* message_desc =
615 desc_pool->FindMessageTypeByName(message_type);
616 OP_REQUIRES(context, message_desc != nullptr,
617 errors::InvalidArgument("No descriptor found for message type ",
618 message_type));
619
620 std::vector<string> field_names;
621 OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names));
622 std::vector<DataType> output_types;
623 OP_REQUIRES_OK(context, context->GetAttr("output_types", &output_types));
624 OP_REQUIRES(
625 context, field_names.size() == output_types.size(),
626 errors::InvalidArgument("field_names and output_types attributes must "
627 "have the same length"));
628
629 // Gather the field descriptors and check that requested output types match.
630 int field_index = 0;
631 std::vector<const FieldDescriptor*> field_descs;
632 std::vector<const FieldDescriptor*> exts;
633 absl::flat_hash_map<string, const FieldDescriptor*> ext_name_to_field;
634 std::vector<const FieldDescriptor*>::iterator ext_it = exts.begin();
635 for (const string& name : field_names) {
636 auto fd = message_desc->FindFieldByName(name);
637 if (fd == nullptr) {
638 // If field can't be found in original message, try to find a matching
639 // extension (by its full_name). First check a hashmap for a matching
640 // extension, and if not found, then iterate through available
641 // extensions to find a match (updating the hashmap while iterating.)
642 auto lookup_result = ext_name_to_field.find(name);
643 if (lookup_result != ext_name_to_field.end()) {
644 fd = lookup_result->second;
645 } else {
646 if (ext_it == exts.begin()) {
647 desc_pool->FindAllExtensions(message_desc, &exts);
648 ext_it = exts.begin();
649 }
650 while (ext_it != exts.end()) {
651 auto ext_name = (*ext_it)->full_name();
652 auto ext_field = *ext_it;
653 ++ext_it;
654
655 ext_name_to_field.insert({ext_name, ext_field});
656 if (ext_name == name) {
657 fd = ext_field;
658 break;
659 }
660 }
661 }
662 }
663 OP_REQUIRES(context, fd != nullptr,
664 errors::InvalidArgument("Unknown field: ", name,
665 " in message type ", message_type));
666 OP_REQUIRES(
667 context,
668 proto_utils::IsCompatibleType(fd->type(), output_types[field_index]),
669 // Many TensorFlow types don't have corresponding proto types and the
670 // user will get an error if they are requested. It would be nice to
671 // allow conversions here, but tf.cast already exists so we don't
672 // duplicate the functionality.
673 errors::InvalidArgument("Unexpected output type for ",
674 fd->full_name(), ": ", fd->cpp_type(), " to ",
675 output_types[field_index]));
676
677 field_index++;
678 field_descs.push_back(fd);
679 }
680
681 // Internally we want the field_descs sorted by their number on the wire.
682 // But the output tensors are allocated in the order given by the caller.
683 // Build a mapping i->j, where field_descs[i] corresponds to outputs[j].
684 std::vector<int> output_indices;
685 output_indices.reserve(field_names.size());
686 for (int i = 0; i < field_names.size(); i++) {
687 output_indices.push_back(i);
688 }
689 std::sort(output_indices.begin(), output_indices.end(),
690 [field_descs](int a, int b) {
691 return field_descs[a]->number() < field_descs[b]->number();
692 });
693
694 // Now store the fields in sorted order.
695 for (int i = 0; i < field_names.size(); i++) {
696 const int output_index = output_indices[i];
697 const DataType dtype = output_types[output_index];
698 const FieldDescriptor* field_descriptor = field_descs[output_index];
699 DefaultValue default_value;
700 OP_REQUIRES_OK(context, InitDefaultValueFromFieldDescriptor(
701 dtype, field_descriptor, &default_value));
702 fields_.push_back(
703 MakeUnique<FieldInfo>(field_descriptor, output_index, default_value));
704 }
705
706 message_prototype_ = message_factory_.GetPrototype(message_desc);
707 OP_REQUIRES(context, message_prototype_ != nullptr,
708 errors::InvalidArgument("Couldn't get prototype message: ",
709 message_desc->full_name()));
710 string format;
711 OP_REQUIRES_OK(context, context->GetAttr("message_format", &format));
712 OP_REQUIRES(
713 context, format == "binary" || format == "text",
714 errors::InvalidArgument("format must be one of binary or text"));
715 is_binary_ = format == "binary";
716
717 // Enable the initial protobuf sanitizer, which is much more expensive than
718 // the decoder.
719 // TODO(nix): Remove this once the fast decoder has passed security review.
720 OP_REQUIRES_OK(context, context->GetAttr("sanitize", &sanitize_));
721 }
722
723 void Compute(OpKernelContext* ctx) override {
724 const Tensor& buf_tensor = ctx->input(0);
725 int message_count = buf_tensor.NumElements();
726 OP_REQUIRES(ctx, message_count >= 1,
727 errors::InvalidArgument(
728 "Bufs argument must contain at least one value"));
729
730 int field_count = fields_.size();
731
732 // Save the argument shape for later, then flatten the input Tensor since we
733 // are working componentwise. We will restore the same shape in the returned
734 // Tensor.
735 const TensorShape& shape_prefix = buf_tensor.shape();
736
737 TensorShape sizes_shape = shape_prefix;
738 sizes_shape.AddDim(field_count);
739 Tensor* sizes_tensor = nullptr;
740 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, sizes_shape, &sizes_tensor));
741
742 // This is used to allocate binary bufs if used. It serves only to define
743 // memory ownership.
744 std::vector<tstring> tmp_binary_bufs(message_count);
745
746 // These are the actual buffers to use, which may be in tmp_binary_bufs
747 // or may be pointers into the buf_tensor. Either way they are not owned
748 // here.
749 std::vector<const tstring*> bufs;
750
751 if (is_binary_ && !sanitize_) {
752 // Fast path.
753 for (int mi = 0; mi < message_count; ++mi) {
754 const tstring* buf = &buf_tensor.flat<tstring>()(mi);
755 bufs.push_back(buf);
756 }
757 } else {
758 // We will have to allocate a copy, either to convert from text to binary
759 // or to sanitize a binary proto.
760 for (int mi = 0; mi < message_count; ++mi) {
761 ReserializeMessage(ctx, buf_tensor.flat<tstring>()(mi),
762 &tmp_binary_bufs[mi]);
763 if (!ctx->status().ok()) {
764 return;
765 }
766 bufs.push_back(&tmp_binary_bufs[mi]);
767 }
768 }
769
770 // Walk through all the strings in the input tensor, counting the number of
771 // fields in each. We can't allocate our actual output Tensor until we know
772 // the maximum repeat count, so we do a first pass through the serialized
773 // proto just counting fields. We always allocate at least one value so that
774 // optional fields are populated with default values - this avoids a TF
775 // conditional when handling the output data. The caller can distinguish
776 // between real data and defaults using the repeat count matrix that is
777 // returned by decode_proto.
778 std::vector<int32> max_sizes(field_count, 1);
779 for (int mi = 0; mi < message_count; ++mi) {
780 CountFields(ctx, mi, *bufs[mi], sizes_tensor, &max_sizes);
781 if (!ctx->status().ok()) {
782 return;
783 }
784 }
785
786 // Allocate the output tensors now that we've seen the max size.
787 // TODO(nix): Use allocate_output_or_forward_input for the largest
788 // output tensor. This can avoid one large allocation by re-using
789 // the memory of the input tensor.
790 std::vector<Tensor*> outputs(field_count);
791 for (int fi = 0; fi < field_count; ++fi) {
792 TensorShape flat_shape = {static_cast<int64_t>(message_count),
793 max_sizes[fi]};
794 TensorShape out_shape = shape_prefix;
795 out_shape.AddDim(max_sizes[fi]);
796
797 // Surprisingly we don't specify the types from the output_types
798 // attribute: that is done for us based on the Op declaration:
799 // REGISTER_OP(...)
800 // .Attr("output_types: list(type) >= 0")
801 // .Output("values: output_types")
802 OP_REQUIRES_OK(ctx, ctx->allocate_output(fields_[fi]->output_index + 1,
803 out_shape, &outputs[fi]));
804 }
805
806 // Make the second pass through the serialized proto, decoding into
807 // preallocated tensors.
808 AccumulateFields(ctx, bufs, outputs);
809 }
810
811 private:
812 // Copy a serialized message to binary, e.g. to handle text proto inputs.
813 void ReserializeMessage(OpKernelContext* ctx, const tstring& buf,
814 tstring* binary_buf) {
815 // Handle text protos by translating them to binary.
816 std::unique_ptr<Message> message(message_prototype_->New());
817 OP_REQUIRES(ctx, message, errors::DataLoss("Initializing message failed"));
818
819 if (is_binary_) {
820 // If we get here we are sanitizing the input protobuf by parsing
821 // and reserializing it with a trusted (but very slow) library.
822 OP_REQUIRES(ctx, message->ParseFromString(buf),
823 errors::DataLoss("Unable to parse binary protobuf"));
824 } else {
825 OP_REQUIRES(ctx, TextFormat::ParseFromString(buf, message.get()),
826 errors::DataLoss("Unable to parse text protobuf"));
827 }
828
829 OP_REQUIRES(ctx, SerializeToTString(*message, binary_buf),
830 errors::DataLoss("Unable to reserialize text proto as binary"));
831 }
832
833 // Count the number of occurrences of each requested field in a message batch.
834 void CountFields(OpKernelContext* ctx, int message_index, const tstring& buf,
835 Tensor* sizes_tensor, std::vector<int32>* max_sizes) {
836 int field_count = fields_.size();
837
838 CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
839 buf.size());
840
841 std::vector<int32> field_sizes(field_count, 0);
842 std::vector<CountCollector> counters;
843 counters.reserve(field_count);
844 for (int i = 0; i < field_count; i++) {
845 counters.emplace_back(&field_sizes[i]);
846 }
847
848 Status st = Collect(&input, absl::MakeSpan(counters));
849 if (st.ok() && !input.ConsumedEntireMessage()) {
850 st = errors::DataLoss("CountFields: Failed to consume entire buffer");
851 }
852 if (kFailOnDecodeError) {
853 OP_REQUIRES_OK(ctx, st); // NOLINT
854 }
855 if (!st.ok()) {
856 // This code suppresses the corrupt proto, treating it as empty
857 // to avoid crashing the process.
858 LOG(WARNING) << "Proto counting error for message type " << message_type_
859 << ": " << st;
860
861 for (int fi = 0; fi < field_count; fi++) {
862 field_sizes[fi] = 0;
863 }
864 // Finished decoding this message.
865 return;
866 }
867
868 // Update the size tensor and max repeat size for each field.
869 auto sizes = sizes_tensor->flat_inner_dims<int32>();
870 for (int fi = 0; fi < field_count; fi++) {
871 int32_t size = field_sizes[fi];
872 sizes(message_index, fields_[fi]->output_index) = size;
873 if ((*max_sizes)[fi] < size) {
874 (*max_sizes)[fi] = size;
875 }
876 }
877 }
878
879 // Parse fields from a serialized message into preallocated tensors.
880 void AccumulateFields(OpKernelContext* ctx,
881 const std::vector<const tstring*>& bufs,
882 std::vector<Tensor*> outputs) {
883 struct TensorInfo {
884 explicit TensorInfo(Tensor* tensor) {
885 // Note that we can decode only max_repeat_count values before overflow.
886 // No other bounds checking is done for repeated fields. For
887 // optional fields there is a check to make sure that only the last
888 // value on the wire appears in the output tensor.
889 dtype = tensor->dtype();
890 last_dim_size = tensor->dim_size(tensor->dims() - 1);
891
892 if (dtype != DT_STRING) {
893 const int element_size = DataTypeSize(dtype);
894 CHECK_GT(element_size, 0);
895 stride = last_dim_size * element_size;
896
897 const int64_t flatshape[1] = {tensor->NumElements() * element_size};
898 data = tensor->bit_casted_shaped<uint8, 1>(flatshape).data();
899 } else {
900 // DataTypeSize() returns 0 for string types.
901 stride = last_dim_size * sizeof(tstring);
902 data = reinterpret_cast<uint8*>(tensor->flat<tstring>().data());
903 }
904 }
905
906 DataType dtype;
907 int last_dim_size;
908 int stride;
909 uint8* data;
910 };
911
912 int field_count = fields_.size();
913
914 std::vector<TensorInfo> tensors;
915 tensors.reserve(field_count);
916 for (int fi = 0; fi < field_count; fi++) {
917 tensors.emplace_back(outputs[fi]);
918 }
919
920 for (int message_index = 0; message_index < bufs.size(); ++message_index) {
921 const tstring& buf = *bufs[message_index];
922
923 std::vector<DenseCollector> collectors;
924 collectors.reserve(field_count);
925 for (int output_index = 0; output_index < field_count; ++output_index) {
926 const TensorInfo& info = tensors[output_index];
927 const FieldInfo* field_info = fields_[output_index].get();
928 DCHECK(field_info != nullptr);
929 const DefaultValue default_value = field_info->default_value;
930 collectors.emplace_back(info.data + message_index * info.stride,
931 default_value, info.last_dim_size);
932 }
933
934 // Fill in output tensors from the wire.
935 CodedInputStream input(reinterpret_cast<const uint8*>(buf.c_str()),
936 buf.size());
937 Status st = Collect(&input, absl::MakeSpan(collectors));
938 if (st.ok() && !input.ConsumedEntireMessage()) {
939 st = errors::DataLoss(
940 "AccumulateFields: Failed to consume entire buffer");
941 }
942 if (kFailOnDecodeError) {
943 OP_REQUIRES_OK(ctx, st); // NOLINT
944 }
945 if (!st.ok()) {
946 // This code suppresses the corrupt proto, treating it as empty
947 // to avoid crashing training.
948 LOG(WARNING) << "Proto counting error for message type "
949 << message_type_ << ": " << st;
950 }
951
952 // Fill the remainder of the dense outputs with default values.
953 for (auto& collector : collectors) {
954 OP_REQUIRES_OK(ctx, collector.FillWithDefaults());
955 }
956 }
957 }
958
959 // Traverses a serialized protobuf, dispatching values to the collectors.
960 template <class CollectorClass>
961 Status Collect(CodedInputStream* input,
962 absl::Span<CollectorClass> collectors) {
963 // At the beginning of each loop, the last field number that was seen,
964 // regardless of whether it was collected or not, or -1 if no field has
965 // been seen before.
966 int last_seen_field_number = -1;
967 // The FieldInfo that is expected to be used next.
968 // It was either used to collect the last seen field number, or if the
969 // last seen field number was not in fields_, it is the next FieldInfo after
970 // the last seen field number. At the beginning it is the first FieldInfo.
971 auto expected_field_info_iter = fields_.begin();
972
973 // The 'tag' variable should always be treated as tainted.
974 for (uint32 tag = input->ReadTag();
975 tag != 0 && WireFormatLite::GetTagWireType(tag) !=
976 WireFormatLite::WIRETYPE_END_GROUP;
977 tag = input->ReadTag()) {
978 DCHECK(expected_field_info_iter == fields_.begin() ||
979 last_seen_field_number >
980 (*(expected_field_info_iter - 1))->number);
981 DCHECK(expected_field_info_iter == fields_.end() ||
982 last_seen_field_number <= (*expected_field_info_iter)->number);
983
984 // The field wire number.
985 const int field_number = WireFormatLite::GetTagFieldNumber(tag);
986 // The field info associated with the field wire number.
987 const FieldInfo* field_info = nullptr;
988
989 // fields_ are ordered by their field numbers. If the field numbers
990 // on wire are also ordered (which is a convention), then we can
991 // monotonically increment `expected_field_info_iter` as the field
992 // numbers on wire get larger. If we detect any out-of-order
993 // field number, we reset `expected_field_info_iter`, and expect that
994 // future wire numbers are ordered. This algorithm is quadratic in the
995 // worst case where field numbers on wire are in descending order, however
996 // it works well in the case where two serialized protobufs are
997 // concatenated together.
998 if (field_number < last_seen_field_number) {
999 expected_field_info_iter = fields_.begin();
1000 }
1001
1002 // Advance expected_field_info_iter until
1003 // field_number <= expected_field_number.
1004 for (; expected_field_info_iter != fields_.end();
1005 ++expected_field_info_iter) {
1006 DCHECK(expected_field_info_iter == fields_.begin() ||
1007 field_number > (*(expected_field_info_iter - 1))->number);
1008 const FieldInfo* expected_field_info = expected_field_info_iter->get();
1009 if (field_number <= expected_field_info->number) {
1010 if (field_number == expected_field_info->number) {
1011 field_info = expected_field_info;
1012 }
1013 break;
1014 }
1015 }
1016 last_seen_field_number = field_number;
1017 if (!field_info) {
1018 // This DCHECK verifies that if we skip a field, we didn't want it.
1019 // In particular, field_builders is empty or the field_number is either:
1020 // before fields_.begin().number or after (fields_.end() - 1).number or
1021 // in-between expected_field_info_iter and expected_field_info_iter - 1.
1022 DCHECK(fields_.empty() || (field_number < (*fields_.begin())->number) ||
1023 (field_number > (*(fields_.end() - 1))->number) ||
1024 (((*(expected_field_info_iter - 1))->number < field_number) &&
1025 (field_number < (*(expected_field_info_iter))->number)));
1026 // Unknown and unrequested fields are skipped.
1027 if (!WireFormatLite::SkipField(input, tag)) {
1028 return errors::DataLoss("Failed skipping unrequested field");
1029 }
1030 continue;
1031 }
1032
1033 TF_RETURN_IF_ERROR(CollectField(
1034 *field_info, WireFormatLite::GetTagWireType(tag), input,
1035 &collectors[expected_field_info_iter - fields_.begin()]));
1036 }
1037 return OkStatus();
1038 }
1039
1040 // Collects values for a single field.
1041 template <class CollectorClass>
1042 Status CollectField(const FieldInfo& field,
1043 WireFormatLite::WireType wire_type,
1044 CodedInputStream* input, CollectorClass* collector) {
1045 // The wire format library defines the same constants used in
1046 // descriptor.proto. This static_cast is safe because they are guaranteed to
1047 // stay in sync.
1048 //
1049 // We need the field type from the FieldDescriptor here because the wire
1050 // format doesn't tell us anything about what happens inside a packed
1051 // repeated field: there is enough information in the wire format to skip
1052 // the whole field but not enough to know how to parse what's inside. For
1053 // that we go to the schema.
1054 WireFormatLite::WireType schema_wire_type =
1055 WireFormatLite::WireTypeForFieldType(field.type);
1056
1057 // Handle packed repeated fields. SkipField would skip the whole
1058 // length-delimited blob without letting us count the values, so we have to
1059 // scan them ourselves.
1060 if (wire_type == WireFormatLite::WIRETYPE_LENGTH_DELIMITED &&
1061 schema_wire_type != WireFormatLite::WIRETYPE_LENGTH_DELIMITED) {
1062 // Handle packed repeated primitives.
1063 int length;
1064 if (!input->ReadVarintSizeAsInt(&length)) {
1065 return errors::DataLoss("CollectField: Failed reading packed size");
1066 }
1067 return collector->ReadPackedValues(input, field, length);
1068 }
1069
1070 // Read ordinary values, including strings, bytes, and messages.
1071 if (wire_type != schema_wire_type) {
1072 if (!WireFormatLite::SkipField(
1073 input, WireFormatLite::MakeTag(field.number, wire_type))) {
1074 return errors::DataLoss(
1075 "CollectField: Failed skipping malformed field");
1076 }
1077 return OkStatus();
1078 }
1079 return collector->ReadValue(input, field);
1080 }
1081
1082 string message_type_;
1083 // Note that fields are sorted by increasing field number, which is not in
1084 // general the order given by the user-specified field_names and output_types
1085 // Op attributes.
1086 std::vector<std::unique_ptr<const FieldInfo>> fields_;
1087
1088 // Owned_desc_pool_ is null when using descriptor_source=local.
1089 std::unique_ptr<DescriptorPool> owned_desc_pool_;
1090 DynamicMessageFactory message_factory_;
1091 const Message* message_prototype_;
1092
1093 // True if decoding binary format, false if decoding text format.
1094 bool is_binary_;
1095
1096 // True if the protos should be sanitized before parsing. Enables the initial
1097 // protobuf sanitizer, which is much more expensive than the decoder. The flag
1098 // defaults to true but can be set to false for trusted sources.
1099 //
1100 // TODO(nix): Flip the default to false when the fast decoder has passed
1101 // security review.
1102 bool sanitize_;
1103
1104 TF_DISALLOW_COPY_AND_ASSIGN(DecodeProtoOp);
1105};
1106
1107REGISTER_KERNEL_BUILDER(Name("DecodeProtoV2").Device(DEVICE_CPU),
1108 DecodeProtoOp);
1109
1110} // namespace
1111} // namespace tensorflow
1112