1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/data_flow_ops.cc.
17
18#include "tensorflow/core/framework/bounds_check.h"
19#include "tensorflow/core/framework/op_kernel.h"
20#include "tensorflow/core/framework/register_types.h"
21#include "tensorflow/core/framework/tensor.h"
22#include "tensorflow/core/lib/core/threadpool.h"
23
24#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
25#include "tensorflow/core/kernels/gpu_device_array.h"
26#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
27
28namespace tensorflow {
29
30typedef Eigen::ThreadPoolDevice CPUDevice;
31#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
32typedef Eigen::GpuDevice GPUDevice;
33#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
34
35template <class T>
36class DynamicStitchOpImplBase : public OpKernel {
37 public:
38 explicit DynamicStitchOpImplBase(OpKernelConstruction* c,
39 const string& op_name)
40 : OpKernel(c) {
41 // Compute expected input signature
42 const DataType dt = DataTypeToEnum<T>::v();
43 const int n = c->num_inputs() / 2;
44 DataTypeVector expected;
45 for (int i = 0; i < n; i++) {
46 expected.push_back(DT_INT32);
47 }
48 for (int i = 0; i < n; i++) {
49 expected.push_back(dt);
50 }
51 OP_REQUIRES_OK(c, c->MatchSignature(expected, {dt}));
52 OP_REQUIRES(c, c->num_inputs() > 0,
53 errors::InvalidArgument(op_name + ": Must have some inputs"));
54 OP_REQUIRES(c, c->num_inputs() % 2 == 0,
55 errors::InvalidArgument(
56 op_name + ": Must have even number of arguments"));
57 }
58
59 protected:
60 // Check if data0.shape[indices0.dims():] == data1.shape[indices1.dims():]
61 static bool SameExtraShape(const Tensor& data0, const Tensor& indices0,
62 const Tensor& data1, const Tensor& indices1) {
63 const int extra0 = data0.dims() - indices0.dims();
64 const int extra1 = data1.dims() - indices1.dims();
65 if (extra0 != extra1) return false;
66 for (int i = 0; i < extra0; i++) {
67 if (data0.dim_size(indices0.dims() + i) !=
68 data1.dim_size(indices1.dims() + i)) {
69 return false;
70 }
71 }
72 return true;
73 }
74
75 void CheckArgsAndAllocateResult(OpKernelContext* c,
76 OpInputList* indices_inputs,
77 OpInputList* data_inputs, int* first_dim_size,
78 int* data_elements_size,
79 Tensor** result_ptr) {
80 // Find maximum index in the indices vectors
81 OP_REQUIRES_OK(c, c->input_list("indices", indices_inputs));
82
83 int32_t max_index = -1;
84 if (data_elements_size) {
85 *data_elements_size = 0;
86 }
87 for (const Tensor& indices : *indices_inputs) {
88 if (indices.NumElements() > 0) {
89 Eigen::Tensor<int32, 0, Eigen::RowMajor> m =
90 indices.flat<int32>().maximum();
91 max_index = std::max(m(), max_index);
92 }
93 if (data_elements_size) {
94 *data_elements_size += indices.NumElements();
95 }
96 }
97
98 *first_dim_size = max_index + 1;
99
100 // Validate that data[i].shape = indices[i].shape + constant
101 OP_REQUIRES_OK(c, c->input_list("data", data_inputs));
102 const Tensor& data0 = (*data_inputs)[0];
103 const Tensor& indices0 = (*indices_inputs)[0];
104 for (int input_num = 0; input_num < indices_inputs->size(); input_num++) {
105 const Tensor& indices = (*indices_inputs)[input_num];
106 const Tensor& data = (*data_inputs)[input_num];
107 OP_REQUIRES(
108 c, TensorShapeUtils::StartsWith(data.shape(), indices.shape()),
109 errors::InvalidArgument("data[", input_num,
110 "].shape = ", data.shape().DebugString(),
111 " does not start with indices[", input_num,
112 "].shape = ", indices.shape().DebugString()));
113 OP_REQUIRES(
114 c, input_num == 0 || SameExtraShape(data0, indices0, data, indices),
115 errors::InvalidArgument(
116 "Need data[0].shape[", indices0.dims(), ":] = data[", input_num,
117 "].shape[", indices.dims(),
118 ":], got data[0].shape = ", data0.shape().DebugString(),
119 ", data[", input_num, "].shape = ", data.shape().DebugString(),
120 ", indices[0].shape = ", indices0.shape().DebugString(),
121 ", indices[", input_num,
122 "].shape = ", indices.shape().DebugString()));
123 }
124
125 // Allocate result tensor of shape
126 // [*first_dim_size] + data.shape[indices.dims:]
127 TensorShape result_shape;
128 result_shape.AddDim(*first_dim_size);
129 for (int d = indices0.dims(); d < data0.dims(); d++) {
130 result_shape.AddDim(data0.dim_size(d));
131 }
132 OP_REQUIRES_OK(c, c->allocate_output(0, result_shape, result_ptr));
133 }
134};
135
136#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
137
138template <typename T>
139void DynamicStitchGPUImpl(const Eigen::GpuDevice& gpu_device,
140 const int32_t slice_size,
141 const int32_t first_dim_size,
142 const GpuDeviceArrayStruct<int>& input_indices,
143 const GpuDeviceArrayStruct<const T*>& input_ptrs,
144 T* output);
145#define REGISTER_GPU(T) \
146 extern template void DynamicStitchGPUImpl( \
147 const Eigen::GpuDevice& gpu_device, const int32 slice_size, \
148 const int32 first_dim_size, \
149 const GpuDeviceArrayStruct<int32>& input_indices, \
150 const GpuDeviceArrayStruct<const T*>& input_ptrs, T* output);
151
152TF_CALL_int32(REGISTER_GPU);
153TF_CALL_int64(REGISTER_GPU);
154TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
155TF_CALL_COMPLEX_TYPES(REGISTER_GPU);
156#undef REGISTER_GPU
157
158template <class T>
159class DynamicStitchOpGPU : public DynamicStitchOpImplBase<T> {
160 public:
161 explicit DynamicStitchOpGPU(OpKernelConstruction* c)
162 : DynamicStitchOpImplBase<T>(c, "DynamicStitchOp") {}
163
164 void Compute(OpKernelContext* c) override {
165 OpInputList indices_inputs;
166 OpInputList data_inputs;
167 int first_dim_size;
168 int data_elements_size;
169 Tensor* merged = nullptr;
170 this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
171 &first_dim_size, &data_elements_size,
172 &merged);
173 if (!c->status().ok()) {
174 // Avoid segmentation faults if merged cannot be allocated and an error is
175 // passed back in the context.
176 return;
177 }
178
179 // TODO(jeff): Currently we leave uninitialized any portions of
180 // merged that aren't covered by an index in indices. What should we do?
181 if (first_dim_size > 0) {
182 // because the collision requirements, we have to deal with
183 // collision first before send data to gpu kernel.
184 // TODO(ekelsen): Instead of doing a serial scan on the CPU to pick the
185 // last of duplicated indices, it could instead be done of the GPU
186 // implicitly using atomics to make sure the last index is the final
187 // write.
188 const int slice_size = merged->flat_outer_dims<T>().dimension(1);
189 GpuDeviceArrayOnHost<int32> indices_flat(c, first_dim_size);
190 GpuDeviceArrayOnHost<const T*> data_flat(c, data_elements_size);
191 OP_REQUIRES_OK(c, indices_flat.Init());
192 OP_REQUIRES_OK(c, data_flat.Init());
193 // initialize the indices_flat (-1 represents missing indices)
194 for (int i = 0; i < first_dim_size; ++i) {
195 indices_flat.Set(i, -1);
196 }
197
198 // data_flat index
199 int32_t idx = 0;
200 // sum of indices_inputs[i].NumElements() for compute indices_flat value.
201 int32_t base_size = 0;
202 for (int i = 0; i < indices_inputs.size(); ++i) {
203 auto indices_vec = indices_inputs[i].flat<int32>();
204 auto data_ptr_base = data_inputs[i].template flat<T>().data();
205 for (int j = 0; j < indices_vec.size(); ++j) {
206 // indices_flat's indices represent the indices of output.
207 // indices_flat's values represent the indices of input_data where the
208 // data located.
209 indices_flat.Set(indices_vec(j), base_size + j);
210 data_flat.Set(
211 idx, const_cast<T*>(reinterpret_cast<const T*>(data_ptr_base) +
212 j * slice_size));
213 ++idx;
214 }
215 base_size += indices_vec.size();
216 }
217 OP_REQUIRES_OK(c, indices_flat.Finalize());
218 OP_REQUIRES_OK(c, data_flat.Finalize());
219
220 auto output = merged->template flat<T>().data();
221 DynamicStitchGPUImpl<T>(c->eigen_gpu_device(), slice_size, first_dim_size,
222 indices_flat.data(), data_flat.data(), output);
223 }
224 }
225};
226
227#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
228
229template <class T, bool Parallel>
230class DynamicStitchOpImplCPU : public DynamicStitchOpImplBase<T> {
231 public:
232 explicit DynamicStitchOpImplCPU(OpKernelConstruction* c)
233 : DynamicStitchOpImplBase<T>(
234 c, (Parallel ? "ParallelDynamicStitchOp" : "DynamicStitchOp")) {}
235
236 void Compute(OpKernelContext* c) override {
237 OpInputList indices_inputs;
238 OpInputList data_inputs;
239 int first_dim_size;
240 Tensor* merged = nullptr;
241 this->CheckArgsAndAllocateResult(c, &indices_inputs, &data_inputs,
242 &first_dim_size, nullptr, &merged);
243 if (!c->status().ok()) {
244 // Avoid segmentation faults if merged cannot be allocated and an error is
245 // passed back in the context.
246 return;
247 }
248
249 // TODO(jeff): Currently we leave uninitialized any portions of
250 // merged that aren't covered by an index in indices. What should we do?
251 if (first_dim_size > 0) {
252 auto merged_flat = merged->flat_outer_dims<T>();
253 // slice_size must not be stored as int for cases of tensors over 2GB.
254 const auto slice_size = merged_flat.dimension(1);
255 const size_t slice_bytes = slice_size * sizeof(T);
256 auto OnInputNumber = [&](int input_num) {
257 const Tensor& indices = indices_inputs[input_num];
258 auto indices_vec = indices.flat<int32>();
259 const Tensor& data = data_inputs[input_num];
260 auto data_flat =
261 data.shaped<T, 2>({indices_vec.dimension(0), slice_size});
262
263 if (DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
264 T* merged_base = merged_flat.data();
265 const T* data_base = data_flat.data();
266 for (int i = 0; i < indices_vec.size(); i++) {
267 int32_t index = internal::SubtleMustCopy(indices_vec(i));
268 OP_REQUIRES(
269 c, FastBoundsCheck(index, first_dim_size),
270 errors::InvalidArgument("indices[", i, "] is out of range"));
271 memcpy(merged_base + index * slice_size, data_base + i * slice_size,
272 slice_bytes);
273 }
274 } else {
275 Eigen::DSizes<Eigen::DenseIndex, 2> sizes(1, slice_size);
276 for (int i = 0; i < indices_vec.size(); i++) {
277 // Copy slice data[i] to merged[indices[i]]
278 Eigen::DSizes<Eigen::DenseIndex, 2> data_indices(i, 0);
279 int32_t index = internal::SubtleMustCopy(indices_vec(i));
280 OP_REQUIRES(
281 c, FastBoundsCheck(index, first_dim_size),
282 errors::InvalidArgument("indices[", i, "] is out of range"));
283 Eigen::DSizes<Eigen::DenseIndex, 2> merged_indices(index, 0);
284 merged_flat.slice(merged_indices, sizes) =
285 data_flat.slice(data_indices, sizes);
286 }
287 }
288 };
289 if (Parallel &&
290 c->device()->tensorflow_cpu_worker_threads()->num_threads > 1) {
291 auto thread_pool =
292 c->device()->tensorflow_cpu_worker_threads()->workers;
293 size_t total_indices_size = 0;
294 for (int input_num = 0; input_num < indices_inputs.size();
295 ++input_num) {
296 total_indices_size += indices_inputs[input_num].NumElements();
297 }
298 const double avg_indices_size =
299 static_cast<double>(total_indices_size) / indices_inputs.size();
300 auto bytes_processed = slice_bytes * avg_indices_size;
301 auto LoopBody = [&](int first, int last) {
302 for (int input_num = first; input_num < last; ++input_num) {
303 OnInputNumber(input_num);
304 }
305 };
306 thread_pool->ParallelFor(indices_inputs.size(), bytes_processed,
307 LoopBody);
308 } else {
309 for (int input_num = 0; input_num < indices_inputs.size();
310 input_num++) {
311 OnInputNumber(input_num);
312 }
313 }
314 }
315 }
316};
317
318// Using inheritance rather than a typedef so that these classes might have more
319// functionality later.
320
321template <typename T>
322struct DynamicStitchOpCPU : DynamicStitchOpImplCPU<T, false> {
323 using DynamicStitchOpImplCPU<T, false>::DynamicStitchOpImplCPU;
324};
325
326template <typename T>
327struct ParallelDynamicStitchOpCPU : DynamicStitchOpImplCPU<T, true> {
328 using DynamicStitchOpImplCPU<T, true>::DynamicStitchOpImplCPU;
329};
330
331#define REGISTER_DYNAMIC_STITCH(type) \
332 REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
333 .Device(DEVICE_CPU) \
334 .TypeConstraint<type>("T") \
335 .HostMemory("indices"), \
336 DynamicStitchOpCPU<type>) \
337 REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
338 .Device(DEVICE_CPU) \
339 .TypeConstraint<type>("T") \
340 .HostMemory("indices"), \
341 ParallelDynamicStitchOpCPU<type>)
342
343TF_CALL_POD_STRING_TYPES(REGISTER_DYNAMIC_STITCH);
344TF_CALL_variant(REGISTER_DYNAMIC_STITCH);
345TF_CALL_QUANTIZED_TYPES(REGISTER_DYNAMIC_STITCH);
346#undef REGISTER_DYNAMIC_STITCH
347
348#define REGISTER_PARALLEL_DYNAMIC_STITCH(type) \
349 REGISTER_KERNEL_BUILDER(Name("ParallelDynamicStitch") \
350 .Device(DEVICE_DEFAULT) \
351 .TypeConstraint<type>("T") \
352 .HostMemory("indices") \
353 .HostMemory("data") \
354 .HostMemory("merged"), \
355 ParallelDynamicStitchOpCPU<type>)
356
357TF_CALL_int32(REGISTER_PARALLEL_DYNAMIC_STITCH);
358TF_CALL_int64(REGISTER_PARALLEL_DYNAMIC_STITCH);
359TF_CALL_GPU_NUMBER_TYPES(REGISTER_PARALLEL_DYNAMIC_STITCH);
360TF_CALL_COMPLEX_TYPES(REGISTER_PARALLEL_DYNAMIC_STITCH);
361#undef REGISTER_PARALLEL_DYNAMIC_STITCH
362
363#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
364#define REGISTER_DYNAMIC_STITCH_GPU(type) \
365 REGISTER_KERNEL_BUILDER(Name("DynamicStitch") \
366 .Device(DEVICE_GPU) \
367 .TypeConstraint<type>("T") \
368 .HostMemory("indices"), \
369 DynamicStitchOpGPU<type>)
370
371TF_CALL_int32(REGISTER_DYNAMIC_STITCH_GPU);
372TF_CALL_int64(REGISTER_DYNAMIC_STITCH_GPU);
373TF_CALL_GPU_NUMBER_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
374TF_CALL_COMPLEX_TYPES(REGISTER_DYNAMIC_STITCH_GPU);
375#undef REGISTER_DYNAMIC_STITCH_GPU
376
377#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
378
379} // namespace tensorflow
380