1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/data_flow_ops.cc. |
17 | |
18 | #include "tensorflow/core/framework/bounds_check.h" |
19 | #include "tensorflow/core/framework/op_kernel.h" |
20 | #include "tensorflow/core/framework/register_types.h" |
21 | #include "tensorflow/core/framework/tensor.h" |
22 | #include "tensorflow/core/lib/core/threadpool.h" |
23 | |
24 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
25 | #include "tensorflow/core/kernels/gpu_device_array.h" |
26 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
27 | |
28 | namespace tensorflow { |
29 | |
30 | typedef Eigen::ThreadPoolDevice CPUDevice; |
31 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
32 | typedef Eigen::GpuDevice GPUDevice; |
33 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
34 | |
35 | template <class T> |
36 | class DynamicStitchOpImplBase : public OpKernel { |
37 | public: |
38 | explicit DynamicStitchOpImplBase(OpKernelConstruction* c, |
39 | const string& op_name) |
40 | : OpKernel(c) { |
41 | // Compute expected input signature |
42 | const DataType dt = DataTypeToEnum<T>::v(); |
43 | const int n = c->num_inputs() / 2; |
44 | DataTypeVector expected; |
45 | for (int i = 0; i < n; i++) { |
46 | expected.push_back(DT_INT32); |
47 | } |
48 | for (int i = 0; i < n; i++) { |
49 | expected.push_back(dt); |
50 | } |
51 | OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt})); |
52 | OP_REQUIRES(c, c->num_inputs() > 0, |
53 | errors::InvalidArgument(op_name + ": Must have some inputs" )); |
54 | OP_REQUIRES(c, c->num_inputs() % 2 == 0, |
55 | errors::InvalidArgument( |
56 | op_name + ": Must have even number of arguments" )); |
57 | } |
58 | |
59 | protected: |
60 | // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():] |
61 | static bool (const Tensor& data0, const Tensor& indices0, |
62 | const Tensor& data1, const Tensor& indices1) { |
63 | const int = data0.dims() - indices0.dims(); |
64 | const int = data1.dims() - indices1.dims(); |
65 | if (extra0 != extra1) return false; |
66 | for (int i = 0; i < extra0; i++) { |
67 | if (data0.dim_size(indices0.dims() + i) != |
68 | data1.dim_size(indices1.dims() + i)) { |
69 | return false; |
70 | } |
71 | } |
72 | return true; |
73 | } |
74 | |
75 | void CheckArgsAndAllocateResult(OpKernelContext* c, |
76 | OpInputList* indices_inputs, |
77 | OpInputList* data_inputs, int* first_dim_size, |
78 | int* data_elements_size, |
79 | Tensor** result_ptr) { |
80 | // Find maximum index in the indices vectors |
81 | OP_REQUIRES_OK(c, c->input_list("indices" , indices_inputs)); |
82 | |
83 | int32_t max_index = -1; |
84 | if (data_elements_size) { |
85 | *data_elements_size = 0; |
86 | } |
87 | for (const Tensor& indices : *indices_inputs) { |
88 | if (indices.NumElements() > 0) { |
89 | Eigen::Tensor<int32, 0, Eigen::RowMajor> m = |
90 | indices.flat<int32>().maximum(); |
91 | max_index = std::max(m(), max_index); |
92 | } |
93 | if (data_elements_size) { |
94 | *data_elements_size += indices.NumElements(); |
95 | } |
96 | } |
97 | |
98 | *first_dim_size = max_index + 1; |
99 | |
100 | // Validate that data[i].shape = indices[i].shape + constant |
101 | OP_REQUIRES_OK(c, c->input_list("data" , data_inputs)); |
102 | const Tensor& data0 = (*data_inputs)[0]; |
103 | const Tensor& indices0 = (*indices_inputs)[0]; |
104 | for (int input_num = 0; input_num < indices_inputs->size(); input_num++) { |
105 | const Tensor& indices = (*indices_inputs)[input_num]; |
106 | const Tensor& data = (*data_inputs)[input_num]; |
107 | OP_REQUIRES( |
108 | c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()), |
109 | errors::InvalidArgument("data[" , input_num, |
110 | "].shape = " , data.shape().DebugString(), |
111 | " does not start with indices[" , input_num, |
112 | "].shape = " , indices.shape().DebugString())); |
113 | OP_REQUIRES( |
114 | c, input_num == 0 || SameExtraShape(data0, indices0, data, indices), |
115 | errors::InvalidArgument( |
116 | "Need data[0].shape[" , indices0.dims(), ":] = data[" , input_num, |
117 | "].shape[" , indices.dims(), |
118 | ":], got data[0].shape = " , data0.shape().DebugString(), |
119 | ", data[" , input_num, "].shape = " , data.shape().DebugString(), |
120 | ", indices[0].shape = " , indices0.shape().DebugString(), |
121 | ", indices[" , input_num, |
122 | "].shape = " , indices.shape().DebugString())); |
123 | } |
124 | |
125 | // Allocate result tensor of shape |
126 | // [*first_dim_size] + data.shape[indices.dims:] |
127 | TensorShape result_shape; |
128 | result_shape.AddDim(*first_dim_size); |
129 | for (int d = indices0.dims(); d < data0.dims(); d++) { |
130 | result_shape.AddDim(data0.dim_size(d)); |
131 | } |
132 | OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, result_ptr)); |
133 | } |
134 | }; |
135 | |
136 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
137 | |
138 | template <typename T> |
139 | void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device, |
140 | const int32_t slice_size, |
141 | const int32_t first_dim_size, |
142 | const GpuDeviceArrayStruct<int>& input_indices, |
143 | const GpuDeviceArrayStruct<const T*>& input_ptrs, |
144 | T* output); |
145 | #define REGISTER_GPU(T) \ |
146 | extern template void DynamicStitchGPUImpl( \ |
147 | const Eigen::GpuDevice& gpu_device, const int32 slice_size, \ |
148 | const int32 first_dim_size, \ |
149 | const GpuDeviceArrayStruct<int32>& input_indices, \ |
150 | const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output); |
151 | |
152 | TF_CALL_int32(REGISTER_GPU); |
153 | TF_CALL_int64(REGISTER_GPU); |
154 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU); |
155 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU); |
156 | #undef REGISTER_GPU |
157 | |
158 | template <class T> |
159 | class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> { |
160 | public: |
161 | explicit DynamicStitchOpGPU(OpKernelConstruction* c) |
162 | : DynamicStitchOpImplBase<T>(c, "DynamicStitchOp" ) {} |
163 | |
164 | void Compute(OpKernelContext* c) override { |
165 | OpInputList indices_inputs; |
166 | OpInputList data_inputs; |
167 | int first_dim_size; |
168 | int data_elements_size; |
169 | Tensor* merged = nullptr; |
170 | this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, |
171 | &first_dim_size, &data_elements_size, |
172 | &merged); |
173 | if (!c->status().ok()) { |
174 | // Avoid segmentation faults if merged cannot be allocated and an error is |
175 | // passed back in the context. |
176 | return; |
177 | } |
178 | |
179 | // TODO(jeff): Currently we leave uninitialized any portions of |
180 | // merged that aren't covered by an index in indices. What should we do? |
181 | if (first_dim_size > 0) { |
182 | // because the collision requirements, we have to deal with |
183 | // collision first before send data to gpu kernel. |
184 | // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the |
185 | // last of duplicated indices, it could instead be done of the GPU |
186 | // implicitly using atomics to make sure the last index is the final |
187 | // write. |
188 | const int slice_size = merged->flat_outer_dims<T>().dimension(1); |
189 | GpuDeviceArrayOnHost<int32> indices_flat(c, first_dim_size); |
190 | GpuDeviceArrayOnHost<const T*> data_flat(c, data_elements_size); |
191 | OP_REQUIRES_OK(c, indices_flat.Init()); |
192 | OP_REQUIRES_OK(c, data_flat.Init()); |
193 | // initialize the indices_flat (-1 represents missing indices) |
194 | for (int i = 0; i < first_dim_size; ++i) { |
195 | indices_flat.Set(i, -1); |
196 | } |
197 | |
198 | // data_flat index |
199 | int32_t idx = 0; |
200 | // sum of indices_inputs[i].NumElements() for compute indices_flat value. |
201 | int32_t base_size = 0; |
202 | for (int i = 0; i < indices_inputs.size(); ++i) { |
203 | auto indices_vec = indices_inputs[i].flat<int32>(); |
204 | auto data_ptr_base = data_inputs[i].template flat<T>().data(); |
205 | for (int j = 0; j < indices_vec.size(); ++j) { |
206 | // indices_flat's indices represent the indices of output. |
207 | // indices_flat's values represent the indices of input_data where the |
208 | // data located. |
209 | indices_flat.Set(indices_vec(j), base_size + j); |
210 | data_flat.Set( |
211 | idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) + |
212 | j * slice_size)); |
213 | ++idx; |
214 | } |
215 | base_size += indices_vec.size(); |
216 | } |
217 | OP_REQUIRES_OK(c, indices_flat.Finalize()); |
218 | OP_REQUIRES_OK(c, data_flat.Finalize()); |
219 | |
220 | auto output = merged->template flat<T>().data(); |
221 | DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size, |
222 | indices_flat.data(), data_flat.data(), output); |
223 | } |
224 | } |
225 | }; |
226 | |
227 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
228 | |
229 | template <class T, bool Parallel> |
230 | class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> { |
231 | public: |
232 | explicit DynamicStitchOpImplCPU(OpKernelConstruction* c) |
233 | : DynamicStitchOpImplBase<T>( |
234 | c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp" )) {} |
235 | |
236 | void Compute(OpKernelContext* c) override { |
237 | OpInputList indices_inputs; |
238 | OpInputList data_inputs; |
239 | int first_dim_size; |
240 | Tensor* merged = nullptr; |
241 | this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs, |
242 | &first_dim_size, nullptr, &merged); |
243 | if (!c->status().ok()) { |
244 | // Avoid segmentation faults if merged cannot be allocated and an error is |
245 | // passed back in the context. |
246 | return; |
247 | } |
248 | |
249 | // TODO(jeff): Currently we leave uninitialized any portions of |
250 | // merged that aren't covered by an index in indices. What should we do? |
251 | if (first_dim_size > 0) { |
252 | auto merged_flat = merged->flat_outer_dims<T>(); |
253 | // slice_size must not be stored as int for cases of tensors over 2GB. |
254 | const auto slice_size = merged_flat.dimension(1); |
255 | const size_t slice_bytes = slice_size * sizeof(T); |
256 | auto OnInputNumber = [&](int input_num) { |
257 | const Tensor& indices = indices_inputs[input_num]; |
258 | auto indices_vec = indices.flat<int32>(); |
259 | const Tensor& data = data_inputs[input_num]; |
260 | auto data_flat = |
261 | data.shaped<T, 2>({indices_vec.dimension(0), slice_size}); |
262 | |
263 | if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) { |
264 | T* merged_base = merged_flat.data(); |
265 | const T* data_base = data_flat.data(); |
266 | for (int i = 0; i < indices_vec.size(); i++) { |
267 | int32_t index = internal::SubtleMustCopy(indices_vec(i)); |
268 | OP_REQUIRES( |
269 | c, FastBoundsCheck(index, first_dim_size), |
270 | errors::InvalidArgument("indices[" , i, "] is out of range" )); |
271 | memcpy(merged_base + index * slice_size, data_base + i * slice_size, |
272 | slice_bytes); |
273 | } |
274 | } else { |
275 | Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size); |
276 | for (int i = 0; i < indices_vec.size(); i++) { |
277 | // Copy slice data[i] to merged[indices[i]] |
278 | Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0); |
279 | int32_t index = internal::SubtleMustCopy(indices_vec(i)); |
280 | OP_REQUIRES( |
281 | c, FastBoundsCheck(index, first_dim_size), |
282 | errors::InvalidArgument("indices[" , i, "] is out of range" )); |
283 | Eigen::DSizes<Eigen::DenseIndex, 2> merged_indices(index, 0); |
284 | merged_flat.slice(merged_indices, sizes) = |
285 | data_flat.slice(data_indices, sizes); |
286 | } |
287 | } |
288 | }; |
289 | if (Parallel && |
290 | c->device()->tensorflow_cpu_worker_threads()->num_threads > 1) { |
291 | auto thread_pool = |
292 | c->device()->tensorflow_cpu_worker_threads()->workers; |
293 | size_t total_indices_size = 0; |
294 | for (int input_num = 0; input_num < indices_inputs.size(); |
295 | ++input_num) { |
296 | total_indices_size += indices_inputs[input_num].NumElements(); |
297 | } |
298 | const double avg_indices_size = |
299 | static_cast<double>(total_indices_size) / indices_inputs.size(); |
300 | auto bytes_processed = slice_bytes * avg_indices_size; |
301 | auto LoopBody = [&](int first, int last) { |
302 | for (int input_num = first; input_num < last; ++input_num) { |
303 | OnInputNumber(input_num); |
304 | } |
305 | }; |
306 | thread_pool->ParallelFor(indices_inputs.size(), bytes_processed, |
307 | LoopBody); |
308 | } else { |
309 | for (int input_num = 0; input_num < indices_inputs.size(); |
310 | input_num++) { |
311 | OnInputNumber(input_num); |
312 | } |
313 | } |
314 | } |
315 | } |
316 | }; |
317 | |
318 | // Using inheritance rather than a typedef so that these classes might have more |
319 | // functionality later. |
320 | |
321 | template <typename T> |
322 | struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> { |
323 | using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU; |
324 | }; |
325 | |
326 | template <typename T> |
327 | struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> { |
328 | using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU; |
329 | }; |
330 | |
331 | #define REGISTER_DYNAMIC_STITCH(type) \ |
332 | REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ |
333 | .Device(DEVICE_CPU) \ |
334 | .TypeConstraint<type>("T") \ |
335 | .HostMemory("indices"), \ |
336 | DynamicStitchOpCPU<type>) \ |
337 | REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ |
338 | .Device(DEVICE_CPU) \ |
339 | .TypeConstraint<type>("T") \ |
340 | .HostMemory("indices"), \ |
341 | ParallelDynamicStitchOpCPU<type>) |
342 | |
343 | TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH); |
344 | TF_CALL_variant(REGISTER_DYNAMIC_STITCH); |
345 | TF_CALL_QUANTIZED_TYPES(REGISTER_DYNAMIC_STITCH); |
346 | #undef REGISTER_DYNAMIC_STITCH |
347 | |
348 | #define REGISTER_PARALLEL_DYNAMIC_STITCH(type) \ |
349 | REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \ |
350 | .Device(DEVICE_DEFAULT) \ |
351 | .TypeConstraint<type>("T") \ |
352 | .HostMemory("indices") \ |
353 | .HostMemory("data") \ |
354 | .HostMemory("merged"), \ |
355 | ParallelDynamicStitchOpCPU<type>) |
356 | |
357 | TF_CALL_int32(REGISTER_PARALLEL_DYNAMIC_STITCH); |
358 | TF_CALL_int64(REGISTER_PARALLEL_DYNAMIC_STITCH); |
359 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_DYNAMIC_STITCH); |
360 | TF_CALL_COMPLEX_TYPES(REGISTER_PARALLEL_DYNAMIC_STITCH); |
361 | #undef REGISTER_PARALLEL_DYNAMIC_STITCH |
362 | |
363 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
364 | #define REGISTER_DYNAMIC_STITCH_GPU(type) \ |
365 | REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \ |
366 | .Device(DEVICE_GPU) \ |
367 | .TypeConstraint<type>("T") \ |
368 | .HostMemory("indices"), \ |
369 | DynamicStitchOpGPU<type>) |
370 | |
371 | TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU); |
372 | TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU); |
373 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU); |
374 | TF_CALL_COMPLEX_TYPES(REGISTER_DYNAMIC_STITCH_GPU); |
375 | #undef REGISTER_DYNAMIC_STITCH_GPU |
376 | |
377 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
378 | |
379 | } // namespace tensorflow |
380 | |