1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ |
18 | |
19 | // Specialization of GatherNdSlice to CPU |
20 | |
21 | #define EIGEN_USE_THREADS |
22 | |
23 | #include <atomic> |
24 | |
25 | #include "tensorflow/core/framework/bounds_check.h" |
26 | #include "tensorflow/core/framework/op_kernel.h" |
27 | #include "tensorflow/core/framework/register_types.h" |
28 | #include "tensorflow/core/framework/tensor.h" |
29 | #include "tensorflow/core/kernels/gather_nd_op.h" |
30 | #include "tensorflow/core/platform/logging.h" |
31 | #include "tensorflow/core/platform/mem.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | #include "tensorflow/core/util/util.h" |
34 | |
35 | namespace tensorflow { |
36 | |
37 | typedef Eigen::ThreadPoolDevice CPUDevice; |
38 | |
39 | namespace generator { |
40 | |
41 | template <typename T, typename Index, int IXDIM> |
42 | class GatherNdSliceGenerator { |
43 | public: |
44 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator( |
45 | const Index slice_size, typename TTypes<Index>::ConstMatrix Tindices, |
46 | typename TTypes<T, IXDIM + 1>::ConstTensor Tparams, |
47 | typename TTypes<T>::Matrix Tout, std::atomic<Index>* error_loc) |
48 | : slice_size_(slice_size), |
49 | Tindices_(Tindices), |
50 | Tparams_(Tparams), |
51 | Tout_(Tout), |
52 | error_loc_(error_loc) {} |
53 | |
54 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices( |
55 | const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const { |
56 | (*ix)[IXDIM] = 0; |
57 | bool out_of_bounds = false; |
58 | for (int i = 0; i < IXDIM; ++i) { |
59 | const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i)); |
60 | (*ix)[i] = ix_i; |
61 | out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i)); |
62 | } |
63 | return out_of_bounds; |
64 | } |
65 | |
66 | EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32 |
67 | operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const { |
68 | const Index loc = loc_array[0]; |
69 | Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix; |
70 | Eigen::array<Eigen::DenseIndex, 2> ix_out; |
71 | ix_out[0] = loc; |
72 | ix_out[1] = 0; |
73 | const bool out_of_bounds = GenerateIndices(loc, &ix); |
74 | if (TF_PREDICT_FALSE(out_of_bounds)) { |
75 | error_loc_->store(loc); |
76 | std::fill_n(&Tout_(ix_out), slice_size_, T()); |
77 | } else { |
78 | std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out)); |
79 | } |
80 | |
81 | return static_cast<int32>(0); // Return something... |
82 | } |
83 | |
84 | private: |
85 | const Index slice_size_; |
86 | const typename TTypes<Index>::ConstMatrix Tindices_; |
87 | const typename TTypes<T, IXDIM + 1>::ConstTensor Tparams_; |
88 | mutable typename TTypes<T>::Matrix Tout_; |
89 | std::atomic<Index>* error_loc_; |
90 | }; |
91 | |
92 | } // namespace generator |
93 | |
94 | namespace functor { |
95 | |
96 | template <typename T, typename Index, int IXDIM> |
97 | struct GatherNdSlice<CPUDevice, T, Index, IXDIM> { |
98 | Index operator()(const CPUDevice& d, const Index slice_size, |
99 | typename TTypes<int32>::Scalar Tscratch, |
100 | typename TTypes<T, IXDIM + 1>::ConstTensor Tparams, |
101 | typename TTypes<Index>::ConstMatrix Tindices, |
102 | typename TTypes<T>::Matrix Tout) { |
103 | std::atomic<Index> error_loc(-1); |
104 | const Eigen::Index batch_size = Tindices.dimension(0); |
105 | generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator( |
106 | slice_size, Tindices, Tparams, Tout, &error_loc); |
107 | |
108 | auto compute_shard = [&](Eigen::Index begin, Eigen::Index end) { |
109 | for (Eigen::Index i = begin; i < end; ++i) { |
110 | const Eigen::array<Eigen::Index, 1> loc{i}; |
111 | gather_nd_generator(loc); |
112 | } |
113 | }; |
114 | Eigen::Index bytes_moved = sizeof(T) * (slice_size + IXDIM); |
115 | auto cost = Eigen::TensorOpCost(bytes_moved /* bytes loaded */, |
116 | bytes_moved /* bytes stored */, |
117 | slice_size + IXDIM /* compute cycles */); |
118 | d.parallelFor(batch_size, cost, compute_shard); |
119 | |
120 | // error_loc() returns -1 if there's no out-of-bounds index, |
121 | // otherwise it returns the location of an OOB index in Tindices. |
122 | return error_loc.load(); |
123 | } |
124 | }; |
125 | |
126 | #define REGISTER_GATHER_ND_FULL(T, Index) \ |
127 | template Index GatherNdSlice<CPUDevice, T, Index, CPU_PROVIDED_IXDIM>:: \ |
128 | operator()(const CPUDevice& d, const Index slice_size, \ |
129 | typename TTypes<int32>::Scalar Tscratch, \ |
130 | typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::ConstTensor Tparams, \ |
131 | typename TTypes<Index>::ConstMatrix Tindices, \ |
132 | typename TTypes<T>::Matrix Tout); |
133 | |
134 | #define REGISTER_GATHER_ND_CPU(type) \ |
135 | REGISTER_GATHER_ND_FULL(type, int32); \ |
136 | REGISTER_GATHER_ND_FULL(type, int64) |
137 | |
138 | TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU); |
139 | TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU); |
140 | |
141 | } // namespace functor |
142 | |
143 | } // namespace tensorflow |
144 | |
145 | #endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_ |
146 | |