1 | /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_EIGEN_CONVOLUTION_HELPERS_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_EIGEN_CONVOLUTION_HELPERS_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | |
21 | namespace Eigen { |
22 | namespace internal { |
23 | |
24 | // TensorEvaluatorHasPartialPacket<TensorEvaluatorType, PacketType, IndexType> |
25 | // provides `value` that is true if TensorEvaluatorType has `PacketType |
26 | // partialPacket<PacketType>(IndexType, unpacket_traits<PacketType>::mask_t) |
27 | // const` and if the PacketType supports masked load. |
28 | // |
29 | // Partial packets are used to: |
30 | // |
31 | // 1) Split the packet over two columns in eigen based spatial convolution and |
32 | // use partial loads for each individual part before combining them to get the |
33 | // required packet. This class is used to pick the correct implementation of |
34 | // loadPacketStandard function. |
35 | // |
36 | // 2) Split the packet over two rows (within the same column) in eigen based |
37 | // cuboid convolution and use partial loads for each individual part before |
38 | // combining them to get the required packet. This class is used to pick the |
39 | // correct implementation of loadPacketStandard function. This usage is similar |
40 | // to the usage in eigen based spatial convolution described above. |
41 | // |
42 | // 3) Finalize packing of columns in gemm_pack_colmajor after processing |
43 | // vectorized part with full packets (see eigen_spatial_convolutions.h). |
44 | template <typename TensorEvaluatorType, typename PacketType, typename IndexType> |
45 | class TensorEvaluatorHasPartialPacket { |
46 | public: |
47 | template <typename TensorEvaluatorT, typename PacketT, typename IndexT> |
48 | static auto functionExistsSfinae( |
49 | typename std::enable_if< |
50 | unpacket_traits<PacketT>::masked_load_available && |
51 | std::is_same<PacketT, |
52 | decltype(std::declval<const TensorEvaluatorT>() |
53 | .template partialPacket<PacketT>( |
54 | std::declval<IndexT>(), |
55 | std::declval<typename unpacket_traits< |
56 | PacketT>::mask_t>()))>::value>:: |
57 | type*) -> std::true_type; |
58 | |
59 | template <typename TensorEvaluatorT, typename PacketT, typename IndexT> |
60 | static auto functionExistsSfinae(...) -> std::false_type; |
61 | |
62 | typedef decltype( |
63 | functionExistsSfinae<TensorEvaluatorType, PacketType, IndexType>( |
64 | nullptr)) status; |
65 | |
66 | static constexpr bool value = status::value; |
67 | }; |
68 | |
69 | // Compute a mask for loading/storing coefficients in/from a packet in a |
70 | // [from, to) range. If the mask bit is 1, element will be loaded/stored. |
71 | template <typename Packet> |
72 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE |
73 | typename std::enable_if<unpacket_traits<Packet>::masked_load_available, |
74 | typename unpacket_traits<Packet>::mask_t>::type |
75 | mask(int from, int to) { |
76 | const Index packet_size = internal::unpacket_traits<Packet>::size; |
77 | eigen_assert(0 <= from && to <= (packet_size + 1) && from < to); |
78 | |
79 | using Mask = typename internal::unpacket_traits<Packet>::mask_t; |
80 | const Mask mask_max = std::numeric_limits<Mask>::max(); |
81 | |
82 | return (mask_max >> (packet_size - to)) ^ (mask_max >> (packet_size - from)); |
83 | } |
84 | |
85 | } // namespace internal |
86 | } // namespace Eigen |
87 | |
88 | #endif // TENSORFLOW_CORE_KERNELS_EIGEN_CONVOLUTION_HELPERS_H_ |
89 | |