1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_ |
18 | |
19 | #include <string> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #define EIGEN_USE_THREADS |
24 | |
25 | #include "tensorflow/core/framework/tensor.pb.h" |
26 | #include "tensorflow/core/framework/type_index.h" |
27 | #include "tensorflow/core/framework/types.h" |
28 | #include "tensorflow/core/framework/variant.h" |
29 | #include "tensorflow/core/framework/variant_encode_decode.h" |
30 | #include "tensorflow/core/lib/gtl/flatmap.h" |
31 | #include "tensorflow/core/lib/hash/hash.h" |
32 | #include "tensorflow/core/platform/abi.h" |
33 | |
34 | namespace tensorflow { |
35 | |
36 | class OpKernelContext; |
37 | // A global UnaryVariantOpRegistry is used to hold callback functions |
38 | // for different variant types. To be used by ShapeOp, RankOp, and |
39 | // SizeOp, decoding, etc. |
40 | |
41 | enum VariantUnaryOp { |
42 | INVALID_VARIANT_UNARY_OP = 0, |
43 | ZEROS_LIKE_VARIANT_UNARY_OP = 1, |
44 | CONJ_VARIANT_UNARY_OP = 2, |
45 | }; |
46 | |
47 | const char* VariantUnaryOpToString(VariantUnaryOp op); |
48 | |
49 | enum VariantBinaryOp { |
50 | INVALID_VARIANT_BINARY_OP = 0, |
51 | ADD_VARIANT_BINARY_OP = 1, |
52 | }; |
53 | |
54 | const char* VariantBinaryOpToString(VariantBinaryOp op); |
55 | |
56 | enum VariantDeviceCopyDirection { |
57 | INVALID_DEVICE_COPY_DIRECTION = 0, |
58 | HOST_TO_DEVICE = 1, |
59 | DEVICE_TO_HOST = 2, |
60 | DEVICE_TO_DEVICE = 3, |
61 | }; |
62 | |
63 | class UnaryVariantOpRegistry; |
64 | extern UnaryVariantOpRegistry* UnaryVariantOpRegistryGlobal(); |
65 | |
66 | class UnaryVariantOpRegistry { |
67 | public: |
68 | typedef std::function<bool(Variant*)> VariantDecodeFn; |
69 | typedef std::function<Status(OpKernelContext*, const Variant&, Variant*)> |
70 | VariantUnaryOpFn; |
71 | typedef std::function<Status(OpKernelContext*, const Variant&, const Variant&, |
72 | Variant*)> |
73 | VariantBinaryOpFn; |
74 | |
75 | // An AsyncTensorDeviceCopyFn is a function provided to |
76 | // the user-provided DeviceCopyFn callback as the third argument ("copier"). |
77 | // |
78 | // Expected inputs: |
79 | // from: A Tensor on the host (if performing cpu->gpu copy), or |
80 | // device (if performing gpu->cpu or gpu->gpu copy). |
81 | // to: An empty/uninitialized tensor. It will be updated upon |
82 | // successful return of the function with the correct dtype and shape. |
83 | // However, the copied data will not be available until the compute |
84 | // stream has been synchronized. |
85 | // |
86 | // Returns: |
87 | // The status upon memory allocation / initialization of the |
88 | // "to" tensor, and enqueue of the copy onto the compute stream. |
89 | // Any failure of the copy itself will update the underlying |
90 | // stream status and propagate through the runtime independent |
91 | // of the caller. |
92 | typedef std::function<Status(const Tensor& from, Tensor* to)> |
93 | AsyncTensorDeviceCopyFn; |
94 | |
95 | // The AsyncVariantDeviceCopyFn is the signature of the 'device_copy_fn' |
96 | // expected to be passed to the registration macro |
97 | // INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION. |
98 | typedef std::function<Status(const Variant& from, Variant* to, |
99 | AsyncTensorDeviceCopyFn copy_fn)> |
100 | AsyncVariantDeviceCopyFn; |
101 | |
102 | // Add a decode function to the registry. |
103 | void RegisterDecodeFn(const std::string& type_name, |
104 | const VariantDecodeFn& decode_fn); |
105 | |
106 | // Returns nullptr if no decode function was found for the given TypeName. |
107 | VariantDecodeFn* GetDecodeFn(StringPiece type_name); |
108 | |
109 | // Add a copy-to-GPU function to the registry. |
110 | void RegisterDeviceCopyFn(const VariantDeviceCopyDirection direction, |
111 | const TypeIndex& type_index, |
112 | const AsyncVariantDeviceCopyFn& device_copy_fn) { |
113 | AsyncVariantDeviceCopyFn* existing = GetDeviceCopyFn(direction, type_index); |
114 | CHECK_EQ(existing, nullptr) |
115 | << "UnaryVariantDeviceCopy for direction: " << direction |
116 | << " and type_index: " << port::MaybeAbiDemangle(type_index.name()) |
117 | << " already registered" ; |
118 | device_copy_fns.insert( |
119 | std::pair<std::pair<VariantDeviceCopyDirection, TypeIndex>, |
120 | AsyncVariantDeviceCopyFn>( |
121 | std::make_pair(direction, type_index), device_copy_fn)); |
122 | } |
123 | |
124 | // Returns nullptr if no copy function was found for the given |
125 | // TypeName and direction. |
126 | AsyncVariantDeviceCopyFn* GetDeviceCopyFn( |
127 | const VariantDeviceCopyDirection direction, const TypeIndex& type_index) { |
128 | auto found = device_copy_fns.find(std::make_pair(direction, type_index)); |
129 | if (found == device_copy_fns.end()) return nullptr; |
130 | return &found->second; |
131 | } |
132 | |
133 | // Add a unary op function to the registry. |
134 | void RegisterUnaryOpFn(VariantUnaryOp op, const std::string& device, |
135 | const TypeIndex& type_index, |
136 | const VariantUnaryOpFn& unary_op_fn) { |
137 | VariantUnaryOpFn* existing = GetUnaryOpFn(op, device, type_index); |
138 | CHECK_EQ(existing, nullptr) |
139 | << "Unary VariantUnaryOpFn for type_index: " |
140 | << port::MaybeAbiDemangle(type_index.name()) |
141 | << " already registered for device type: " << device; |
142 | unary_op_fns.insert(std::pair<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn>( |
143 | {op, GetPersistentStringPiece(device), type_index}, unary_op_fn)); |
144 | } |
145 | |
146 | // Returns nullptr if no unary op function was found for the given |
147 | // op, device, and TypeName. |
148 | VariantUnaryOpFn* GetUnaryOpFn(VariantUnaryOp op, StringPiece device, |
149 | const TypeIndex& type_index) { |
150 | auto found = unary_op_fns.find({op, device, type_index}); |
151 | if (found == unary_op_fns.end()) return nullptr; |
152 | return &found->second; |
153 | } |
154 | |
155 | // Add a binary op function to the registry. |
156 | void RegisterBinaryOpFn(VariantBinaryOp op, const std::string& device, |
157 | const TypeIndex& type_index, |
158 | const VariantBinaryOpFn& add_fn) { |
159 | VariantBinaryOpFn* existing = GetBinaryOpFn(op, device, type_index); |
160 | CHECK_EQ(existing, nullptr) |
161 | << "Unary VariantBinaryOpFn for type_index: " |
162 | << port::MaybeAbiDemangle(type_index.name()) |
163 | << " already registered for device type: " << device; |
164 | binary_op_fns.insert( |
165 | std::pair<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn>( |
166 | {op, GetPersistentStringPiece(device), type_index}, add_fn)); |
167 | } |
168 | |
169 | // Returns nullptr if no binary op function was found for the given |
170 | // op, device and TypeName. |
171 | VariantBinaryOpFn* GetBinaryOpFn(VariantBinaryOp op, StringPiece device, |
172 | const TypeIndex& type_index) { |
173 | auto found = binary_op_fns.find({op, device, type_index}); |
174 | if (found == binary_op_fns.end()) return nullptr; |
175 | return &found->second; |
176 | } |
177 | |
178 | // Get a pointer to a global UnaryVariantOpRegistry object |
179 | static UnaryVariantOpRegistry* Global() { |
180 | return UnaryVariantOpRegistryGlobal(); |
181 | } |
182 | |
183 | // Get a pointer to a global persistent string storage object. |
184 | // ISO/IEC C++ working draft N4296 clarifies that insertion into an |
185 | // std::unordered_set does not invalidate memory locations of |
186 | // *values* inside the set (though it may invalidate existing |
187 | // iterators). In other words, one may safely point a StringPiece to |
188 | // a value in the set without that StringPiece being invalidated by |
189 | // future insertions. |
190 | static std::unordered_set<string>* PersistentStringStorage(); |
191 | |
192 | private: |
193 | struct TypeIndexHash { |
194 | std::size_t operator()(const TypeIndex& x) const { return x.hash_code(); } |
195 | }; |
196 | |
197 | gtl::FlatMap<StringPiece, VariantDecodeFn, StringPieceHasher> decode_fns; |
198 | |
199 | // Map std::pair<Direction, type_name> to function. |
200 | struct PairHash { |
201 | template <typename Direction> |
202 | std::size_t operator()(const std::pair<Direction, TypeIndex>& x) const { |
203 | // The hash of an enum is just its value as a std::size_t. |
204 | std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); |
205 | ret = Hash64Combine(ret, std::get<1>(x).hash_code()); |
206 | return ret; |
207 | } |
208 | }; |
209 | |
210 | gtl::FlatMap<std::pair<VariantDeviceCopyDirection, TypeIndex>, |
211 | AsyncVariantDeviceCopyFn, PairHash> |
212 | device_copy_fns; |
213 | |
214 | // Map std::tuple<Op, device, type_name> to function. |
215 | |
216 | // this breaks by falling victim to "too perfect forwarding" |
217 | // see https://stackoverflow.com/questions/44475317/variadic-template-issue |
218 | // and references therein |
219 | template <typename Op> |
220 | struct FuncTuple { |
221 | FuncTuple(const Op& op, const StringPiece& dev, const TypeIndex& type_index) |
222 | : op_type_(op), device_(dev), type_index_(type_index) {} |
223 | Op op_type_; |
224 | StringPiece device_; |
225 | TypeIndex type_index_; |
226 | }; |
227 | // friend declaration for operator== |
228 | // needed for clang |
229 | template <typename Op> |
230 | friend bool operator==(const FuncTuple<Op>& l, const FuncTuple<Op>& r); |
231 | struct TupleHash { |
232 | template <typename Op> |
233 | std::size_t operator()( |
234 | const std::tuple<Op, StringPiece, TypeIndex>& x) const { |
235 | // The hash of an enum is just its value as a std::size_t. |
236 | std::size_t ret = static_cast<std::size_t>(std::get<0>(x)); |
237 | ret = Hash64Combine(ret, sp_hasher_(std::get<1>(x))); |
238 | ret = Hash64Combine(ret, std::get<2>(x).hash_code()); |
239 | return ret; |
240 | } |
241 | |
242 | template <typename Op> |
243 | std::size_t operator()(const FuncTuple<Op>& x) const { |
244 | // The hash of an enum is just its value as a std::size_t. |
245 | std::size_t ret = static_cast<std::size_t>(x.op_type_); |
246 | ret = Hash64Combine(ret, sp_hasher_(x.device_)); |
247 | ret = Hash64Combine(ret, x.type_index_.hash_code()); |
248 | return ret; |
249 | } |
250 | StringPieceHasher sp_hasher_; |
251 | }; |
252 | gtl::FlatMap<FuncTuple<VariantUnaryOp>, VariantUnaryOpFn, TupleHash> |
253 | unary_op_fns; |
254 | gtl::FlatMap<FuncTuple<VariantBinaryOp>, VariantBinaryOpFn, TupleHash> |
255 | binary_op_fns; |
256 | |
257 | // Find or insert a string into a persistent string storage |
258 | // container; return the StringPiece pointing to the permanent string |
259 | // location. |
260 | static StringPiece GetPersistentStringPiece(const std::string& str) { |
261 | const auto string_storage = PersistentStringStorage(); |
262 | auto found = string_storage->find(str); |
263 | if (found == string_storage->end()) { |
264 | auto inserted = string_storage->insert(str); |
265 | return StringPiece(*inserted.first); |
266 | } else { |
267 | return StringPiece(*found); |
268 | } |
269 | } |
270 | }; |
271 | template <typename Op> |
272 | inline bool operator==(const UnaryVariantOpRegistry::FuncTuple<Op>& lhs, |
273 | const UnaryVariantOpRegistry::FuncTuple<Op>& rhs) { |
274 | return (lhs.op_type_ == rhs.op_type_) && (lhs.device_ == rhs.device_) && |
275 | (lhs.type_index_ == rhs.type_index_); |
276 | } |
277 | |
278 | // Decodes the Variant whose data_type has a registered decode |
279 | // function. Returns an Internal error if the Variant does not have a |
280 | // registered decode function, or if the decoding function fails. |
281 | // |
282 | // REQUIRES: |
283 | // variant is not null. |
284 | // |
285 | bool DecodeUnaryVariant(Variant* variant); |
286 | |
287 | // Copies a variant between CPU<->GPU, or between GPU<->GPU. |
288 | // The variant 'from' must have a registered DeviceCopyFn for the |
289 | // given direction. The returned variant 'to' will have |
290 | // (some subset of its) tensors stored on destination according to the |
291 | // registered DeviceCopyFn function for the given direction. Returns |
292 | // an Internal error if the Variant does not have a registered |
293 | // DeviceCopyFn function for the given direction, or if initiating the |
294 | // copy fails. |
295 | // |
296 | // REQUIRES: |
297 | // 'to' is not null. |
298 | // |
299 | Status VariantDeviceCopy( |
300 | const VariantDeviceCopyDirection direction, const Variant& from, |
301 | Variant* to, |
302 | const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copy_fn); |
303 | |
304 | // Sets *v_out = unary_op(v). The variant v must have a registered |
305 | // UnaryOp function for the given Device. Returns an Internal error |
306 | // if v does not have a registered unary_op function for this device, or if |
307 | // UnaryOp fails. |
308 | // |
309 | // REQUIRES: |
310 | // v_out is not null. |
311 | // |
312 | template <typename Device> |
313 | Status UnaryOpVariant(OpKernelContext* ctx, VariantUnaryOp op, const Variant& v, |
314 | Variant* v_out) { |
315 | const std::string& device = DeviceName<Device>::value; |
316 | UnaryVariantOpRegistry::VariantUnaryOpFn* unary_op_fn = |
317 | UnaryVariantOpRegistry::Global()->GetUnaryOpFn(op, device, v.TypeId()); |
318 | if (unary_op_fn == nullptr) { |
319 | return errors::Internal("No unary variant unary_op function found for op " , |
320 | VariantUnaryOpToString(op), |
321 | " Variant type_name: " , v.TypeName(), |
322 | " for device type: " , device); |
323 | } |
324 | return (*unary_op_fn)(ctx, v, v_out); |
325 | } |
326 | |
327 | // Sets *out = binary_op(a, b). The variants a and b must be the same type |
328 | // and have a registered binary_op function for the given Device. Returns an |
329 | // Internal error if a and b are not the same type_name or if |
330 | // if a does not have a registered op function for this device, or if |
331 | // BinaryOp fails. |
332 | // |
333 | // REQUIRES: |
334 | // out is not null. |
335 | // |
336 | template <typename Device> |
337 | Status BinaryOpVariants(OpKernelContext* ctx, VariantBinaryOp op, |
338 | const Variant& a, const Variant& b, Variant* out) { |
339 | if (a.TypeId() != b.TypeId()) { |
340 | return errors::Internal( |
341 | "BinaryOpVariants: Variants a and b have different " |
342 | "type ids. Type names: '" , |
343 | a.TypeName(), "' vs. '" , b.TypeName(), "'" ); |
344 | } |
345 | const std::string& device = DeviceName<Device>::value; |
346 | UnaryVariantOpRegistry::VariantBinaryOpFn* binary_op_fn = |
347 | UnaryVariantOpRegistry::Global()->GetBinaryOpFn(op, device, a.TypeId()); |
348 | if (binary_op_fn == nullptr) { |
349 | return errors::Internal("No unary variant binary_op function found for op " , |
350 | VariantBinaryOpToString(op), |
351 | " Variant type_name: '" , a.TypeName(), |
352 | "' for device type: " , device); |
353 | } |
354 | return (*binary_op_fn)(ctx, a, b, out); |
355 | } |
356 | |
357 | namespace variant_op_registry_fn_registration { |
358 | |
359 | template <typename T> |
360 | class UnaryVariantDecodeRegistration { |
361 | public: |
362 | UnaryVariantDecodeRegistration(const std::string& type_name) { |
363 | // The Variant is passed by pointer because it should be |
364 | // mutable: get below may Decode the variant, which |
365 | // is a self-mutating behavior. The variant is not modified in |
366 | // any other way. |
367 | UnaryVariantOpRegistry::Global()->RegisterDecodeFn( |
368 | type_name, [type_name](Variant* v) -> bool { |
369 | DCHECK_NE(v, nullptr); |
370 | VariantTensorDataProto* t = v->get<VariantTensorDataProto>(); |
371 | if (t == nullptr) { |
372 | return false; |
373 | } |
374 | Variant decoded = T(); |
375 | VariantTensorData data(std::move(*t)); |
376 | if (!decoded.Decode(std::move(data))) { |
377 | return false; |
378 | } |
379 | std::swap(decoded, *v); |
380 | return true; |
381 | }); |
382 | } |
383 | }; |
384 | |
385 | template <typename T> |
386 | class UnaryVariantDeviceCopyRegistration { |
387 | public: |
388 | typedef std::function<Status(const T& t, T* t_out, |
389 | UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn)> |
390 | LocalVariantDeviceCopyFn; |
391 | UnaryVariantDeviceCopyRegistration( |
392 | const VariantDeviceCopyDirection direction, const TypeIndex& type_index, |
393 | const LocalVariantDeviceCopyFn& device_copy_fn) { |
394 | const std::string type_index_name = |
395 | port::MaybeAbiDemangle(type_index.name()); |
396 | UnaryVariantOpRegistry::Global()->RegisterDeviceCopyFn( |
397 | direction, type_index, |
398 | [type_index_name, device_copy_fn]( |
399 | const Variant& from, Variant* to, |
400 | UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn |
401 | device_copy_tensor_fn) -> Status { |
402 | DCHECK_NE(to, nullptr); |
403 | *to = T(); |
404 | if (from.get<T>() == nullptr) { |
405 | return errors::Internal( |
406 | "VariantCopyToGPUFn: Could not access object, type_index: " , |
407 | type_index_name); |
408 | } |
409 | const T& t = *from.get<T>(); |
410 | T* t_out = to->get<T>(); |
411 | return device_copy_fn(t, t_out, device_copy_tensor_fn); |
412 | }); |
413 | } |
414 | }; |
415 | |
416 | template <typename T> |
417 | class UnaryVariantUnaryOpRegistration { |
418 | typedef std::function<Status(OpKernelContext* ctx, const T& t, T* t_out)> |
419 | LocalVariantUnaryOpFn; |
420 | |
421 | public: |
422 | UnaryVariantUnaryOpRegistration(VariantUnaryOp op, const std::string& device, |
423 | const TypeIndex& type_index, |
424 | const LocalVariantUnaryOpFn& unary_op_fn) { |
425 | const std::string type_index_name = |
426 | port::MaybeAbiDemangle(type_index.name()); |
427 | UnaryVariantOpRegistry::Global()->RegisterUnaryOpFn( |
428 | op, device, type_index, |
429 | [type_index_name, unary_op_fn](OpKernelContext* ctx, const Variant& v, |
430 | Variant* v_out) -> Status { |
431 | DCHECK_NE(v_out, nullptr); |
432 | *v_out = T(); |
433 | if (v.get<T>() == nullptr) { |
434 | return errors::Internal( |
435 | "VariantUnaryOpFn: Could not access object, type_index: " , |
436 | type_index_name); |
437 | } |
438 | const T& t = *v.get<T>(); |
439 | T* t_out = v_out->get<T>(); |
440 | return unary_op_fn(ctx, t, t_out); |
441 | }); |
442 | } |
443 | }; |
444 | |
445 | template <typename T> |
446 | class UnaryVariantBinaryOpRegistration { |
447 | typedef std::function<Status(OpKernelContext* ctx, const T& a, const T& b, |
448 | T* out)> |
449 | LocalVariantBinaryOpFn; |
450 | |
451 | public: |
452 | UnaryVariantBinaryOpRegistration(VariantBinaryOp op, |
453 | const std::string& device, |
454 | const TypeIndex& type_index, |
455 | const LocalVariantBinaryOpFn& binary_op_fn) { |
456 | const std::string type_index_name = |
457 | port::MaybeAbiDemangle(type_index.name()); |
458 | UnaryVariantOpRegistry::Global()->RegisterBinaryOpFn( |
459 | op, device, type_index, |
460 | [type_index_name, binary_op_fn](OpKernelContext* ctx, const Variant& a, |
461 | const Variant& b, |
462 | Variant* out) -> Status { |
463 | DCHECK_NE(out, nullptr); |
464 | *out = T(); |
465 | if (a.get<T>() == nullptr) { |
466 | return errors::Internal( |
467 | "VariantBinaryOpFn: Could not access object 'a', type_index: " , |
468 | type_index_name); |
469 | } |
470 | if (b.get<T>() == nullptr) { |
471 | return errors::Internal( |
472 | "VariantBinaryOpFn: Could not access object 'b', type_index: " , |
473 | type_index_name); |
474 | } |
475 | const T& t_a = *a.get<T>(); |
476 | const T& t_b = *b.get<T>(); |
477 | T* t_out = out->get<T>(); |
478 | return binary_op_fn(ctx, t_a, t_b, t_out); |
479 | }); |
480 | } |
481 | }; |
482 | |
483 | }; // namespace variant_op_registry_fn_registration |
484 | |
485 | // Register a unary decode variant function for the given type. |
486 | #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION(T, type_name) \ |
487 | REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(__COUNTER__, T, type_name) |
488 | |
489 | #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ_HELPER(ctr, T, type_name) \ |
490 | REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) |
491 | |
492 | #define REGISTER_UNARY_VARIANT_DECODE_FUNCTION_UNIQ(ctr, T, type_name) \ |
493 | static ::tensorflow::variant_op_registry_fn_registration:: \ |
494 | UnaryVariantDecodeRegistration<T> \ |
495 | register_unary_variant_op_decoder_fn_##ctr(type_name) |
496 | |
497 | // ****** NOTE ****** |
498 | // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. |
499 | // ****** NOTE ****** |
500 | // |
501 | // Register a device copy variant function for the given copy |
502 | // direction and type; where direction is the enum |
503 | // VariantDeviceCopyDirection, and the device_copy_fn has signature: |
504 | // |
505 | // Status device_copy_fn( |
506 | // const T& t, T* t_out, |
507 | // const UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn& copier); |
508 | // |
509 | // And device_copy_fn calls copier 0 or more times. For details on |
510 | // the behavior of the copier function, see the comments at the |
511 | // declaration of UnaryVariantOpRegistry::AsyncTensorDeviceCopyFn. |
512 | // |
513 | // Note, the device_copy_fn may choose to keep some tensors |
514 | // on host, e.g. by assigning to->tensor = from.tensor (assuming |
515 | // from.tensor is already on host); or by setting |
516 | // to->tensor = Tensor(cpu_allocator(), ...) |
517 | // and manually updating its values. |
518 | // |
519 | // If this is the case, the CopyFns for HOST_TO_DEVICE, |
520 | // DEVICE_TO_HOST, and DEVICE_TO_DEVICE must perform host-to-host |
521 | // copies in a consistent manner. For example, one must always |
522 | // manually copy any "always on host" tensors in all directions instead of e.g. |
523 | // - performing a host-to-host copy in one direction, |
524 | // - using the provided copier function in the reverse direction. |
525 | // Doing the latter will cause program failures. |
526 | // |
527 | // ****** NOTE ****** |
528 | // FOR INTERNAL USE ONLY. IF YOU USE THIS WE MAY BREAK YOUR CODE. |
529 | // ****** NOTE ****** |
530 | #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \ |
531 | device_copy_fn) \ |
532 | INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ |
533 | __COUNTER__, T, direction, TypeIndex::Make<T>(), device_copy_fn) |
534 | |
535 | #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \ |
536 | ctr, T, direction, type_index, device_copy_fn) \ |
537 | INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ |
538 | ctr, T, direction, type_index, device_copy_fn) |
539 | |
540 | #define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ( \ |
541 | ctr, T, direction, type_index, device_copy_fn) \ |
542 | static variant_op_registry_fn_registration:: \ |
543 | UnaryVariantDeviceCopyRegistration<T> \ |
544 | register_unary_variant_op_device_copy_fn_##ctr( \ |
545 | direction, type_index, device_copy_fn) |
546 | |
547 | // Register a unary unary_op variant function with the signature: |
548 | // Status UnaryOpFn(OpKernelContext* ctx, const T& t, T* t_out); |
549 | // to Variants having TypeIndex type_index, for device string device, |
550 | // for UnaryVariantOp enum op. |
551 | #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \ |
552 | unary_op_function) \ |
553 | REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ |
554 | __COUNTER__, op, device, T, TypeIndex::Make<T>(), unary_op_function) |
555 | |
556 | #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \ |
557 | ctr, op, device, T, type_index, unary_op_function) \ |
558 | REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ(ctr, op, device, T, \ |
559 | type_index, unary_op_function) |
560 | |
561 | #define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ( \ |
562 | ctr, op, device, T, type_index, unary_op_function) \ |
563 | static ::tensorflow::variant_op_registry_fn_registration:: \ |
564 | UnaryVariantUnaryOpRegistration<T> \ |
565 | register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ |
566 | unary_op_function) |
567 | |
568 | // Register a binary_op variant function with the signature: |
569 | // Status BinaryOpFn(OpKernelContext* ctx, const T& a, const T& b, T* out); |
570 | // to Variants having TypeIndex type_index, for device string device, |
571 | // for BinaryVariantOp enum OP. |
572 | #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \ |
573 | binary_op_function) \ |
574 | REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ |
575 | __COUNTER__, op, device, T, TypeIndex::Make<T>(), binary_op_function) |
576 | |
577 | #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \ |
578 | ctr, op, device, T, type_index, binary_op_function) \ |
579 | REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ |
580 | ctr, op, device, T, type_index, binary_op_function) |
581 | |
582 | #define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ( \ |
583 | ctr, op, device, T, type_index, binary_op_function) \ |
584 | static ::tensorflow::variant_op_registry_fn_registration:: \ |
585 | UnaryVariantBinaryOpRegistration<T> \ |
586 | register_unary_variant_op_decoder_fn_##ctr(op, device, type_index, \ |
587 | binary_op_function) |
588 | |
589 | } // end namespace tensorflow |
590 | |
591 | #endif // TENSORFLOW_CORE_FRAMEWORK_VARIANT_OP_REGISTRY_H_ |
592 | |