1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_WHERE_OP_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_WHERE_OP_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | #include "tensorflow/core/framework/op_kernel.h" |
21 | #include "tensorflow/core/framework/tensor_types.h" |
22 | #include "tensorflow/core/platform/macros.h" |
23 | #include "tensorflow/core/platform/types.h" |
24 | |
25 | namespace tensorflow { |
26 | |
27 | #define TF_CALL_WHERE_GPU_TYPES(m) \ |
28 | TF_CALL_int8(m); \ |
29 | TF_CALL_uint8(m); \ |
30 | TF_CALL_int64(m); \ |
31 | TF_CALL_float(m); \ |
32 | TF_CALL_double(m); \ |
33 | TF_CALL_complex64(m); \ |
34 | TF_CALL_complex128(m); \ |
35 | TF_CALL_bool(m); |
36 | |
37 | namespace functor { |
38 | |
39 | template <typename Device, typename T, typename TIndex> |
40 | struct NumTrue { |
41 | EIGEN_ALWAYS_INLINE static Status Compute( |
42 | OpKernelContext* ctx, const Device& d, |
43 | typename TTypes<T>::ConstFlat input, |
44 | typename TTypes<TIndex>::UnalignedScalar num_true); |
45 | }; |
46 | |
47 | template <typename Device, int NDIM, typename T, typename TIndex> |
48 | struct Where { |
49 | // Copies indices of true values in input into output. The pointer |
50 | // found_true should sit on the host. Compute should copy the |
51 | // number of true elements found into it. At the end, if |
52 | // *found_true != output.dimension(0), |
53 | // then the input may have changed between the initial counting of |
54 | // the true values and the call to Where. |
55 | EIGEN_ALWAYS_INLINE static Status Compute( |
56 | OpKernelContext* ctx, const Device& d, |
57 | typename TTypes<T, NDIM>::ConstTensor input, |
58 | typename TTypes<int64_t>::Matrix output, TIndex* found_true); |
59 | }; |
60 | |
61 | } // namespace functor |
62 | |
63 | } // namespace tensorflow |
64 | |
65 | #endif // TENSORFLOW_CORE_KERNELS_WHERE_OP_H_ |
66 | |