1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
17#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20
21#include "tensorflow/core/framework/bounds_check.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor_types.h"
24#include "tensorflow/core/framework/type_traits.h"
25#include "tensorflow/core/framework/variant.h"
26#include "tensorflow/core/platform/prefetch.h"
27#include "tensorflow/core/platform/types.h"
28#include "tensorflow/core/util/work_sharder.h"
29
30namespace tensorflow {
31typedef Eigen::ThreadPoolDevice CPUDevice;
32typedef Eigen::GpuDevice GPUDevice;
33
34namespace functor {
35
36// Helper method to copy using memcpy.
37template <typename T, typename Index, typename SliceIndex,
38 SliceIndex static_slice_elems>
39SliceIndex HandleCopies(OpKernelContext* ctx,
40 typename TTypes<T, 3>::ConstTensor params,
41 typename TTypes<Index>::ConstFlat indices,
42 SliceIndex slice_elems,
43 typename TTypes<T, 3>::Tensor out) {
44 const SliceIndex indices_size = static_cast<SliceIndex>(indices.dimension(0));
45 const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
46 const Index limit = static_cast<Index>(params.dimension(1));
47 T* out_base = out.data();
48 const T* params_base = params.data();
49 if (static_slice_elems >= 0) {
50 // Give compiler static knowledge of the number of elements/bytes
51 slice_elems = static_slice_elems;
52 }
53 // Compute slice_bytes here so that static knowledge is available
54 const size_t slice_bytes = slice_elems * sizeof(T);
55 auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
56 mutex mu;
57 // Store the value of invalidate index for printing error information, it's a
58 // shared variable.
59 SliceIndex result = -1;
60 auto work = [&](int64_t start, int64_t end) {
61 SliceIndex batch_idx = static_cast<SliceIndex>(start / indices_size);
62 SliceIndex indices_idx = static_cast<SliceIndex>(start % indices_size);
63 SliceIndex batch_idx_end = static_cast<SliceIndex>(end / indices_size);
64 SliceIndex indices_idx_end = static_cast<SliceIndex>(end % indices_size);
65
66 while ((batch_idx < batch_idx_end) ||
67 (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) {
68 SliceIndex i_next = indices_idx + 1;
69 SliceIndex b_next = batch_idx + 1;
70 const Index index = internal::SubtleMustCopy(indices(indices_idx));
71 if (!FastBoundsCheck(index, limit)) {
72 mutex_lock l(mu);
73 result = indices_idx;
74 return;
75 }
76 if ((batch_idx == batch_idx_end && i_next < indices_idx_end) ||
77 (i_next < indices_size)) {
78 port::prefetch<port::PREFETCH_HINT_T0>(
79 &params(batch_idx, indices(i_next), 0));
80 port::prefetch<port::PREFETCH_HINT_T0>(&out(batch_idx, i_next, 0));
81 b_next = batch_idx;
82 } else if (b_next <= batch_idx_end) {
83 port::prefetch<port::PREFETCH_HINT_T0>(&params(b_next, indices(0), 0));
84 port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0));
85 i_next = 0;
86 }
87 // Copy using memcpy if possible, otherwise an Eigen loop
88 // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
89 // ahead-of-time compilation binary size).
90 if (is_simple_type<T>::value) {
91 // Avoid auto-promotion to Index from SliceIndex by casting.
92 memcpy(
93 out_base + (batch_idx * indices_size + indices_idx) * slice_elems,
94 params_base + (batch_idx * static_cast<SliceIndex>(limit) +
95 static_cast<SliceIndex>(index)) *
96 slice_elems,
97 slice_bytes);
98 } else {
99 // For non-"simple" types (e.g. strings).
100 out.template chip<0>(batch_idx).template chip<0>(indices_idx) =
101 params.template chip<0>(batch_idx).template chip<0>(index);
102 }
103 indices_idx = i_next;
104 batch_idx = b_next;
105 }
106 };
107
108 Shard(worker_threads->num_threads, worker_threads->workers,
109 batch_size * indices_size, slice_elems * sizeof(T), work);
110 return result;
111}
112
113template <typename T, typename Index>
114struct GatherFunctorCPU {
115 int64_t operator()(OpKernelContext* ctx,
116 typename TTypes<T, 3>::ConstTensor params,
117 typename TTypes<Index>::ConstFlat indices,
118 typename TTypes<T, 3>::Tensor out) {
119 const int64_t indices_size = indices.size();
120 const int64_t slice_size = out.dimension(2);
121 int64_t bad_i;
122
123 const int64_t batch_size = params.dimension(0);
124
125 bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
126 params.size() > std::numeric_limits<int32>::max() ||
127 indices_size > std::numeric_limits<int32>::max() ||
128 batch_size * indices_size * slice_size >
129 std::numeric_limits<int32>::max());
130#define CALL(elems) \
131 do { \
132 if (use_large) { \
133 bad_i = HandleCopies<T, Index, int64_t, elems>(ctx, params, indices, \
134 slice_size, out); \
135 } else { \
136 const int32 small_slice = static_cast<int32>(slice_size); \
137 bad_i = HandleCopies<T, Index, int32, elems>(ctx, params, indices, \
138 small_slice, out); \
139 } \
140 } while (0)
141
142 if (slice_size == 10)
143 CALL(10);
144 else if (slice_size == 20)
145 CALL(20);
146 else
147 CALL(-1);
148#undef CALL
149
150 return bad_i;
151 }
152};
153
154template <typename Device, typename T, typename Index>
155struct GatherFunctor {
156 int64_t operator()(OpKernelContext* ctx,
157 typename TTypes<T, 3>::ConstTensor params,
158 typename TTypes<Index>::ConstFlat indices,
159 typename TTypes<T, 3>::Tensor out);
160};
161
162template <typename T, typename Index>
163struct GatherFunctor<CPUDevice, T, Index> {
164 int64_t operator()(OpKernelContext* ctx,
165 typename TTypes<T, 3>::ConstTensor params,
166 typename TTypes<Index>::ConstFlat indices,
167 typename TTypes<T, 3>::Tensor out) {
168 return GatherFunctorCPU<T, Index>()(ctx, params, indices, out);
169 }
170};
171
172template <typename Index>
173struct GatherFunctor<GPUDevice, Variant, Index> {
174 int64_t operator()(OpKernelContext* ctx,
175 typename TTypes<Variant, 3>::ConstTensor params,
176 typename TTypes<Index>::ConstFlat indices,
177 typename TTypes<Variant, 3>::Tensor out) {
178 return GatherFunctorCPU<Variant, Index>()(ctx, params, indices, out);
179 }
180};
181
182} // namespace functor
183} // namespace tensorflow
184
185#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_
186