1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
17#define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
18
19#define _USE_MATH_DEFINES
20#include <cmath>
21#include <functional>
22#include <type_traits>
23
24#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25#include "tensorflow/core/framework/bounds_check.h"
26#include "tensorflow/core/framework/numeric_types.h"
27#include "tensorflow/core/framework/tensor_types.h"
28
29namespace Eigen {
30namespace internal {
31
32#if GOOGLE_CUDA
33template <>
34struct scalar_arg_op<std::complex<float>> {
35 typedef typename Eigen::NumTraits<std::complex<float>>::Real result_type;
36 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()(
37 const std::complex<float>& a) const {
38 return ::atan2f(a.imag(), a.real());
39 }
40};
41
42template <>
43struct scalar_arg_op<std::complex<double>> {
44 typedef typename Eigen::NumTraits<std::complex<double>>::Real result_type;
45 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double operator()(
46 const std::complex<double>& a) const {
47 return ::atan2(a.imag(), a.real());
48 }
49};
50#endif
51
52#if EIGEN_HAS_CXX11_MATH == 0
53template <typename T>
54struct scalar_asinh_op {
55 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
56 return static_cast<T>(std::asinh(a));
57 }
58};
59template <typename T>
60struct functor_traits<scalar_asinh_op<T>> {
61 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
62};
63
64template <typename T>
65struct scalar_acosh_op {
66 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
67 return static_cast<T>(std::acosh(a));
68 }
69};
70template <typename T>
71struct functor_traits<scalar_acosh_op<T>> {
72 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
73};
74
75template <typename T>
76struct scalar_atanh_op {
77 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
78 return static_cast<T>(std::atanh(a));
79 }
80};
81template <typename T>
82struct functor_traits<scalar_atanh_op<T>> {
83 enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false };
84};
85#endif
86
87template <typename Scalar, typename Exponent>
88struct safe_scalar_binary_pow_op {
89 static_assert(std::is_integral<Scalar>::value, "Integer type expected");
90 static_assert(std::is_integral<Exponent>::value &&
91 std::is_signed<Exponent>::value,
92 "Signed integer type expected");
93
94 bool* const error;
95
96 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error)
97 : error(error) {}
98
99 EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a,
100 const Exponent& b) const {
101 const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b);
102 if (TF_PREDICT_TRUE(safe_b >= 0)) {
103 return numext::pow(a, safe_b);
104 } else {
105 *error = true;
106 return 0;
107 }
108 }
109};
110
111template <typename Scalar, typename Exponent>
112struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> {
113 enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false };
114};
115
116template <typename T, typename DivOrMod>
117struct safe_div_or_mod_op {
118 static_assert(std::is_integral<T>::value, "Integer type expected");
119
120 bool* const error;
121
122 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error)
123 : error(error) {}
124
125 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
126 const T& b) const {
127 const T safe_b = tensorflow::internal::SubtleMustCopy(b);
128 if (TF_PREDICT_TRUE(safe_b != 0)) {
129 // Avoid FPE for INT_MIN/-1.
130 const T safe_a = tensorflow::internal::SubtleMustCopy(a);
131 if (TF_PREDICT_FALSE(std::is_signed<T>::value &&
132 safe_a == std::numeric_limits<T>::min() &&
133 safe_b == T(-1))) {
134 // Prefer to overflow 'a' instead of crashing.
135 return DivOrMod()(-safe_a, 1);
136 }
137 return DivOrMod()(safe_a, safe_b);
138 } else {
139 *error = true;
140 return 0;
141 }
142 }
143};
144
145template <typename T, typename DivOrMod>
146struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> {
147 enum {
148 Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost,
149 PacketAccess = false,
150 };
151};
152
153template <typename T, typename Binary>
154struct no_nan_op {
155 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
156 const T& b) const {
157 if (b != T(0)) {
158 return Binary()(a, b);
159 } else {
160 return T(0);
161 }
162 }
163 template <typename Packet>
164 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
165 const Packet& b) const {
166 const Packet mask = pcmp_eq(b, pzero(b));
167 const Packet quotient = Binary().packetOp(a, b);
168 return pandnot(quotient, mask);
169 }
170};
171
172template <typename T, bool IsComplex = Eigen::NumTraits<T>::IsComplex>
173struct div_no_nan_op;
174
175template <typename T>
176struct div_no_nan_op<T, /*IsComplex=*/false>
177 : public no_nan_op<T, scalar_quotient_op<T>> {
178};
179
180template <typename T>
181struct functor_traits<div_no_nan_op<T, /*IsComplex=*/false>> {
182 enum {
183 Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost,
184 PacketAccess = true,
185 };
186};
187
188// Whether or not complex division produces a NaN depends on the underlying
189// implementation. Some compilers (e.g. gcc) use a simple method that divides
190// by |b|^2, which may underflow to 0 for b != 0.
191template <typename T>
192struct div_no_nan_op<T, /*IsComplex=*/true> {
193 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a,
194 const T& b) const {
195 if (b == T(0)) {
196 return T(0);
197 } else {
198 // If the numerator is zero, then the result must be zero even if |b|^2
199 // underflows to zero.
200 const T numerator =
201 scalar_product_op<T>()(a, scalar_conjugate_op<T>()(b));
202 if (numerator == T(0)) {
203 return T(0);
204 }
205 }
206 return scalar_quotient_op<T>()(a, b);
207 }
208 template <typename Packet>
209 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
210 const Packet& b) const {
211 const Packet numerator = pmul(a, pconj(b));
212 const Packet mask = por(pcmp_eq(b, pzero(a)), pcmp_eq(numerator, pzero(a)));
213 const Packet quotient = pdiv(a, b);
214 return pandnot(quotient, mask);
215 }
216};
217
218template <typename T>
219struct functor_traits<div_no_nan_op<T, /*IsComplex=*/true>> {
220 enum {
221 Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::MulCost,
222 PacketAccess = packet_traits<T>::HasMul && packet_traits<T>::HasDiv &&
223 packet_traits<T>::HasConj,
224 };
225};
226
227template <typename T>
228struct mul_no_nan_op : public no_nan_op<T, scalar_product_op<T>> {
229};
230
231template <typename T>
232struct functor_traits<mul_no_nan_op<T>> {
233 enum {
234 Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost,
235 PacketAccess = true,
236 };
237};
238
239// scalar_left and scalar_right are template helpers to partially
240// apply a binary function.
241//
242// Suppose Binary is a binary functor f(x, y), scalar_left<> is a
243// unary functor g_x(y) = f(x, y), where x is provided via the
244// constructor. Similarly, scalar_right<> is a unary functor g_y(x) =
245// f(x, y).
246
247template <typename Tout, typename Tin, typename Binary,
248 bool is_scalar_in_host_memory = false>
249struct scalar_left : private Binary {
250 using result_type = Tout;
251 using TinPacket = typename Eigen::internal::packet_traits<Tin>::type;
252
253 const Tin* left;
254 TinPacket left_packet; // initialized iff is_scalar_in_host_memory == true
255
256 inline scalar_left(const scalar_left& other) = default;
257
258 template <typename... Args>
259 EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args)
260 : Binary(args...), left(c) {
261 if (is_scalar_in_host_memory) {
262 left_packet = Eigen::internal::pset1<TinPacket>(*left);
263 }
264 }
265
266 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const {
267 return Binary::operator()(*left, right);
268 }
269
270 template <typename Packet,
271 typename std::enable_if<!is_scalar_in_host_memory ||
272 !std::is_same<TinPacket, Packet>::value,
273 int>::type = 0>
274 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
275 const Packet left_packet = Eigen::internal::pset1<Packet>(*left);
276 return Binary::packetOp(left_packet, right_packet);
277 }
278
279 template <typename Packet,
280 typename std::enable_if<is_scalar_in_host_memory &&
281 std::is_same<TinPacket, Packet>::value,
282 int>::type = 0>
283 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const {
284 return Binary::packetOp(left_packet, right_packet);
285 }
286};
287
288template <typename Tout, typename Tin, typename Binary,
289 bool is_scalar_in_host_memory>
290struct functor_traits<
291 scalar_left<Tout, Tin, Binary, is_scalar_in_host_memory>> {
292 enum {
293 Cost = functor_traits<Binary>::Cost,
294 PacketAccess = functor_traits<Binary>::PacketAccess,
295 };
296};
297
298template <typename Tout, typename Tin, typename Binary,
299 bool is_scalar_in_host_memory = false>
300struct scalar_right : private Binary {
301 using result_type = Tout;
302 using TinPacket = typename Eigen::internal::packet_traits<Tin>::type;
303
304 const Tin* right;
305 TinPacket right_packet; // initialized iff is_scalar_in_host_memory == true
306
307 inline scalar_right(const scalar_right& other) = default;
308
309 template <typename... Args>
310 EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args)
311 : Binary(args...), right(c) {
312 if (is_scalar_in_host_memory) {
313 right_packet = Eigen::internal::pset1<TinPacket>(*right);
314 }
315 }
316
317 EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const {
318 return Binary::operator()(left, *right);
319 }
320
321 template <typename Packet,
322 typename std::enable_if<!is_scalar_in_host_memory ||
323 !std::is_same<TinPacket, Packet>::value,
324 int>::type = 0>
325 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
326 const Packet right_packet = Eigen::internal::pset1<Packet>(*right);
327 return Binary::packetOp(left_packet, right_packet);
328 }
329
330 template <typename Packet,
331 typename std::enable_if<is_scalar_in_host_memory &&
332 std::is_same<TinPacket, Packet>::value,
333 int>::type = 0>
334 EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const {
335 return Binary::packetOp(left_packet, right_packet);
336 }
337};
338
339template <typename Tout, typename Tin, typename Binary,
340 bool is_scalar_in_host_memory>
341struct functor_traits<
342 scalar_right<Tout, Tin, Binary, is_scalar_in_host_memory>> {
343 enum {
344 Cost = functor_traits<Binary>::Cost,
345 PacketAccess = functor_traits<Binary>::PacketAccess,
346 };
347};
348
349// similar to std::equal_to, but with the DEVICE_FUNC qualifier
350template <class T>
351struct equal_to : std::function<bool(T, T)> {
352 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
353 const T& y) const {
354 return x == y;
355 }
356};
357
358// similar to std::not_equal_to, but with the DEVICE_FUNC qualifier
359template <class T>
360struct not_equal_to : std::function<bool(T, T)> {
361 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
362 const T& y) const {
363 return x != y;
364 }
365};
366
367// similar to std::greater, but with the DEVICE_FUNC qualifier
368template <class T>
369struct greater : std::function<bool(T, T)> {
370 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
371 const T& y) const {
372 return x > y;
373 }
374};
375
376// similar to std::less, but with the DEVICE_FUNC qualifier
377template <class T>
378struct less : std::function<bool(T, T)> {
379 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
380 const T& y) const {
381 return x < y;
382 }
383};
384
385// similar to std::greater_equal, but with the DEVICE_FUNC qualifier
386template <class T>
387struct greater_equal : std::function<bool(T, T)> {
388 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
389 const T& y) const {
390 return x >= y;
391 }
392};
393
394// similar to std::less_equal, but with the DEVICE_FUNC qualifier
395template <class T>
396struct less_equal : std::function<bool(T, T)> {
397 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x,
398 const T& y) const {
399 return x <= y;
400 }
401};
402
403// Functor that enables squared difference functor.
404template <typename Scalar>
405struct scalar_squared_difference_op {
406 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
407 operator()(const Scalar& a, const Scalar& b) const {
408 const Scalar v = scalar_difference_op<Scalar>()(a, b);
409 return scalar_product_op<Scalar>()(v, scalar_conjugate_op<Scalar>()(v));
410 }
411 template <typename Packet>
412 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
413 const Packet& b) const {
414 const Packet v = scalar_difference_op<Scalar>().packetOp(a, b);
415 return scalar_product_op<Scalar>().packetOp(
416 v, scalar_conjugate_op<Scalar>().packetOp(v));
417 }
418};
419
420template <typename Scalar>
421struct functor_traits<scalar_squared_difference_op<Scalar>> {
422 enum {
423 Cost = functor_traits<scalar_difference_op<Scalar>>::Cost +
424 functor_traits<scalar_conjugate_op<Scalar>>::Cost +
425 functor_traits<scalar_product_op<Scalar>>::Cost,
426 PacketAccess = functor_traits<scalar_difference_op<Scalar>>::PacketAccess &&
427 functor_traits<scalar_conjugate_op<Scalar>>::PacketAccess &&
428 functor_traits<scalar_product_op<Scalar>>::PacketAccess
429 };
430};
431
432// TODO(b/32239616): This kernel should be moved into Eigen and vectorized.
433template <typename T, typename Enable = void>
434struct google_floor_div {
435 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
436 const T& y) const {
437 const T z = x / y;
438 // Subtract one if there is a remainder and if the inputs have opposite
439 // signs. This approach avoids unnecessary overflows.
440 return z * y != x && (x < T(0) != y < T(0)) ? z - T(1) : z;
441 }
442 template <typename Packet>
443 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
444 const Packet& y) const {
445 Packet zeros = pzero(x);
446 Packet x_mask = pcmp_lt(x, zeros);
447 Packet y_mask = pcmp_lt(y, zeros);
448 Packet x_div_y = pdiv(x, y);
449 Packet x_div_y_times_y = pmul(x_div_y, y);
450 return pselect(por(peq(x_div_y_times_y, x), peq(x_mask, y_mask)), x_div_y,
451 psub(x_div_y, pones(x)));
452 }
453};
454
455template <typename T>
456struct google_floor_div<
457 T, typename std::enable_if<std::is_unsigned<T>::value>::type> {
458 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
459 const T& y) const {
460 return x / y;
461 }
462 template <typename Packet>
463 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
464 const Packet& y) const {
465 return pdiv(x, y);
466 }
467};
468
469template <typename Scalar>
470struct functor_traits<google_floor_div<Scalar>> {
471 enum {
472 Cost = 2 * Eigen::internal::scalar_div_cost<
473 Scalar, packet_traits<Scalar>::HasDiv>::value +
474 NumTraits<Scalar>::AddCost,
475 PacketAccess = packet_traits<Scalar>::HasDiv
476 };
477};
478
479template <typename T, typename Enable = void>
480struct google_floor_div_real {
481 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
482 const T& y) const {
483 return Eigen::numext::floor(x / y);
484 }
485 template <typename Packet>
486 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
487 const Packet& y) const {
488 return pfloor(pdiv(x, y));
489 }
490};
491
492template <typename Scalar>
493struct functor_traits<google_floor_div_real<Scalar>> {
494 enum {
495 Cost = 2 * Eigen::internal::scalar_div_cost<
496 Scalar, packet_traits<Scalar>::HasDiv>::value +
497 2 * NumTraits<Scalar>::AddCost,
498 PacketAccess =
499 packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor
500 };
501};
502
503// TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here.
504template <typename T>
505struct google_floor_fmod {
506 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
507 const T& y) const {
508 // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL);
509 T trunc_mod = scalar_fmod_op<T>()(x, y);
510 return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y
511 : trunc_mod;
512 }
513};
514
515template <typename Scalar>
516struct functor_traits<google_floor_fmod<Scalar>> {
517 enum {
518 Cost = functor_traits<Eigen::internal::scalar_fmod_op<Scalar>>::Cost +
519 NumTraits<Scalar>::AddCost,
520 PacketAccess = false
521 };
522};
523
524// TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here.
525template <typename T>
526struct google_floor_mod {
527 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
528 const T& y) const {
529 // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL);
530 T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(x, y);
531 return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y
532 : trunc_mod;
533 }
534};
535
536template <typename Scalar>
537struct functor_traits<google_floor_mod<Scalar>> {
538 enum {
539 Cost = functor_traits<Eigen::internal::scalar_mod2_op<Scalar>>::Cost +
540 NumTraits<Scalar>::AddCost,
541 PacketAccess = false
542 };
543};
544
545#if EIGEN_COMP_GNUC && __cplusplus > 199711L
546#define DISABLE_FLOAT_EQUALITY_WARNING \
547 _Pragma("GCC diagnostic push") \
548 _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"")
549#define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop")
550#else
551#define DISABLE_FLOAT_EQUALITY_WARNING
552#define ENABLE_FLOAT_EQUALITY_WARNING
553#endif
554
555template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger,
556 bool HasRint = packet_traits<Scalar>::HasRint>
557struct scalar_round_half_to_even_op {
558 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
559 operator()(const Scalar& x) const {
560 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
561 NUMERIC_TYPE_MUST_BE_REAL)
562
563 const Scalar round_val = Eigen::numext::floor(x + Scalar(0.5));
564 const Scalar fraction = round_val - x;
565 if (TF_PREDICT_FALSE(fraction == Scalar(.5))) {
566 return Scalar(2) * Eigen::numext::floor(Scalar(.5) * x + Scalar(0.5));
567 } else {
568 return round_val;
569 }
570 }
571
572 template <typename Packet>
573 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
574 Packet half = pset1<Packet>(Scalar(0.5));
575 Packet round_val = pfloor(padd(x, half));
576 Packet fraction = psub(round_val, x);
577 Packet half_mask = pcmp_eq(fraction, half);
578 bool any_halves = predux_any(half_mask);
579 if (TF_PREDICT_FALSE(any_halves)) {
580 Packet two = pset1<Packet>(Scalar(2));
581 Packet nearest_even = pmul(two, pfloor(pmadd(half, x, half)));
582 return pselect(half_mask, nearest_even, round_val);
583 } else {
584 return round_val;
585 }
586 }
587};
588
589template <typename Scalar>
590struct scalar_round_half_to_even_op<Scalar, true, false> {
591 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
592 operator()(const Scalar& x) const {
593 return x;
594 }
595 template <typename Packet>
596 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
597 return x;
598 }
599};
600
601template <typename Scalar>
602struct scalar_round_half_to_even_op<Scalar, false, true> {
603 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
604 operator()(const Scalar& x) const {
605 return Eigen::numext::rint(x);
606 }
607 template <typename Packet>
608 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
609 return print(x);
610 }
611};
612
613template <typename Scalar>
614struct functor_traits<scalar_round_half_to_even_op<Scalar>> {
615 enum {
616 Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0
617 : 4 * NumTraits<Scalar>::AddCost,
618 PacketAccess = packet_traits<Scalar>::HasFloor &&
619 packet_traits<Scalar>::HasAdd &&
620 packet_traits<Scalar>::HasMul,
621 };
622};
623
624template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger>
625struct scalar_round_up_op {
626 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
627 operator()(const Scalar& x) const {
628 EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex),
629 NUMERIC_TYPE_MUST_BE_REAL)
630 return Eigen::numext::floor(x + Scalar(0.5));
631 }
632
633 template <typename Packet>
634 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
635 return pfloor(padd(x, pset1<Packet>(0.5)));
636 }
637};
638
639template <typename Scalar>
640struct scalar_round_up_op<Scalar, true> {
641 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
642 operator()(const Scalar& x) const {
643 return x;
644 }
645
646 template <typename Packet>
647 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
648 return x;
649 }
650};
651
652template <typename Scalar, bool IsInteger>
653struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> {
654 enum {
655 Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost,
656 PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor
657 };
658};
659
660#undef ENABLE_FLOAT_EQUALITY_WARNING
661#undef DISABLE_FLOAT_EQUALITY_WARNING
662
663template <typename Scalar>
664struct bitwise_xor_op {
665 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
666 operator()(const Scalar& x, const Scalar& y) const {
667 return x ^ y;
668 }
669 template <typename Packet>
670 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a,
671 const Packet& b) const {
672 return Eigen::internal::pxor(a, b);
673 }
674};
675
676template <typename Scalar>
677struct functor_traits<bitwise_xor_op<Scalar>> {
678 enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true };
679};
680
681template <typename Scalar>
682struct xlogy_op {
683 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
684 operator()(const Scalar& x, const Scalar& y) const {
685 if (x == Scalar(0.)) {
686 return Scalar(0.);
687 }
688 return x * numext::log(y);
689 }
690 template <typename Packet>
691 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
692 const Packet& y) const {
693 Packet zeros = pzero(x);
694 Packet mask = pcmp_eq(x, zeros);
695 scalar_log_op<Scalar> log_op;
696 Packet log_y = log_op.packetOp(y);
697 Packet x_log_y = pmul(x, log_y);
698 return pselect(mask, x, x_log_y);
699 }
700};
701
702template <typename Scalar>
703struct functor_traits<xlogy_op<Scalar>> {
704 enum {
705 Cost = functor_traits<scalar_log_op<Scalar>>::Cost +
706 Eigen::NumTraits<Scalar>::MulCost,
707 PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess
708 };
709};
710
711template <typename Scalar>
712struct xlog1py_op {
713 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
714 operator()(const Scalar& x, const Scalar& y) const {
715 if (x == Scalar(0.)) {
716 return Scalar(0.);
717 }
718 return x * numext::log1p(y);
719 }
720 template <typename Packet>
721 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
722 const Packet& y) const {
723 Packet zeros = pzero(x);
724 Packet mask = pcmp_eq(x, zeros);
725 scalar_log1p_op<Scalar> log1p_op;
726 Packet log1p_y = log1p_op.packetOp(y);
727 Packet x_log1p_y = pmul(x, log1p_y);
728 return pselect(mask, x, x_log1p_y);
729 }
730};
731
732template <typename Scalar>
733struct functor_traits<xlog1py_op<Scalar>> {
734 enum {
735 Cost = functor_traits<scalar_log1p_op<Scalar>>::Cost +
736 Eigen::NumTraits<Scalar>::MulCost,
737#if TENSORFLOW_USE_ROCM
738 PacketAccess = false,
739#else
740 PacketAccess = functor_traits<scalar_log1p_op<Scalar>>::PacketAccess
741#endif
742 };
743};
744
745template <typename Scalar>
746struct xdivy_op {
747 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
748 operator()(const Scalar& x, const Scalar& y) const {
749 if (x == Scalar(0.)) {
750 return Scalar(0.);
751 }
752 return x / y;
753 }
754 template <typename Packet>
755 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x,
756 const Packet& y) const {
757 Packet zeros = pzero(x);
758 Packet mask = pcmp_eq(x, zeros);
759 Packet x_div_y = pdiv(x, y);
760 return pselect(mask, x, x_div_y);
761 }
762};
763
764template <typename Scalar>
765struct functor_traits<xdivy_op<Scalar>> {
766 enum {
767 Cost =
768 Eigen::NumTraits<Scalar>::AddCost +
769 Eigen::internal::scalar_div_cost<Scalar,
770 packet_traits<Scalar>::HasDiv>::value,
771 PacketAccess = packet_traits<Scalar>::HasDiv
772 };
773};
774
775template <typename T>
776struct scalar_erfinv_op {
777 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const {
778 constexpr T half = T(0.5);
779 T y = numext::ndtri(half * x + half);
780 constexpr T half_sqrt = T(M_SQRT1_2);
781 return y * half_sqrt;
782 }
783 template <typename Packet>
784 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const {
785 Packet half = pset1<Packet>(T(0.5));
786 Packet y = pndtri<Packet>(pmadd(half, x, half));
787 Packet half_sqrt = pset1<Packet>(T(M_SQRT1_2));
788 return pmul(y, half_sqrt);
789 }
790};
791
792template <typename T>
793struct functor_traits<scalar_erfinv_op<T>> {
794 enum {
795 Cost = functor_traits<scalar_ndtri_op<T>>::Cost + NumTraits<T>::AddCost,
796 PacketAccess = packet_traits<T>::HasNdtri,
797 };
798};
799
800} // end namespace internal
801} // end namespace Eigen
802
803namespace tensorflow {
804namespace functor {
805
806////////////////////////////////////////////////////////////////////////////////
807// Helpers
808////////////////////////////////////////////////////////////////////////////////
809
810// Base template for functors whose input scalar type is T and
811// output scalar type is R.
812template <typename T, typename F, typename R = T>
813struct base {
814 // func defines operator() and its vectorized version packetOp().
815 typedef F func;
816
817 // If true, the functor's corresponding binary op will instantiate
818 // specialized kernels to perform an optimized broadcast
819 // operation. Each functor for which this is enabled increases the
820 // code size, so by default this is disabled for binary functors and
821 // is enabled on a per-op basis as needed.
822 static constexpr bool use_bcast_optimization = false;
823
824 // operator() has the signature:
825 // out_type operator()(in_type in0, in_type in1 ...)
826 typedef R out_type;
827 typedef T in_type;
828
829 // TensorFlow provides tensor-ized version of "func". Roughly
830 // speaking, the tensorflow operation has the signature:
831 // tout_type op(tin_type in0)
832 // tout_type op(tin_type in0, tin_type in1)
833 // tout_type op(tin_type in0, in_type scalar)
834 typedef typename TTypes<out_type>::Flat tout_type;
835 typedef typename TTypes<in_type>::ConstFlat tin_type;
836 typedef typename TTypes<in_type>::ConstScalar tscalar_type;
837
838 // Whether the functor can error out. Currently applies only to integer
839 // div and mod.
840 static constexpr bool has_errors = false;
841};
842
843// For now, we only apply certain speed optimization for
844// float/double's broadcast binary op.
845template <typename T>
846struct use_bcast_optimization {
847 static constexpr bool value = false;
848};
849
850template <>
851struct use_bcast_optimization<float> {
852 static constexpr bool value = true;
853};
854
855template <>
856struct use_bcast_optimization<double> {
857 static constexpr bool value = true;
858};
859
860////////////////////////////////////////////////////////////////////////////////
861// Unary functors
862////////////////////////////////////////////////////////////////////////////////
863
864// abs(x) = |x|
865// neg(x) = - x
866// inverse(x) = 1 / x
867// square(x) = x^2
868// sqrt(x) = x^(1/2)
869// rsqrt(x) = x^(-1/2)
870// exp(x) = e^x
871// expm1(x) = e^x - 1
872// log(x) = natural logarithm of x
873// log1p(x) = natural logarithm of 1 + x
874// tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
875// sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic
876//
877// NOTE: We may eventually implement common functions used in NN
878// here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc.
879// For reference, see speech/lstm/eigen_functors.h.
880
881template <typename T>
882struct abs : base<T, Eigen::internal::scalar_abs_op<T>,
883 typename Eigen::internal::scalar_abs_op<T>::result_type> {};
884
885template <typename T>
886struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {};
887
888template <typename T>
889struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {};
890
891template <typename T>
892struct square : base<T, Eigen::internal::scalar_square_op<T>> {};
893
894template <typename T>
895struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {};
896
897template <typename T>
898struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {};
899
900template <typename T>
901struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {};
902
903template <typename T>
904struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {};
905
906template <typename T>
907struct log : base<T, Eigen::internal::scalar_log_op<T>> {};
908
909template <typename T>
910struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {};
911
912template <typename T>
913struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {};
914
915template <typename T>
916struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {};
917
918template <typename T>
919struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {};
920
921template <typename T>
922struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {};
923
924template <typename T>
925struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {};
926
927template <typename T>
928struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {};
929
930template <typename T>
931struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {};
932
933template <typename T>
934struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {};
935
936template <typename T>
937struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {};
938
939template <typename T>
940struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {};
941
942template <typename T>
943struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {};
944
945template <typename T>
946struct ndtri : base<T, Eigen::internal::scalar_ndtri_op<T>> {};
947
948template <typename T>
949struct erfinv : base<T, Eigen::internal::scalar_erfinv_op<T>> {};
950
951template <typename T>
952struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {};
953
954template <typename T>
955struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {};
956
957template <typename T>
958struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {};
959
960template <typename T>
961struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {};
962
963template <typename T>
964struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {};
965
966template <typename T>
967struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {};
968
969template <typename T>
970struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {};
971
972struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> {
973};
974
975// Flip all bits. Named invert to be consistent with numpy.
976template <typename T>
977struct invert_op {
978 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const {
979 return ~a;
980 }
981};
982
983template <typename T>
984struct invert : base<T, invert_op<T>> {};
985
986// NOTE: std::isinf, std::isnan, std::isfinite are plain function.
987// Therefore we need to wrap them in functors to be used with Eigen's
988// type system.
989template <typename T>
990struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {};
991
992template <typename T>
993struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {};
994
995template <typename T>
996struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {};
997
998template <typename T>
999struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {};
1000
1001template <typename T>
1002struct round : base<T, Eigen::internal::scalar_round_half_to_even_op<T>> {};
1003
1004template <typename T>
1005struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {};
1006
1007// Note: rint rounds half values to even, just like round_half_to_even_op.
1008template <typename T>
1009struct rint : base<T, Eigen::internal::scalar_rint_op<T>> {};
1010
1011////////////////////////////////////////////////////////////////////////////////
1012// Binary functors
1013////////////////////////////////////////////////////////////////////////////////
1014
1015// Binary functors:
1016//
1017// add(x, y) = x + y
1018// sub(x, y) = x - y
1019// mul(x, y) = x * y
1020// div(x, y) = x / y
1021// mod(x, y) = x % y (int32 and int64 only)
1022// fmod(x, y) = fmod(x, y) (float and double only)
1023// pow(x, y) = x ^ y
1024// maximum(x, y) = x > y ? x : y
1025// minimum(x, y) = x < y ? x : y
1026// squared_difference(x, y) = conj(x - y) * (x - y)
1027
1028template <typename T>
1029struct add : base<T, Eigen::internal::scalar_sum_op<T>> {
1030 static constexpr bool use_bcast_optimization = true;
1031};
1032
1033template <typename T>
1034struct sub : base<T, Eigen::internal::scalar_difference_op<T>> {
1035 static constexpr bool use_bcast_optimization = true;
1036};
1037
1038template <typename T>
1039struct mul : base<T, Eigen::internal::scalar_product_op<T>> {
1040 static constexpr bool use_bcast_optimization = true;
1041};
1042
1043template <typename T>
1044struct mul_no_nan : base<T, Eigen::internal::mul_no_nan_op<T>> {};
1045
1046template <typename T>
1047struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {};
1048
1049template <typename T>
1050struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op<
1051 T, Eigen::internal::scalar_quotient_op<T>>> {
1052 static constexpr bool has_errors = true;
1053};
1054
1055template <typename T>
1056struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {};
1057
1058template <typename T>
1059struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {};
1060
1061template <typename T>
1062struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {};
1063
1064template <typename T>
1065struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op<
1066 T, Eigen::internal::scalar_mod2_op<T>>> {
1067 static constexpr bool has_errors = true;
1068};
1069
1070template <typename T>
1071struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {};
1072
1073template <typename T>
1074struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op<
1075 T, Eigen::internal::google_floor_mod<T>>> {
1076 static constexpr bool has_errors = true;
1077};
1078
1079template <typename T>
1080struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {};
1081
1082template <typename T>
1083struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op<
1084 T, Eigen::internal::google_floor_div<T>>> {
1085 static constexpr bool has_errors = true;
1086};
1087
1088template <typename T>
1089struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {};
1090
1091template <typename T>
1092struct pow : base<T, Eigen::internal::scalar_pow_op<T, T>> {};
1093
1094template <typename T>
1095struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> {
1096 static constexpr bool has_errors = true;
1097};
1098
1099// Version of safe_pow for integers which returns 0 if RHS is negative and LHS
1100// is not 1 or -1. For use on GPUs, where we cannot raise an error.
1101template <typename T>
1102struct safe_pow_ignore_error_op {
1103 static_assert(std::is_integral<T>::value, "Integer type expected");
1104 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1105 const T& y) const {
1106 if (TF_PREDICT_FALSE(y < 0)) {
1107 if (x == T(-1)) {
1108 T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2));
1109 return trunc_mod == T(-1) ? T(-1) : T(1);
1110 }
1111 return x == T(1) ? T(1) : T(0);
1112 }
1113 return Eigen::internal::scalar_pow_op<T, T>{}(x, y);
1114 }
1115};
1116
1117template <typename T>
1118struct safe_pow_ignore_error : base<T, safe_pow_ignore_error_op<T>> {};
1119
1120template <typename T>
1121struct maximum
1122 : base<T, Eigen::internal::scalar_max_op<T, T, Eigen::PropagateNaN>> {};
1123
1124template <typename T>
1125struct minimum
1126 : base<T, Eigen::internal::scalar_min_op<T, T, Eigen::PropagateNaN>> {};
1127
1128template <typename T>
1129struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {};
1130
1131template <typename T>
1132struct random_gamma_grad
1133 : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {};
1134
1135template <typename T>
1136struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {};
1137
1138template <typename T>
1139struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {};
1140
1141template <typename T>
1142struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {};
1143
1144template <typename Scalar>
1145struct scalar_atan2_op {
1146 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar
1147 operator()(const Scalar& y, const Scalar& x) const {
1148#if TENSORFLOW_USE_ROCM
1149 return static_cast<Scalar>(::atan2(y, x));
1150#else
1151 return static_cast<Scalar>(std::atan2(y, x));
1152#endif
1153 }
1154};
1155
1156template <typename T>
1157struct atan2 : base<T, scalar_atan2_op<T>> {};
1158
1159template <typename T>
1160struct squared_difference
1161 : base<T, Eigen::internal::scalar_squared_difference_op<T>> {};
1162
1163template <typename T>
1164struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {};
1165
1166template <typename T>
1167struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {};
1168
1169template <typename T>
1170struct xlog1py : base<T, Eigen::internal::xlog1py_op<T>> {};
1171
1172template <typename T>
1173struct less : base<T, Eigen::internal::less<T>, bool> {};
1174
1175template <typename T>
1176struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {};
1177
1178template <typename T>
1179struct greater : base<T, Eigen::internal::greater<T>, bool> {};
1180
1181template <typename T>
1182struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {};
1183
1184template <typename T>
1185struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {};
1186
1187template <typename T>
1188struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {};
1189
1190struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {};
1191
1192struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {};
1193
1194template <typename T>
1195struct bitwise_and_op {
1196 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1197 const T& y) const {
1198 return x & y;
1199 }
1200};
1201
1202template <typename T>
1203struct bitwise_or_op {
1204 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1205 const T& y) const {
1206 return x | y;
1207 }
1208};
1209
1210template <typename T>
1211struct bitwise_and : base<T, bitwise_and_op<T>> {};
1212
1213template <typename T>
1214struct bitwise_or : base<T, bitwise_or_op<T>> {};
1215
1216template <typename T>
1217struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {};
1218
1219template <typename T>
1220struct left_shift_op {
1221 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1222 const T& y) const {
1223 // Avoids UB: don't shift by larger than the bitwidth of T, and
1224 // performs left shifts as unsigned shifts.
1225 T y_clamped = y;
1226 if (y_clamped < 0) {
1227 y_clamped = 0;
1228 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
1229 y_clamped = sizeof(T) * CHAR_BIT - 1;
1230 }
1231 using U = typename std::make_unsigned<T>::type;
1232 return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped));
1233 }
1234};
1235
1236template <typename T>
1237struct right_shift_op {
1238 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x,
1239 const T& y) const {
1240 // Avoids UB: don't shift by larger than the bitwidth of T.
1241 T y_clamped = y;
1242 if (y_clamped < 0) {
1243 y_clamped = 0;
1244 } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) {
1245 y_clamped = sizeof(T) * CHAR_BIT - 1;
1246 }
1247 // Technically right shifts of signed integers are not necessarily
1248 // arithmetic shifts according to the C++ standard. However in practice most
1249 // implementations are arithmetic shifts. If this proves to be a problem in
1250 // practice, we may need to use an alternative implementation.
1251 return x >> y_clamped;
1252 }
1253};
1254
1255template <typename T>
1256struct left_shift : base<T, left_shift_op<T>> {};
1257
1258template <typename T>
1259struct right_shift : base<T, right_shift_op<T>> {};
1260
1261template <typename T>
1262struct make_complex_func {
1263 typedef std::complex<T> result_type;
1264 EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real,
1265 T imag) const {
1266 return std::complex<T>(real, imag);
1267 }
1268};
1269
1270template <typename T>
1271struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {};
1272
1273template <typename T>
1274struct get_real
1275 : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {};
1276
1277template <typename T>
1278struct get_imag
1279 : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {};
1280
1281template <typename T>
1282struct get_angle
1283 : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {};
1284
1285template <typename T>
1286struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {};
1287
1288////////////////////////////////////////////////////////////////////////////////
1289// Functors takes 1 or 2 tensors, computes the base functor on
1290// coefficient of the input tensors and puts the results in the output
1291// tensor.
1292////////////////////////////////////////////////////////////////////////////////
1293template <typename Device, typename Functor>
1294struct UnaryFunctor {
1295 // Computes on device "d": out[i] = Functor(in[i])
1296 void operator()(const Device& d, typename Functor::tout_type out,
1297 typename Functor::tin_type in);
1298};
1299
1300template <typename Device, typename Functor, typename Targ>
1301struct UnaryFunctorWithArg {
1302 // Computes on device "d": out[i] = Functor(in[i])
1303 void operator()(const Device& d, typename Functor::tout_type out,
1304 typename Functor::tin_type in, Targ val);
1305};
1306
1307template <typename Device, typename Functor, int NDIMS,
1308 bool has_errors = Functor::has_errors>
1309struct BinaryFunctor {
1310 // Computes on device "d": out[i] = Functor(in0[i], in1[i])
1311 void operator()(const Device& d, typename Functor::tout_type out,
1312 typename Functor::tin_type in0,
1313 typename Functor::tin_type in1, bool* error);
1314
1315 // Computes on device "d": out[i] = Functor(scalar[0], in[i])
1316 void Left(const Device& d, typename Functor::tout_type out,
1317 typename Functor::tscalar_type scalar,
1318 typename Functor::tin_type in, bool* error);
1319
1320 // Computes on device "d": out[i] = Functor(in[i], scalar[0])
1321 void Right(const Device& d, typename Functor::tout_type out,
1322 typename Functor::tin_type in,
1323 typename Functor::tscalar_type scalar, bool* error);
1324
1325 // Computes on device "d":
1326 // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1))
1327 //
1328 // TODO(zhifengc): makes BCast a template member function on NDIMS
1329 // instead making BinaryFunctor templates on NDIMS.
1330 void BCast(const Device& d,
1331 typename TTypes<typename Functor::out_type, NDIMS>::Tensor out,
1332 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0,
1333 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0,
1334 typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1,
1335 typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1,
1336 bool* error);
1337};
1338
1339template <typename Device, typename T>
1340struct ApproximateEqual {
1341 void operator()(const Device& d, typename TTypes<T>::ConstFlat x,
1342 typename TTypes<T>::ConstFlat y, T tolerance,
1343 typename TTypes<bool>::Flat z);
1344};
1345
1346template <int NDIMS>
1347bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) {
1348 for (size_t i = 0; i < a.size(); ++i) {
1349 if (a[i] != 1) return false;
1350 }
1351 return true;
1352}
1353
1354template <typename Device, typename T>
1355struct SelectFunctor {
1356 void operator()(const Device& d, typename TTypes<T>::Flat out,
1357 typename TTypes<bool>::ConstFlat cond_flat,
1358 typename TTypes<T>::ConstFlat then_flat,
1359 typename TTypes<T>::ConstFlat else_flat);
1360};
1361
1362template <typename Device, typename T>
1363struct SelectScalarFunctor {
1364 void operator()(const Device& d, typename TTypes<T>::Flat out,
1365 typename TTypes<bool>::ConstScalar cond,
1366 typename TTypes<T>::ConstFlat then_flat,
1367 typename TTypes<T>::ConstFlat else_flat);
1368};
1369
1370template <typename Device, typename T>
1371struct BatchSelectFunctor {
1372 void operator()(const Device& d,
1373 typename TTypes<T>::Matrix output_flat_outer_dims,
1374 TTypes<bool>::ConstVec cond_vec,
1375 typename TTypes<T>::ConstMatrix then_flat_outer_dims,
1376 typename TTypes<T>::ConstMatrix else_flat_outer_dims);
1377};
1378
1379template <typename Device, typename T, int NDIMS>
1380struct BCastSelectFunctor {
1381 void operator()(const Device& d,
1382 typename TTypes<T, NDIMS>::Tensor output_tensor,
1383 typename TTypes<bool, NDIMS>::ConstTensor cond_tensor,
1384 typename TTypes<T, NDIMS>::ConstTensor then_tensor,
1385 typename TTypes<T, NDIMS>::ConstTensor else_tensor,
1386 typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast,
1387 typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast,
1388 typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast);
1389};
1390
1391} // end namespace functor
1392} // end namespace tensorflow
1393
1394#endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_
1395