1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #define EIGEN_USE_THREADS |
19 | |
20 | #include "tensorflow/core/kernels/sparse_tensor_dense_matmul_op.h" |
21 | |
22 | #include "tensorflow/core/framework/bounds_check.h" |
23 | #include "tensorflow/core/framework/op.h" |
24 | #include "tensorflow/core/framework/op_kernel.h" |
25 | #include "tensorflow/core/kernels/fill_functor.h" |
26 | #include "tensorflow/core/platform/bfloat16.h" |
27 | |
28 | namespace tensorflow { |
29 | |
30 | typedef Eigen::ThreadPoolDevice CPUDevice; |
31 | typedef Eigen::GpuDevice GPUDevice; |
32 | |
33 | template <typename Device, typename T, typename Tindices> |
34 | class SparseTensorDenseMatMulOp : public OpKernel { |
35 | public: |
36 | explicit SparseTensorDenseMatMulOp(OpKernelConstruction* ctx) |
37 | : OpKernel(ctx) { |
38 | OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_a" , &adjoint_a_)); |
39 | OP_REQUIRES_OK(ctx, ctx->GetAttr("adjoint_b" , &adjoint_b_)); |
40 | } |
41 | |
42 | void Compute(OpKernelContext* ctx) override { |
43 | const Tensor* a_indices; |
44 | const Tensor* a_values; |
45 | const Tensor* a_shape; |
46 | const Tensor* b; |
47 | OP_REQUIRES_OK(ctx, ctx->input("a_indices" , &a_indices)); |
48 | OP_REQUIRES_OK(ctx, ctx->input("a_values" , &a_values)); |
49 | OP_REQUIRES_OK(ctx, ctx->input("a_shape" , &a_shape)); |
50 | OP_REQUIRES_OK(ctx, ctx->input("b" , &b)); |
51 | |
52 | // Check that the dimensions of the two matrices are valid. |
53 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(b->shape()), |
54 | errors::InvalidArgument("Tensor 'b' is not a matrix" )); |
55 | |
56 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_shape->shape()), |
57 | errors::InvalidArgument("Tensor 'a_shape' is not a vector" )); |
58 | |
59 | OP_REQUIRES( |
60 | ctx, a_shape->NumElements() == 2, |
61 | errors::InvalidArgument("Tensor 'a_shape' must have 2 elements" )); |
62 | |
63 | OP_REQUIRES(ctx, TensorShapeUtils::IsVector(a_values->shape()), |
64 | errors::InvalidArgument("Tensor 'a_values' is not a vector" )); |
65 | |
66 | OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(a_indices->shape()), |
67 | errors::InvalidArgument("Tensor 'a_indices' is not a matrix" )); |
68 | |
69 | const int64_t nnz = a_indices->shape().dim_size(0); |
70 | OP_REQUIRES(ctx, nnz == a_values->NumElements(), |
71 | errors::InvalidArgument("Number of rows of a_indices does not " |
72 | "match number of entries in a_values" )); |
73 | |
74 | OP_REQUIRES( |
75 | ctx, a_indices->shape().dim_size(1) == a_shape->NumElements(), |
76 | errors::InvalidArgument("Number of columns of a_indices does not match " |
77 | "number of entries in a_shape" )); |
78 | |
79 | auto a_shape_t = a_shape->vec<int64_t>(); |
80 | const int64_t outer_left = (adjoint_a_) ? a_shape_t(1) : a_shape_t(0); |
81 | const int64_t outer_right = |
82 | (adjoint_b_) ? b->shape().dim_size(0) : b->shape().dim_size(1); |
83 | const int64_t inner_left = (adjoint_a_) ? a_shape_t(0) : a_shape_t(1); |
84 | const int64_t inner_right = |
85 | (adjoint_b_) ? b->shape().dim_size(1) : b->shape().dim_size(0); |
86 | |
87 | OP_REQUIRES( |
88 | ctx, inner_right == inner_left, |
89 | errors::InvalidArgument( |
90 | "Cannot multiply A and B because inner dimension does not match: " , |
91 | inner_left, " vs. " , inner_right, |
92 | ". Did you forget a transpose? " |
93 | "Dimensions of A: [" , |
94 | a_shape_t(0), ", " , a_shape_t(1), |
95 | "). Dimensions of B: " , b->shape().DebugString())); |
96 | |
97 | if (std::is_same<Device, GPUDevice>::value) { |
98 | // The GPU implementation is optimized to use 32 bit indexing, so |
99 | // give a friendly error to the programmer early on if they |
100 | // exceed. |
101 | const int int32max = std::numeric_limits<int>::max(); |
102 | OP_REQUIRES( |
103 | ctx, |
104 | (FastBoundsCheck(inner_left, int32max) && |
105 | FastBoundsCheck(inner_right, int32max) && |
106 | FastBoundsCheck(outer_left, int32max) && |
107 | FastBoundsCheck(outer_right, int32max) && |
108 | FastBoundsCheck(b->NumElements(), int32max) && |
109 | FastBoundsCheck(outer_left * outer_right, int32max) && |
110 | FastBoundsCheck(a_values->NumElements(), int32max)), |
111 | errors::InvalidArgument("Cannot use GPU for > 2^31 entry inputs" )); |
112 | OP_REQUIRES(ctx, FastBoundsCheck(nnz * outer_right, int32max), |
113 | errors::InvalidArgument( |
114 | "Cannot use GPU when output.shape[1] * nnz(a) > 2^31" )); |
115 | } |
116 | |
117 | TensorShape out_shape({outer_left, outer_right}); |
118 | Tensor* out = nullptr; |
119 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); |
120 | |
121 | if (out->NumElements() == 0) { |
122 | // If a has shape [0, x] or b has shape [x, 0], the output shape |
123 | // is a 0-element matrix, so there is nothing to do. |
124 | return; |
125 | } |
126 | |
127 | if (a_values->NumElements() == 0 || b->NumElements() == 0) { |
128 | // If a has shape [x, 0] and b has shape [0, y], the |
129 | // output shape is [x, y] where x and y are non-zero, so we fill |
130 | // the output with zeros. |
131 | functor::SetZeroFunctor<Device, T> f; |
132 | f(ctx->eigen_device<Device>(), out->flat<T>()); |
133 | return; |
134 | } |
135 | |
136 | #define MAYBE_ADJOINT(ADJ_A, ADJ_B) \ |
137 | if (adjoint_a_ == ADJ_A && adjoint_b_ == ADJ_B) { \ |
138 | Status functor_status = functor::SparseTensorDenseMatMulFunctor< \ |
139 | Device, T, Tindices, ADJ_A, \ |
140 | ADJ_B>::Compute(ctx, out->matrix<T>(), a_indices->matrix<Tindices>(), \ |
141 | a_values->vec<T>(), b->matrix<T>()); \ |
142 | OP_REQUIRES_OK(ctx, functor_status); \ |
143 | } |
144 | |
145 | MAYBE_ADJOINT(false, false); |
146 | MAYBE_ADJOINT(false, true); |
147 | MAYBE_ADJOINT(true, false); |
148 | MAYBE_ADJOINT(true, true); |
149 | |
150 | #undef MAYBE_ADJOINT |
151 | } |
152 | |
153 | private: |
154 | bool adjoint_a_; |
155 | bool adjoint_b_; |
156 | }; |
157 | |
158 | #define REGISTER_CPU(TypeT, TypeIndex) \ |
159 | REGISTER_KERNEL_BUILDER( \ |
160 | Name("SparseTensorDenseMatMul") \ |
161 | .Device(DEVICE_CPU) \ |
162 | .TypeConstraint<TypeT>("T") \ |
163 | .TypeConstraint<TypeIndex>("Tindices") \ |
164 | .HostMemory("a_shape"), \ |
165 | SparseTensorDenseMatMulOp<CPUDevice, TypeT, TypeIndex>); |
166 | |
167 | #define REGISTER_KERNELS_CPU(T) \ |
168 | REGISTER_CPU(T, int64_t); \ |
169 | REGISTER_CPU(T, int32) |
170 | |
171 | REGISTER_KERNELS_CPU(Eigen::half); |
172 | REGISTER_KERNELS_CPU(float); |
173 | REGISTER_KERNELS_CPU(double); |
174 | REGISTER_KERNELS_CPU(int32); |
175 | REGISTER_KERNELS_CPU(complex64); |
176 | REGISTER_KERNELS_CPU(complex128); |
177 | REGISTER_KERNELS_CPU(bfloat16); |
178 | |
179 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
180 | |
181 | namespace functor { |
182 | #define DECLARE_GPU_SPEC(T, Tindices, ADJ_A, ADJ_B) \ |
183 | template <> \ |
184 | Status SparseTensorDenseMatMulFunctor< \ |
185 | GPUDevice, T, Tindices, ADJ_A, \ |
186 | ADJ_B>::Compute(OpKernelContext* ctx, typename TTypes<T>::Matrix out, \ |
187 | TTypes<Tindices>::ConstMatrix a_indices, \ |
188 | typename TTypes<T>::ConstVec a_values, \ |
189 | typename TTypes<T>::ConstMatrix b); \ |
190 | extern template struct SparseTensorDenseMatMulFunctor< \ |
191 | GPUDevice, T, Tindices, ADJ_A, ADJ_B>; |
192 | |
193 | #define REGISTER_GPU_SPEC(T, ADJ_A, ADJ_B) \ |
194 | DECLARE_GPU_SPEC(T, int32, ADJ_A, ADJ_B); \ |
195 | DECLARE_GPU_SPEC(T, int64_t, ADJ_A, ADJ_B) |
196 | |
197 | #define DECLARE_ADJOINT_GPU_SPEC(T) \ |
198 | REGISTER_GPU_SPEC(T, false, false) \ |
199 | REGISTER_GPU_SPEC(T, false, true) \ |
200 | REGISTER_GPU_SPEC(T, true, false) \ |
201 | REGISTER_GPU_SPEC(T, true, true) |
202 | |
203 | DECLARE_ADJOINT_GPU_SPEC(Eigen::half); |
204 | DECLARE_ADJOINT_GPU_SPEC(float); |
205 | DECLARE_ADJOINT_GPU_SPEC(double); |
206 | DECLARE_ADJOINT_GPU_SPEC(complex64); |
207 | DECLARE_ADJOINT_GPU_SPEC(complex128); |
208 | |
209 | #undef DECLARE_ADJOINT_GPU_SPEC |
210 | #undef DECLARE_GPU_SPEC |
211 | #undef REGISTER_GPU_SPEC |
212 | |
213 | } // namespace functor |
214 | |
215 | #define REGISTER_GPU(TypeT, TypeIndex) \ |
216 | REGISTER_KERNEL_BUILDER( \ |
217 | Name("SparseTensorDenseMatMul") \ |
218 | .Device(DEVICE_GPU) \ |
219 | .TypeConstraint<TypeT>("T") \ |
220 | .TypeConstraint<TypeIndex>("Tindices") \ |
221 | .HostMemory("a_shape"), \ |
222 | SparseTensorDenseMatMulOp<GPUDevice, TypeT, TypeIndex>); |
223 | |
224 | #define REGISTER_KERNELS_GPU(T) \ |
225 | REGISTER_GPU(T, int64_t); \ |
226 | REGISTER_GPU(T, int32) |
227 | |
228 | REGISTER_KERNELS_GPU(Eigen::half); |
229 | REGISTER_KERNELS_GPU(float); |
230 | REGISTER_KERNELS_GPU(double); |
231 | REGISTER_KERNELS_GPU(complex64); |
232 | REGISTER_KERNELS_GPU(complex128); |
233 | |
234 | #undef REGISTER_GPU |
235 | #undef REGISTER_KERNELS_GPU |
236 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
237 | |
238 | namespace functor { |
239 | |
240 | namespace { |
241 | Status KOutOfBoundsError(int64_t k, std::size_t i, int rhs_index_a, |
242 | std::size_t lhs_right) { |
243 | return errors::InvalidArgument("k (" , k, ") from index[" , i, "," , rhs_index_a, |
244 | "] out of bounds (>=" , lhs_right, ")" ); |
245 | } |
246 | |
247 | Status MOutOfBoundsError(int64_t m, std::size_t i, int lhs_index_a, |
248 | int64_t out_dim0) { |
249 | return errors::InvalidArgument("m (" , m, ") from index[" , i, "," , lhs_index_a, |
250 | "] out of bounds (>=" , out_dim0, ")" ); |
251 | } |
252 | |
253 | template <typename T, typename Tsum, typename Tindices, bool ADJ_A, bool ADJ_B> |
254 | Status SparseTensorDenseMatMulImpl( |
255 | typename TTypes<Tsum>::Matrix out, |
256 | typename TTypes<Tindices>::ConstMatrix a_indices, |
257 | typename TTypes<T>::ConstVec a_values, typename TTypes<T>::ConstMatrix b) { |
258 | // Vectorize certain operations above this size. |
259 | static constexpr std::size_t kNumVectorize = 32; |
260 | |
261 | const std::size_t nnz = a_values.size(); |
262 | const std::size_t rhs_right = (ADJ_B ? b.dimension(0) : b.dimension(1)); |
263 | const std::size_t lhs_right = (ADJ_B ? b.dimension(1) : b.dimension(0)); |
264 | const int lhs_index_a = ADJ_A ? 1 : 0; |
265 | const int rhs_index_a = ADJ_A ? 0 : 1; |
266 | |
267 | // TODO(ebrevdo): After many failed experiments, can't find a multi-threaded |
268 | // approach that achieves the performance of the single threaded |
269 | // one. Perhaps Eigen threadpool implementation is just too slow? |
270 | |
271 | if (rhs_right < kNumVectorize) { |
272 | // Disable vectorization if the RHS of output is too small |
273 | auto maybe_adjoint_b = MaybeAdjoint<decltype(b), ADJ_B>(b); |
274 | |
275 | for (std::size_t i = 0; i < nnz; ++i) { |
276 | const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); |
277 | const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); |
278 | if (!FastBoundsCheck(k, lhs_right)) { |
279 | return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); |
280 | } |
281 | if (!FastBoundsCheck(m, out.dimension(0))) { |
282 | return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); |
283 | } |
284 | const T a_value = ADJ_A ? MaybeConj(a_values(i)) : a_values(i); |
285 | for (std::size_t n = 0; n < rhs_right; ++n) { |
286 | const T b_value = maybe_adjoint_b(k, n); |
287 | out(m, n) += static_cast<Tsum>(a_value) * static_cast<Tsum>(b_value); |
288 | } |
289 | } |
290 | } else { |
291 | // Vectorization via Eigen. |
292 | const int b_chip_index = ADJ_B ? 1 : 0; |
293 | |
294 | #define LOOP_NNZ(b_passed) \ |
295 | for (std::size_t i = 0; i < nnz; ++i) { \ |
296 | const Tindices m = internal::SubtleMustCopy(a_indices(i, lhs_index_a)); \ |
297 | const Tindices k = internal::SubtleMustCopy(a_indices(i, rhs_index_a)); \ |
298 | const T a_value = (ADJ_A) ? MaybeConj(a_values(i)) : a_values(i); \ |
299 | if (!FastBoundsCheck(k, lhs_right)) { \ |
300 | return KOutOfBoundsError(k, i, rhs_index_a, lhs_right); \ |
301 | } \ |
302 | if (!FastBoundsCheck(m, out.dimension(0))) { \ |
303 | return MOutOfBoundsError(m, i, lhs_index_a, out.dimension(0)); \ |
304 | } \ |
305 | out.template chip<0>(m) += \ |
306 | b_passed.template chip<b_chip_index>(k).template cast<Tsum>() * \ |
307 | static_cast<Tsum>(a_value); \ |
308 | } |
309 | |
310 | if (ADJ_B) { |
311 | // Perform transpose and conjugation on B once, since we chip out B's |
312 | // columns in the nnz loop. |
313 | Eigen::array<int, 2> shuffle(1, 0); // preserve dimension order |
314 | Eigen::Tensor<T, 2, Eigen::ColMajor> col_major_conj_b = |
315 | b.swap_layout().shuffle(shuffle).conjugate(); |
316 | LOOP_NNZ(col_major_conj_b); |
317 | } else { |
318 | LOOP_NNZ(b); |
319 | } |
320 | #undef LOOP_NNZ |
321 | } |
322 | return OkStatus(); |
323 | } |
324 | } // namespace |
325 | |
326 | template <typename T, typename Tindices, bool ADJ_A, bool ADJ_B> |
327 | struct SparseTensorDenseMatMulFunctor<CPUDevice, T, Tindices, ADJ_A, ADJ_B> { |
328 | static Status Compute(OpKernelContext* ctx, typename TTypes<T>::Matrix out, |
329 | typename TTypes<Tindices>::ConstMatrix a_indices, |
330 | typename TTypes<T>::ConstVec a_values, |
331 | typename TTypes<T>::ConstMatrix b) { |
332 | using Tsum = typename SumType<T>::type; |
333 | Tensor temp_out_t; |
334 | if (!std::is_same<T, Tsum>::value) { |
335 | TF_RETURN_IF_ERROR(ctx->allocate_temp( |
336 | DataTypeToEnum<Tsum>::value, |
337 | TensorShape({out.dimension(0), out.dimension(1)}), &temp_out_t)); |
338 | auto temp_out = temp_out_t.matrix<Tsum>(); |
339 | temp_out.setZero(); |
340 | TF_RETURN_IF_ERROR( |
341 | SparseTensorDenseMatMulImpl<T, Tsum, Tindices, ADJ_A, ADJ_B>( |
342 | temp_out, a_indices, a_values, b)); |
343 | out = temp_out.template cast<T>(); |
344 | } else { |
345 | out.setZero(); |
346 | // This reinterpret_cast is just to avoid a compilation error. The result |
347 | // is only used if Tsum == T. |
348 | auto out_workaround = |
349 | *reinterpret_cast<typename TTypes<Tsum>::Matrix*>(&out); |
350 | TF_RETURN_IF_ERROR( |
351 | SparseTensorDenseMatMulImpl<T, Tsum, Tindices, ADJ_A, ADJ_B>( |
352 | out_workaround, a_indices, a_values, b)); |
353 | } |
354 | return OkStatus(); |
355 | } |
356 | }; |
357 | |
358 | } // namespace functor |
359 | |
360 | } // namespace tensorflow |
361 | |