1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
17#define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
18
19#include <string>
20#include <unordered_set>
21#include <vector>
22
23#define EIGEN_USE_THREADS
24
25#include "tensorflow/core/framework/tensor.pb.h"
26#include "tensorflow/core/framework/type_index.h"
27#include "tensorflow/core/framework/types.h"
28#include "tensorflow/core/framework/variant.h"
29#include "tensorflow/core/framework/variant_encode_decode.h"
30#include "tensorflow/core/lib/gtl/flatmap.h"
31#include "tensorflow/core/lib/hash/hash.h"
32#include "tensorflow/core/platform/abi.h"
33
34namespace tensorflow {
35
36class OpKernelContext;
37// A global UnaryVariantOpRegistry is used to hold callback functions
38// for different variant types. To be used by ShapeOp, RankOp, and
39// SizeOp, decoding, etc.
40
41enum VariantUnaryOp {
42 INVALID_VARIANT_UNARY_OP = 0,
43 ZEROS_LIKE_VARIANT_UNARY_OP = 1,
44 CONJ_VARIANT_UNARY_OP = 2,
45};
46
47const char* VariantUnaryOpToString(VariantUnaryOp op);
48
49enum VariantBinaryOp {
50 INVALID_VARIANT_BINARY_OP = 0,
51 ADD_VARIANT_BINARY_OP = 1,
52};
53
54const char* VariantBinaryOpToString(VariantBinaryOp op);
55
56enum VariantDeviceCopyDirection {
57 INVALID_DEVICE_COPY_DIRECTION = 0,
58 HOST_TO_DEVICE = 1,
59 DEVICE_TO_HOST = 2,
60 DEVICE_TO_DEVICE = 3,
61};
62
63class UnaryVariantOpRegistry;
64extern UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal();
65
66class UnaryVariantOpRegistry {
67 public:
68 typedef std::function<bool(Variant*)> VariantDecodeFn;
69 typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)>
70 VariantUnaryOpFn;
71 typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&,
72 Variant*)>
73 VariantBinaryOpFn;
74
75 // An AsyncTensorDeviceCopyFn is a function provided to
76 // the user-provided DeviceCopyFn callback as the third argument ("copier").
77 //
78 // Expected inputs:
79 // from: A Tensor on the host (if performing cpu->gpu copy), or
80 // device (if performing gpu->cpu or gpu->gpu copy).
81 // to: An empty/uninitialized tensor. It will be updated upon
82 // successful return of the function with the correct dtype and shape.
83 // However, the copied data will not be available until the compute
84 // stream has been synchronized.
85 //
86 // Returns:
87 // The status upon memory allocation / initialization of the
88 // "to" tensor, and enqueue of the copy onto the compute stream.
89 // Any failure of the copy itself will update the underlying
90 // stream status and propagate through the runtime independent
91 // of the caller.
92 typedef std::function<Status(const Tensor& from, Tensor* to)>
93 AsyncTensorDeviceCopyFn;
94
95 // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn'
96 // expected to be passed to the registration macro
97 // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION.
98 typedef std::function<Status(const Variant& from, Variant* to,
99 AsyncTensorDeviceCopyFn copy_fn)>
100 AsyncVariantDeviceCopyFn;
101
102 // Add a decode function to the registry.
103 void RegisterDecodeFn(const std::string& type_name,
104 const VariantDecodeFn& decode_fn);
105
106 // Returns nullptr if no decode function was found for the given TypeName.
107 VariantDecodeFn* GetDecodeFn(StringPiece type_name);
108
109 // Add a copy-to-GPU function to the registry.
110 void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction,
111 const TypeIndex& type_index,
112 const AsyncVariantDeviceCopyFn& device_copy_fn) {
113 AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index);
114 CHECK_EQ(existing, nullptr)
115 << "UnaryVariantDeviceCopy for direction: " << direction
116 << " and type_index: " << port::MaybeAbiDemangle(type_index.name())
117 << " already registered";
118 device_copy_fns.insert(
119 std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>,
120 AsyncVariantDeviceCopyFn>(
121 std::make_pair(direction, type_index), device_copy_fn));
122 }
123
124 // Returns nullptr if no copy function was found for the given
125 // TypeName and direction.
126 AsyncVariantDeviceCopyFn* GetDeviceCopyFn(
127 const VariantDeviceCopyDirection direction, const TypeIndex& type_index) {
128 auto found = device_copy_fns.find(std::make_pair(direction, type_index));
129 if (found == device_copy_fns.end()) return nullptr;
130 return &found->second;
131 }
132
133 // Add a unary op function to the registry.
134 void RegisterUnaryOpFn(VariantUnaryOp op, const std::string& device,
135 const TypeIndex& type_index,
136 const VariantUnaryOpFn& unary_op_fn) {
137 VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index);
138 CHECK_EQ(existing, nullptr)
139 << "Unary VariantUnaryOpFn for type_index: "
140 << port::MaybeAbiDemangle(type_index.name())
141 << " already registered for device type: " << device;
142 unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>(
143 {op, GetPersistentStringPiece(device), type_index}, unary_op_fn));
144 }
145
146 // Returns nullptr if no unary op function was found for the given
147 // op, device, and TypeName.
148 VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device,
149 const TypeIndex& type_index) {
150 auto found = unary_op_fns.find({op, device, type_index});
151 if (found == unary_op_fns.end()) return nullptr;
152 return &found->second;
153 }
154
155 // Add a binary op function to the registry.
156 void RegisterBinaryOpFn(VariantBinaryOp op, const std::string& device,
157 const TypeIndex& type_index,
158 const VariantBinaryOpFn& add_fn) {
159 VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index);
160 CHECK_EQ(existing, nullptr)
161 << "Unary VariantBinaryOpFn for type_index: "
162 << port::MaybeAbiDemangle(type_index.name())
163 << " already registered for device type: " << device;
164 binary_op_fns.insert(
165 std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>(
166 {op, GetPersistentStringPiece(device), type_index}, add_fn));
167 }
168
169 // Returns nullptr if no binary op function was found for the given
170 // op, device and TypeName.
171 VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device,
172 const TypeIndex& type_index) {
173 auto found = binary_op_fns.find({op, device, type_index});
174 if (found == binary_op_fns.end()) return nullptr;
175 return &found->second;
176 }
177
178 // Get a pointer to a global UnaryVariantOpRegistry object
179 static UnaryVariantOpRegistry* Global() {
180 return UnaryVariantOpRegistryGlobal();
181 }
182
183 // Get a pointer to a global persistent string storage object.
184 // ISO/IEC C++ working draft N4296 clarifies that insertion into an
185 // std::unordered_set does not invalidate memory locations of
186 // *values* inside the set (though it may invalidate existing
187 // iterators). In other words, one may safely point a StringPiece to
188 // a value in the set without that StringPiece being invalidated by
189 // future insertions.
190 static std::unordered_set<string>* PersistentStringStorage();
191
192 private:
193 struct TypeIndexHash {
194 std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); }
195 };
196
197 gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns;
198
199 // Map std::pair<Direction, type_name> to function.
200 struct PairHash {
201 template <typename Direction>
202 std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const {
203 // The hash of an enum is just its value as a std::size_t.
204 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
205 ret = Hash64Combine(ret, std::get<1>(x).hash_code());
206 return ret;
207 }
208 };
209
210 gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>,
211 AsyncVariantDeviceCopyFn, PairHash>
212 device_copy_fns;
213
214 // Map std::tuple<Op, device, type_name> to function.
215
216 // this breaks by falling victim to "too perfect forwarding"
217 // see https://stackoverflow.com/questions/44475317/variadic-template-issue
218 // and references therein
219 template <typename Op>
220 struct FuncTuple {
221 FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index)
222 : op_type_(op), device_(dev), type_index_(type_index) {}
223 Op op_type_;
224 StringPiece device_;
225 TypeIndex type_index_;
226 };
227 // friend declaration for operator==
228 // needed for clang
229 template <typename Op>
230 friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r);
231 struct TupleHash {
232 template <typename Op>
233 std::size_t operator()(
234 const std::tuple<Op, StringPiece, TypeIndex>& x) const {
235 // The hash of an enum is just its value as a std::size_t.
236 std::size_t ret = static_cast<std::size_t>(std::get<0>(x));
237 ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x)));
238 ret = Hash64Combine(ret, std::get<2>(x).hash_code());
239 return ret;
240 }
241
242 template <typename Op>
243 std::size_t operator()(const FuncTuple<Op>& x) const {
244 // The hash of an enum is just its value as a std::size_t.
245 std::size_t ret = static_cast<std::size_t>(x.op_type_);
246 ret = Hash64Combine(ret, sp_hasher_(x.device_));
247 ret = Hash64Combine(ret, x.type_index_.hash_code());
248 return ret;
249 }
250 StringPieceHasher sp_hasher_;
251 };
252 gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash>
253 unary_op_fns;
254 gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash>
255 binary_op_fns;
256
257 // Find or insert a string into a persistent string storage
258 // container; return the StringPiece pointing to the permanent string
259 // location.
260 static StringPiece GetPersistentStringPiece(const std::string& str) {
261 const auto string_storage = PersistentStringStorage();
262 auto found = string_storage->find(str);
263 if (found == string_storage->end()) {
264 auto inserted = string_storage->insert(str);
265 return StringPiece(*inserted.first);
266 } else {
267 return StringPiece(*found);
268 }
269 }
270};
271template <typename Op>
272inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs,
273 const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) {
274 return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) &&
275 (lhs.type_index_ == rhs.type_index_);
276}
277
278// Decodes the Variant whose data_type has a registered decode
279// function. Returns an Internal error if the Variant does not have a
280// registered decode function, or if the decoding function fails.
281//
282// REQUIRES:
283// variant is not null.
284//
285bool DecodeUnaryVariant(Variant* variant);
286
287// Copies a variant between CPU<->GPU, or between GPU<->GPU.
288// The variant 'from' must have a registered DeviceCopyFn for the
289// given direction. The returned variant 'to' will have
290// (some subset of its) tensors stored on destination according to the
291// registered DeviceCopyFn function for the given direction. Returns
292// an Internal error if the Variant does not have a registered
293// DeviceCopyFn function for the given direction, or if initiating the
294// copy fails.
295//
296// REQUIRES:
297// 'to' is not null.
298//
299Status VariantDeviceCopy(
300 const VariantDeviceCopyDirection direction, const Variant& from,
301 Variant* to,
302 const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn);
303
304// Sets *v_out = unary_op(v). The variant v must have a registered
305// UnaryOp function for the given Device. Returns an Internal error
306// if v does not have a registered unary_op function for this device, or if
307// UnaryOp fails.
308//
309// REQUIRES:
310// v_out is not null.
311//
312template <typename Device>
313Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v,
314 Variant* v_out) {
315 const std::string& device = DeviceName<Device>::value;
316 UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn =
317 UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId());
318 if (unary_op_fn == nullptr) {
319 return errors::Internal("No unary variant unary_op function found for op ",
320 VariantUnaryOpToString(op),
321 " Variant type_name: ", v.TypeName(),
322 " for device type: ", device);
323 }
324 return (*unary_op_fn)(ctx, v, v_out);
325}
326
327// Sets *out = binary_op(a, b). The variants a and b must be the same type
328// and have a registered binary_op function for the given Device. Returns an
329// Internal error if a and b are not the same type_name or if
330// if a does not have a registered op function for this device, or if
331// BinaryOp fails.
332//
333// REQUIRES:
334// out is not null.
335//
336template <typename Device>
337Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op,
338 const Variant& a, const Variant& b, Variant* out) {
339 if (a.TypeId() != b.TypeId()) {
340 return errors::Internal(
341 "BinaryOpVariants: Variants a and b have different "
342 "type ids. Type names: '",
343 a.TypeName(), "' vs. '", b.TypeName(), "'");
344 }
345 const std::string& device = DeviceName<Device>::value;
346 UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn =
347 UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId());
348 if (binary_op_fn == nullptr) {
349 return errors::Internal("No unary variant binary_op function found for op ",
350 VariantBinaryOpToString(op),
351 " Variant type_name: '", a.TypeName(),
352 "' for device type: ", device);
353 }
354 return (*binary_op_fn)(ctx, a, b, out);
355}
356
357namespace variant_op_registry_fn_registration {
358
359template <typename T>
360class UnaryVariantDecodeRegistration {
361 public:
362 UnaryVariantDecodeRegistration(const std::string& type_name) {
363 // The Variant is passed by pointer because it should be
364 // mutable: get below may Decode the variant, which
365 // is a self-mutating behavior. The variant is not modified in
366 // any other way.
367 UnaryVariantOpRegistry::Global()->RegisterDecodeFn(
368 type_name, [type_name](Variant* v) -> bool {
369 DCHECK_NE(v, nullptr);
370 VariantTensorDataProto* t = v->get<VariantTensorDataProto>();
371 if (t == nullptr) {
372 return false;
373 }
374 Variant decoded = T();
375 VariantTensorData data(std::move(*t));
376 if (!decoded.Decode(std::move(data))) {
377 return false;
378 }
379 std::swap(decoded, *v);
380 return true;
381 });
382 }
383};
384
385template <typename T>
386class UnaryVariantDeviceCopyRegistration {
387 public:
388 typedef std::function<Status(const T& t, T* t_out,
389 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)>
390 LocalVariantDeviceCopyFn;
391 UnaryVariantDeviceCopyRegistration(
392 const VariantDeviceCopyDirection direction, const TypeIndex& type_index,
393 const LocalVariantDeviceCopyFn& device_copy_fn) {
394 const std::string type_index_name =
395 port::MaybeAbiDemangle(type_index.name());
396 UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn(
397 direction, type_index,
398 [type_index_name, device_copy_fn](
399 const Variant& from, Variant* to,
400 UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn
401 device_copy_tensor_fn) -> Status {
402 DCHECK_NE(to, nullptr);
403 *to = T();
404 if (from.get<T>() == nullptr) {
405 return errors::Internal(
406 "VariantCopyToGPUFn: Could not access object, type_index: ",
407 type_index_name);
408 }
409 const T& t = *from.get<T>();
410 T* t_out = to->get<T>();
411 return device_copy_fn(t, t_out, device_copy_tensor_fn);
412 });
413 }
414};
415
416template <typename T>
417class UnaryVariantUnaryOpRegistration {
418 typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)>
419 LocalVariantUnaryOpFn;
420
421 public:
422 UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const std::string& device,
423 const TypeIndex& type_index,
424 const LocalVariantUnaryOpFn& unary_op_fn) {
425 const std::string type_index_name =
426 port::MaybeAbiDemangle(type_index.name());
427 UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn(
428 op, device, type_index,
429 [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v,
430 Variant* v_out) -> Status {
431 DCHECK_NE(v_out, nullptr);
432 *v_out = T();
433 if (v.get<T>() == nullptr) {
434 return errors::Internal(
435 "VariantUnaryOpFn: Could not access object, type_index: ",
436 type_index_name);
437 }
438 const T& t = *v.get<T>();
439 T* t_out = v_out->get<T>();
440 return unary_op_fn(ctx, t, t_out);
441 });
442 }
443};
444
445template <typename T>
446class UnaryVariantBinaryOpRegistration {
447 typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b,
448 T* out)>
449 LocalVariantBinaryOpFn;
450
451 public:
452 UnaryVariantBinaryOpRegistration(VariantBinaryOp op,
453 const std::string& device,
454 const TypeIndex& type_index,
455 const LocalVariantBinaryOpFn& binary_op_fn) {
456 const std::string type_index_name =
457 port::MaybeAbiDemangle(type_index.name());
458 UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn(
459 op, device, type_index,
460 [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a,
461 const Variant& b,
462 Variant* out) -> Status {
463 DCHECK_NE(out, nullptr);
464 *out = T();
465 if (a.get<T>() == nullptr) {
466 return errors::Internal(
467 "VariantBinaryOpFn: Could not access object 'a', type_index: ",
468 type_index_name);
469 }
470 if (b.get<T>() == nullptr) {
471 return errors::Internal(
472 "VariantBinaryOpFn: Could not access object 'b', type_index: ",
473 type_index_name);
474 }
475 const T& t_a = *a.get<T>();
476 const T& t_b = *b.get<T>();
477 T* t_out = out->get<T>();
478 return binary_op_fn(ctx, t_a, t_b, t_out);
479 });
480 }
481};
482
483}; // namespace variant_op_registry_fn_registration
484
485// Register a unary decode variant function for the given type.
486#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \
487 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name)
488
489#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \
490 REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name)
491
492#define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \
493 static ::tensorflow::variant_op_registry_fn_registration:: \
494 UnaryVariantDecodeRegistration<T> \
495 register_unary_variant_op_decoder_fn_##ctr(type_name)
496
497// ****** NOTE ******
498// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
499// ****** NOTE ******
500//
501// Register a device copy variant function for the given copy
502// direction and type; where direction is the enum
503// VariantDeviceCopyDirection, and the device_copy_fn has signature:
504//
505// Status device_copy_fn(
506// const T& t, T* t_out,
507// const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier);
508//
509// And device_copy_fn calls copier 0 or more times. For details on
510// the behavior of the copier function, see the comments at the
511// declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn.
512//
513// Note, the device_copy_fn may choose to keep some tensors
514// on host, e.g. by assigning to->tensor = from.tensor (assuming
515// from.tensor is already on host); or by setting
516// to->tensor = Tensor(cpu_allocator(), ...)
517// and manually updating its values.
518//
519// If this is the case, the CopyFns for HOST_TO_DEVICE,
520// DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host
521// copies in a consistent manner. For example, one must always
522// manually copy any "always on host" tensors in all directions instead of e.g.
523// - performing a host-to-host copy in one direction,
524// - using the provided copier function in the reverse direction.
525// Doing the latter will cause program failures.
526//
527// ****** NOTE ******
528// FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE.
529// ****** NOTE ******
530#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
531 device_copy_fn) \
532 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
533 __COUNTER__, T, direction, TypeIndex::Make<T>(), device_copy_fn)
534
535#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
536 ctr, T, direction, type_index, device_copy_fn) \
537 INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
538 ctr, T, direction, type_index, device_copy_fn)
539
540#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \
541 ctr, T, direction, type_index, device_copy_fn) \
542 static variant_op_registry_fn_registration:: \
543 UnaryVariantDeviceCopyRegistration<T> \
544 register_unary_variant_op_device_copy_fn_##ctr( \
545 direction, type_index, device_copy_fn)
546
547// Register a unary unary_op variant function with the signature:
548// Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out);
549// to Variants having TypeIndex type_index, for device string device,
550// for UnaryVariantOp enum op.
551#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
552 unary_op_function) \
553 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
554 __COUNTER__, op, device, T, TypeIndex::Make<T>(), unary_op_function)
555
556#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
557 ctr, op, device, T, type_index, unary_op_function) \
558 REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \
559 type_index, unary_op_function)
560
561#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \
562 ctr, op, device, T, type_index, unary_op_function) \
563 static ::tensorflow::variant_op_registry_fn_registration:: \
564 UnaryVariantUnaryOpRegistration<T> \
565 register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
566 unary_op_function)
567
568// Register a binary_op variant function with the signature:
569// Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out);
570// to Variants having TypeIndex type_index, for device string device,
571// for BinaryVariantOp enum OP.
572#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
573 binary_op_function) \
574 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
575 __COUNTER__, op, device, T, TypeIndex::Make<T>(), binary_op_function)
576
577#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
578 ctr, op, device, T, type_index, binary_op_function) \
579 REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
580 ctr, op, device, T, type_index, binary_op_function)
581
582#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \
583 ctr, op, device, T, type_index, binary_op_function) \
584 static ::tensorflow::variant_op_registry_fn_registration:: \
585 UnaryVariantBinaryOpRegistration<T> \
586 register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \
587 binary_op_function)
588
589} // end namespace tensorflow
590
591#endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_
592