1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
21 | #define EIGEN_USE_GPU |
22 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
23 | |
24 | #include <numeric> |
25 | |
26 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
27 | #include "tensorflow/core/framework/bounds_check.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/register_types.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/kernels/ops_util.h" |
32 | #include "tensorflow/core/kernels/split_lib.h" |
33 | #include "tensorflow/core/lib/core/status.h" |
34 | #include "tensorflow/core/lib/gtl/array_slice.h" |
35 | #include "tensorflow/core/util/work_sharder.h" |
36 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
37 | #include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h" |
38 | #include "tensorflow/core/kernels/gpu_device_array.h" |
39 | #include "tensorflow/core/kernels/split_lib_gpu.h" |
40 | #include "tensorflow/core/platform/stream_executor.h" |
41 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
42 | |
43 | namespace tensorflow { |
44 | |
45 | typedef Eigen::ThreadPoolDevice CPUDevice; |
46 | typedef Eigen::GpuDevice GPUDevice; |
47 | |
48 | template <typename Device, typename T, typename Tlen> |
49 | class SplitVOpBase : public OpKernel { |
50 | public: |
51 | explicit SplitVOpBase(OpKernelConstruction* c) : OpKernel(c) {} |
52 | |
53 | void ComputeEasyCases(OpKernelContext* context, bool* done, |
54 | std::vector<Tlen>* split_sizes_vec) { |
55 | const int32_t num_split = context->num_outputs(); |
56 | const Tensor& input = context->input(0); |
57 | const TensorShape& input_shape = input.shape(); |
58 | const Tensor& split_tensor = context->input(1); |
59 | const Tensor& split_dim_tensor = context->input(2); |
60 | |
61 | OP_REQUIRES(context, split_dim_tensor.NumElements() == 1, |
62 | errors::InvalidArgument("split_dim_tensor must have " |
63 | "exactly one element." )); |
64 | |
65 | const int32_t split_dim_orig = split_dim_tensor.flat<int32>()(0); |
66 | const int32_t split_dim = |
67 | split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; |
68 | |
69 | OP_REQUIRES( |
70 | context, |
71 | split_tensor.dims() == 1 && split_tensor.NumElements() == num_split, |
72 | errors::InvalidArgument("size of the split_tensor must be 1-D and have " |
73 | "the same elements as outputs got " , |
74 | split_tensor.dims(), " -D and " , |
75 | split_tensor.NumElements(), " elements" )); |
76 | |
77 | auto split_sizes_d = split_tensor.vec<Tlen>(); |
78 | |
79 | split_sizes_vec->resize(split_sizes_d.size()); |
80 | |
81 | std::copy(split_sizes_d.data(), split_sizes_d.data() + split_sizes_d.size(), |
82 | split_sizes_vec->begin()); |
83 | |
84 | OP_REQUIRES( |
85 | context, num_split > 0, |
86 | errors::InvalidArgument( |
87 | "Number of ways to split should be > 0, but got " , num_split)); |
88 | |
89 | OP_REQUIRES( |
90 | context, 0 <= split_dim && split_dim < input.dims(), |
91 | errors::InvalidArgument("-input rank(-" , input.dims(), |
92 | ") <= split_dim < input rank (" , input.dims(), |
93 | "), but got " , split_dim_orig)); |
94 | |
95 | Tlen input_size_split_dim = input_shape.dim_size(split_dim); |
96 | |
97 | // Special case 1: num_split == 1. Nothing to do. |
98 | if (num_split == 1) { |
99 | context->set_output(0, context->input(0)); |
100 | OP_REQUIRES( |
101 | context, (*split_sizes_vec)[0] == input_size_split_dim, |
102 | errors::InvalidArgument("If there is only one output, it must have " |
103 | "the same size as the input. Input size: " , |
104 | input_size_split_dim, |
105 | " output size: " , (*split_sizes_vec)[0])); |
106 | *done = true; |
107 | return; |
108 | } |
109 | |
110 | // Determine sizes of output, in case of a -1 input value |
111 | int neg_one_dim = -1; |
112 | Tlen determined_size = 0; |
113 | for (int d = 0; d < split_sizes_vec->size(); ++d) { |
114 | Tlen size = (*split_sizes_vec)[d]; |
115 | |
116 | if (size == -1) { |
117 | OP_REQUIRES(context, neg_one_dim == -1, |
118 | errors::InvalidArgument("There can only be one -1 in the " |
119 | "input." )); |
120 | neg_one_dim = d; |
121 | } else { |
122 | determined_size += size; |
123 | } |
124 | } |
125 | |
126 | OP_REQUIRES( |
127 | context, |
128 | (neg_one_dim == -1 && determined_size == input_size_split_dim) || |
129 | (neg_one_dim >= 0 && determined_size <= input_size_split_dim), |
130 | errors::InvalidArgument("Determined shape must either match " |
131 | "input shape along split_dim exactly if " |
132 | "fully specified, or be less than the size of " |
133 | "the input along split_dim if not fully " |
134 | "specified. Got: " , |
135 | determined_size)); |
136 | |
137 | if (neg_one_dim >= 0) { |
138 | (*split_sizes_vec)[neg_one_dim] = input_size_split_dim - determined_size; |
139 | } |
140 | |
141 | for (int i = 0; i < split_sizes_vec->size(); ++i) { |
142 | const Tlen& split_size = (*split_sizes_vec)[i]; |
143 | OP_REQUIRES(context, split_size >= Tlen(0), |
144 | errors::InvalidArgument("Split size at index " , i, |
145 | " must be >= 0. Got: " , split_size)); |
146 | } |
147 | |
148 | // Special case 2: split along the 1st dimension. The requirements are that |
149 | // either we are splitting the outer dimension of two or more such that |
150 | // every outer subpart is aligned or that the split sizes mean that they are |
151 | // always aligned. In these cases, we can share the underlying buffer. |
152 | // |
153 | // Apply this optimization conservatively: if input is aligned, |
154 | // the resulting tensors must be aligned. It's conservative |
155 | // because if the immediate consumer of the resulting tensors are |
156 | // not using eigen for computation, its perfectly fine to avoid |
157 | // the copying. |
158 | if (SplitHasAlignedOutputsInFirstDimension( |
159 | input_shape, split_dim, absl::MakeConstSpan(*split_sizes_vec))) { |
160 | Tlen start = 0; |
161 | for (int i = 0; i < num_split; ++i) { |
162 | context->set_output(i, |
163 | input.Slice(start, start + (*split_sizes_vec)[i])); |
164 | start += (*split_sizes_vec)[i]; |
165 | } |
166 | *done = true; |
167 | return; |
168 | } |
169 | } |
170 | |
171 | template <typename IndexType> |
172 | std::tuple<IndexType, IndexType, IndexType> SetDims( |
173 | const TensorShape& input_shape, const int32_t split_dim) const { |
174 | static_assert(std::is_integral<IndexType>::value, |
175 | "IndexType must be an integer type" ); |
176 | int32_t prefix_dim_size = 1; |
177 | for (int i = 0; i < split_dim; ++i) { |
178 | prefix_dim_size *= input_shape.dim_size(i); |
179 | } |
180 | |
181 | // Caller must ensure that dim_size and suffix_dim_size are < |
182 | // std::numeric_limits<IndexType>::max() |
183 | IndexType split_dim_size = |
184 | static_cast<IndexType>(input_shape.dim_size(split_dim)); |
185 | |
186 | IndexType suffix_dim_size = 1; |
187 | for (int i = split_dim + 1; i < input_shape.dims(); ++i) { |
188 | suffix_dim_size *= static_cast<IndexType>(input_shape.dim_size(i)); |
189 | } |
190 | return std::make_tuple(prefix_dim_size, split_dim_size, suffix_dim_size); |
191 | } |
192 | |
193 | private: |
194 | // Determines whether the given split configuration can be done using slicing |
195 | // on the first dimension of the tensor. The requirement is that each result |
196 | // tensor from the slice is correctly aligned within the input tensor. |
197 | static bool SplitHasAlignedOutputsInFirstDimension( |
198 | const TensorShape& input_shape, int32_t split_dim, |
199 | absl::Span<const Tlen> split_sizes) { |
200 | if (split_dim != 0) { |
201 | return false; |
202 | } |
203 | Tlen start = 0; |
204 | for (const Tlen split_size : split_sizes) { |
205 | if (!IsDim0SliceAligned<T>(input_shape, start, start + split_size)) { |
206 | return false; |
207 | } |
208 | start += split_size; |
209 | } |
210 | return true; |
211 | } |
212 | }; |
213 | |
214 | template <typename T, typename Tlen, typename InputReshapedType, int NDims> |
215 | class SplitVOpCPUImpl { |
216 | public: |
217 | template <typename MakeSizesType, typename ReshapeResultType> |
218 | void operator()(OpKernelContext* context, |
219 | const InputReshapedType& input_reshaped, |
220 | const std::vector<int64_t>& split_start_points, |
221 | const TensorShape& input_shape, int32_t split_dim, |
222 | Eigen::DenseIndex prefix_dim_size, |
223 | Eigen::DenseIndex split_dim_size, |
224 | Eigen::DenseIndex suffix_dim_size, |
225 | std::vector<Tlen>& split_sizes_vec, |
226 | const MakeSizesType& make_sizes, |
227 | const ReshapeResultType& reshape_result) const { |
228 | constexpr uint64 kMinimumSplitNum = 4; |
229 | |
230 | Eigen::DSizes<Eigen::DenseIndex, NDims> indices; |
231 | for (int i = 0; i < NDims; ++i) { |
232 | indices[i] = 0; |
233 | } |
234 | const auto num_threads = |
235 | context->device()->tensorflow_cpu_worker_threads()->num_threads; |
236 | // TODO(jewillco): Tune heuristic further. |
237 | const auto input_element_count = input_shape.num_elements(); |
238 | const int num_split = split_start_points.size(); |
239 | const bool use_parallelism_between_outputs = |
240 | (num_split >= kMinimumSplitNum && |
241 | input_element_count >= std::min(num_threads, num_split) * 4096 && |
242 | input_element_count < num_split * 180 * 1024); |
243 | |
244 | auto range_output_func = [&indices, context, &input_shape, split_dim, |
245 | &split_sizes_vec, &split_start_points, |
246 | use_parallelism_between_outputs, &input_reshaped, |
247 | &make_sizes, |
248 | &reshape_result](int64_t start, int64_t limit) { |
249 | for (int64_t i = start; i < limit; ++i) { |
250 | TensorShape output_shape(input_shape); |
251 | output_shape.set_dim(split_dim, split_sizes_vec[i]); |
252 | Tensor* result = nullptr; |
253 | OP_REQUIRES_OK(context, |
254 | context->allocate_output(i, output_shape, &result)); |
255 | |
256 | const auto sizes = make_sizes(split_sizes_vec[i]); |
257 | |
258 | if (sizes.TotalSize() > 0) { |
259 | auto result_shaped = reshape_result(result, split_sizes_vec[i]); |
260 | |
261 | auto current_indices = indices; |
262 | current_indices[NDims - 2] = split_start_points[i]; |
263 | if (use_parallelism_between_outputs) { |
264 | // Use sequential implementation for single output. |
265 | result_shaped = input_reshaped.slice(current_indices, sizes); |
266 | } else { |
267 | // This implementation may be parallel internally. |
268 | functor::Split<CPUDevice, T, NDims>()( |
269 | context->eigen_device<CPUDevice>(), result_shaped, |
270 | input_reshaped, current_indices, sizes); |
271 | } |
272 | } |
273 | } |
274 | }; |
275 | |
276 | if (use_parallelism_between_outputs) { |
277 | // A thread maps a output tensor, this thread will traverse all the data, |
278 | // and then put specified data to mapped output tensor. Run in parallel, |
279 | // disabling parallelism in functor. |
280 | Shard(num_split, |
281 | context->device()->tensorflow_cpu_worker_threads()->workers, |
282 | num_split, input_element_count / num_split, range_output_func); |
283 | } else { |
284 | // Run sequentially, but allow internal parallelism in functor. |
285 | range_output_func(0, num_split); |
286 | } |
287 | } |
288 | }; |
289 | |
290 | template <typename T, typename Tlen> |
291 | class SplitVOpCPU : public SplitVOpBase<CPUDevice, T, Tlen> { |
292 | public: |
293 | typedef SplitVOpBase<CPUDevice, T, Tlen> Base; |
294 | explicit SplitVOpCPU(OpKernelConstruction* c) : Base(c) {} |
295 | |
296 | void Compute(OpKernelContext* context) override { |
297 | bool done = false; |
298 | std::vector<Tlen> split_sizes_vec; |
299 | Base::ComputeEasyCases(context, &done, &split_sizes_vec); |
300 | if (!context->status().ok() || done) { |
301 | return; |
302 | } |
303 | const int32_t num_split = Base::num_outputs(); |
304 | const Tensor& input = context->input(0); |
305 | const TensorShape& input_shape = input.shape(); |
306 | const int32_t split_dim_orig = context->input(2).flat<int32>()(0); |
307 | const int32_t split_dim = |
308 | split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; |
309 | |
310 | // Android also uses int32 indexing, so check here also. |
311 | OP_REQUIRES( |
312 | context, |
313 | FastBoundsCheck(input.NumElements(), |
314 | std::numeric_limits<Eigen::DenseIndex>::max()), |
315 | errors::InvalidArgument("Split requires input size < " , |
316 | std::numeric_limits<Eigen::DenseIndex>::max())); |
317 | |
318 | Eigen::DenseIndex prefix_dim_size; |
319 | Eigen::DenseIndex split_dim_size; |
320 | Eigen::DenseIndex suffix_dim_size; |
321 | |
322 | std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) = |
323 | Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim); |
324 | std::vector<int64_t> split_start_points(num_split); |
325 | for (int i = 0; i < num_split; ++i) { |
326 | if (i == 0) { |
327 | split_start_points[i] = 0; |
328 | } else { |
329 | split_start_points[i] = |
330 | split_start_points[i - 1] + split_sizes_vec[i - 1]; |
331 | } |
332 | } |
333 | |
334 | if (prefix_dim_size == 1) { |
335 | auto input_reshaped = |
336 | input.shaped<T, 2>({split_dim_size, suffix_dim_size}); |
337 | auto make_sizes = [&](Eigen::DenseIndex split_size) { |
338 | return Eigen::DSizes<Eigen::DenseIndex, 2>{split_size, suffix_dim_size}; |
339 | }; |
340 | auto reshape_result = [&](Tensor* result, Tlen split_size) { |
341 | return result->shaped<T, 2>({split_size, suffix_dim_size}); |
342 | }; |
343 | SplitVOpCPUImpl<T, Tlen, decltype(input_reshaped), 2>{}( |
344 | context, input_reshaped, split_start_points, input_shape, split_dim, |
345 | prefix_dim_size, split_dim_size, suffix_dim_size, split_sizes_vec, |
346 | make_sizes, reshape_result); |
347 | } else { |
348 | auto input_reshaped = input.shaped<T, 3>( |
349 | {prefix_dim_size, split_dim_size, suffix_dim_size}); |
350 | auto make_sizes = [&](Eigen::DenseIndex split_size) { |
351 | return Eigen::DSizes<Eigen::DenseIndex, 3>{prefix_dim_size, split_size, |
352 | suffix_dim_size}; |
353 | }; |
354 | auto reshape_result = [&](Tensor* result, Tlen split_size) { |
355 | return result->shaped<T, 3>( |
356 | {prefix_dim_size, split_size, suffix_dim_size}); |
357 | }; |
358 | SplitVOpCPUImpl<T, Tlen, decltype(input_reshaped), 3>{}( |
359 | context, input_reshaped, split_start_points, input_shape, split_dim, |
360 | prefix_dim_size, split_dim_size, suffix_dim_size, split_sizes_vec, |
361 | make_sizes, reshape_result); |
362 | } |
363 | } |
364 | }; |
365 | |
366 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
367 | |
368 | // Partial specialization for GPU |
369 | template <typename T, typename Tlen> |
370 | class SplitVOpGPU : public SplitVOpBase<GPUDevice, T, Tlen> { |
371 | public: |
372 | typedef SplitVOpBase<GPUDevice, T, Tlen> Base; |
373 | explicit SplitVOpGPU(OpKernelConstruction* c) : Base(c) {} |
374 | |
375 | void Compute(OpKernelContext* context) override { |
376 | bool done = false; |
377 | std::vector<Tlen> split_sizes_vec; |
378 | Base::ComputeEasyCases(context, &done, &split_sizes_vec); |
379 | if (!context->status().ok() || done) { |
380 | return; |
381 | } |
382 | const int32_t num_split = Base::num_outputs(); |
383 | const Tensor& input = context->input(0); |
384 | const TensorShape& input_shape = input.shape(); |
385 | const int32_t split_dim_orig = context->input(2).flat<int32>()(0); |
386 | const int32_t split_dim = |
387 | split_dim_orig < 0 ? split_dim_orig + input.dims() : split_dim_orig; |
388 | OP_REQUIRES( |
389 | context, |
390 | FastBoundsCheck(input.NumElements(), std::numeric_limits<int32>::max()), |
391 | errors::InvalidArgument("Split on GPU requires input size " |
392 | "< max int32" )); |
393 | |
394 | int32_t prefix_dim_size; |
395 | int32_t split_dim_size; |
396 | int32_t suffix_dim_size; |
397 | std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) = |
398 | Base::template SetDims<int32>(input_shape, split_dim); |
399 | |
400 | // use the same approach as concat (see documentation there) |
401 | // reshape to 2D |
402 | |
403 | if (num_split > 16) { |
404 | GpuDeviceArrayOnHost<T*> ptrs(context, num_split); |
405 | OP_REQUIRES_OK(context, ptrs.Init()); |
406 | |
407 | GpuDeviceArrayOnHost<Tlen> offsets(context, num_split + 1); |
408 | OP_REQUIRES_OK(context, offsets.Init()); |
409 | |
410 | Tlen offset = 0; |
411 | int entry = split_sizes_vec[0]; |
412 | bool fixed_size = |
413 | std::all_of(split_sizes_vec.begin(), split_sizes_vec.end(), |
414 | [&entry](int n) { return n == entry; }); |
415 | |
416 | for (int i = 0; i < num_split; ++i) { |
417 | TensorShape output_shape(input_shape); |
418 | output_shape.set_dim(split_dim, split_sizes_vec[i]); |
419 | Tensor* result = nullptr; |
420 | OP_REQUIRES_OK(context, |
421 | context->allocate_output(i, output_shape, &result)); |
422 | ptrs.Set(i, result->flat<T>().data()); |
423 | offsets.Set(i, offset); |
424 | offset += split_sizes_vec[i] * suffix_dim_size; |
425 | } |
426 | offsets.Set(num_split, offset); |
427 | OP_REQUIRES_OK(context, ptrs.Finalize()); |
428 | OP_REQUIRES_OK(context, offsets.Finalize()); |
429 | |
430 | if (input.NumElements() > 0) { |
431 | SplitVOpGPULaunch<T, Tlen>().Run( |
432 | context->eigen_device<GPUDevice>(), fixed_size, |
433 | input.flat<T>().data(), prefix_dim_size, |
434 | input.NumElements() / prefix_dim_size, offsets.data(), ptrs.data()); |
435 | OP_REQUIRES( |
436 | context, context->op_device_context()->stream()->ok(), |
437 | errors::Internal("Launch of gpu kernel for SplitVOp failed" )); |
438 | } |
439 | } else { |
440 | Eigen::DenseIndex prefix_dim_size; |
441 | Eigen::DenseIndex split_dim_size; |
442 | Eigen::DenseIndex suffix_dim_size; |
443 | |
444 | std::tie(prefix_dim_size, split_dim_size, suffix_dim_size) = |
445 | Base::template SetDims<Eigen::DenseIndex>(input_shape, split_dim); |
446 | auto input_reshaped = input.shaped<T, 2>( |
447 | {prefix_dim_size, split_dim_size * suffix_dim_size}); |
448 | |
449 | Eigen::DSizes<Eigen::DenseIndex, 2> indices{0, 0}; |
450 | |
451 | for (int i = 0; i < num_split; ++i) { |
452 | TensorShape output_shape(input_shape); |
453 | output_shape.set_dim(split_dim, split_sizes_vec[i]); |
454 | Tensor* result = nullptr; |
455 | OP_REQUIRES_OK(context, |
456 | context->allocate_output(i, output_shape, &result)); |
457 | |
458 | Eigen::DSizes<Eigen::DenseIndex, 2> sizes{ |
459 | prefix_dim_size, split_sizes_vec[i] * suffix_dim_size}; |
460 | |
461 | if (sizes.TotalSize() > 0) { |
462 | auto result_shaped = result->shaped<T, 2>( |
463 | {prefix_dim_size, split_sizes_vec[i] * suffix_dim_size}); |
464 | |
465 | functor::SplitCustom<GPUDevice, T>()( |
466 | context->eigen_device<GPUDevice>(), result_shaped, input_reshaped, |
467 | indices, sizes); |
468 | } |
469 | indices[1] += split_sizes_vec[i] * suffix_dim_size; |
470 | } |
471 | } |
472 | } |
473 | }; |
474 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
475 | |
476 | #define REGISTER_SPLIT(type, len_type) \ |
477 | REGISTER_KERNEL_BUILDER(Name("SplitV") \ |
478 | .Device(DEVICE_CPU) \ |
479 | .TypeConstraint<len_type>("Tlen") \ |
480 | .TypeConstraint<type>("T") \ |
481 | .HostMemory("size_splits") \ |
482 | .HostMemory("split_dim"), \ |
483 | SplitVOpCPU<type, len_type>); |
484 | |
485 | #define REGISTER_SPLIT_LEN(type) \ |
486 | REGISTER_SPLIT(type, int8); \ |
487 | REGISTER_SPLIT(type, int32); \ |
488 | REGISTER_SPLIT(type, int64_t); |
489 | |
490 | TF_CALL_ALL_TYPES(REGISTER_SPLIT_LEN); |
491 | |
492 | #undef REGISTER_SPLIT_LEN |
493 | #undef REGISTER_SPLIT |
494 | |
495 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
496 | |
497 | #define REGISTER_GPU(type, len_type) \ |
498 | REGISTER_KERNEL_BUILDER(Name("SplitV") \ |
499 | .Device(DEVICE_GPU) \ |
500 | .TypeConstraint<len_type>("Tlen") \ |
501 | .TypeConstraint<type>("T") \ |
502 | .HostMemory("size_splits") \ |
503 | .HostMemory("split_dim"), \ |
504 | SplitVOpGPU<type, len_type>); |
505 | |
506 | #define REGISTER_GPU_LEN(type) \ |
507 | REGISTER_GPU(type, int8); \ |
508 | REGISTER_GPU(type, int32); \ |
509 | REGISTER_GPU(type, int64_t); |
510 | |
511 | TF_CALL_bfloat16(REGISTER_GPU_LEN); |
512 | TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_LEN); |
513 | TF_CALL_COMPLEX_TYPES(REGISTER_GPU_LEN); |
514 | #undef REGISTER_GPU_LEN |
515 | #undef REGISTER_GPU |
516 | |
517 | // special GPU kernel for int32 |
518 | |
519 | #define REGISTER_GPU_int32(len_type) \ |
520 | REGISTER_KERNEL_BUILDER(Name("SplitV") \ |
521 | .Device(DEVICE_GPU) \ |
522 | .TypeConstraint<int32>("T") \ |
523 | .TypeConstraint<len_type>("Tlen") \ |
524 | .HostMemory("size_splits") \ |
525 | .HostMemory("split_dim") \ |
526 | .HostMemory("value") \ |
527 | .HostMemory("output"), \ |
528 | SplitVOpCPU<int32, len_type>); |
529 | |
530 | REGISTER_GPU_int32(int32); |
531 | REGISTER_GPU_int32(int64_t); |
532 | |
533 | #undef REGISTER_GPU_int32 |
534 | |
535 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
536 | |
537 | } // end namespace tensorflow |
538 | |