1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/variant.h" |
17 | |
18 | #include "tensorflow/core/framework/tensor.pb.h" |
19 | #include "tensorflow/core/framework/variant_encode_decode.h" |
20 | #include "tensorflow/core/framework/variant_op_registry.h" |
21 | |
22 | namespace tensorflow { |
23 | |
24 | Variant::~Variant() { ResetMemory(); } |
25 | |
26 | bool Variant::Decode(VariantTensorData data) { |
27 | if (!is_empty()) { |
28 | return GetValue()->Decode(std::move(data)); |
29 | } |
30 | return true; |
31 | } |
32 | |
33 | template <> |
34 | void* Variant::get() { |
35 | if (is_empty()) { |
36 | return nullptr; |
37 | } |
38 | return GetValue()->RawPtr(); |
39 | } |
40 | |
41 | template <> |
42 | const void* Variant::get() const { |
43 | if (is_empty()) { |
44 | return nullptr; |
45 | } |
46 | return GetValue()->RawPtr(); |
47 | } |
48 | |
49 | template <> |
50 | string TypeNameVariant(const VariantTensorDataProto& value) { |
51 | return value.type_name(); |
52 | } |
53 | |
54 | template <> |
55 | void EncodeVariant(const VariantTensorDataProto& value, |
56 | VariantTensorData* data) { |
57 | data->FromConstProto(value); |
58 | } |
59 | |
60 | template <> |
61 | bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) { |
62 | data->ToProto(value); |
63 | return true; |
64 | } |
65 | |
66 | template <> |
67 | void EncodeVariant(const VariantTensorDataProto& value, string* buf) { |
68 | value.SerializeToString(buf); |
69 | } |
70 | |
71 | template <> |
72 | bool DecodeVariant(string* buf, VariantTensorDataProto* value) { |
73 | return value->ParseFromString(*buf); |
74 | } |
75 | |
76 | void EncodeVariantList(const Variant* variant_array, int64_t n, |
77 | std::unique_ptr<port::StringListEncoder> e) { |
78 | for (int i = 0; i < n; ++i) { |
79 | string s; |
80 | variant_array[i].Encode(&s); |
81 | e->Append(s); |
82 | } |
83 | e->Finalize(); |
84 | } |
85 | |
86 | bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d, |
87 | Variant* variant_array, int64_t n) { |
88 | std::vector<uint32> sizes(n); |
89 | if (!d->ReadSizes(&sizes)) return false; |
90 | |
91 | for (int i = 0; i < n; ++i) { |
92 | if (variant_array[i].is_empty()) { |
93 | variant_array[i] = VariantTensorDataProto(); |
94 | } |
95 | // TODO(ebrevdo): Replace with StringPiece? Any way to make this a |
96 | // zero-copy operation that keeps a reference to the data in d? |
97 | string str(d->Data(sizes[i]), sizes[i]); |
98 | if (!variant_array[i].Decode(std::move(str))) return false; |
99 | if (!DecodeUnaryVariant(&variant_array[i])) { |
100 | LOG(ERROR) << "Could not decode variant with type_name: \"" |
101 | << variant_array[i].TypeName() |
102 | << "\". Perhaps you forgot to register a " |
103 | "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?" ; |
104 | return false; |
105 | } |
106 | } |
107 | return true; |
108 | } |
109 | |
110 | } // end namespace tensorflow |
111 | |