1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_ |
18 | |
19 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
20 | |
21 | #include "tensorflow/core/framework/bounds_check.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/tensor_types.h" |
24 | #include "tensorflow/core/framework/type_traits.h" |
25 | #include "tensorflow/core/framework/variant.h" |
26 | #include "tensorflow/core/platform/prefetch.h" |
27 | #include "tensorflow/core/platform/types.h" |
28 | #include "tensorflow/core/util/work_sharder.h" |
29 | |
30 | namespace tensorflow { |
31 | typedef Eigen::ThreadPoolDevice CPUDevice; |
32 | typedef Eigen::GpuDevice GPUDevice; |
33 | |
34 | namespace functor { |
35 | |
36 | // Helper method to copy using memcpy. |
37 | template <typename T, typename Index, typename SliceIndex, |
38 | SliceIndex static_slice_elems> |
39 | SliceIndex HandleCopies(OpKernelContext* ctx, |
40 | typename TTypes<T, 3>::ConstTensor params, |
41 | typename TTypes<Index>::ConstFlat indices, |
42 | SliceIndex slice_elems, |
43 | typename TTypes<T, 3>::Tensor out) { |
44 | const SliceIndex indices_size = static_cast<SliceIndex>(indices.dimension(0)); |
45 | const SliceIndex batch_size = static_cast<SliceIndex>(params.dimension(0)); |
46 | const Index limit = static_cast<Index>(params.dimension(1)); |
47 | T* out_base = out.data(); |
48 | const T* params_base = params.data(); |
49 | if (static_slice_elems >= 0) { |
50 | // Give compiler static knowledge of the number of elements/bytes |
51 | slice_elems = static_slice_elems; |
52 | } |
53 | // Compute slice_bytes here so that static knowledge is available |
54 | const size_t slice_bytes = slice_elems * sizeof(T); |
55 | auto* worker_threads = ctx->device()->tensorflow_cpu_worker_threads(); |
56 | mutex mu; |
57 | // Store the value of invalidate index for printing error information, it's a |
58 | // shared variable. |
59 | SliceIndex result = -1; |
60 | auto work = [&](int64_t start, int64_t end) { |
61 | SliceIndex batch_idx = static_cast<SliceIndex>(start / indices_size); |
62 | SliceIndex indices_idx = static_cast<SliceIndex>(start % indices_size); |
63 | SliceIndex batch_idx_end = static_cast<SliceIndex>(end / indices_size); |
64 | SliceIndex indices_idx_end = static_cast<SliceIndex>(end % indices_size); |
65 | |
66 | while ((batch_idx < batch_idx_end) || |
67 | (batch_idx == batch_idx_end && indices_idx < indices_idx_end)) { |
68 | SliceIndex i_next = indices_idx + 1; |
69 | SliceIndex b_next = batch_idx + 1; |
70 | const Index index = internal::SubtleMustCopy(indices(indices_idx)); |
71 | if (!FastBoundsCheck(index, limit)) { |
72 | mutex_lock l(mu); |
73 | result = indices_idx; |
74 | return; |
75 | } |
76 | if ((batch_idx == batch_idx_end && i_next < indices_idx_end) || |
77 | (i_next < indices_size)) { |
78 | port::prefetch<port::PREFETCH_HINT_T0>( |
79 | ¶ms(batch_idx, indices(i_next), 0)); |
80 | port::prefetch<port::PREFETCH_HINT_T0>(&out(batch_idx, i_next, 0)); |
81 | b_next = batch_idx; |
82 | } else if (b_next <= batch_idx_end) { |
83 | port::prefetch<port::PREFETCH_HINT_T0>(¶ms(b_next, indices(0), 0)); |
84 | port::prefetch<port::PREFETCH_HINT_T0>(&out(b_next, 0, 0)); |
85 | i_next = 0; |
86 | } |
87 | // Copy using memcpy if possible, otherwise an Eigen loop |
88 | // TODO(cwhipkey): avoid linking to framework to get Allocator (to improve |
89 | // ahead-of-time compilation binary size). |
90 | if (is_simple_type<T>::value) { |
91 | // Avoid auto-promotion to Index from SliceIndex by casting. |
92 | memcpy( |
93 | out_base + (batch_idx * indices_size + indices_idx) * slice_elems, |
94 | params_base + (batch_idx * static_cast<SliceIndex>(limit) + |
95 | static_cast<SliceIndex>(index)) * |
96 | slice_elems, |
97 | slice_bytes); |
98 | } else { |
99 | // For non-"simple" types (e.g. strings). |
100 | out.template chip<0>(batch_idx).template chip<0>(indices_idx) = |
101 | params.template chip<0>(batch_idx).template chip<0>(index); |
102 | } |
103 | indices_idx = i_next; |
104 | batch_idx = b_next; |
105 | } |
106 | }; |
107 | |
108 | Shard(worker_threads->num_threads, worker_threads->workers, |
109 | batch_size * indices_size, slice_elems * sizeof(T), work); |
110 | return result; |
111 | } |
112 | |
113 | template <typename T, typename Index> |
114 | struct GatherFunctorCPU { |
115 | int64_t operator()(OpKernelContext* ctx, |
116 | typename TTypes<T, 3>::ConstTensor params, |
117 | typename TTypes<Index>::ConstFlat indices, |
118 | typename TTypes<T, 3>::Tensor out) { |
119 | const int64_t indices_size = indices.size(); |
120 | const int64_t slice_size = out.dimension(2); |
121 | int64_t bad_i; |
122 | |
123 | const int64_t batch_size = params.dimension(0); |
124 | |
125 | bool use_large = (slice_size > std::numeric_limits<int32>::max() || |
126 | params.size() > std::numeric_limits<int32>::max() || |
127 | indices_size > std::numeric_limits<int32>::max() || |
128 | batch_size * indices_size * slice_size > |
129 | std::numeric_limits<int32>::max()); |
130 | #define CALL(elems) \ |
131 | do { \ |
132 | if (use_large) { \ |
133 | bad_i = HandleCopies<T, Index, int64_t, elems>(ctx, params, indices, \ |
134 | slice_size, out); \ |
135 | } else { \ |
136 | const int32 small_slice = static_cast<int32>(slice_size); \ |
137 | bad_i = HandleCopies<T, Index, int32, elems>(ctx, params, indices, \ |
138 | small_slice, out); \ |
139 | } \ |
140 | } while (0) |
141 | |
142 | if (slice_size == 10) |
143 | CALL(10); |
144 | else if (slice_size == 20) |
145 | CALL(20); |
146 | else |
147 | CALL(-1); |
148 | #undef CALL |
149 | |
150 | return bad_i; |
151 | } |
152 | }; |
153 | |
154 | template <typename Device, typename T, typename Index> |
155 | struct GatherFunctor { |
156 | int64_t operator()(OpKernelContext* ctx, |
157 | typename TTypes<T, 3>::ConstTensor params, |
158 | typename TTypes<Index>::ConstFlat indices, |
159 | typename TTypes<T, 3>::Tensor out); |
160 | }; |
161 | |
162 | template <typename T, typename Index> |
163 | struct GatherFunctor<CPUDevice, T, Index> { |
164 | int64_t operator()(OpKernelContext* ctx, |
165 | typename TTypes<T, 3>::ConstTensor params, |
166 | typename TTypes<Index>::ConstFlat indices, |
167 | typename TTypes<T, 3>::Tensor out) { |
168 | return GatherFunctorCPU<T, Index>()(ctx, params, indices, out); |
169 | } |
170 | }; |
171 | |
172 | template <typename Index> |
173 | struct GatherFunctor<GPUDevice, Variant, Index> { |
174 | int64_t operator()(OpKernelContext* ctx, |
175 | typename TTypes<Variant, 3>::ConstTensor params, |
176 | typename TTypes<Index>::ConstFlat indices, |
177 | typename TTypes<Variant, 3>::Tensor out) { |
178 | return GatherFunctorCPU<Variant, Index>()(ctx, params, indices, out); |
179 | } |
180 | }; |
181 | |
182 | } // namespace functor |
183 | } // namespace tensorflow |
184 | |
185 | #endif // TENSORFLOW_CORE_KERNELS_GATHER_FUNCTOR_H_ |
186 | |