1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/op.h" |
17 | #include "tensorflow/core/framework/shape_inference.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | using tensorflow::shape_inference::InferenceContext; |
22 | using tensorflow::shape_inference::ShapeHandle; |
23 | |
24 | REGISTER_OP("DecodeProtoV2" ) |
25 | .Input("bytes: string" ) |
26 | .Attr("message_type: string" ) |
27 | .Attr("field_names: list(string)" ) |
28 | .Attr("output_types: list(type) >= 0" ) |
29 | .Attr("descriptor_source: string = 'local://'" ) |
30 | .Attr("message_format: string = 'binary'" ) |
31 | .Attr("sanitize: bool = false" ) |
32 | .Output("sizes: int32" ) |
33 | .Output("values: output_types" ) |
34 | .SetShapeFn([](InferenceContext* c) { |
35 | ShapeHandle input = c->input(0); |
36 | |
37 | std::vector<tensorflow::DataType> output_types; |
38 | TF_RETURN_IF_ERROR(c->GetAttr("output_types" , &output_types)); |
39 | |
40 | ShapeHandle sizes; |
41 | TF_RETURN_IF_ERROR( |
42 | c->Concatenate(input, c->Vector(output_types.size()), &sizes)); |
43 | c->set_output(0, sizes); |
44 | |
45 | // TODO(nix): to do the best possible job of shape inference, we |
46 | // should examine the proto descriptors here in order to set shape |
47 | // indices to 1 instead of unknown for optional or required fields. |
48 | // Any general-purpose code will have to handle the unknown case, |
49 | // but there might be XLA code that could be sped up with the additional |
50 | // knowledge. |
51 | for (int i = 0; i < output_types.size(); ++i) { |
52 | ShapeHandle values; |
53 | TF_RETURN_IF_ERROR( |
54 | c->Concatenate(input, c->Vector(c->UnknownDim()), &values)); |
55 | c->set_output(i + 1, values); |
56 | } |
57 | |
58 | return OkStatus(); |
59 | }); |
60 | |
61 | // TODO(nix): Consider adding an additional input argument that truncates |
62 | // repeated fields to a maximum count. For now this could be done by passing |
63 | // the output through tf.slice. |
64 | |
65 | // TODO(nix): define missing value behavior. |
66 | |
67 | } // namespace tensorflow |
68 | |