1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ |
18 | |
19 | #define _USE_MATH_DEFINES |
20 | #include <cmath> |
21 | #include <functional> |
22 | #include <type_traits> |
23 | |
24 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
25 | #include "tensorflow/core/framework/bounds_check.h" |
26 | #include "tensorflow/core/framework/numeric_types.h" |
27 | #include "tensorflow/core/framework/tensor_types.h" |
28 | |
29 | namespace Eigen { |
30 | namespace internal { |
31 | |
32 | #if GOOGLE_CUDA |
33 | template <> |
34 | struct scalar_arg_op<std::complex<float>> { |
35 | typedef typename Eigen::NumTraits<std::complex<float>>::Real result_type; |
36 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE float operator()( |
37 | const std::complex<float>& a) const { |
38 | return ::atan2f(a.imag(), a.real()); |
39 | } |
40 | }; |
41 | |
42 | template <> |
43 | struct scalar_arg_op<std::complex<double>> { |
44 | typedef typename Eigen::NumTraits<std::complex<double>>::Real result_type; |
45 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE double operator()( |
46 | const std::complex<double>& a) const { |
47 | return ::atan2(a.imag(), a.real()); |
48 | } |
49 | }; |
50 | #endif |
51 | |
52 | #if EIGEN_HAS_CXX11_MATH == 0 |
53 | template <typename T> |
54 | struct scalar_asinh_op { |
55 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { |
56 | return static_cast<T>(std::asinh(a)); |
57 | } |
58 | }; |
59 | template <typename T> |
60 | struct functor_traits<scalar_asinh_op<T>> { |
61 | enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; |
62 | }; |
63 | |
64 | template <typename T> |
65 | struct scalar_acosh_op { |
66 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { |
67 | return static_cast<T>(std::acosh(a)); |
68 | } |
69 | }; |
70 | template <typename T> |
71 | struct functor_traits<scalar_acosh_op<T>> { |
72 | enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; |
73 | }; |
74 | |
75 | template <typename T> |
76 | struct scalar_atanh_op { |
77 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { |
78 | return static_cast<T>(std::atanh(a)); |
79 | } |
80 | }; |
81 | template <typename T> |
82 | struct functor_traits<scalar_atanh_op<T>> { |
83 | enum { Cost = 5 * NumTraits<T>::MulCost, PacketAccess = false }; |
84 | }; |
85 | #endif |
86 | |
87 | template <typename Scalar, typename Exponent> |
88 | struct safe_scalar_binary_pow_op { |
89 | static_assert(std::is_integral<Scalar>::value, "Integer type expected" ); |
90 | static_assert(std::is_integral<Exponent>::value && |
91 | std::is_signed<Exponent>::value, |
92 | "Signed integer type expected" ); |
93 | |
94 | bool* const error; |
95 | |
96 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_scalar_binary_pow_op(bool* error) |
97 | : error(error) {} |
98 | |
99 | EIGEN_DEVICE_FUNC inline Scalar operator()(const Scalar& a, |
100 | const Exponent& b) const { |
101 | const Exponent safe_b = tensorflow::internal::SubtleMustCopy(b); |
102 | if (TF_PREDICT_TRUE(safe_b >= 0)) { |
103 | return numext::pow(a, safe_b); |
104 | } else { |
105 | *error = true; |
106 | return 0; |
107 | } |
108 | } |
109 | }; |
110 | |
111 | template <typename Scalar, typename Exponent> |
112 | struct functor_traits<safe_scalar_binary_pow_op<Scalar, Exponent>> { |
113 | enum { Cost = 5 * NumTraits<Scalar>::MulCost, PacketAccess = false }; |
114 | }; |
115 | |
116 | template <typename T, typename DivOrMod> |
117 | struct safe_div_or_mod_op { |
118 | static_assert(std::is_integral<T>::value, "Integer type expected" ); |
119 | |
120 | bool* const error; |
121 | |
122 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE safe_div_or_mod_op(bool* error) |
123 | : error(error) {} |
124 | |
125 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, |
126 | const T& b) const { |
127 | const T safe_b = tensorflow::internal::SubtleMustCopy(b); |
128 | if (TF_PREDICT_TRUE(safe_b != 0)) { |
129 | // Avoid FPE for INT_MIN/-1. |
130 | const T safe_a = tensorflow::internal::SubtleMustCopy(a); |
131 | if (TF_PREDICT_FALSE(std::is_signed<T>::value && |
132 | safe_a == std::numeric_limits<T>::min() && |
133 | safe_b == T(-1))) { |
134 | // Prefer to overflow 'a' instead of crashing. |
135 | return DivOrMod()(-safe_a, 1); |
136 | } |
137 | return DivOrMod()(safe_a, safe_b); |
138 | } else { |
139 | *error = true; |
140 | return 0; |
141 | } |
142 | } |
143 | }; |
144 | |
145 | template <typename T, typename DivOrMod> |
146 | struct functor_traits<safe_div_or_mod_op<T, DivOrMod>> { |
147 | enum { |
148 | Cost = functor_traits<DivOrMod>::Cost + NumTraits<T>::AddCost, |
149 | PacketAccess = false, |
150 | }; |
151 | }; |
152 | |
153 | template <typename T, typename Binary> |
154 | struct no_nan_op { |
155 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, |
156 | const T& b) const { |
157 | if (b != T(0)) { |
158 | return Binary()(a, b); |
159 | } else { |
160 | return T(0); |
161 | } |
162 | } |
163 | template <typename Packet> |
164 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, |
165 | const Packet& b) const { |
166 | const Packet mask = pcmp_eq(b, pzero(b)); |
167 | const Packet quotient = Binary().packetOp(a, b); |
168 | return pandnot(quotient, mask); |
169 | } |
170 | }; |
171 | |
172 | template <typename T, bool IsComplex = Eigen::NumTraits<T>::IsComplex> |
173 | struct div_no_nan_op; |
174 | |
175 | template <typename T> |
176 | struct div_no_nan_op<T, /*IsComplex=*/false> |
177 | : public no_nan_op<T, scalar_quotient_op<T>> { |
178 | }; |
179 | |
180 | template <typename T> |
181 | struct functor_traits<div_no_nan_op<T, /*IsComplex=*/false>> { |
182 | enum { |
183 | Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::AddCost, |
184 | PacketAccess = true, |
185 | }; |
186 | }; |
187 | |
188 | // Whether or not complex division produces a NaN depends on the underlying |
189 | // implementation. Some compilers (e.g. gcc) use a simple method that divides |
190 | // by |b|^2, which may underflow to 0 for b != 0. |
191 | template <typename T> |
192 | struct div_no_nan_op<T, /*IsComplex=*/true> { |
193 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a, |
194 | const T& b) const { |
195 | if (b == T(0)) { |
196 | return T(0); |
197 | } else { |
198 | // If the numerator is zero, then the result must be zero even if |b|^2 |
199 | // underflows to zero. |
200 | const T numerator = |
201 | scalar_product_op<T>()(a, scalar_conjugate_op<T>()(b)); |
202 | if (numerator == T(0)) { |
203 | return T(0); |
204 | } |
205 | } |
206 | return scalar_quotient_op<T>()(a, b); |
207 | } |
208 | template <typename Packet> |
209 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, |
210 | const Packet& b) const { |
211 | const Packet numerator = pmul(a, pconj(b)); |
212 | const Packet mask = por(pcmp_eq(b, pzero(a)), pcmp_eq(numerator, pzero(a))); |
213 | const Packet quotient = pdiv(a, b); |
214 | return pandnot(quotient, mask); |
215 | } |
216 | }; |
217 | |
218 | template <typename T> |
219 | struct functor_traits<div_no_nan_op<T, /*IsComplex=*/true>> { |
220 | enum { |
221 | Cost = functor_traits<scalar_quotient_op<T>>::Cost + NumTraits<T>::MulCost, |
222 | PacketAccess = packet_traits<T>::HasMul && packet_traits<T>::HasDiv && |
223 | packet_traits<T>::HasConj, |
224 | }; |
225 | }; |
226 | |
227 | template <typename T> |
228 | struct mul_no_nan_op : public no_nan_op<T, scalar_product_op<T>> { |
229 | }; |
230 | |
231 | template <typename T> |
232 | struct functor_traits<mul_no_nan_op<T>> { |
233 | enum { |
234 | Cost = functor_traits<scalar_product_op<T>>::Cost + NumTraits<T>::AddCost, |
235 | PacketAccess = true, |
236 | }; |
237 | }; |
238 | |
239 | // scalar_left and scalar_right are template helpers to partially |
240 | // apply a binary function. |
241 | // |
242 | // Suppose Binary is a binary functor f(x, y), scalar_left<> is a |
243 | // unary functor g_x(y) = f(x, y), where x is provided via the |
244 | // constructor. Similarly, scalar_right<> is a unary functor g_y(x) = |
245 | // f(x, y). |
246 | |
247 | template <typename Tout, typename Tin, typename Binary, |
248 | bool is_scalar_in_host_memory = false> |
249 | struct scalar_left : private Binary { |
250 | using result_type = Tout; |
251 | using TinPacket = typename Eigen::internal::packet_traits<Tin>::type; |
252 | |
253 | const Tin* left; |
254 | TinPacket left_packet; // initialized iff is_scalar_in_host_memory == true |
255 | |
256 | inline scalar_left(const scalar_left& other) = default; |
257 | |
258 | template <typename... Args> |
259 | EIGEN_DEVICE_FUNC inline explicit scalar_left(const Tin* c, Args... args) |
260 | : Binary(args...), left(c) { |
261 | if (is_scalar_in_host_memory) { |
262 | left_packet = Eigen::internal::pset1<TinPacket>(*left); |
263 | } |
264 | } |
265 | |
266 | EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& right) const { |
267 | return Binary::operator()(*left, right); |
268 | } |
269 | |
270 | template <typename Packet, |
271 | typename std::enable_if<!is_scalar_in_host_memory || |
272 | !std::is_same<TinPacket, Packet>::value, |
273 | int>::type = 0> |
274 | EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { |
275 | const Packet left_packet = Eigen::internal::pset1<Packet>(*left); |
276 | return Binary::packetOp(left_packet, right_packet); |
277 | } |
278 | |
279 | template <typename Packet, |
280 | typename std::enable_if<is_scalar_in_host_memory && |
281 | std::is_same<TinPacket, Packet>::value, |
282 | int>::type = 0> |
283 | EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& right_packet) const { |
284 | return Binary::packetOp(left_packet, right_packet); |
285 | } |
286 | }; |
287 | |
288 | template <typename Tout, typename Tin, typename Binary, |
289 | bool is_scalar_in_host_memory> |
290 | struct functor_traits< |
291 | scalar_left<Tout, Tin, Binary, is_scalar_in_host_memory>> { |
292 | enum { |
293 | Cost = functor_traits<Binary>::Cost, |
294 | PacketAccess = functor_traits<Binary>::PacketAccess, |
295 | }; |
296 | }; |
297 | |
298 | template <typename Tout, typename Tin, typename Binary, |
299 | bool is_scalar_in_host_memory = false> |
300 | struct scalar_right : private Binary { |
301 | using result_type = Tout; |
302 | using TinPacket = typename Eigen::internal::packet_traits<Tin>::type; |
303 | |
304 | const Tin* right; |
305 | TinPacket right_packet; // initialized iff is_scalar_in_host_memory == true |
306 | |
307 | inline scalar_right(const scalar_right& other) = default; |
308 | |
309 | template <typename... Args> |
310 | EIGEN_DEVICE_FUNC inline explicit scalar_right(const Tin* c, Args... args) |
311 | : Binary(args...), right(c) { |
312 | if (is_scalar_in_host_memory) { |
313 | right_packet = Eigen::internal::pset1<TinPacket>(*right); |
314 | } |
315 | } |
316 | |
317 | EIGEN_DEVICE_FUNC inline Tout operator()(const Tin& left) const { |
318 | return Binary::operator()(left, *right); |
319 | } |
320 | |
321 | template <typename Packet, |
322 | typename std::enable_if<!is_scalar_in_host_memory || |
323 | !std::is_same<TinPacket, Packet>::value, |
324 | int>::type = 0> |
325 | EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { |
326 | const Packet right_packet = Eigen::internal::pset1<Packet>(*right); |
327 | return Binary::packetOp(left_packet, right_packet); |
328 | } |
329 | |
330 | template <typename Packet, |
331 | typename std::enable_if<is_scalar_in_host_memory && |
332 | std::is_same<TinPacket, Packet>::value, |
333 | int>::type = 0> |
334 | EIGEN_DEVICE_FUNC inline Packet packetOp(const Packet& left_packet) const { |
335 | return Binary::packetOp(left_packet, right_packet); |
336 | } |
337 | }; |
338 | |
339 | template <typename Tout, typename Tin, typename Binary, |
340 | bool is_scalar_in_host_memory> |
341 | struct functor_traits< |
342 | scalar_right<Tout, Tin, Binary, is_scalar_in_host_memory>> { |
343 | enum { |
344 | Cost = functor_traits<Binary>::Cost, |
345 | PacketAccess = functor_traits<Binary>::PacketAccess, |
346 | }; |
347 | }; |
348 | |
349 | // similar to std::equal_to, but with the DEVICE_FUNC qualifier |
350 | template <class T> |
351 | struct equal_to : std::function<bool(T, T)> { |
352 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, |
353 | const T& y) const { |
354 | return x == y; |
355 | } |
356 | }; |
357 | |
358 | // similar to std::not_equal_to, but with the DEVICE_FUNC qualifier |
359 | template <class T> |
360 | struct not_equal_to : std::function<bool(T, T)> { |
361 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, |
362 | const T& y) const { |
363 | return x != y; |
364 | } |
365 | }; |
366 | |
367 | // similar to std::greater, but with the DEVICE_FUNC qualifier |
368 | template <class T> |
369 | struct greater : std::function<bool(T, T)> { |
370 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, |
371 | const T& y) const { |
372 | return x > y; |
373 | } |
374 | }; |
375 | |
376 | // similar to std::less, but with the DEVICE_FUNC qualifier |
377 | template <class T> |
378 | struct less : std::function<bool(T, T)> { |
379 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, |
380 | const T& y) const { |
381 | return x < y; |
382 | } |
383 | }; |
384 | |
385 | // similar to std::greater_equal, but with the DEVICE_FUNC qualifier |
386 | template <class T> |
387 | struct greater_equal : std::function<bool(T, T)> { |
388 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, |
389 | const T& y) const { |
390 | return x >= y; |
391 | } |
392 | }; |
393 | |
394 | // similar to std::less_equal, but with the DEVICE_FUNC qualifier |
395 | template <class T> |
396 | struct less_equal : std::function<bool(T, T)> { |
397 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE bool operator()(const T& x, |
398 | const T& y) const { |
399 | return x <= y; |
400 | } |
401 | }; |
402 | |
403 | // Functor that enables squared difference functor. |
404 | template <typename Scalar> |
405 | struct scalar_squared_difference_op { |
406 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
407 | operator()(const Scalar& a, const Scalar& b) const { |
408 | const Scalar v = scalar_difference_op<Scalar>()(a, b); |
409 | return scalar_product_op<Scalar>()(v, scalar_conjugate_op<Scalar>()(v)); |
410 | } |
411 | template <typename Packet> |
412 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, |
413 | const Packet& b) const { |
414 | const Packet v = scalar_difference_op<Scalar>().packetOp(a, b); |
415 | return scalar_product_op<Scalar>().packetOp( |
416 | v, scalar_conjugate_op<Scalar>().packetOp(v)); |
417 | } |
418 | }; |
419 | |
420 | template <typename Scalar> |
421 | struct functor_traits<scalar_squared_difference_op<Scalar>> { |
422 | enum { |
423 | Cost = functor_traits<scalar_difference_op<Scalar>>::Cost + |
424 | functor_traits<scalar_conjugate_op<Scalar>>::Cost + |
425 | functor_traits<scalar_product_op<Scalar>>::Cost, |
426 | PacketAccess = functor_traits<scalar_difference_op<Scalar>>::PacketAccess && |
427 | functor_traits<scalar_conjugate_op<Scalar>>::PacketAccess && |
428 | functor_traits<scalar_product_op<Scalar>>::PacketAccess |
429 | }; |
430 | }; |
431 | |
432 | // TODO(b/32239616): This kernel should be moved into Eigen and vectorized. |
433 | template <typename T, typename Enable = void> |
434 | struct google_floor_div { |
435 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
436 | const T& y) const { |
437 | const T z = x / y; |
438 | // Subtract one if there is a remainder and if the inputs have opposite |
439 | // signs. This approach avoids unnecessary overflows. |
440 | return z * y != x && (x < T(0) != y < T(0)) ? z - T(1) : z; |
441 | } |
442 | template <typename Packet> |
443 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, |
444 | const Packet& y) const { |
445 | Packet zeros = pzero(x); |
446 | Packet x_mask = pcmp_lt(x, zeros); |
447 | Packet y_mask = pcmp_lt(y, zeros); |
448 | Packet x_div_y = pdiv(x, y); |
449 | Packet x_div_y_times_y = pmul(x_div_y, y); |
450 | return pselect(por(peq(x_div_y_times_y, x), peq(x_mask, y_mask)), x_div_y, |
451 | psub(x_div_y, pones(x))); |
452 | } |
453 | }; |
454 | |
455 | template <typename T> |
456 | struct google_floor_div< |
457 | T, typename std::enable_if<std::is_unsigned<T>::value>::type> { |
458 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
459 | const T& y) const { |
460 | return x / y; |
461 | } |
462 | template <typename Packet> |
463 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, |
464 | const Packet& y) const { |
465 | return pdiv(x, y); |
466 | } |
467 | }; |
468 | |
469 | template <typename Scalar> |
470 | struct functor_traits<google_floor_div<Scalar>> { |
471 | enum { |
472 | Cost = 2 * Eigen::internal::scalar_div_cost< |
473 | Scalar, packet_traits<Scalar>::HasDiv>::value + |
474 | NumTraits<Scalar>::AddCost, |
475 | PacketAccess = packet_traits<Scalar>::HasDiv |
476 | }; |
477 | }; |
478 | |
479 | template <typename T, typename Enable = void> |
480 | struct google_floor_div_real { |
481 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
482 | const T& y) const { |
483 | return Eigen::numext::floor(x / y); |
484 | } |
485 | template <typename Packet> |
486 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, |
487 | const Packet& y) const { |
488 | return pfloor(pdiv(x, y)); |
489 | } |
490 | }; |
491 | |
492 | template <typename Scalar> |
493 | struct functor_traits<google_floor_div_real<Scalar>> { |
494 | enum { |
495 | Cost = 2 * Eigen::internal::scalar_div_cost< |
496 | Scalar, packet_traits<Scalar>::HasDiv>::value + |
497 | 2 * NumTraits<Scalar>::AddCost, |
498 | PacketAccess = |
499 | packet_traits<Scalar>::HasDiv && packet_traits<Scalar>::HasFloor |
500 | }; |
501 | }; |
502 | |
503 | // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. |
504 | template <typename T> |
505 | struct google_floor_fmod { |
506 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
507 | const T& y) const { |
508 | // EIGEN_STATIC_ASSERT(NUMERIC_TYPE_MUST_BE_REAL); |
509 | T trunc_mod = scalar_fmod_op<T>()(x, y); |
510 | return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y |
511 | : trunc_mod; |
512 | } |
513 | }; |
514 | |
515 | template <typename Scalar> |
516 | struct functor_traits<google_floor_fmod<Scalar>> { |
517 | enum { |
518 | Cost = functor_traits<Eigen::internal::scalar_fmod_op<Scalar>>::Cost + |
519 | NumTraits<Scalar>::AddCost, |
520 | PacketAccess = false |
521 | }; |
522 | }; |
523 | |
524 | // TODO(rmlarsen): Add vectorized mod & fmod in Eigen and use it here. |
525 | template <typename T> |
526 | struct google_floor_mod { |
527 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
528 | const T& y) const { |
529 | // EIGEN_STATIC_ASSERT(!NUMERIC_TYPE_MUST_BE_REAL); |
530 | T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(x, y); |
531 | return trunc_mod != T(0) && (y < T(0) != trunc_mod < T(0)) ? trunc_mod + y |
532 | : trunc_mod; |
533 | } |
534 | }; |
535 | |
536 | template <typename Scalar> |
537 | struct functor_traits<google_floor_mod<Scalar>> { |
538 | enum { |
539 | Cost = functor_traits<Eigen::internal::scalar_mod2_op<Scalar>>::Cost + |
540 | NumTraits<Scalar>::AddCost, |
541 | PacketAccess = false |
542 | }; |
543 | }; |
544 | |
545 | #if EIGEN_COMP_GNUC && __cplusplus > 199711L |
546 | #define DISABLE_FLOAT_EQUALITY_WARNING \ |
547 | _Pragma("GCC diagnostic push") \ |
548 | _Pragma("GCC diagnostic ignored \"-Wfloat-equal\"") |
549 | #define ENABLE_FLOAT_EQUALITY_WARNING _Pragma("GCC diagnostic pop") |
550 | #else |
551 | #define DISABLE_FLOAT_EQUALITY_WARNING |
552 | #define ENABLE_FLOAT_EQUALITY_WARNING |
553 | #endif |
554 | |
555 | template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger, |
556 | bool HasRint = packet_traits<Scalar>::HasRint> |
557 | struct scalar_round_half_to_even_op { |
558 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
559 | operator()(const Scalar& x) const { |
560 | EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), |
561 | NUMERIC_TYPE_MUST_BE_REAL) |
562 | |
563 | const Scalar round_val = Eigen::numext::floor(x + Scalar(0.5)); |
564 | const Scalar fraction = round_val - x; |
565 | if (TF_PREDICT_FALSE(fraction == Scalar(.5))) { |
566 | return Scalar(2) * Eigen::numext::floor(Scalar(.5) * x + Scalar(0.5)); |
567 | } else { |
568 | return round_val; |
569 | } |
570 | } |
571 | |
572 | template <typename Packet> |
573 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { |
574 | Packet half = pset1<Packet>(Scalar(0.5)); |
575 | Packet round_val = pfloor(padd(x, half)); |
576 | Packet fraction = psub(round_val, x); |
577 | Packet half_mask = pcmp_eq(fraction, half); |
578 | bool any_halves = predux_any(half_mask); |
579 | if (TF_PREDICT_FALSE(any_halves)) { |
580 | Packet two = pset1<Packet>(Scalar(2)); |
581 | Packet nearest_even = pmul(two, pfloor(pmadd(half, x, half))); |
582 | return pselect(half_mask, nearest_even, round_val); |
583 | } else { |
584 | return round_val; |
585 | } |
586 | } |
587 | }; |
588 | |
589 | template <typename Scalar> |
590 | struct scalar_round_half_to_even_op<Scalar, true, false> { |
591 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
592 | operator()(const Scalar& x) const { |
593 | return x; |
594 | } |
595 | template <typename Packet> |
596 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { |
597 | return x; |
598 | } |
599 | }; |
600 | |
601 | template <typename Scalar> |
602 | struct scalar_round_half_to_even_op<Scalar, false, true> { |
603 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
604 | operator()(const Scalar& x) const { |
605 | return Eigen::numext::rint(x); |
606 | } |
607 | template <typename Packet> |
608 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { |
609 | return print(x); |
610 | } |
611 | }; |
612 | |
613 | template <typename Scalar> |
614 | struct functor_traits<scalar_round_half_to_even_op<Scalar>> { |
615 | enum { |
616 | Cost = Eigen::NumTraits<Scalar>::IsInteger ? 0 |
617 | : 4 * NumTraits<Scalar>::AddCost, |
618 | PacketAccess = packet_traits<Scalar>::HasFloor && |
619 | packet_traits<Scalar>::HasAdd && |
620 | packet_traits<Scalar>::HasMul, |
621 | }; |
622 | }; |
623 | |
624 | template <typename Scalar, bool IsInteger = Eigen::NumTraits<Scalar>::IsInteger> |
625 | struct scalar_round_up_op { |
626 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
627 | operator()(const Scalar& x) const { |
628 | EIGEN_STATIC_ASSERT((!NumTraits<Scalar>::IsComplex), |
629 | NUMERIC_TYPE_MUST_BE_REAL) |
630 | return Eigen::numext::floor(x + Scalar(0.5)); |
631 | } |
632 | |
633 | template <typename Packet> |
634 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { |
635 | return pfloor(padd(x, pset1<Packet>(0.5))); |
636 | } |
637 | }; |
638 | |
639 | template <typename Scalar> |
640 | struct scalar_round_up_op<Scalar, true> { |
641 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
642 | operator()(const Scalar& x) const { |
643 | return x; |
644 | } |
645 | |
646 | template <typename Packet> |
647 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { |
648 | return x; |
649 | } |
650 | }; |
651 | |
652 | template <typename Scalar, bool IsInteger> |
653 | struct functor_traits<scalar_round_up_op<Scalar, IsInteger>> { |
654 | enum { |
655 | Cost = IsInteger ? 0 : 4 * NumTraits<Scalar>::AddCost, |
656 | PacketAccess = IsInteger || packet_traits<Scalar>::HasFloor |
657 | }; |
658 | }; |
659 | |
660 | #undef ENABLE_FLOAT_EQUALITY_WARNING |
661 | #undef DISABLE_FLOAT_EQUALITY_WARNING |
662 | |
663 | template <typename Scalar> |
664 | struct bitwise_xor_op { |
665 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
666 | operator()(const Scalar& x, const Scalar& y) const { |
667 | return x ^ y; |
668 | } |
669 | template <typename Packet> |
670 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& a, |
671 | const Packet& b) const { |
672 | return Eigen::internal::pxor(a, b); |
673 | } |
674 | }; |
675 | |
676 | template <typename Scalar> |
677 | struct functor_traits<bitwise_xor_op<Scalar>> { |
678 | enum { Cost = Eigen::NumTraits<Scalar>::AddCost, PacketAccess = true }; |
679 | }; |
680 | |
681 | template <typename Scalar> |
682 | struct xlogy_op { |
683 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
684 | operator()(const Scalar& x, const Scalar& y) const { |
685 | if (x == Scalar(0.)) { |
686 | return Scalar(0.); |
687 | } |
688 | return x * numext::log(y); |
689 | } |
690 | template <typename Packet> |
691 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, |
692 | const Packet& y) const { |
693 | Packet zeros = pzero(x); |
694 | Packet mask = pcmp_eq(x, zeros); |
695 | scalar_log_op<Scalar> log_op; |
696 | Packet log_y = log_op.packetOp(y); |
697 | Packet x_log_y = pmul(x, log_y); |
698 | return pselect(mask, x, x_log_y); |
699 | } |
700 | }; |
701 | |
702 | template <typename Scalar> |
703 | struct functor_traits<xlogy_op<Scalar>> { |
704 | enum { |
705 | Cost = functor_traits<scalar_log_op<Scalar>>::Cost + |
706 | Eigen::NumTraits<Scalar>::MulCost, |
707 | PacketAccess = functor_traits<scalar_log_op<Scalar>>::PacketAccess |
708 | }; |
709 | }; |
710 | |
711 | template <typename Scalar> |
712 | struct xlog1py_op { |
713 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
714 | operator()(const Scalar& x, const Scalar& y) const { |
715 | if (x == Scalar(0.)) { |
716 | return Scalar(0.); |
717 | } |
718 | return x * numext::log1p(y); |
719 | } |
720 | template <typename Packet> |
721 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, |
722 | const Packet& y) const { |
723 | Packet zeros = pzero(x); |
724 | Packet mask = pcmp_eq(x, zeros); |
725 | scalar_log1p_op<Scalar> log1p_op; |
726 | Packet log1p_y = log1p_op.packetOp(y); |
727 | Packet x_log1p_y = pmul(x, log1p_y); |
728 | return pselect(mask, x, x_log1p_y); |
729 | } |
730 | }; |
731 | |
732 | template <typename Scalar> |
733 | struct functor_traits<xlog1py_op<Scalar>> { |
734 | enum { |
735 | Cost = functor_traits<scalar_log1p_op<Scalar>>::Cost + |
736 | Eigen::NumTraits<Scalar>::MulCost, |
737 | #if TENSORFLOW_USE_ROCM |
738 | PacketAccess = false, |
739 | #else |
740 | PacketAccess = functor_traits<scalar_log1p_op<Scalar>>::PacketAccess |
741 | #endif |
742 | }; |
743 | }; |
744 | |
745 | template <typename Scalar> |
746 | struct xdivy_op { |
747 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
748 | operator()(const Scalar& x, const Scalar& y) const { |
749 | if (x == Scalar(0.)) { |
750 | return Scalar(0.); |
751 | } |
752 | return x / y; |
753 | } |
754 | template <typename Packet> |
755 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x, |
756 | const Packet& y) const { |
757 | Packet zeros = pzero(x); |
758 | Packet mask = pcmp_eq(x, zeros); |
759 | Packet x_div_y = pdiv(x, y); |
760 | return pselect(mask, x, x_div_y); |
761 | } |
762 | }; |
763 | |
764 | template <typename Scalar> |
765 | struct functor_traits<xdivy_op<Scalar>> { |
766 | enum { |
767 | Cost = |
768 | Eigen::NumTraits<Scalar>::AddCost + |
769 | Eigen::internal::scalar_div_cost<Scalar, |
770 | packet_traits<Scalar>::HasDiv>::value, |
771 | PacketAccess = packet_traits<Scalar>::HasDiv |
772 | }; |
773 | }; |
774 | |
775 | template <typename T> |
776 | struct scalar_erfinv_op { |
777 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x) const { |
778 | constexpr T half = T(0.5); |
779 | T y = numext::ndtri(half * x + half); |
780 | constexpr T half_sqrt = T(M_SQRT1_2); |
781 | return y * half_sqrt; |
782 | } |
783 | template <typename Packet> |
784 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Packet packetOp(const Packet& x) const { |
785 | Packet half = pset1<Packet>(T(0.5)); |
786 | Packet y = pndtri<Packet>(pmadd(half, x, half)); |
787 | Packet half_sqrt = pset1<Packet>(T(M_SQRT1_2)); |
788 | return pmul(y, half_sqrt); |
789 | } |
790 | }; |
791 | |
792 | template <typename T> |
793 | struct functor_traits<scalar_erfinv_op<T>> { |
794 | enum { |
795 | Cost = functor_traits<scalar_ndtri_op<T>>::Cost + NumTraits<T>::AddCost, |
796 | PacketAccess = packet_traits<T>::HasNdtri, |
797 | }; |
798 | }; |
799 | |
800 | } // end namespace internal |
801 | } // end namespace Eigen |
802 | |
803 | namespace tensorflow { |
804 | namespace functor { |
805 | |
806 | //////////////////////////////////////////////////////////////////////////////// |
807 | // Helpers |
808 | //////////////////////////////////////////////////////////////////////////////// |
809 | |
810 | // Base template for functors whose input scalar type is T and |
811 | // output scalar type is R. |
812 | template <typename T, typename F, typename R = T> |
813 | struct base { |
814 | // func defines operator() and its vectorized version packetOp(). |
815 | typedef F func; |
816 | |
817 | // If true, the functor's corresponding binary op will instantiate |
818 | // specialized kernels to perform an optimized broadcast |
819 | // operation. Each functor for which this is enabled increases the |
820 | // code size, so by default this is disabled for binary functors and |
821 | // is enabled on a per-op basis as needed. |
822 | static constexpr bool use_bcast_optimization = false; |
823 | |
824 | // operator() has the signature: |
825 | // out_type operator()(in_type in0, in_type in1 ...) |
826 | typedef R out_type; |
827 | typedef T in_type; |
828 | |
829 | // TensorFlow provides tensor-ized version of "func". Roughly |
830 | // speaking, the tensorflow operation has the signature: |
831 | // tout_type op(tin_type in0) |
832 | // tout_type op(tin_type in0, tin_type in1) |
833 | // tout_type op(tin_type in0, in_type scalar) |
834 | typedef typename TTypes<out_type>::Flat tout_type; |
835 | typedef typename TTypes<in_type>::ConstFlat tin_type; |
836 | typedef typename TTypes<in_type>::ConstScalar tscalar_type; |
837 | |
838 | // Whether the functor can error out. Currently applies only to integer |
839 | // div and mod. |
840 | static constexpr bool has_errors = false; |
841 | }; |
842 | |
843 | // For now, we only apply certain speed optimization for |
844 | // float/double's broadcast binary op. |
845 | template <typename T> |
846 | struct use_bcast_optimization { |
847 | static constexpr bool value = false; |
848 | }; |
849 | |
850 | template <> |
851 | struct use_bcast_optimization<float> { |
852 | static constexpr bool value = true; |
853 | }; |
854 | |
855 | template <> |
856 | struct use_bcast_optimization<double> { |
857 | static constexpr bool value = true; |
858 | }; |
859 | |
860 | //////////////////////////////////////////////////////////////////////////////// |
861 | // Unary functors |
862 | //////////////////////////////////////////////////////////////////////////////// |
863 | |
864 | // abs(x) = |x| |
865 | // neg(x) = - x |
866 | // inverse(x) = 1 / x |
867 | // square(x) = x^2 |
868 | // sqrt(x) = x^(1/2) |
869 | // rsqrt(x) = x^(-1/2) |
870 | // exp(x) = e^x |
871 | // expm1(x) = e^x - 1 |
872 | // log(x) = natural logarithm of x |
873 | // log1p(x) = natural logarithm of 1 + x |
874 | // tanh = (exp(x) - exp(-x)) / (exp(x) + exp(-x)) |
875 | // sigmoid = 1 / (1 + exp(-x)) // a.k.a, logistic |
876 | // |
877 | // NOTE: We may eventually implement common functions used in NN |
878 | // here. E.g., rectifier, softplus, derivatives of tanh, sigmod, etc. |
879 | // For reference, see speech/lstm/eigen_functors.h. |
880 | |
881 | template <typename T> |
882 | struct abs : base<T, Eigen::internal::scalar_abs_op<T>, |
883 | typename Eigen::internal::scalar_abs_op<T>::result_type> {}; |
884 | |
885 | template <typename T> |
886 | struct neg : base<T, Eigen::internal::scalar_opposite_op<T>> {}; |
887 | |
888 | template <typename T> |
889 | struct inverse : base<T, Eigen::internal::scalar_inverse_op<T>> {}; |
890 | |
891 | template <typename T> |
892 | struct square : base<T, Eigen::internal::scalar_square_op<T>> {}; |
893 | |
894 | template <typename T> |
895 | struct sqrt : base<T, Eigen::internal::scalar_sqrt_op<T>> {}; |
896 | |
897 | template <typename T> |
898 | struct rsqrt : base<T, Eigen::internal::scalar_rsqrt_op<T>> {}; |
899 | |
900 | template <typename T> |
901 | struct exp : base<T, Eigen::internal::scalar_exp_op<T>> {}; |
902 | |
903 | template <typename T> |
904 | struct expm1 : base<T, Eigen::internal::scalar_expm1_op<T>> {}; |
905 | |
906 | template <typename T> |
907 | struct log : base<T, Eigen::internal::scalar_log_op<T>> {}; |
908 | |
909 | template <typename T> |
910 | struct log1p : base<T, Eigen::internal::scalar_log1p_op<T>> {}; |
911 | |
912 | template <typename T> |
913 | struct sign : base<T, Eigen::internal::scalar_sign_op<T>> {}; |
914 | |
915 | template <typename T> |
916 | struct sinh : base<T, Eigen::internal::scalar_sinh_op<T>> {}; |
917 | |
918 | template <typename T> |
919 | struct cosh : base<T, Eigen::internal::scalar_cosh_op<T>> {}; |
920 | |
921 | template <typename T> |
922 | struct tanh : base<T, Eigen::internal::scalar_tanh_op<T>> {}; |
923 | |
924 | template <typename T> |
925 | struct asinh : base<T, Eigen::internal::scalar_asinh_op<T>> {}; |
926 | |
927 | template <typename T> |
928 | struct acosh : base<T, Eigen::internal::scalar_acosh_op<T>> {}; |
929 | |
930 | template <typename T> |
931 | struct atanh : base<T, Eigen::internal::scalar_atanh_op<T>> {}; |
932 | |
933 | template <typename T> |
934 | struct lgamma : base<T, Eigen::internal::scalar_lgamma_op<T>> {}; |
935 | |
936 | template <typename T> |
937 | struct digamma : base<T, Eigen::internal::scalar_digamma_op<T>> {}; |
938 | |
939 | template <typename T> |
940 | struct erf : base<T, Eigen::internal::scalar_erf_op<T>> {}; |
941 | |
942 | template <typename T> |
943 | struct erfc : base<T, Eigen::internal::scalar_erfc_op<T>> {}; |
944 | |
945 | template <typename T> |
946 | struct ndtri : base<T, Eigen::internal::scalar_ndtri_op<T>> {}; |
947 | |
948 | template <typename T> |
949 | struct erfinv : base<T, Eigen::internal::scalar_erfinv_op<T>> {}; |
950 | |
951 | template <typename T> |
952 | struct sigmoid : base<T, Eigen::internal::scalar_logistic_op<T>> {}; |
953 | |
954 | template <typename T> |
955 | struct sin : base<T, Eigen::internal::scalar_sin_op<T>> {}; |
956 | |
957 | template <typename T> |
958 | struct cos : base<T, Eigen::internal::scalar_cos_op<T>> {}; |
959 | |
960 | template <typename T> |
961 | struct tan : base<T, Eigen::internal::scalar_tan_op<T>> {}; |
962 | |
963 | template <typename T> |
964 | struct asin : base<T, Eigen::internal::scalar_asin_op<T>> {}; |
965 | |
966 | template <typename T> |
967 | struct acos : base<T, Eigen::internal::scalar_acos_op<T>> {}; |
968 | |
969 | template <typename T> |
970 | struct atan : base<T, Eigen::internal::scalar_atan_op<T>> {}; |
971 | |
972 | struct logical_not : base<bool, Eigen::internal::scalar_boolean_not_op<bool>> { |
973 | }; |
974 | |
975 | // Flip all bits. Named invert to be consistent with numpy. |
976 | template <typename T> |
977 | struct invert_op { |
978 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& a) const { |
979 | return ~a; |
980 | } |
981 | }; |
982 | |
983 | template <typename T> |
984 | struct invert : base<T, invert_op<T>> {}; |
985 | |
986 | // NOTE: std::isinf, std::isnan, std::isfinite are plain function. |
987 | // Therefore we need to wrap them in functors to be used with Eigen's |
988 | // type system. |
989 | template <typename T> |
990 | struct isinf : base<T, Eigen::internal::scalar_isinf_op<T>, bool> {}; |
991 | |
992 | template <typename T> |
993 | struct isnan : base<T, Eigen::internal::scalar_isnan_op<T>, bool> {}; |
994 | |
995 | template <typename T> |
996 | struct isfinite : base<T, Eigen::internal::scalar_isfinite_op<T>, bool> {}; |
997 | |
998 | template <typename T> |
999 | struct floor : base<T, Eigen::internal::scalar_floor_op<T>> {}; |
1000 | |
1001 | template <typename T> |
1002 | struct round : base<T, Eigen::internal::scalar_round_half_to_even_op<T>> {}; |
1003 | |
1004 | template <typename T> |
1005 | struct ceil : base<T, Eigen::internal::scalar_ceil_op<T>> {}; |
1006 | |
1007 | // Note: rint rounds half values to even, just like round_half_to_even_op. |
1008 | template <typename T> |
1009 | struct rint : base<T, Eigen::internal::scalar_rint_op<T>> {}; |
1010 | |
1011 | //////////////////////////////////////////////////////////////////////////////// |
1012 | // Binary functors |
1013 | //////////////////////////////////////////////////////////////////////////////// |
1014 | |
1015 | // Binary functors: |
1016 | // |
1017 | // add(x, y) = x + y |
1018 | // sub(x, y) = x - y |
1019 | // mul(x, y) = x * y |
1020 | // div(x, y) = x / y |
1021 | // mod(x, y) = x % y (int32 and int64 only) |
1022 | // fmod(x, y) = fmod(x, y) (float and double only) |
1023 | // pow(x, y) = x ^ y |
1024 | // maximum(x, y) = x > y ? x : y |
1025 | // minimum(x, y) = x < y ? x : y |
1026 | // squared_difference(x, y) = conj(x - y) * (x - y) |
1027 | |
1028 | template <typename T> |
1029 | struct add : base<T, Eigen::internal::scalar_sum_op<T>> { |
1030 | static constexpr bool use_bcast_optimization = true; |
1031 | }; |
1032 | |
1033 | template <typename T> |
1034 | struct sub : base<T, Eigen::internal::scalar_difference_op<T>> { |
1035 | static constexpr bool use_bcast_optimization = true; |
1036 | }; |
1037 | |
1038 | template <typename T> |
1039 | struct mul : base<T, Eigen::internal::scalar_product_op<T>> { |
1040 | static constexpr bool use_bcast_optimization = true; |
1041 | }; |
1042 | |
1043 | template <typename T> |
1044 | struct mul_no_nan : base<T, Eigen::internal::mul_no_nan_op<T>> {}; |
1045 | |
1046 | template <typename T> |
1047 | struct div : base<T, Eigen::internal::scalar_quotient_op<T>> {}; |
1048 | |
1049 | template <typename T> |
1050 | struct safe_div : base<T, Eigen::internal::safe_div_or_mod_op< |
1051 | T, Eigen::internal::scalar_quotient_op<T>>> { |
1052 | static constexpr bool has_errors = true; |
1053 | }; |
1054 | |
1055 | template <typename T> |
1056 | struct div_no_nan : base<T, Eigen::internal::div_no_nan_op<T>> {}; |
1057 | |
1058 | template <typename T> |
1059 | struct fmod : base<T, Eigen::internal::scalar_fmod_op<T>> {}; |
1060 | |
1061 | template <typename T> |
1062 | struct mod : base<T, Eigen::internal::scalar_mod2_op<T>> {}; |
1063 | |
1064 | template <typename T> |
1065 | struct safe_mod : base<T, Eigen::internal::safe_div_or_mod_op< |
1066 | T, Eigen::internal::scalar_mod2_op<T>>> { |
1067 | static constexpr bool has_errors = true; |
1068 | }; |
1069 | |
1070 | template <typename T> |
1071 | struct floor_fmod : base<T, Eigen::internal::google_floor_fmod<T>> {}; |
1072 | |
1073 | template <typename T> |
1074 | struct safe_floor_mod : base<T, Eigen::internal::safe_div_or_mod_op< |
1075 | T, Eigen::internal::google_floor_mod<T>>> { |
1076 | static constexpr bool has_errors = true; |
1077 | }; |
1078 | |
1079 | template <typename T> |
1080 | struct floor_div : base<T, Eigen::internal::google_floor_div<T>> {}; |
1081 | |
1082 | template <typename T> |
1083 | struct safe_floor_div : base<T, Eigen::internal::safe_div_or_mod_op< |
1084 | T, Eigen::internal::google_floor_div<T>>> { |
1085 | static constexpr bool has_errors = true; |
1086 | }; |
1087 | |
1088 | template <typename T> |
1089 | struct floor_div_real : base<T, Eigen::internal::google_floor_div_real<T>> {}; |
1090 | |
1091 | template <typename T> |
1092 | struct pow : base<T, Eigen::internal::scalar_pow_op<T, T>> {}; |
1093 | |
1094 | template <typename T> |
1095 | struct safe_pow : base<T, Eigen::internal::safe_scalar_binary_pow_op<T, T>> { |
1096 | static constexpr bool has_errors = true; |
1097 | }; |
1098 | |
1099 | // Version of safe_pow for integers which returns 0 if RHS is negative and LHS |
1100 | // is not 1 or -1. For use on GPUs, where we cannot raise an error. |
1101 | template <typename T> |
1102 | struct safe_pow_ignore_error_op { |
1103 | static_assert(std::is_integral<T>::value, "Integer type expected" ); |
1104 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
1105 | const T& y) const { |
1106 | if (TF_PREDICT_FALSE(y < 0)) { |
1107 | if (x == T(-1)) { |
1108 | T trunc_mod = Eigen::internal::scalar_mod2_op<T>()(y, T(2)); |
1109 | return trunc_mod == T(-1) ? T(-1) : T(1); |
1110 | } |
1111 | return x == T(1) ? T(1) : T(0); |
1112 | } |
1113 | return Eigen::internal::scalar_pow_op<T, T>{}(x, y); |
1114 | } |
1115 | }; |
1116 | |
1117 | template <typename T> |
1118 | struct safe_pow_ignore_error : base<T, safe_pow_ignore_error_op<T>> {}; |
1119 | |
1120 | template <typename T> |
1121 | struct maximum |
1122 | : base<T, Eigen::internal::scalar_max_op<T, T, Eigen::PropagateNaN>> {}; |
1123 | |
1124 | template <typename T> |
1125 | struct minimum |
1126 | : base<T, Eigen::internal::scalar_min_op<T, T, Eigen::PropagateNaN>> {}; |
1127 | |
1128 | template <typename T> |
1129 | struct igamma : base<T, Eigen::internal::scalar_igamma_op<T>> {}; |
1130 | |
1131 | template <typename T> |
1132 | struct random_gamma_grad |
1133 | : base<T, Eigen::internal::scalar_gamma_sample_der_alpha_op<T>> {}; |
1134 | |
1135 | template <typename T> |
1136 | struct igammac : base<T, Eigen::internal::scalar_igammac_op<T>> {}; |
1137 | |
1138 | template <typename T> |
1139 | struct zeta : base<T, Eigen::internal::scalar_zeta_op<T>> {}; |
1140 | |
1141 | template <typename T> |
1142 | struct polygamma : base<T, Eigen::internal::scalar_polygamma_op<T>> {}; |
1143 | |
1144 | template <typename Scalar> |
1145 | struct scalar_atan2_op { |
1146 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE Scalar |
1147 | operator()(const Scalar& y, const Scalar& x) const { |
1148 | #if TENSORFLOW_USE_ROCM |
1149 | return static_cast<Scalar>(::atan2(y, x)); |
1150 | #else |
1151 | return static_cast<Scalar>(std::atan2(y, x)); |
1152 | #endif |
1153 | } |
1154 | }; |
1155 | |
1156 | template <typename T> |
1157 | struct atan2 : base<T, scalar_atan2_op<T>> {}; |
1158 | |
1159 | template <typename T> |
1160 | struct squared_difference |
1161 | : base<T, Eigen::internal::scalar_squared_difference_op<T>> {}; |
1162 | |
1163 | template <typename T> |
1164 | struct xdivy : base<T, Eigen::internal::xdivy_op<T>> {}; |
1165 | |
1166 | template <typename T> |
1167 | struct xlogy : base<T, Eigen::internal::xlogy_op<T>> {}; |
1168 | |
1169 | template <typename T> |
1170 | struct xlog1py : base<T, Eigen::internal::xlog1py_op<T>> {}; |
1171 | |
1172 | template <typename T> |
1173 | struct less : base<T, Eigen::internal::less<T>, bool> {}; |
1174 | |
1175 | template <typename T> |
1176 | struct less_equal : base<T, Eigen::internal::less_equal<T>, bool> {}; |
1177 | |
1178 | template <typename T> |
1179 | struct greater : base<T, Eigen::internal::greater<T>, bool> {}; |
1180 | |
1181 | template <typename T> |
1182 | struct greater_equal : base<T, Eigen::internal::greater_equal<T>, bool> {}; |
1183 | |
1184 | template <typename T> |
1185 | struct equal_to : base<T, Eigen::internal::equal_to<T>, bool> {}; |
1186 | |
1187 | template <typename T> |
1188 | struct not_equal_to : base<T, Eigen::internal::not_equal_to<T>, bool> {}; |
1189 | |
1190 | struct logical_and : base<bool, Eigen::internal::scalar_boolean_and_op> {}; |
1191 | |
1192 | struct logical_or : base<bool, Eigen::internal::scalar_boolean_or_op> {}; |
1193 | |
1194 | template <typename T> |
1195 | struct bitwise_and_op { |
1196 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
1197 | const T& y) const { |
1198 | return x & y; |
1199 | } |
1200 | }; |
1201 | |
1202 | template <typename T> |
1203 | struct bitwise_or_op { |
1204 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
1205 | const T& y) const { |
1206 | return x | y; |
1207 | } |
1208 | }; |
1209 | |
1210 | template <typename T> |
1211 | struct bitwise_and : base<T, bitwise_and_op<T>> {}; |
1212 | |
1213 | template <typename T> |
1214 | struct bitwise_or : base<T, bitwise_or_op<T>> {}; |
1215 | |
1216 | template <typename T> |
1217 | struct bitwise_xor : base<T, Eigen::internal::bitwise_xor_op<T>> {}; |
1218 | |
1219 | template <typename T> |
1220 | struct left_shift_op { |
1221 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
1222 | const T& y) const { |
1223 | // Avoids UB: don't shift by larger than the bitwidth of T, and |
1224 | // performs left shifts as unsigned shifts. |
1225 | T y_clamped = y; |
1226 | if (y_clamped < 0) { |
1227 | y_clamped = 0; |
1228 | } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { |
1229 | y_clamped = sizeof(T) * CHAR_BIT - 1; |
1230 | } |
1231 | using U = typename std::make_unsigned<T>::type; |
1232 | return static_cast<T>(static_cast<U>(x) << static_cast<U>(y_clamped)); |
1233 | } |
1234 | }; |
1235 | |
1236 | template <typename T> |
1237 | struct right_shift_op { |
1238 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE T operator()(const T& x, |
1239 | const T& y) const { |
1240 | // Avoids UB: don't shift by larger than the bitwidth of T. |
1241 | T y_clamped = y; |
1242 | if (y_clamped < 0) { |
1243 | y_clamped = 0; |
1244 | } else if (y_clamped > sizeof(T) * CHAR_BIT - 1) { |
1245 | y_clamped = sizeof(T) * CHAR_BIT - 1; |
1246 | } |
1247 | // Technically right shifts of signed integers are not necessarily |
1248 | // arithmetic shifts according to the C++ standard. However in practice most |
1249 | // implementations are arithmetic shifts. If this proves to be a problem in |
1250 | // practice, we may need to use an alternative implementation. |
1251 | return x >> y_clamped; |
1252 | } |
1253 | }; |
1254 | |
1255 | template <typename T> |
1256 | struct left_shift : base<T, left_shift_op<T>> {}; |
1257 | |
1258 | template <typename T> |
1259 | struct right_shift : base<T, right_shift_op<T>> {}; |
1260 | |
1261 | template <typename T> |
1262 | struct make_complex_func { |
1263 | typedef std::complex<T> result_type; |
1264 | EIGEN_DEVICE_FUNC EIGEN_STRONG_INLINE result_type operator()(T real, |
1265 | T imag) const { |
1266 | return std::complex<T>(real, imag); |
1267 | } |
1268 | }; |
1269 | |
1270 | template <typename T> |
1271 | struct make_complex : base<T, make_complex_func<T>, std::complex<T>> {}; |
1272 | |
1273 | template <typename T> |
1274 | struct get_real |
1275 | : base<T, Eigen::internal::scalar_real_op<T>, typename T::value_type> {}; |
1276 | |
1277 | template <typename T> |
1278 | struct get_imag |
1279 | : base<T, Eigen::internal::scalar_imag_op<T>, typename T::value_type> {}; |
1280 | |
1281 | template <typename T> |
1282 | struct get_angle |
1283 | : base<T, Eigen::internal::scalar_arg_op<T>, typename T::value_type> {}; |
1284 | |
1285 | template <typename T> |
1286 | struct conj : base<T, Eigen::internal::scalar_conjugate_op<T>> {}; |
1287 | |
1288 | //////////////////////////////////////////////////////////////////////////////// |
1289 | // Functors takes 1 or 2 tensors, computes the base functor on |
1290 | // coefficient of the input tensors and puts the results in the output |
1291 | // tensor. |
1292 | //////////////////////////////////////////////////////////////////////////////// |
1293 | template <typename Device, typename Functor> |
1294 | struct UnaryFunctor { |
1295 | // Computes on device "d": out[i] = Functor(in[i]) |
1296 | void operator()(const Device& d, typename Functor::tout_type out, |
1297 | typename Functor::tin_type in); |
1298 | }; |
1299 | |
1300 | template <typename Device, typename Functor, typename Targ> |
1301 | struct UnaryFunctorWithArg { |
1302 | // Computes on device "d": out[i] = Functor(in[i]) |
1303 | void operator()(const Device& d, typename Functor::tout_type out, |
1304 | typename Functor::tin_type in, Targ val); |
1305 | }; |
1306 | |
1307 | template <typename Device, typename Functor, int NDIMS, |
1308 | bool has_errors = Functor::has_errors> |
1309 | struct BinaryFunctor { |
1310 | // Computes on device "d": out[i] = Functor(in0[i], in1[i]) |
1311 | void operator()(const Device& d, typename Functor::tout_type out, |
1312 | typename Functor::tin_type in0, |
1313 | typename Functor::tin_type in1, bool* error); |
1314 | |
1315 | // Computes on device "d": out[i] = Functor(scalar[0], in[i]) |
1316 | void Left(const Device& d, typename Functor::tout_type out, |
1317 | typename Functor::tscalar_type scalar, |
1318 | typename Functor::tin_type in, bool* error); |
1319 | |
1320 | // Computes on device "d": out[i] = Functor(in[i], scalar[0]) |
1321 | void Right(const Device& d, typename Functor::tout_type out, |
1322 | typename Functor::tin_type in, |
1323 | typename Functor::tscalar_type scalar, bool* error); |
1324 | |
1325 | // Computes on device "d": |
1326 | // out = Functor(in0.broadcast(bcast0), in1.broadcast(bcast1)) |
1327 | // |
1328 | // TODO(zhifengc): makes BCast a template member function on NDIMS |
1329 | // instead making BinaryFunctor templates on NDIMS. |
1330 | void BCast(const Device& d, |
1331 | typename TTypes<typename Functor::out_type, NDIMS>::Tensor out, |
1332 | typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in0, |
1333 | typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast0, |
1334 | typename TTypes<typename Functor::in_type, NDIMS>::ConstTensor in1, |
1335 | typename Eigen::array<Eigen::DenseIndex, NDIMS> bcast1, |
1336 | bool* error); |
1337 | }; |
1338 | |
1339 | template <typename Device, typename T> |
1340 | struct ApproximateEqual { |
1341 | void operator()(const Device& d, typename TTypes<T>::ConstFlat x, |
1342 | typename TTypes<T>::ConstFlat y, T tolerance, |
1343 | typename TTypes<bool>::Flat z); |
1344 | }; |
1345 | |
1346 | template <int NDIMS> |
1347 | bool AllOne(const typename Eigen::array<Eigen::DenseIndex, NDIMS>& a) { |
1348 | for (size_t i = 0; i < a.size(); ++i) { |
1349 | if (a[i] != 1) return false; |
1350 | } |
1351 | return true; |
1352 | } |
1353 | |
1354 | template <typename Device, typename T> |
1355 | struct SelectFunctor { |
1356 | void operator()(const Device& d, typename TTypes<T>::Flat out, |
1357 | typename TTypes<bool>::ConstFlat cond_flat, |
1358 | typename TTypes<T>::ConstFlat then_flat, |
1359 | typename TTypes<T>::ConstFlat else_flat); |
1360 | }; |
1361 | |
1362 | template <typename Device, typename T> |
1363 | struct SelectScalarFunctor { |
1364 | void operator()(const Device& d, typename TTypes<T>::Flat out, |
1365 | typename TTypes<bool>::ConstScalar cond, |
1366 | typename TTypes<T>::ConstFlat then_flat, |
1367 | typename TTypes<T>::ConstFlat else_flat); |
1368 | }; |
1369 | |
1370 | template <typename Device, typename T> |
1371 | struct BatchSelectFunctor { |
1372 | void operator()(const Device& d, |
1373 | typename TTypes<T>::Matrix output_flat_outer_dims, |
1374 | TTypes<bool>::ConstVec cond_vec, |
1375 | typename TTypes<T>::ConstMatrix then_flat_outer_dims, |
1376 | typename TTypes<T>::ConstMatrix else_flat_outer_dims); |
1377 | }; |
1378 | |
1379 | template <typename Device, typename T, int NDIMS> |
1380 | struct BCastSelectFunctor { |
1381 | void operator()(const Device& d, |
1382 | typename TTypes<T, NDIMS>::Tensor output_tensor, |
1383 | typename TTypes<bool, NDIMS>::ConstTensor cond_tensor, |
1384 | typename TTypes<T, NDIMS>::ConstTensor then_tensor, |
1385 | typename TTypes<T, NDIMS>::ConstTensor else_tensor, |
1386 | typename Eigen::array<Eigen::DenseIndex, NDIMS> cond_bcast, |
1387 | typename Eigen::array<Eigen::DenseIndex, NDIMS> then_bcast, |
1388 | typename Eigen::array<Eigen::DenseIndex, NDIMS> else_bcast); |
1389 | }; |
1390 | |
1391 | } // end namespace functor |
1392 | } // end namespace tensorflow |
1393 | |
1394 | #endif // TENSORFLOW_CORE_KERNELS_CWISE_OPS_H_ |
1395 | |