1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// EncodeProto is a TensorFlow Op which serializes tensors into
17// arbitrary protobufs.
18//
19// See the docstring in ../ops/encode_proto_op.cc for usage of the op.
20//
21// This implementation writes the serialized format using a handful of
22// calls from the WireFormatLite API.
23
24#include <memory>
25#include <vector>
26
27#include "third_party/eigen3/Eigen/Core"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/tensor_types.h"
30#include "tensorflow/core/lib/core/errors.h"
31#include "tensorflow/core/platform/logging.h"
32#include "tensorflow/core/platform/protobuf.h"
33#include "tensorflow/core/util/proto/descriptors.h"
34#include "tensorflow/core/util/proto/proto_utils.h"
35
36namespace tensorflow {
37namespace {
38
39using ::tensorflow::protobuf::Descriptor;
40using ::tensorflow::protobuf::DescriptorPool;
41using ::tensorflow::protobuf::FieldDescriptor;
42using ::tensorflow::protobuf::internal::WireFormatLite;
43using ::tensorflow::protobuf::io::CodedOutputStream;
44using ::tensorflow::protobuf::io::StringOutputStream;
45
46// Computes the total serialized size for a packed repeated field. For
47// fixed-size types this can just multiply, but for variable-sized types it has
48// to iterate through the values in the tensor.
49template <WireFormatLite::FieldType FieldType, typename TensorT>
50size_t TotalPackedSize(const Tensor& input, int message_index, int size);
51
52template <>
53size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input,
54 int message_index,
55 int size) {
56 return size * WireFormatLite::kDoubleSize;
57}
58
59template <>
60size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input,
61 int message_index,
62 int size) {
63 return size * WireFormatLite::kFloatSize;
64}
65
66template <>
67size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input,
68 int message_index,
69 int size) {
70 return size * WireFormatLite::kFloatSize;
71}
72
73template <>
74size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64_t>(const Tensor& input,
75 int message_index,
76 int size) {
77 size_t data_size = 0;
78 auto input_t = input.flat_inner_dims<int64_t>();
79 for (int64_t i = 0; i < size; i++) {
80 data_size += WireFormatLite::Int64Size(
81 input_t(static_cast<int64_t>(message_index), i));
82 }
83 return data_size;
84}
85
86template <>
87size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, uint64>(const Tensor& input,
88 int message_index,
89 int size) {
90 size_t data_size = 0;
91 auto input_t = input.flat_inner_dims<uint64>();
92 for (int64_t i = 0; i < size; i++) {
93 data_size += WireFormatLite::UInt64Size(
94 input_t(static_cast<int64_t>(message_index), i));
95 }
96 return data_size;
97}
98
99template <>
100size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int64_t>(const Tensor& input,
101 int message_index,
102 int size) {
103 size_t data_size = 0;
104 auto input_t = input.flat_inner_dims<int64_t>();
105 for (int64_t i = 0; i < size; i++) {
106 data_size += WireFormatLite::Int32Size(
107 input_t(static_cast<int64_t>(message_index), i));
108 }
109 return data_size;
110}
111
112template <>
113size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input,
114 int message_index,
115 int size) {
116 size_t data_size = 0;
117 auto input_t = input.flat_inner_dims<int32>();
118 for (int64_t i = 0; i < size; i++) {
119 data_size += WireFormatLite::Int32Size(
120 input_t(static_cast<int64_t>(message_index), i));
121 }
122 return data_size;
123}
124
125template <>
126size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, uint64>(
127 const Tensor& input, int message_index, int size) {
128 return size * WireFormatLite::kFixed64Size;
129}
130
131template <>
132size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint64>(
133 const Tensor& input, int message_index, int size) {
134 return size * WireFormatLite::kFixed32Size;
135}
136
137template <>
138size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint32>(
139 const Tensor& input, int message_index, int size) {
140 return size * WireFormatLite::kFixed32Size;
141}
142
143template <>
144size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input,
145 int message_index,
146 int size) {
147 return size * WireFormatLite::kBoolSize;
148}
149
150template <>
151size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint64>(const Tensor& input,
152 int message_index,
153 int size) {
154 size_t data_size = 0;
155 auto input_t = input.flat_inner_dims<uint64>();
156 for (int64_t i = 0; i < size; i++) {
157 data_size += WireFormatLite::UInt32Size(
158 input_t(static_cast<int64_t>(message_index), i));
159 }
160 return data_size;
161}
162
163template <>
164size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint32>(const Tensor& input,
165 int message_index,
166 int size) {
167 size_t data_size = 0;
168 auto input_t = input.flat_inner_dims<uint32>();
169 for (int64_t i = 0; i < size; i++) {
170 data_size += WireFormatLite::UInt32Size(
171 input_t(static_cast<int64_t>(message_index), i));
172 }
173 return data_size;
174}
175
176template <>
177size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input,
178 int message_index,
179 int size) {
180 size_t data_size = 0;
181 auto input_t = input.flat_inner_dims<int32>();
182 for (int64_t i = 0; i < size; i++) {
183 data_size += WireFormatLite::EnumSize(
184 input_t(static_cast<int64_t>(message_index), i));
185 }
186 return data_size;
187}
188
189template <>
190size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>(
191 const Tensor& input, int message_index, int size) {
192 return size * WireFormatLite::kSFixed32Size;
193}
194
195template <>
196size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int64_t>(
197 const Tensor& input, int message_index, int size) {
198 return size * WireFormatLite::kSFixed32Size;
199}
200
201template <>
202size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64_t>(
203 const Tensor& input, int message_index, int size) {
204 return size * WireFormatLite::kSFixed64Size;
205}
206
207template <>
208size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input,
209 int message_index,
210 int size) {
211 size_t data_size = 0;
212 auto input_t = input.flat_inner_dims<int32>();
213 for (int64_t i = 0; i < size; i++) {
214 data_size += WireFormatLite::SInt32Size(
215 input_t(static_cast<int64_t>(message_index), i));
216 }
217 return data_size;
218}
219
220template <>
221size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int64_t>(
222 const Tensor& input, int message_index, int size) {
223 size_t data_size = 0;
224 auto input_t = input.flat_inner_dims<int64_t>();
225 for (int64_t i = 0; i < size; i++) {
226 data_size += WireFormatLite::SInt32Size(
227 input_t(static_cast<int64_t>(message_index), i));
228 }
229 return data_size;
230}
231
232template <>
233size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64_t>(
234 const Tensor& input, int message_index, int size) {
235 size_t data_size = 0;
236 auto input_t = input.flat_inner_dims<int64_t>();
237 for (int64_t i = 0; i < size; i++) {
238 data_size += WireFormatLite::SInt64Size(
239 input_t(static_cast<int64_t>(message_index), i));
240 }
241 return data_size;
242}
243
244// Writes a possibly repeated primitive field. TensorFlow does not have unsigned
245// types, so we decode them to signed and encode them back to unsigned.
246template <typename TensorT, typename ProtoT,
247 WireFormatLite::FieldType FieldType,
248 void Writer(ProtoT, CodedOutputStream*)>
249Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
250 int message_index, int size, CodedOutputStream* output) {
251 auto wire_type = WireFormatLite::WireTypeForFieldType(
252 WireFormatLite::FieldType(field_desc.type()));
253
254 auto input_t = input.flat_inner_dims<TensorT>();
255 if (field_desc.options().packed()) {
256 // Write the tag for the packed field.
257 WireFormatLite::WriteTag(field_desc.number(),
258 WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output);
259
260 // Write the total packed length.
261 size_t data_size =
262 TotalPackedSize<FieldType, TensorT>(input, message_index, size);
263 output->WriteVarint32(data_size);
264
265 // Write individual values.
266 for (int64_t i = 0; i < size; i++) {
267 // Note implicit cast from signed to unsigned.
268 const ProtoT& value = input_t(static_cast<int64_t>(message_index), i);
269 Writer(value, output);
270 }
271 } else {
272 for (int64_t i = 0; i < size; i++) {
273 WireFormatLite::WriteTag(field_desc.number(), wire_type, output);
274
275 // Note implicit cast from signed to unsigned.
276 const ProtoT& value = input_t(static_cast<int64_t>(message_index), i);
277 Writer(value, output);
278 }
279 }
280 return OkStatus();
281}
282
283// Writes a possibly repeated string, bytes, or message field.
284template <typename T, void Writer(int, const T&, CodedOutputStream*)>
285Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input,
286 int message_index, int size,
287 CodedOutputStream* output) {
288 auto input_t = input.flat_inner_dims<T>();
289 for (int64_t i = 0; i < size; i++) {
290 const T& value = input_t(static_cast<int64_t>(message_index), i);
291 // TODO(nix): there doesn't seem to be an inlined version of
292 // WireFormatLite::WriteString or its relatives, which might allow a
293 // small speedup.
294 Writer(field_desc.number(), value, output);
295 }
296 return OkStatus();
297}
298
299static void WriteStringAdapter(int field_number, const tstring& value,
300 CodedOutputStream* output) {
301 // Unfortunately, external proto does not accept string_view.
302#if defined(PLATFORM_GOOGLE)
303 WireFormatLite::WriteString(field_number, StringPiece(value), output);
304#else
305 WireFormatLite::WriteString(field_number, string(value), output);
306#endif
307}
308
309static void WriteBytesAdapter(int field_number, const tstring& value,
310 CodedOutputStream* output) {
311 // Unfortunately, external proto does not accept string_view.
312#if defined(PLATFORM_GOOGLE)
313 WireFormatLite::WriteBytes(field_number, StringPiece(value), output);
314#else
315 WireFormatLite::WriteBytes(field_number, string(value), output);
316#endif
317}
318
319// Writes a group field. Groups are treated like submessages, but tag-delimited
320// instead of length-delimited. WireFormatLite handles this differently so we
321// code it ourselves.
322Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input,
323 int message_index, int size, CodedOutputStream* output) {
324 auto input_t = input.flat_inner_dims<tstring>();
325 for (int64_t i = 0; i < size; i++) {
326 const string& value = input_t(static_cast<int64_t>(message_index), i);
327 WireFormatLite::WriteTag(field_desc.number(),
328 WireFormatLite::WIRETYPE_START_GROUP, output);
329 // Note the use of WriteRaw instead of WriteString to skip the length.
330 output->WriteRaw(value.data(), value.size());
331 WireFormatLite::WriteTag(field_desc.number(),
332 WireFormatLite::WIRETYPE_END_GROUP, output);
333 }
334 return OkStatus();
335}
336
337// Writes a (possibly repeated) field into an output stream. It is the caller's
338// responsibility to ensure that the type of the input tensor is compatible with
339// the type of the proto field descriptor, and that (message_index, size-1) is
340// within bounds.
341Status WriteField(const FieldDescriptor& field_desc, const Tensor& input,
342 int message_index, int size, CodedOutputStream* output) {
343 DataType dtype = input.dtype();
344
345 switch (field_desc.type()) {
346 case WireFormatLite::TYPE_DOUBLE:
347 return WriteField<double, double, WireFormatLite::TYPE_DOUBLE,
348 WireFormatLite::WriteDoubleNoTag>(
349 field_desc, input, message_index, size, output);
350 case WireFormatLite::TYPE_FLOAT:
351 switch (dtype) {
352 case DataType::DT_FLOAT:
353 return WriteField<float, float, WireFormatLite::TYPE_FLOAT,
354 WireFormatLite::WriteFloatNoTag>(
355 field_desc, input, message_index, size, output);
356 case DataType::DT_DOUBLE:
357 return WriteField<double, float, WireFormatLite::TYPE_FLOAT,
358 WireFormatLite::WriteFloatNoTag>(
359 field_desc, input, message_index, size, output);
360 default:
361 return errors::DataLoss("Failed writing TYPE_FLOAT for ",
362 DataTypeString(dtype));
363 }
364 case WireFormatLite::TYPE_INT64:
365 return WriteField<int64_t, protobuf_int64, WireFormatLite::TYPE_INT64,
366 WireFormatLite::WriteInt64NoTag>(
367 field_desc, input, message_index, size, output);
368 case WireFormatLite::TYPE_UINT64:
369 return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_UINT64,
370 WireFormatLite::WriteUInt64NoTag>(
371 field_desc, input, message_index, size, output);
372 case WireFormatLite::TYPE_INT32:
373 switch (dtype) {
374 case DataType::DT_INT64:
375 return WriteField<int64_t, int32, WireFormatLite::TYPE_INT32,
376 WireFormatLite::WriteInt32NoTag>(
377 field_desc, input, message_index, size, output);
378 case DataType::DT_INT32:
379 return WriteField<int32, int32, WireFormatLite::TYPE_INT32,
380 WireFormatLite::WriteInt32NoTag>(
381 field_desc, input, message_index, size, output);
382 default:
383 return errors::DataLoss("Failed writing TYPE_INT32 for ",
384 DataTypeString(dtype));
385 }
386 case WireFormatLite::TYPE_FIXED64:
387 return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_FIXED64,
388 WireFormatLite::WriteFixed64NoTag>(
389 field_desc, input, message_index, size, output);
390 case WireFormatLite::TYPE_FIXED32:
391 switch (dtype) {
392 case DataType::DT_UINT64:
393 return WriteField<uint64, uint32, WireFormatLite::TYPE_FIXED32,
394 WireFormatLite::WriteFixed32NoTag>(
395 field_desc, input, message_index, size, output);
396 case DataType::DT_UINT32:
397 return WriteField<uint32, uint32, WireFormatLite::TYPE_FIXED32,
398 WireFormatLite::WriteFixed32NoTag>(
399 field_desc, input, message_index, size, output);
400 default:
401 return errors::DataLoss("Failed writing TYPE_FIXED32 for ",
402 DataTypeString(dtype));
403 }
404 case WireFormatLite::TYPE_BOOL:
405 return WriteField<bool, bool, WireFormatLite::TYPE_BOOL,
406 WireFormatLite::WriteBoolNoTag>(
407 field_desc, input, message_index, size, output);
408 case WireFormatLite::TYPE_STRING:
409 return WriteVarLenField<tstring, WriteStringAdapter>(
410 field_desc, input, message_index, size, output);
411 case WireFormatLite::TYPE_GROUP:
412 return WriteGroup(field_desc, input, message_index, size, output);
413 case WireFormatLite::TYPE_MESSAGE:
414 return WriteVarLenField<tstring, WriteBytesAdapter>(
415 field_desc, input, message_index, size, output);
416 case WireFormatLite::TYPE_BYTES:
417 return WriteVarLenField<tstring, WriteBytesAdapter>(
418 field_desc, input, message_index, size, output);
419 case WireFormatLite::TYPE_UINT32:
420 switch (dtype) {
421 case DataType::DT_UINT64:
422 return WriteField<uint64, uint32, WireFormatLite::TYPE_UINT32,
423 WireFormatLite::WriteUInt32NoTag>(
424 field_desc, input, message_index, size, output);
425 case DataType::DT_UINT32:
426 return WriteField<uint32, uint32, WireFormatLite::TYPE_UINT32,
427 WireFormatLite::WriteUInt32NoTag>(
428 field_desc, input, message_index, size, output);
429 default:
430 return errors::DataLoss("Failed writing TYPE_UINT32 for ",
431 DataTypeString(dtype));
432 }
433 case WireFormatLite::TYPE_ENUM:
434 return WriteField<int32, int32, WireFormatLite::TYPE_ENUM,
435 WireFormatLite::WriteEnumNoTag>(
436 field_desc, input, message_index, size, output);
437 case WireFormatLite::TYPE_SFIXED32:
438 switch (dtype) {
439 case DataType::DT_INT64:
440 return WriteField<int64_t, int32, WireFormatLite::TYPE_SFIXED32,
441 WireFormatLite::WriteSFixed32NoTag>(
442 field_desc, input, message_index, size, output);
443 case DataType::DT_INT32:
444 return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32,
445 WireFormatLite::WriteSFixed32NoTag>(
446 field_desc, input, message_index, size, output);
447 default:
448 return errors::DataLoss("Failed writing TYPE_SFIXED32 for ",
449 DataTypeString(dtype));
450 }
451 case WireFormatLite::TYPE_SFIXED64:
452 return WriteField<int64_t, protobuf_int64, WireFormatLite::TYPE_SFIXED64,
453 WireFormatLite::WriteSFixed64NoTag>(
454 field_desc, input, message_index, size, output);
455 case WireFormatLite::TYPE_SINT32:
456 switch (dtype) {
457 case DataType::DT_INT64:
458 return WriteField<int64_t, int32, WireFormatLite::TYPE_SINT32,
459 WireFormatLite::WriteSInt32NoTag>(
460 field_desc, input, message_index, size, output);
461 case DataType::DT_INT32:
462 return WriteField<int32, int32, WireFormatLite::TYPE_SINT32,
463 WireFormatLite::WriteSInt32NoTag>(
464 field_desc, input, message_index, size, output);
465 default:
466 return errors::DataLoss("Failed writing TYPE_SINT32 for ",
467 DataTypeString(dtype));
468 }
469 case WireFormatLite::TYPE_SINT64:
470 return WriteField<int64_t, protobuf_int64, WireFormatLite::TYPE_SINT64,
471 WireFormatLite::WriteSInt64NoTag>(
472 field_desc, input, message_index, size, output);
473 // default: intentionally omitted in order to enable static checking.
474 }
475}
476
477class EncodeProtoOp : public OpKernel {
478 public:
479 explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) {
480 string descriptor_source;
481 OP_REQUIRES_OK(context,
482 context->GetAttr("descriptor_source", &descriptor_source));
483 // We always get back a desc_pool, but we may not own it. If we own it,
484 // owned_desc_pool_ will be filled in.
485 DescriptorPool const* desc_pool;
486 OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source,
487 &desc_pool, &owned_desc_pool_));
488
489 string message_type;
490 OP_REQUIRES_OK(context, context->GetAttr("message_type", &message_type));
491 const Descriptor* message_desc =
492 desc_pool->FindMessageTypeByName(message_type);
493 OP_REQUIRES(context, message_desc != nullptr,
494 errors::InvalidArgument("No descriptor found for message type ",
495 message_type));
496
497 OP_REQUIRES_OK(context, context->GetAttr("field_names", &field_names_));
498
499 // Gather the field descriptors for the given field_names.
500 field_descs_.resize(field_names_.size());
501 for (int i = 0; i < field_names_.size(); i++) {
502 const string& name = field_names_[i];
503 auto field_desc = message_desc->FindFieldByName(name);
504 OP_REQUIRES(context, field_desc != nullptr,
505 errors::InvalidArgument("Unknown field: ", name,
506 " in message type ", message_type));
507
508 field_descs_[i] = field_desc;
509 }
510
511 // Build a list of indices into field_descs sorted by increasing
512 // field_number. This will be used to output fields in sorted order,
513 // which is strongly encouraged when serializing protobufs.
514 sorted_field_index_.resize(field_names_.size());
515 // Start with the fields sorted by current index.
516 for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i;
517 // Then sort the field indices by their proto field number.
518 std::sort(sorted_field_index_.begin(), sorted_field_index_.end(),
519 [this](int a, int b) -> bool {
520 return field_descs_[a]->number() < field_descs_[b]->number();
521 });
522 }
523
524 void Compute(OpKernelContext* ctx) override {
525 const Tensor* sizes_tensor;
526 OP_REQUIRES_OK(ctx, ctx->input("sizes", &sizes_tensor));
527
528 OpInputList values;
529 OP_REQUIRES_OK(ctx, ctx->input_list("values", &values));
530
531 OP_REQUIRES(ctx, field_descs_.size() == values.size(),
532 errors::InvalidArgument(
533 "Length of inputs list must match field_names"));
534
535 // Check the arguments for consistency.
536 TensorShape common_prefix;
537 int message_count = 0;
538 for (int i = 0; i < field_descs_.size(); i++) {
539 const Tensor& v = values[i];
540
541 // The type of each value tensor must match the corresponding field.
542 OP_REQUIRES(
543 ctx,
544 proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()),
545 errors::InvalidArgument(
546 "Incompatible type for field ", field_names_[i],
547 ". Saw dtype: ", DataTypeString(v.dtype()),
548 " but field type is: ", field_descs_[i]->type_name()));
549
550 OP_REQUIRES(
551 ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()),
552 errors::InvalidArgument("Invalid shape for field ", field_names_[i],
553 ". Saw shape ", v.shape().DebugString(),
554 " but it should be at least a matrix."));
555
556 // All value tensors must have the same shape prefix (i.e. batch size).
557 TensorShape shape_prefix = v.shape();
558 shape_prefix.RemoveDim(shape_prefix.dims() - 1);
559
560 // Do some initialization on the first input value. The rest will
561 // have to match this one.
562 if (i == 0) {
563 OP_REQUIRES(ctx, v.dims() >= 1,
564 errors::InvalidArgument(
565 "Expected value to be at least a vector, saw shape: ",
566 v.shape().DebugString()));
567 common_prefix = shape_prefix;
568 message_count = common_prefix.num_elements();
569 } else {
570 OP_REQUIRES(ctx, shape_prefix == common_prefix,
571 errors::InvalidArgument(
572 "Values must match up to the last dimension"));
573 }
574 }
575
576 TensorShape expected_sizes_shape = common_prefix;
577 expected_sizes_shape.AddDim(field_descs_.size());
578
579 OP_REQUIRES(ctx, sizes_tensor->shape() == expected_sizes_shape,
580 errors::InvalidArgument(
581 "sizes should be batch_size + [len(field_names)]. Saw: ",
582 sizes_tensor->shape().DebugString(),
583 " but expected: ", expected_sizes_shape.DebugString()));
584
585 auto sizes = sizes_tensor->flat_inner_dims<int32>();
586
587 for (int i = 0; i < field_descs_.size(); ++i) {
588 const Tensor& v = values[i];
589 int max_size = v.dim_size(v.dims() - 1);
590
591 // The last dimension of a value tensor must be greater than the
592 // corresponding size in the sizes tensor.
593 for (int message_index = 0; message_index < message_count;
594 message_index++) {
595 OP_REQUIRES(
596 ctx, sizes(message_index, i) <= max_size,
597 errors::InvalidArgument(
598 "Size to write must not be larger than value tensor; but saw: ",
599 sizes(message_index, i), " > ", max_size, " at message ",
600 message_index, " field ", i));
601 }
602 }
603
604 // This pointer is owned by the context.
605 Tensor* output_tensor;
606 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, common_prefix, &output_tensor));
607
608 auto bufs = output_tensor->flat<tstring>();
609 for (int message_index = 0; message_index < message_count;
610 message_index++) {
611 // TODO(nix): possibly optimize allocation here by calling
612 // `bufs(message_index).reserve(DEFAULT_BUF_SIZE)`.
613 TStringOutputStream output_string(&bufs(message_index));
614 CodedOutputStream out(&output_string);
615 // Write fields in ascending field_number order.
616 for (int i : sorted_field_index_) {
617 auto& field_desc = *field_descs_[i];
618 const Tensor& v = values[i];
619 int size = sizes(message_index, i);
620 if (!size) continue;
621 OP_REQUIRES_OK(ctx,
622 WriteField(field_desc, v, message_index, size, &out));
623 }
624 }
625 }
626
627 private:
628 std::vector<string> field_names_;
629 std::vector<const FieldDescriptor*> field_descs_;
630
631 // Owned_desc_pool_ is null when using descriptor_source=local.
632 std::unique_ptr<DescriptorPool> owned_desc_pool_;
633
634 // Contains indices into field_names_, sorted by field number since that's the
635 // order of writing.
636 std::vector<int> sorted_field_index_;
637
638 TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp);
639};
640
641REGISTER_KERNEL_BUILDER(Name("EncodeProto").Device(DEVICE_CPU), EncodeProtoOp);
642
643} // namespace
644} // namespace tensorflow
645