1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/variant.h"
17
18#include "tensorflow/core/framework/tensor.pb.h"
19#include "tensorflow/core/framework/variant_encode_decode.h"
20#include "tensorflow/core/framework/variant_op_registry.h"
21
22namespace tensorflow {
23
24Variant::~Variant() { ResetMemory(); }
25
26bool Variant::Decode(VariantTensorData data) {
27 if (!is_empty()) {
28 return GetValue()->Decode(std::move(data));
29 }
30 return true;
31}
32
33template <>
34void* Variant::get() {
35 if (is_empty()) {
36 return nullptr;
37 }
38 return GetValue()->RawPtr();
39}
40
41template <>
42const void* Variant::get() const {
43 if (is_empty()) {
44 return nullptr;
45 }
46 return GetValue()->RawPtr();
47}
48
49template <>
50string TypeNameVariant(const VariantTensorDataProto& value) {
51 return value.type_name();
52}
53
54template <>
55void EncodeVariant(const VariantTensorDataProto& value,
56 VariantTensorData* data) {
57 data->FromConstProto(value);
58}
59
60template <>
61bool DecodeVariant(VariantTensorData* data, VariantTensorDataProto* value) {
62 data->ToProto(value);
63 return true;
64}
65
66template <>
67void EncodeVariant(const VariantTensorDataProto& value, string* buf) {
68 value.SerializeToString(buf);
69}
70
71template <>
72bool DecodeVariant(string* buf, VariantTensorDataProto* value) {
73 return value->ParseFromString(*buf);
74}
75
76void EncodeVariantList(const Variant* variant_array, int64_t n,
77 std::unique_ptr<port::StringListEncoder> e) {
78 for (int i = 0; i < n; ++i) {
79 string s;
80 variant_array[i].Encode(&s);
81 e->Append(s);
82 }
83 e->Finalize();
84}
85
86bool DecodeVariantList(std::unique_ptr<port::StringListDecoder> d,
87 Variant* variant_array, int64_t n) {
88 std::vector<uint32> sizes(n);
89 if (!d->ReadSizes(&sizes)) return false;
90
91 for (int i = 0; i < n; ++i) {
92 if (variant_array[i].is_empty()) {
93 variant_array[i] = VariantTensorDataProto();
94 }
95 // TODO(ebrevdo): Replace with StringPiece? Any way to make this a
96 // zero-copy operation that keeps a reference to the data in d?
97 string str(d->Data(sizes[i]), sizes[i]);
98 if (!variant_array[i].Decode(std::move(str))) return false;
99 if (!DecodeUnaryVariant(&variant_array[i])) {
100 LOG(ERROR) << "Could not decode variant with type_name: \""
101 << variant_array[i].TypeName()
102 << "\". Perhaps you forgot to register a "
103 "decoder via REGISTER_UNARY_VARIANT_DECODE_FUNCTION?";
104 return false;
105 }
106 }
107 return true;
108}
109
110} // end namespace tensorflow
111