1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | |
21 | #include "tensorflow/core/framework/bounds_check.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor_types.h" |
24 | #include "tensorflow/core/framework/type_traits.h" |
25 | #include "tensorflow/core/framework/variant.h" |
26 | #include "tensorflow/core/platform/prefetch.h" |
27 | #include "tensorflow/core/platform/types.h" |
28 | #include "tensorflow/core/util/work_sharder.h" |
29 | |
30 | namespace tensorflow { |
31 | typedef Eigen::ThreadPoolDevice CPUDevice; |
32 | typedef Eigen::GpuDevice GPUDevice; |
33 | |
34 | namespace functor { |
35 | |
36 | // Helper method to copy using memcpy. |
37 | template <typename T, typename Index, typename SliceIndex, |
38 | SliceIndex static_slice_elems> |
39 | SliceIndex HandleCopiesBatched(OpKernelContext* ctx, |
40 | typename TTypes<T, 4>::ConstTensor params, |
41 | typename TTypes<Index>::ConstFlat indices, |
42 | SliceIndex slice_elems, |
43 | typename TTypes<T, 4>::Tensor out) { |
44 | const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0)); |
45 | const SliceIndex outer_size = static_cast<SliceIndex>(params.dimension(1)); |
46 | const SliceIndex indices_size = |
47 | static_cast<SliceIndex>(indices.dimension(0)) / batch_size; |
48 | |
49 | const Index limit = static_cast<Index>(params.dimension(2)); |
50 | if (static_slice_elems >= 0) { |
51 | // Give compiler static knowledge of the number of elements/bytes |
52 | slice_elems = static_slice_elems; |
53 | } |
54 | // Compute slice_bytes here so that static knowledge is available |
55 | const size_t slice_bytes = slice_elems * sizeof(T); |
56 | auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); |
57 | mutex mu; |
58 | // Store the value of invalidate index for printing error information, it's a |
59 | // shared variable. |
60 | SliceIndex result = -1; |
61 | auto work = [&](int64_t start, int64_t end) { |
62 | const int64_t r_start = start % (outer_size * indices_size); |
63 | SliceIndex batch_idx = static_cast<SliceIndex>( |
64 | start / (outer_size * indices_size)); |
65 | SliceIndex outer_idx = static_cast<SliceIndex>(r_start / indices_size); |
66 | SliceIndex indices_idx = static_cast<SliceIndex>(r_start % indices_size); |
67 | |
68 | SliceIndex batch_offset = batch_idx * indices_size; |
69 | for (; start < end; ++start) { |
70 | SliceIndex i_next = indices_idx + 1; |
71 | SliceIndex o_next = outer_idx; |
72 | SliceIndex b_next = batch_idx; |
73 | SliceIndex b_offset_next = batch_offset; |
74 | |
75 | if (i_next >= indices_size) { |
76 | i_next = 0; |
77 | if (++o_next >= outer_size) { |
78 | o_next = 0; |
79 | ++b_next; |
80 | b_offset_next += indices_size; |
81 | } |
82 | } |
83 | if (start + 1 < end) { |
84 | port::prefetch<port::PREFETCH_HINT_T0>( |
85 | ¶ms(b_next, o_next, indices(b_offset_next + i_next), 0)); |
86 | port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, o_next, i_next, 0)); |
87 | } |
88 | const Index index = internal::SubtleMustCopy( |
89 | indices(batch_offset + indices_idx)); |
90 | if (!FastBoundsCheck(index, limit)) { |
91 | mutex_lock l(mu); |
92 | result = batch_offset + indices_idx; |
93 | return; |
94 | } |
95 | |
96 | // Copy using memcpy if possible, otherwise an Eigen loop |
97 | // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve |
98 | // ahead-of-time compilation binary size). |
99 | if (is_simple_type<T>::value) { |
100 | // Avoid auto-promotion to Index from SliceIndex by casting. |
101 | memcpy( |
102 | &out(batch_idx, outer_idx, indices_idx, 0), |
103 | ¶ms(batch_idx, outer_idx, static_cast<SliceIndex>(index), 0), |
104 | slice_bytes); |
105 | } else { |
106 | // For non-"simple" types (e.g. strings). |
107 | out.template chip<0>(batch_idx) |
108 | .template chip<0>(outer_idx) |
109 | .template chip<0>(indices_idx) = |
110 | params.template chip<0>(batch_idx) |
111 | .template chip<0>(outer_idx) |
112 | .template chip<0>(static_cast<SliceIndex>(index)); |
113 | } |
114 | |
115 | indices_idx = i_next; |
116 | outer_idx = o_next; |
117 | batch_idx = b_next; |
118 | batch_offset = b_offset_next; |
119 | } |
120 | }; |
121 | |
122 | Shard(worker_threads->num_threads, worker_threads->workers, |
123 | batch_size * outer_size * indices_size, slice_elems * sizeof(T), work); |
124 | return result; |
125 | } |
126 | |
127 | template <typename T, typename Index> |
128 | struct GatherFunctorBatchedCPU { |
129 | int64_t operator()(OpKernelContext* ctx, |
130 | typename TTypes<T, 4>::ConstTensor params, |
131 | typename TTypes<Index>::ConstFlat indices, |
132 | typename TTypes<T, 4>::Tensor out) { |
133 | const int64_t indices_size = indices.size(); // Includes the batch_size. |
134 | const int64_t slice_size = out.dimension(3); |
135 | int64_t bad_i; |
136 | |
137 | const int64_t batch_size = params.dimension(0); |
138 | const int64_t outer_size = params.dimension(1); |
139 | |
140 | bool use_large = (slice_size > std::numeric_limits<int32>::max() || |
141 | params.size() > std::numeric_limits<int32>::max() || |
142 | indices_size > std::numeric_limits<int32>::max() || |
143 | batch_size * outer_size * indices_size * slice_size > |
144 | std::numeric_limits<int32>::max()); |
145 | #define CALL(elems) \ |
146 | do { \ |
147 | if (use_large) { \ |
148 | bad_i = HandleCopiesBatched<T, Index, int64_t, elems>( \ |
149 | ctx, params, indices, slice_size, out); \ |
150 | } else { \ |
151 | const int32 small_slice = static_cast<int32>(slice_size); \ |
152 | bad_i = HandleCopiesBatched<T, Index, int32, elems>( \ |
153 | ctx, params, indices, small_slice, out); \ |
154 | } \ |
155 | } while (0) |
156 | |
157 | // TODO(rmlarsen): Investigate whether these specializations are still |
158 | // needed and, if yes, whether the slice sizes are appropriate. |
159 | if (slice_size == 10) |
160 | CALL(10); |
161 | else if (slice_size == 20) |
162 | CALL(20); |
163 | else |
164 | CALL(-1); |
165 | #undef CALL |
166 | |
167 | return bad_i; |
168 | } |
169 | }; |
170 | |
171 | template <typename Device, typename T, typename Index> |
172 | struct GatherFunctorBatched { |
173 | int64_t operator()(OpKernelContext* ctx, |
174 | typename TTypes<T, 4>::ConstTensor params, |
175 | typename TTypes<Index>::ConstFlat indices, |
176 | typename TTypes<T, 4>::Tensor out); |
177 | }; |
178 | |
179 | template <typename T, typename Index> |
180 | struct GatherFunctorBatched<CPUDevice, T, Index> { |
181 | int64_t operator()(OpKernelContext* ctx, |
182 | typename TTypes<T, 4>::ConstTensor params, |
183 | typename TTypes<Index>::ConstFlat indices, |
184 | typename TTypes<T, 4>::Tensor out) { |
185 | return GatherFunctorBatchedCPU<T, Index>()(ctx, params, indices, out); |
186 | } |
187 | }; |
188 | |
189 | template <typename Index> |
190 | struct GatherFunctorBatched<GPUDevice, Variant, Index> { |
191 | int64_t operator()(OpKernelContext* ctx, |
192 | typename TTypes<Variant, 4>::ConstTensor params, |
193 | typename TTypes<Index>::ConstFlat indices, |
194 | typename TTypes<Variant, 4>::Tensor out) { |
195 | return GatherFunctorBatchedCPU<Variant, Index>()(ctx, params, indices, out); |
196 | } |
197 | }; |
198 | |
199 | } // namespace functor |
200 | } // namespace tensorflow |
201 | |
202 | #endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_BATCHED_H_ |
203 | |