1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
17#define TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
18
19// Specialization of GatherNdSlice to CPU
20
21#define EIGEN_USE_THREADS
22
23#include <atomic>
24
25#include "tensorflow/core/framework/bounds_check.h"
26#include "tensorflow/core/framework/op_kernel.h"
27#include "tensorflow/core/framework/register_types.h"
28#include "tensorflow/core/framework/tensor.h"
29#include "tensorflow/core/kernels/gather_nd_op.h"
30#include "tensorflow/core/platform/logging.h"
31#include "tensorflow/core/platform/mem.h"
32#include "tensorflow/core/platform/types.h"
33#include "tensorflow/core/util/util.h"
34
35namespace tensorflow {
36
37typedef Eigen::ThreadPoolDevice CPUDevice;
38
39namespace generator {
40
41template <typename T, typename Index, int IXDIM>
42class GatherNdSliceGenerator {
43 public:
44 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE GatherNdSliceGenerator(
45 const Index slice_size, typename TTypes<Index>::ConstMatrix Tindices,
46 typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
47 typename TTypes<T>::Matrix Tout, std::atomic<Index>* error_loc)
48 : slice_size_(slice_size),
49 Tindices_(Tindices),
50 Tparams_(Tparams),
51 Tout_(Tout),
52 error_loc_(error_loc) {}
53
54 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE bool GenerateIndices(
55 const Index loc, Eigen::array<Eigen::DenseIndex, IXDIM + 1>* ix) const {
56 (*ix)[IXDIM] = 0;
57 bool out_of_bounds = false;
58 for (int i = 0; i < IXDIM; ++i) {
59 const Index ix_i = internal::SubtleMustCopy(Tindices_(loc, i));
60 (*ix)[i] = ix_i;
61 out_of_bounds |= !FastBoundsCheck(ix_i, Tparams_.dimension(i));
62 }
63 return out_of_bounds;
64 }
65
66 EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE int32
67 operator()(const Eigen::array<Eigen::DenseIndex, 1>& loc_array) const {
68 const Index loc = loc_array[0];
69 Eigen::array<Eigen::DenseIndex, IXDIM + 1> ix;
70 Eigen::array<Eigen::DenseIndex, 2> ix_out;
71 ix_out[0] = loc;
72 ix_out[1] = 0;
73 const bool out_of_bounds = GenerateIndices(loc, &ix);
74 if (TF_PREDICT_FALSE(out_of_bounds)) {
75 error_loc_->store(loc);
76 std::fill_n(&Tout_(ix_out), slice_size_, T());
77 } else {
78 std::copy_n(&Tparams_(ix), slice_size_, &Tout_(ix_out));
79 }
80
81 return static_cast<int32>(0); // Return something...
82 }
83
84 private:
85 const Index slice_size_;
86 const typename TTypes<Index>::ConstMatrix Tindices_;
87 const typename TTypes<T, IXDIM + 1>::ConstTensor Tparams_;
88 mutable typename TTypes<T>::Matrix Tout_;
89 std::atomic<Index>* error_loc_;
90};
91
92} // namespace generator
93
94namespace functor {
95
96template <typename T, typename Index, int IXDIM>
97struct GatherNdSlice<CPUDevice, T, Index, IXDIM> {
98 Index operator()(const CPUDevice& d, const Index slice_size,
99 typename TTypes<int32>::Scalar Tscratch,
100 typename TTypes<T, IXDIM + 1>::ConstTensor Tparams,
101 typename TTypes<Index>::ConstMatrix Tindices,
102 typename TTypes<T>::Matrix Tout) {
103 std::atomic<Index> error_loc(-1);
104 const Eigen::Index batch_size = Tindices.dimension(0);
105 generator::GatherNdSliceGenerator<T, Index, IXDIM> gather_nd_generator(
106 slice_size, Tindices, Tparams, Tout, &error_loc);
107
108 auto compute_shard = [&](Eigen::Index begin, Eigen::Index end) {
109 for (Eigen::Index i = begin; i < end; ++i) {
110 const Eigen::array<Eigen::Index, 1> loc{i};
111 gather_nd_generator(loc);
112 }
113 };
114 Eigen::Index bytes_moved = sizeof(T) * (slice_size + IXDIM);
115 auto cost = Eigen::TensorOpCost(bytes_moved /* bytes loaded */,
116 bytes_moved /* bytes stored */,
117 slice_size + IXDIM /* compute cycles */);
118 d.parallelFor(batch_size, cost, compute_shard);
119
120 // error_loc() returns -1 if there's no out-of-bounds index,
121 // otherwise it returns the location of an OOB index in Tindices.
122 return error_loc.load();
123 }
124};
125
126#define REGISTER_GATHER_ND_FULL(T, Index) \
127 template Index GatherNdSlice<CPUDevice, T, Index, CPU_PROVIDED_IXDIM>:: \
128 operator()(const CPUDevice& d, const Index slice_size, \
129 typename TTypes<int32>::Scalar Tscratch, \
130 typename TTypes<T, CPU_PROVIDED_IXDIM + 1>::ConstTensor Tparams, \
131 typename TTypes<Index>::ConstMatrix Tindices, \
132 typename TTypes<T>::Matrix Tout);
133
134#define REGISTER_GATHER_ND_CPU(type) \
135 REGISTER_GATHER_ND_FULL(type, int32); \
136 REGISTER_GATHER_ND_FULL(type, int64)
137
138TF_CALL_ALL_TYPES(REGISTER_GATHER_ND_CPU);
139TF_CALL_QUANTIZED_TYPES(REGISTER_GATHER_ND_CPU);
140
141} // namespace functor
142
143} // namespace tensorflow
144
145#endif // TENSORFLOW_CORE_KERNELS_GATHER_ND_OP_CPU_IMPL_H_
146