1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
17#define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
18
19#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20
21#include "tensorflow/core/framework/bounds_check.h"
22#include "tensorflow/core/framework/op_kernel.h"
23#include "tensorflow/core/framework/tensor_types.h"
24#include "tensorflow/core/framework/type_traits.h"
25#include "tensorflow/core/framework/variant.h"
26#include "tensorflow/core/platform/prefetch.h"
27#include "tensorflow/core/platform/types.h"
28#include "tensorflow/core/util/work_sharder.h"
29
30namespace tensorflow {
31typedef Eigen::ThreadPoolDevice CPUDevice;
32typedef Eigen::GpuDevice GPUDevice;
33
34namespace functor {
35
36// Helper method to copy using memcpy.
37template <typename T, typename Index, typename SliceIndex,
38 SliceIndex static_slice_elems>
39SliceIndex HandleCopiesBatched(OpKernelContext* ctx,
40 typename TTypes<T, 4>::ConstTensor params,
41 typename TTypes<Index>::ConstFlat indices,
42 SliceIndex slice_elems,
43 typename TTypes<T, 4>::Tensor out) {
44 const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0));
45 const SliceIndex outer_size = static_cast<SliceIndex>(params.dimension(1));
46 const SliceIndex indices_size =
47 static_cast<SliceIndex>(indices.dimension(0)) / batch_size;
48
49 const Index limit = static_cast<Index>(params.dimension(2));
50 if (static_slice_elems >= 0) {
51 // Give compiler static knowledge of the number of elements/bytes
52 slice_elems = static_slice_elems;
53 }
54 // Compute slice_bytes here so that static knowledge is available
55 const size_t slice_bytes = slice_elems * sizeof(T);
56 auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads();
57 mutex mu;
58 // Store the value of invalidate index for printing error information, it's a
59 // shared variable.
60 SliceIndex result = -1;
61 auto work = [&](int64_t start, int64_t end) {
62 const int64_t r_start = start % (outer_size * indices_size);
63 SliceIndex batch_idx = static_cast<SliceIndex>(
64 start / (outer_size * indices_size));
65 SliceIndex outer_idx = static_cast<SliceIndex>(r_start / indices_size);
66 SliceIndex indices_idx = static_cast<SliceIndex>(r_start % indices_size);
67
68 SliceIndex batch_offset = batch_idx * indices_size;
69 for (; start < end; ++start) {
70 SliceIndex i_next = indices_idx + 1;
71 SliceIndex o_next = outer_idx;
72 SliceIndex b_next = batch_idx;
73 SliceIndex b_offset_next = batch_offset;
74
75 if (i_next >= indices_size) {
76 i_next = 0;
77 if (++o_next >= outer_size) {
78 o_next = 0;
79 ++b_next;
80 b_offset_next += indices_size;
81 }
82 }
83 if (start + 1 < end) {
84 port::prefetch<port::PREFETCH_HINT_T0>(
85 &params(b_next, o_next, indices(b_offset_next + i_next), 0));
86 port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, o_next, i_next, 0));
87 }
88 const Index index = internal::SubtleMustCopy(
89 indices(batch_offset + indices_idx));
90 if (!FastBoundsCheck(index, limit)) {
91 mutex_lock l(mu);
92 result = batch_offset + indices_idx;
93 return;
94 }
95
96 // Copy using memcpy if possible, otherwise an Eigen loop
97 // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve
98 // ahead-of-time compilation binary size).
99 if (is_simple_type<T>::value) {
100 // Avoid auto-promotion to Index from SliceIndex by casting.
101 memcpy(
102 &out(batch_idx, outer_idx, indices_idx, 0),
103 &params(batch_idx, outer_idx, static_cast<SliceIndex>(index), 0),
104 slice_bytes);
105 } else {
106 // For non-"simple" types (e.g. strings).
107 out.template chip<0>(batch_idx)
108 .template chip<0>(outer_idx)
109 .template chip<0>(indices_idx) =
110 params.template chip<0>(batch_idx)
111 .template chip<0>(outer_idx)
112 .template chip<0>(static_cast<SliceIndex>(index));
113 }
114
115 indices_idx = i_next;
116 outer_idx = o_next;
117 batch_idx = b_next;
118 batch_offset = b_offset_next;
119 }
120 };
121
122 Shard(worker_threads->num_threads, worker_threads->workers,
123 batch_size * outer_size * indices_size, slice_elems * sizeof(T), work);
124 return result;
125}
126
127template <typename T, typename Index>
128struct GatherFunctorBatchedCPU {
129 int64_t operator()(OpKernelContext* ctx,
130 typename TTypes<T, 4>::ConstTensor params,
131 typename TTypes<Index>::ConstFlat indices,
132 typename TTypes<T, 4>::Tensor out) {
133 const int64_t indices_size = indices.size(); // Includes the batch_size.
134 const int64_t slice_size = out.dimension(3);
135 int64_t bad_i;
136
137 const int64_t batch_size = params.dimension(0);
138 const int64_t outer_size = params.dimension(1);
139
140 bool use_large = (slice_size > std::numeric_limits<int32>::max() ||
141 params.size() > std::numeric_limits<int32>::max() ||
142 indices_size > std::numeric_limits<int32>::max() ||
143 batch_size * outer_size * indices_size * slice_size >
144 std::numeric_limits<int32>::max());
145#define CALL(elems) \
146 do { \
147 if (use_large) { \
148 bad_i = HandleCopiesBatched<T, Index, int64_t, elems>( \
149 ctx, params, indices, slice_size, out); \
150 } else { \
151 const int32 small_slice = static_cast<int32>(slice_size); \
152 bad_i = HandleCopiesBatched<T, Index, int32, elems>( \
153 ctx, params, indices, small_slice, out); \
154 } \
155 } while (0)
156
157 // TODO(rmlarsen): Investigate whether these specializations are still
158 // needed and, if yes, whether the slice sizes are appropriate.
159 if (slice_size == 10)
160 CALL(10);
161 else if (slice_size == 20)
162 CALL(20);
163 else
164 CALL(-1);
165#undef CALL
166
167 return bad_i;
168 }
169};
170
171template <typename Device, typename T, typename Index>
172struct GatherFunctorBatched {
173 int64_t operator()(OpKernelContext* ctx,
174 typename TTypes<T, 4>::ConstTensor params,
175 typename TTypes<Index>::ConstFlat indices,
176 typename TTypes<T, 4>::Tensor out);
177};
178
179template <typename T, typename Index>
180struct GatherFunctorBatched<CPUDevice, T, Index> {
181 int64_t operator()(OpKernelContext* ctx,
182 typename TTypes<T, 4>::ConstTensor params,
183 typename TTypes<Index>::ConstFlat indices,
184 typename TTypes<T, 4>::Tensor out) {
185 return GatherFunctorBatchedCPU<T, Index>()(ctx, params, indices, out);
186 }
187};
188
189template <typename Index>
190struct GatherFunctorBatched<GPUDevice, Variant, Index> {
191 int64_t operator()(OpKernelContext* ctx,
192 typename TTypes<Variant, 4>::ConstTensor params,
193 typename TTypes<Index>::ConstFlat indices,
194 typename TTypes<Variant, 4>::Tensor out) {
195 return GatherFunctorBatchedCPU<Variant, Index>()(ctx, params, indices, out);
196 }
197};
198
199} // namespace functor
200} // namespace tensorflow
201
202#endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_
203