1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // EncodeProto is a TensorFlow Op which serializes tensors into |
17 | // arbitrary protobufs. |
18 | // |
19 | // See the docstring in ../ops/encode_proto_op.cc for usage of the op. |
20 | // |
21 | // This implementation writes the serialized format using a handful of |
22 | // calls from the WireFormatLite API. |
23 | |
24 | #include <memory> |
25 | #include <vector> |
26 | |
27 | #include "third_party/eigen3/Eigen/Core" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/tensor_types.h" |
30 | #include "tensorflow/core/lib/core/errors.h" |
31 | #include "tensorflow/core/platform/logging.h" |
32 | #include "tensorflow/core/platform/protobuf.h" |
33 | #include "tensorflow/core/util/proto/descriptors.h" |
34 | #include "tensorflow/core/util/proto/proto_utils.h" |
35 | |
36 | namespace tensorflow { |
37 | namespace { |
38 | |
39 | using ::tensorflow::protobuf::Descriptor; |
40 | using ::tensorflow::protobuf::DescriptorPool; |
41 | using ::tensorflow::protobuf::FieldDescriptor; |
42 | using ::tensorflow::protobuf::internal::WireFormatLite; |
43 | using ::tensorflow::protobuf::io::CodedOutputStream; |
44 | using ::tensorflow::protobuf::io::StringOutputStream; |
45 | |
46 | // Computes the total serialized size for a packed repeated field. For |
47 | // fixed-size types this can just multiply, but for variable-sized types it has |
48 | // to iterate through the values in the tensor. |
49 | template <WireFormatLite::FieldType FieldType, typename TensorT> |
50 | size_t TotalPackedSize(const Tensor& input, int message_index, int size); |
51 | |
52 | template <> |
53 | size_t TotalPackedSize<WireFormatLite::TYPE_DOUBLE, double>(const Tensor& input, |
54 | int message_index, |
55 | int size) { |
56 | return size * WireFormatLite::kDoubleSize; |
57 | } |
58 | |
59 | template <> |
60 | size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, double>(const Tensor& input, |
61 | int message_index, |
62 | int size) { |
63 | return size * WireFormatLite::kFloatSize; |
64 | } |
65 | |
66 | template <> |
67 | size_t TotalPackedSize<WireFormatLite::TYPE_FLOAT, float>(const Tensor& input, |
68 | int message_index, |
69 | int size) { |
70 | return size * WireFormatLite::kFloatSize; |
71 | } |
72 | |
73 | template <> |
74 | size_t TotalPackedSize<WireFormatLite::TYPE_INT64, int64_t>(const Tensor& input, |
75 | int message_index, |
76 | int size) { |
77 | size_t data_size = 0; |
78 | auto input_t = input.flat_inner_dims<int64_t>(); |
79 | for (int64_t i = 0; i < size; i++) { |
80 | data_size += WireFormatLite::Int64Size( |
81 | input_t(static_cast<int64_t>(message_index), i)); |
82 | } |
83 | return data_size; |
84 | } |
85 | |
86 | template <> |
87 | size_t TotalPackedSize<WireFormatLite::TYPE_UINT64, uint64>(const Tensor& input, |
88 | int message_index, |
89 | int size) { |
90 | size_t data_size = 0; |
91 | auto input_t = input.flat_inner_dims<uint64>(); |
92 | for (int64_t i = 0; i < size; i++) { |
93 | data_size += WireFormatLite::UInt64Size( |
94 | input_t(static_cast<int64_t>(message_index), i)); |
95 | } |
96 | return data_size; |
97 | } |
98 | |
99 | template <> |
100 | size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int64_t>(const Tensor& input, |
101 | int message_index, |
102 | int size) { |
103 | size_t data_size = 0; |
104 | auto input_t = input.flat_inner_dims<int64_t>(); |
105 | for (int64_t i = 0; i < size; i++) { |
106 | data_size += WireFormatLite::Int32Size( |
107 | input_t(static_cast<int64_t>(message_index), i)); |
108 | } |
109 | return data_size; |
110 | } |
111 | |
112 | template <> |
113 | size_t TotalPackedSize<WireFormatLite::TYPE_INT32, int32>(const Tensor& input, |
114 | int message_index, |
115 | int size) { |
116 | size_t data_size = 0; |
117 | auto input_t = input.flat_inner_dims<int32>(); |
118 | for (int64_t i = 0; i < size; i++) { |
119 | data_size += WireFormatLite::Int32Size( |
120 | input_t(static_cast<int64_t>(message_index), i)); |
121 | } |
122 | return data_size; |
123 | } |
124 | |
125 | template <> |
126 | size_t TotalPackedSize<WireFormatLite::TYPE_FIXED64, uint64>( |
127 | const Tensor& input, int message_index, int size) { |
128 | return size * WireFormatLite::kFixed64Size; |
129 | } |
130 | |
131 | template <> |
132 | size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint64>( |
133 | const Tensor& input, int message_index, int size) { |
134 | return size * WireFormatLite::kFixed32Size; |
135 | } |
136 | |
137 | template <> |
138 | size_t TotalPackedSize<WireFormatLite::TYPE_FIXED32, uint32>( |
139 | const Tensor& input, int message_index, int size) { |
140 | return size * WireFormatLite::kFixed32Size; |
141 | } |
142 | |
143 | template <> |
144 | size_t TotalPackedSize<WireFormatLite::TYPE_BOOL, bool>(const Tensor& input, |
145 | int message_index, |
146 | int size) { |
147 | return size * WireFormatLite::kBoolSize; |
148 | } |
149 | |
150 | template <> |
151 | size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint64>(const Tensor& input, |
152 | int message_index, |
153 | int size) { |
154 | size_t data_size = 0; |
155 | auto input_t = input.flat_inner_dims<uint64>(); |
156 | for (int64_t i = 0; i < size; i++) { |
157 | data_size += WireFormatLite::UInt32Size( |
158 | input_t(static_cast<int64_t>(message_index), i)); |
159 | } |
160 | return data_size; |
161 | } |
162 | |
163 | template <> |
164 | size_t TotalPackedSize<WireFormatLite::TYPE_UINT32, uint32>(const Tensor& input, |
165 | int message_index, |
166 | int size) { |
167 | size_t data_size = 0; |
168 | auto input_t = input.flat_inner_dims<uint32>(); |
169 | for (int64_t i = 0; i < size; i++) { |
170 | data_size += WireFormatLite::UInt32Size( |
171 | input_t(static_cast<int64_t>(message_index), i)); |
172 | } |
173 | return data_size; |
174 | } |
175 | |
176 | template <> |
177 | size_t TotalPackedSize<WireFormatLite::TYPE_ENUM, int32>(const Tensor& input, |
178 | int message_index, |
179 | int size) { |
180 | size_t data_size = 0; |
181 | auto input_t = input.flat_inner_dims<int32>(); |
182 | for (int64_t i = 0; i < size; i++) { |
183 | data_size += WireFormatLite::EnumSize( |
184 | input_t(static_cast<int64_t>(message_index), i)); |
185 | } |
186 | return data_size; |
187 | } |
188 | |
189 | template <> |
190 | size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int32>( |
191 | const Tensor& input, int message_index, int size) { |
192 | return size * WireFormatLite::kSFixed32Size; |
193 | } |
194 | |
195 | template <> |
196 | size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED32, int64_t>( |
197 | const Tensor& input, int message_index, int size) { |
198 | return size * WireFormatLite::kSFixed32Size; |
199 | } |
200 | |
201 | template <> |
202 | size_t TotalPackedSize<WireFormatLite::TYPE_SFIXED64, int64_t>( |
203 | const Tensor& input, int message_index, int size) { |
204 | return size * WireFormatLite::kSFixed64Size; |
205 | } |
206 | |
207 | template <> |
208 | size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int32>(const Tensor& input, |
209 | int message_index, |
210 | int size) { |
211 | size_t data_size = 0; |
212 | auto input_t = input.flat_inner_dims<int32>(); |
213 | for (int64_t i = 0; i < size; i++) { |
214 | data_size += WireFormatLite::SInt32Size( |
215 | input_t(static_cast<int64_t>(message_index), i)); |
216 | } |
217 | return data_size; |
218 | } |
219 | |
220 | template <> |
221 | size_t TotalPackedSize<WireFormatLite::TYPE_SINT32, int64_t>( |
222 | const Tensor& input, int message_index, int size) { |
223 | size_t data_size = 0; |
224 | auto input_t = input.flat_inner_dims<int64_t>(); |
225 | for (int64_t i = 0; i < size; i++) { |
226 | data_size += WireFormatLite::SInt32Size( |
227 | input_t(static_cast<int64_t>(message_index), i)); |
228 | } |
229 | return data_size; |
230 | } |
231 | |
232 | template <> |
233 | size_t TotalPackedSize<WireFormatLite::TYPE_SINT64, int64_t>( |
234 | const Tensor& input, int message_index, int size) { |
235 | size_t data_size = 0; |
236 | auto input_t = input.flat_inner_dims<int64_t>(); |
237 | for (int64_t i = 0; i < size; i++) { |
238 | data_size += WireFormatLite::SInt64Size( |
239 | input_t(static_cast<int64_t>(message_index), i)); |
240 | } |
241 | return data_size; |
242 | } |
243 | |
244 | // Writes a possibly repeated primitive field. TensorFlow does not have unsigned |
245 | // types, so we decode them to signed and encode them back to unsigned. |
246 | template <typename TensorT, typename ProtoT, |
247 | WireFormatLite::FieldType FieldType, |
248 | void Writer(ProtoT, CodedOutputStream*)> |
249 | Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, |
250 | int message_index, int size, CodedOutputStream* output) { |
251 | auto wire_type = WireFormatLite::WireTypeForFieldType( |
252 | WireFormatLite::FieldType(field_desc.type())); |
253 | |
254 | auto input_t = input.flat_inner_dims<TensorT>(); |
255 | if (field_desc.options().packed()) { |
256 | // Write the tag for the packed field. |
257 | WireFormatLite::WriteTag(field_desc.number(), |
258 | WireFormatLite::WIRETYPE_LENGTH_DELIMITED, output); |
259 | |
260 | // Write the total packed length. |
261 | size_t data_size = |
262 | TotalPackedSize<FieldType, TensorT>(input, message_index, size); |
263 | output->WriteVarint32(data_size); |
264 | |
265 | // Write individual values. |
266 | for (int64_t i = 0; i < size; i++) { |
267 | // Note implicit cast from signed to unsigned. |
268 | const ProtoT& value = input_t(static_cast<int64_t>(message_index), i); |
269 | Writer(value, output); |
270 | } |
271 | } else { |
272 | for (int64_t i = 0; i < size; i++) { |
273 | WireFormatLite::WriteTag(field_desc.number(), wire_type, output); |
274 | |
275 | // Note implicit cast from signed to unsigned. |
276 | const ProtoT& value = input_t(static_cast<int64_t>(message_index), i); |
277 | Writer(value, output); |
278 | } |
279 | } |
280 | return OkStatus(); |
281 | } |
282 | |
283 | // Writes a possibly repeated string, bytes, or message field. |
284 | template <typename T, void Writer(int, const T&, CodedOutputStream*)> |
285 | Status WriteVarLenField(const FieldDescriptor& field_desc, const Tensor& input, |
286 | int message_index, int size, |
287 | CodedOutputStream* output) { |
288 | auto input_t = input.flat_inner_dims<T>(); |
289 | for (int64_t i = 0; i < size; i++) { |
290 | const T& value = input_t(static_cast<int64_t>(message_index), i); |
291 | // TODO(nix): there doesn't seem to be an inlined version of |
292 | // WireFormatLite::WriteString or its relatives, which might allow a |
293 | // small speedup. |
294 | Writer(field_desc.number(), value, output); |
295 | } |
296 | return OkStatus(); |
297 | } |
298 | |
299 | static void WriteStringAdapter(int field_number, const tstring& value, |
300 | CodedOutputStream* output) { |
301 | // Unfortunately, external proto does not accept string_view. |
302 | #if defined(PLATFORM_GOOGLE) |
303 | WireFormatLite::WriteString(field_number, StringPiece(value), output); |
304 | #else |
305 | WireFormatLite::WriteString(field_number, string(value), output); |
306 | #endif |
307 | } |
308 | |
309 | static void WriteBytesAdapter(int field_number, const tstring& value, |
310 | CodedOutputStream* output) { |
311 | // Unfortunately, external proto does not accept string_view. |
312 | #if defined(PLATFORM_GOOGLE) |
313 | WireFormatLite::WriteBytes(field_number, StringPiece(value), output); |
314 | #else |
315 | WireFormatLite::WriteBytes(field_number, string(value), output); |
316 | #endif |
317 | } |
318 | |
319 | // Writes a group field. Groups are treated like submessages, but tag-delimited |
320 | // instead of length-delimited. WireFormatLite handles this differently so we |
321 | // code it ourselves. |
322 | Status WriteGroup(const FieldDescriptor& field_desc, const Tensor& input, |
323 | int message_index, int size, CodedOutputStream* output) { |
324 | auto input_t = input.flat_inner_dims<tstring>(); |
325 | for (int64_t i = 0; i < size; i++) { |
326 | const string& value = input_t(static_cast<int64_t>(message_index), i); |
327 | WireFormatLite::WriteTag(field_desc.number(), |
328 | WireFormatLite::WIRETYPE_START_GROUP, output); |
329 | // Note the use of WriteRaw instead of WriteString to skip the length. |
330 | output->WriteRaw(value.data(), value.size()); |
331 | WireFormatLite::WriteTag(field_desc.number(), |
332 | WireFormatLite::WIRETYPE_END_GROUP, output); |
333 | } |
334 | return OkStatus(); |
335 | } |
336 | |
337 | // Writes a (possibly repeated) field into an output stream. It is the caller's |
338 | // responsibility to ensure that the type of the input tensor is compatible with |
339 | // the type of the proto field descriptor, and that (message_index, size-1) is |
340 | // within bounds. |
341 | Status WriteField(const FieldDescriptor& field_desc, const Tensor& input, |
342 | int message_index, int size, CodedOutputStream* output) { |
343 | DataType dtype = input.dtype(); |
344 | |
345 | switch (field_desc.type()) { |
346 | case WireFormatLite::TYPE_DOUBLE: |
347 | return WriteField<double, double, WireFormatLite::TYPE_DOUBLE, |
348 | WireFormatLite::WriteDoubleNoTag>( |
349 | field_desc, input, message_index, size, output); |
350 | case WireFormatLite::TYPE_FLOAT: |
351 | switch (dtype) { |
352 | case DataType::DT_FLOAT: |
353 | return WriteField<float, float, WireFormatLite::TYPE_FLOAT, |
354 | WireFormatLite::WriteFloatNoTag>( |
355 | field_desc, input, message_index, size, output); |
356 | case DataType::DT_DOUBLE: |
357 | return WriteField<double, float, WireFormatLite::TYPE_FLOAT, |
358 | WireFormatLite::WriteFloatNoTag>( |
359 | field_desc, input, message_index, size, output); |
360 | default: |
361 | return errors::DataLoss("Failed writing TYPE_FLOAT for " , |
362 | DataTypeString(dtype)); |
363 | } |
364 | case WireFormatLite::TYPE_INT64: |
365 | return WriteField<int64_t, protobuf_int64, WireFormatLite::TYPE_INT64, |
366 | WireFormatLite::WriteInt64NoTag>( |
367 | field_desc, input, message_index, size, output); |
368 | case WireFormatLite::TYPE_UINT64: |
369 | return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_UINT64, |
370 | WireFormatLite::WriteUInt64NoTag>( |
371 | field_desc, input, message_index, size, output); |
372 | case WireFormatLite::TYPE_INT32: |
373 | switch (dtype) { |
374 | case DataType::DT_INT64: |
375 | return WriteField<int64_t, int32, WireFormatLite::TYPE_INT32, |
376 | WireFormatLite::WriteInt32NoTag>( |
377 | field_desc, input, message_index, size, output); |
378 | case DataType::DT_INT32: |
379 | return WriteField<int32, int32, WireFormatLite::TYPE_INT32, |
380 | WireFormatLite::WriteInt32NoTag>( |
381 | field_desc, input, message_index, size, output); |
382 | default: |
383 | return errors::DataLoss("Failed writing TYPE_INT32 for " , |
384 | DataTypeString(dtype)); |
385 | } |
386 | case WireFormatLite::TYPE_FIXED64: |
387 | return WriteField<uint64, protobuf_uint64, WireFormatLite::TYPE_FIXED64, |
388 | WireFormatLite::WriteFixed64NoTag>( |
389 | field_desc, input, message_index, size, output); |
390 | case WireFormatLite::TYPE_FIXED32: |
391 | switch (dtype) { |
392 | case DataType::DT_UINT64: |
393 | return WriteField<uint64, uint32, WireFormatLite::TYPE_FIXED32, |
394 | WireFormatLite::WriteFixed32NoTag>( |
395 | field_desc, input, message_index, size, output); |
396 | case DataType::DT_UINT32: |
397 | return WriteField<uint32, uint32, WireFormatLite::TYPE_FIXED32, |
398 | WireFormatLite::WriteFixed32NoTag>( |
399 | field_desc, input, message_index, size, output); |
400 | default: |
401 | return errors::DataLoss("Failed writing TYPE_FIXED32 for " , |
402 | DataTypeString(dtype)); |
403 | } |
404 | case WireFormatLite::TYPE_BOOL: |
405 | return WriteField<bool, bool, WireFormatLite::TYPE_BOOL, |
406 | WireFormatLite::WriteBoolNoTag>( |
407 | field_desc, input, message_index, size, output); |
408 | case WireFormatLite::TYPE_STRING: |
409 | return WriteVarLenField<tstring, WriteStringAdapter>( |
410 | field_desc, input, message_index, size, output); |
411 | case WireFormatLite::TYPE_GROUP: |
412 | return WriteGroup(field_desc, input, message_index, size, output); |
413 | case WireFormatLite::TYPE_MESSAGE: |
414 | return WriteVarLenField<tstring, WriteBytesAdapter>( |
415 | field_desc, input, message_index, size, output); |
416 | case WireFormatLite::TYPE_BYTES: |
417 | return WriteVarLenField<tstring, WriteBytesAdapter>( |
418 | field_desc, input, message_index, size, output); |
419 | case WireFormatLite::TYPE_UINT32: |
420 | switch (dtype) { |
421 | case DataType::DT_UINT64: |
422 | return WriteField<uint64, uint32, WireFormatLite::TYPE_UINT32, |
423 | WireFormatLite::WriteUInt32NoTag>( |
424 | field_desc, input, message_index, size, output); |
425 | case DataType::DT_UINT32: |
426 | return WriteField<uint32, uint32, WireFormatLite::TYPE_UINT32, |
427 | WireFormatLite::WriteUInt32NoTag>( |
428 | field_desc, input, message_index, size, output); |
429 | default: |
430 | return errors::DataLoss("Failed writing TYPE_UINT32 for " , |
431 | DataTypeString(dtype)); |
432 | } |
433 | case WireFormatLite::TYPE_ENUM: |
434 | return WriteField<int32, int32, WireFormatLite::TYPE_ENUM, |
435 | WireFormatLite::WriteEnumNoTag>( |
436 | field_desc, input, message_index, size, output); |
437 | case WireFormatLite::TYPE_SFIXED32: |
438 | switch (dtype) { |
439 | case DataType::DT_INT64: |
440 | return WriteField<int64_t, int32, WireFormatLite::TYPE_SFIXED32, |
441 | WireFormatLite::WriteSFixed32NoTag>( |
442 | field_desc, input, message_index, size, output); |
443 | case DataType::DT_INT32: |
444 | return WriteField<int32, int32, WireFormatLite::TYPE_SFIXED32, |
445 | WireFormatLite::WriteSFixed32NoTag>( |
446 | field_desc, input, message_index, size, output); |
447 | default: |
448 | return errors::DataLoss("Failed writing TYPE_SFIXED32 for " , |
449 | DataTypeString(dtype)); |
450 | } |
451 | case WireFormatLite::TYPE_SFIXED64: |
452 | return WriteField<int64_t, protobuf_int64, WireFormatLite::TYPE_SFIXED64, |
453 | WireFormatLite::WriteSFixed64NoTag>( |
454 | field_desc, input, message_index, size, output); |
455 | case WireFormatLite::TYPE_SINT32: |
456 | switch (dtype) { |
457 | case DataType::DT_INT64: |
458 | return WriteField<int64_t, int32, WireFormatLite::TYPE_SINT32, |
459 | WireFormatLite::WriteSInt32NoTag>( |
460 | field_desc, input, message_index, size, output); |
461 | case DataType::DT_INT32: |
462 | return WriteField<int32, int32, WireFormatLite::TYPE_SINT32, |
463 | WireFormatLite::WriteSInt32NoTag>( |
464 | field_desc, input, message_index, size, output); |
465 | default: |
466 | return errors::DataLoss("Failed writing TYPE_SINT32 for " , |
467 | DataTypeString(dtype)); |
468 | } |
469 | case WireFormatLite::TYPE_SINT64: |
470 | return WriteField<int64_t, protobuf_int64, WireFormatLite::TYPE_SINT64, |
471 | WireFormatLite::WriteSInt64NoTag>( |
472 | field_desc, input, message_index, size, output); |
473 | // default: intentionally omitted in order to enable static checking. |
474 | } |
475 | } |
476 | |
477 | class EncodeProtoOp : public OpKernel { |
478 | public: |
479 | explicit EncodeProtoOp(OpKernelConstruction* context) : OpKernel(context) { |
480 | string descriptor_source; |
481 | OP_REQUIRES_OK(context, |
482 | context->GetAttr("descriptor_source" , &descriptor_source)); |
483 | // We always get back a desc_pool, but we may not own it. If we own it, |
484 | // owned_desc_pool_ will be filled in. |
485 | DescriptorPool const* desc_pool; |
486 | OP_REQUIRES_OK(context, GetDescriptorPool(context->env(), descriptor_source, |
487 | &desc_pool, &owned_desc_pool_)); |
488 | |
489 | string message_type; |
490 | OP_REQUIRES_OK(context, context->GetAttr("message_type" , &message_type)); |
491 | const Descriptor* message_desc = |
492 | desc_pool->FindMessageTypeByName(message_type); |
493 | OP_REQUIRES(context, message_desc != nullptr, |
494 | errors::InvalidArgument("No descriptor found for message type " , |
495 | message_type)); |
496 | |
497 | OP_REQUIRES_OK(context, context->GetAttr("field_names" , &field_names_)); |
498 | |
499 | // Gather the field descriptors for the given field_names. |
500 | field_descs_.resize(field_names_.size()); |
501 | for (int i = 0; i < field_names_.size(); i++) { |
502 | const string& name = field_names_[i]; |
503 | auto field_desc = message_desc->FindFieldByName(name); |
504 | OP_REQUIRES(context, field_desc != nullptr, |
505 | errors::InvalidArgument("Unknown field: " , name, |
506 | " in message type " , message_type)); |
507 | |
508 | field_descs_[i] = field_desc; |
509 | } |
510 | |
511 | // Build a list of indices into field_descs sorted by increasing |
512 | // field_number. This will be used to output fields in sorted order, |
513 | // which is strongly encouraged when serializing protobufs. |
514 | sorted_field_index_.resize(field_names_.size()); |
515 | // Start with the fields sorted by current index. |
516 | for (int i = 0; i < field_names_.size(); i++) sorted_field_index_[i] = i; |
517 | // Then sort the field indices by their proto field number. |
518 | std::sort(sorted_field_index_.begin(), sorted_field_index_.end(), |
519 | [this](int a, int b) -> bool { |
520 | return field_descs_[a]->number() < field_descs_[b]->number(); |
521 | }); |
522 | } |
523 | |
524 | void Compute(OpKernelContext* ctx) override { |
525 | const Tensor* sizes_tensor; |
526 | OP_REQUIRES_OK(ctx, ctx->input("sizes" , &sizes_tensor)); |
527 | |
528 | OpInputList values; |
529 | OP_REQUIRES_OK(ctx, ctx->input_list("values" , &values)); |
530 | |
531 | OP_REQUIRES(ctx, field_descs_.size() == values.size(), |
532 | errors::InvalidArgument( |
533 | "Length of inputs list must match field_names" )); |
534 | |
535 | // Check the arguments for consistency. |
536 | TensorShape common_prefix; |
537 | int message_count = 0; |
538 | for (int i = 0; i < field_descs_.size(); i++) { |
539 | const Tensor& v = values[i]; |
540 | |
541 | // The type of each value tensor must match the corresponding field. |
542 | OP_REQUIRES( |
543 | ctx, |
544 | proto_utils::IsCompatibleType(field_descs_[i]->type(), v.dtype()), |
545 | errors::InvalidArgument( |
546 | "Incompatible type for field " , field_names_[i], |
547 | ". Saw dtype: " , DataTypeString(v.dtype()), |
548 | " but field type is: " , field_descs_[i]->type_name())); |
549 | |
550 | OP_REQUIRES( |
551 | ctx, TensorShapeUtils::IsMatrixOrHigher(v.shape()), |
552 | errors::InvalidArgument("Invalid shape for field " , field_names_[i], |
553 | ". Saw shape " , v.shape().DebugString(), |
554 | " but it should be at least a matrix." )); |
555 | |
556 | // All value tensors must have the same shape prefix (i.e. batch size). |
557 | TensorShape shape_prefix = v.shape(); |
558 | shape_prefix.RemoveDim(shape_prefix.dims() - 1); |
559 | |
560 | // Do some initialization on the first input value. The rest will |
561 | // have to match this one. |
562 | if (i == 0) { |
563 | OP_REQUIRES(ctx, v.dims() >= 1, |
564 | errors::InvalidArgument( |
565 | "Expected value to be at least a vector, saw shape: " , |
566 | v.shape().DebugString())); |
567 | common_prefix = shape_prefix; |
568 | message_count = common_prefix.num_elements(); |
569 | } else { |
570 | OP_REQUIRES(ctx, shape_prefix == common_prefix, |
571 | errors::InvalidArgument( |
572 | "Values must match up to the last dimension" )); |
573 | } |
574 | } |
575 | |
576 | TensorShape expected_sizes_shape = common_prefix; |
577 | expected_sizes_shape.AddDim(field_descs_.size()); |
578 | |
579 | OP_REQUIRES(ctx, sizes_tensor->shape() == expected_sizes_shape, |
580 | errors::InvalidArgument( |
581 | "sizes should be batch_size + [len(field_names)]. Saw: " , |
582 | sizes_tensor->shape().DebugString(), |
583 | " but expected: " , expected_sizes_shape.DebugString())); |
584 | |
585 | auto sizes = sizes_tensor->flat_inner_dims<int32>(); |
586 | |
587 | for (int i = 0; i < field_descs_.size(); ++i) { |
588 | const Tensor& v = values[i]; |
589 | int max_size = v.dim_size(v.dims() - 1); |
590 | |
591 | // The last dimension of a value tensor must be greater than the |
592 | // corresponding size in the sizes tensor. |
593 | for (int message_index = 0; message_index < message_count; |
594 | message_index++) { |
595 | OP_REQUIRES( |
596 | ctx, sizes(message_index, i) <= max_size, |
597 | errors::InvalidArgument( |
598 | "Size to write must not be larger than value tensor; but saw: " , |
599 | sizes(message_index, i), " > " , max_size, " at message " , |
600 | message_index, " field " , i)); |
601 | } |
602 | } |
603 | |
604 | // This pointer is owned by the context. |
605 | Tensor* output_tensor; |
606 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, common_prefix, &output_tensor)); |
607 | |
608 | auto bufs = output_tensor->flat<tstring>(); |
609 | for (int message_index = 0; message_index < message_count; |
610 | message_index++) { |
611 | // TODO(nix): possibly optimize allocation here by calling |
612 | // `bufs(message_index).reserve(DEFAULT_BUF_SIZE)`. |
613 | TStringOutputStream output_string(&bufs(message_index)); |
614 | CodedOutputStream out(&output_string); |
615 | // Write fields in ascending field_number order. |
616 | for (int i : sorted_field_index_) { |
617 | auto& field_desc = *field_descs_[i]; |
618 | const Tensor& v = values[i]; |
619 | int size = sizes(message_index, i); |
620 | if (!size) continue; |
621 | OP_REQUIRES_OK(ctx, |
622 | WriteField(field_desc, v, message_index, size, &out)); |
623 | } |
624 | } |
625 | } |
626 | |
627 | private: |
628 | std::vector<string> field_names_; |
629 | std::vector<const FieldDescriptor*> field_descs_; |
630 | |
631 | // Owned_desc_pool_ is null when using descriptor_source=local. |
632 | std::unique_ptr<DescriptorPool> owned_desc_pool_; |
633 | |
634 | // Contains indices into field_names_, sorted by field number since that's the |
635 | // order of writing. |
636 | std::vector<int> sorted_field_index_; |
637 | |
638 | TF_DISALLOW_COPY_AND_ASSIGN(EncodeProtoOp); |
639 | }; |
640 | |
641 | REGISTER_KERNEL_BUILDER(Name("EncodeProto" ).Device(DEVICE_CPU), EncodeProtoOp); |
642 | |
643 | } // namespace |
644 | } // namespace tensorflow |
645 | |