1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/variant_op_registry.h"
17
18#include <string>
19
20#include "tensorflow/core/framework/register_types.h"
21#include "tensorflow/core/framework/type_index.h"
22#include "tensorflow/core/framework/variant.h"
23#include "tensorflow/core/lib/core/errors.h"
24#include "tensorflow/core/platform/logging.h"
25#include "tensorflow/core/public/version.h"
26
27namespace tensorflow {
28
29const char* VariantUnaryOpToString(VariantUnaryOp op) {
30 switch (op) {
31 case INVALID_VARIANT_UNARY_OP:
32 return "INVALID";
33 case ZEROS_LIKE_VARIANT_UNARY_OP:
34 return "ZEROS_LIKE";
35 case CONJ_VARIANT_UNARY_OP:
36 return "CONJ";
37 }
38}
39
40const char* VariantBinaryOpToString(VariantBinaryOp op) {
41 switch (op) {
42 case INVALID_VARIANT_BINARY_OP:
43 return "INVALID";
44 case ADD_VARIANT_BINARY_OP:
45 return "ADD";
46 }
47}
48
49std::unordered_set<string>* UnaryVariantOpRegistry::PersistentStringStorage() {
50 static std::unordered_set<string>* string_storage =
51 new std::unordered_set<string>();
52 return string_storage;
53}
54
55// Get a pointer to a global UnaryVariantOpRegistry object
56UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal() {
57 static UnaryVariantOpRegistry* global_unary_variant_op_registry = nullptr;
58
59 if (global_unary_variant_op_registry == nullptr) {
60 global_unary_variant_op_registry = new UnaryVariantOpRegistry;
61 }
62 return global_unary_variant_op_registry;
63}
64
65UnaryVariantOpRegistry::VariantDecodeFn* UnaryVariantOpRegistry::GetDecodeFn(
66 StringPiece type_name) {
67 auto found = decode_fns.find(type_name);
68 if (found == decode_fns.end()) return nullptr;
69 return &found->second;
70}
71
72void UnaryVariantOpRegistry::RegisterDecodeFn(
73 const string& type_name, const VariantDecodeFn& decode_fn) {
74 CHECK(!type_name.empty()) << "Need a valid name for UnaryVariantDecode";
75 VariantDecodeFn* existing = GetDecodeFn(type_name);
76 CHECK_EQ(existing, nullptr)
77 << "Unary VariantDecodeFn for type_name: " << type_name
78 << " already registered";
79 decode_fns.insert(std::pair<StringPiece, VariantDecodeFn>(
80 GetPersistentStringPiece(type_name), decode_fn));
81}
82
83bool DecodeUnaryVariant(Variant* variant) {
84 CHECK_NOTNULL(variant);
85 if (variant->TypeName().empty()) {
86 VariantTensorDataProto* t = variant->get<VariantTensorDataProto>();
87 if (t == nullptr || !t->metadata().empty() || !t->tensors().empty()) {
88 // Malformed variant.
89 return false;
90 } else {
91 // Serialization of an empty Variant.
92 variant->clear();
93 return true;
94 }
95 }
96 UnaryVariantOpRegistry::VariantDecodeFn* decode_fn =
97 UnaryVariantOpRegistry::Global()->GetDecodeFn(variant->TypeName());
98 if (decode_fn == nullptr) {
99 return false;
100 }
101 const string type_name = variant->TypeName();
102 bool decoded = (*decode_fn)(variant);
103 if (!decoded) return false;
104 if (variant->TypeName() != type_name) {
105 LOG(ERROR) << "DecodeUnaryVariant: Variant type_name before decoding was: "
106 << type_name
107 << " but after decoding was: " << variant->TypeName()
108 << ". Treating this as a failure.";
109 return false;
110 }
111 return true;
112}
113
114// Add some basic registrations for use by others, e.g., for testing.
115
116#define REGISTER_VARIANT_DECODE_TYPE(T) \
117 REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, TF_STR(T));
118
119// No encode/decode registered for std::complex<> and Eigen::half
120// objects yet.
121REGISTER_VARIANT_DECODE_TYPE(int);
122REGISTER_VARIANT_DECODE_TYPE(float);
123REGISTER_VARIANT_DECODE_TYPE(bool);
124REGISTER_VARIANT_DECODE_TYPE(double);
125
126#undef REGISTER_VARIANT_DECODE_TYPE
127
128Status VariantDeviceCopy(
129 const VariantDeviceCopyDirection direction, const Variant& from,
130 Variant* to,
131 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn) {
132 UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn* device_copy_fn =
133 UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(direction,
134 from.TypeId());
135 if (device_copy_fn == nullptr) {
136 return errors::Internal(
137 "No unary variant device copy function found for direction: ",
138 direction, " and Variant type_index: ",
139 port::MaybeAbiDemangle(from.TypeId().name()));
140 }
141 return (*device_copy_fn)(from, to, copy_fn);
142}
143
144namespace {
145template <typename T>
146Status DeviceCopyPrimitiveType(
147 const T& in, T* out,
148 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier) {
149 // Dummy copy, we don't actually bother copying to the device and back for
150 // testing.
151 *out = in;
152 return OkStatus();
153}
154} // namespace
155
156#define REGISTER_VARIANT_DEVICE_COPY_TYPE(T) \
157 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
158 T, VariantDeviceCopyDirection::HOST_TO_DEVICE, \
159 DeviceCopyPrimitiveType<T>); \
160 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
161 T, VariantDeviceCopyDirection::DEVICE_TO_HOST, \
162 DeviceCopyPrimitiveType<T>); \
163 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION( \
164 T, VariantDeviceCopyDirection::DEVICE_TO_DEVICE, \
165 DeviceCopyPrimitiveType<T>);
166
167// No zeros_like registered for std::complex<> or Eigen::half objects yet.
168REGISTER_VARIANT_DEVICE_COPY_TYPE(int);
169REGISTER_VARIANT_DEVICE_COPY_TYPE(float);
170REGISTER_VARIANT_DEVICE_COPY_TYPE(double);
171REGISTER_VARIANT_DEVICE_COPY_TYPE(bool);
172
173#undef REGISTER_VARIANT_DEVICE_COPY_TYPE
174
175namespace {
176template <typename T>
177Status ZerosLikeVariantPrimitiveType(OpKernelContext* ctx, const T& t,
178 T* t_out) {
179 *t_out = T(0);
180 return OkStatus();
181}
182} // namespace
183
184#define REGISTER_VARIANT_ZEROS_LIKE_TYPE(T) \
185 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(ZEROS_LIKE_VARIANT_UNARY_OP, \
186 DEVICE_CPU, T, \
187 ZerosLikeVariantPrimitiveType<T>);
188
189// No zeros_like registered for std::complex<> or Eigen::half objects yet.
190REGISTER_VARIANT_ZEROS_LIKE_TYPE(int);
191REGISTER_VARIANT_ZEROS_LIKE_TYPE(float);
192REGISTER_VARIANT_ZEROS_LIKE_TYPE(double);
193REGISTER_VARIANT_ZEROS_LIKE_TYPE(bool);
194
195#undef REGISTER_VARIANT_ZEROS_LIKE_TYPE
196
197namespace {
198template <typename T>
199Status AddVariantPrimitiveType(OpKernelContext* ctx, const T& a, const T& b,
200 T* out) {
201 *out = a + b;
202 return OkStatus();
203}
204} // namespace
205
206#define REGISTER_VARIANT_ADD_TYPE(T) \
207 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(ADD_VARIANT_BINARY_OP, DEVICE_CPU, \
208 T, AddVariantPrimitiveType<T>);
209
210// No add registered for std::complex<> or Eigen::half objects yet.
211REGISTER_VARIANT_ADD_TYPE(int);
212REGISTER_VARIANT_ADD_TYPE(float);
213REGISTER_VARIANT_ADD_TYPE(double);
214REGISTER_VARIANT_ADD_TYPE(bool);
215
216#undef REGISTER_VARIANT_ADD_TYPE
217
218} // namespace tensorflow
219