1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/array_ops.cc. |
17 | |
18 | #include <vector> |
19 | |
20 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
21 | #include "tensorflow/core/framework/op_kernel.h" |
22 | #include "tensorflow/core/framework/register_types.h" |
23 | #include "tensorflow/core/framework/tensor.h" |
24 | #include "tensorflow/core/framework/tensor_types.h" |
25 | #include "tensorflow/core/framework/types.h" |
26 | |
27 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
28 | |
29 | #include "tensorflow/core/kernels/concat_lib_gpu.h" |
30 | #include "tensorflow/core/kernels/gpu_device_array.h" |
31 | |
32 | namespace tensorflow { |
33 | namespace { |
34 | |
35 | template <typename T, typename IntType> |
36 | void ConcatGPUCall( |
37 | OpKernelContext* c, |
38 | const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& |
39 | inputs_flat, |
40 | typename TTypes<T, 2>::Tensor* output_flat) { |
41 | GpuDeviceArrayOnHost<const T*> input_ptrs(c, inputs_flat.size()); |
42 | OP_REQUIRES_OK(c, input_ptrs.Init()); |
43 | for (int i = 0; i < inputs_flat.size(); ++i) { |
44 | input_ptrs.Set(i, inputs_flat[i]->data()); |
45 | } |
46 | OP_REQUIRES_OK(c, input_ptrs.Finalize()); |
47 | |
48 | GpuDeviceArrayOnHost<IntType> output_scan(c, inputs_flat.size() + 1); |
49 | OP_REQUIRES_OK(c, output_scan.Init()); |
50 | IntType scan = 0; |
51 | output_scan.Set(0, scan); |
52 | bool one_size_input = true; |
53 | for (int i = 0; i < inputs_flat.size(); ++i) { |
54 | if (one_size_input && i < inputs_flat.size() - 1 && |
55 | inputs_flat[i]->dimension(1) != inputs_flat[i + 1]->dimension(1)) { |
56 | one_size_input = false; |
57 | } |
58 | scan += inputs_flat[i]->dimension(1); |
59 | output_scan.Set(i + 1, scan); |
60 | } |
61 | if (!one_size_input) OP_REQUIRES_OK(c, output_scan.Finalize()); |
62 | |
63 | ConcatGPUImpl<T, IntType>(c->eigen_gpu_device(), input_ptrs.data(), |
64 | output_scan.data(), one_size_input, |
65 | inputs_flat[0]->dimension(1), output_flat); |
66 | } |
67 | |
68 | } // end namespace |
69 | |
70 | template <typename T> |
71 | void ConcatGPU( |
72 | OpKernelContext* c, |
73 | const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& |
74 | inputs_flat, |
75 | Tensor* output, typename TTypes<T, 2>::Tensor* output_flat) { |
76 | if (inputs_flat.size() < 16) { |
77 | if (output->NumElements() < std::numeric_limits<int32>::max()) { |
78 | ConcatGPUSlice<T, int32>(c->eigen_gpu_device(), inputs_flat, output_flat); |
79 | } else { |
80 | ConcatGPUSlice<T, int64_t>(c->eigen_gpu_device(), inputs_flat, |
81 | output_flat); |
82 | } |
83 | } else { |
84 | // Switching indexing to int64 might cause performance issues. |
85 | // Hence, we keep int32 indexing in the GPU kernel unless we need to |
86 | // switch to int64. |
87 | if (output->NumElements() < std::numeric_limits<int32>::max()) { |
88 | ConcatGPUCall<T, int32>(c, inputs_flat, output_flat); |
89 | } else { |
90 | ConcatGPUCall<T, int64_t>(c, inputs_flat, output_flat); |
91 | } |
92 | } |
93 | } |
94 | |
95 | #define REGISTER(T) \ |
96 | template void ConcatGPU<T>( \ |
97 | OpKernelContext * c, \ |
98 | const std::vector<std::unique_ptr<typename TTypes<T, 2>::ConstMatrix>>& \ |
99 | inputs_flat, \ |
100 | Tensor* output, typename TTypes<T, 2>::Tensor* output_flat); |
101 | |
102 | TF_CALL_INTEGRAL_TYPES(REGISTER); // int32 Needed for TensorLists. |
103 | TF_CALL_bfloat16(REGISTER); |
104 | TF_CALL_GPU_ALL_TYPES(REGISTER); |
105 | |
106 | #undef REGISTER |
107 | |
108 | } // namespace tensorflow |
109 | |
110 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
111 | |