1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | // See docs in ../ops/math_ops.cc. |
17 | |
18 | #ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ |
19 | #define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ |
20 | |
21 | #define EIGEN_USE_THREADS |
22 | |
23 | #include <type_traits> |
24 | #include <utility> |
25 | #include <vector> |
26 | |
27 | #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/framework/register_types.h" |
30 | #include "tensorflow/core/framework/tensor.h" |
31 | #include "tensorflow/core/framework/tensor_shape.h" |
32 | #include "tensorflow/core/framework/type_traits.h" |
33 | #include "tensorflow/core/framework/types.h" |
34 | #include "tensorflow/core/kernels/fill_functor.h" |
35 | #include "tensorflow/core/lib/core/errors.h" |
36 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
37 | #include "tensorflow/core/platform/logging.h" |
38 | #include "tensorflow/core/platform/types.h" |
39 | #include "tensorflow/core/util/matmul_autotune.h" |
40 | #include "tensorflow/core/util/matmul_bcast.h" |
41 | #include "tensorflow/core/util/work_sharder.h" |
42 | |
43 | #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL) |
44 | #include "tensorflow/core/kernels/eigen_contraction_kernel.h" |
45 | #endif |
46 | |
47 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
48 | #include "tensorflow/core/kernels/gpu_utils.h" |
49 | #include "tensorflow/core/platform/stream_executor.h" |
50 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
51 | #if GOOGLE_CUDA |
52 | #include "third_party/gpus/cuda/include/cuda.h" |
53 | #include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h" |
54 | #include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h" |
55 | #include "tensorflow/core/kernels/matmul_util.h" |
56 | #endif // GOOGLE_CUDA |
57 | |
58 | namespace tensorflow { |
59 | |
60 | typedef Eigen::ThreadPoolDevice CPUDevice; |
61 | typedef Eigen::GpuDevice GPUDevice; |
62 | |
63 | namespace { |
64 | |
65 | // Returns the pair of dimensions along which to perform Tensor contraction to |
66 | // emulate matrix multiplication. |
67 | // For matrix multiplication of 2D Tensors X and Y, X is contracted along |
68 | // second dimension and Y is contracted along the first dimension (if neither X |
69 | // nor Y is adjointed). The dimension to contract along is switched when any |
70 | // operand is adjointed. |
71 | // See http://en.wikipedia.org/wiki/Tensor_contraction |
72 | Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) { |
73 | return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0); |
74 | } |
75 | |
76 | // Parallel batch matmul kernel based on the multi-threaded tensor contraction |
77 | // in Eigen. |
78 | template <typename Scalar, bool IsComplex = true> |
79 | struct ParallelMatMulKernel { |
80 | static void Conjugate(const OpKernelContext* context, Tensor* out) { |
81 | const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); |
82 | auto z = out->tensor<Scalar, 3>(); |
83 | z.device(d) = z.conjugate(); |
84 | } |
85 | |
86 | static void Run(const OpKernelContext* context, const Tensor& in_x, |
87 | const Tensor in_y, bool adj_x, bool adj_y, bool trans_x, |
88 | bool trans_y, const MatMulBCast& bcast, Tensor* out, |
89 | int batch_size) { |
90 | static_assert(IsComplex, "Complex type expected." ); |
91 | auto Tx = in_x.tensor<Scalar, 3>(); |
92 | auto Ty = in_y.tensor<Scalar, 3>(); |
93 | auto Tz = out->tensor<Scalar, 3>(); |
94 | // We use the identities |
95 | // conj(a) * conj(b) = conj(a * b) |
96 | // conj(a) * b = conj(a * conj(b)) |
97 | // to halve the number of cases. The final conjugation of the result is |
98 | // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch(). |
99 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; |
100 | contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y); |
101 | const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); |
102 | |
103 | const bool should_bcast = bcast.IsBroadcastingRequired(); |
104 | const auto& x_batch_indices = bcast.x_batch_indices(); |
105 | const auto& y_batch_indices = bcast.y_batch_indices(); |
106 | // TODO(rmlarsen): Consider launching these contractions asynchronously. |
107 | for (int64_t i = 0; i < batch_size; ++i) { |
108 | const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; |
109 | const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; |
110 | |
111 | auto x = Tx.template chip<0>(x_batch_index); |
112 | auto z = Tz.template chip<0>(i); |
113 | if (adj_x != adj_y) { |
114 | auto y = Ty.template chip<0>(y_batch_index).conjugate(); |
115 | z.device(d) = x.contract(y, contract_pairs); |
116 | } else { |
117 | auto y = Ty.template chip<0>(y_batch_index); |
118 | z.device(d) = x.contract(y, contract_pairs); |
119 | } |
120 | } |
121 | } |
122 | }; |
123 | |
124 | // The Eigen contraction kernel used here is very large and slow to compile, |
125 | // so we partially specialize ParallelMatMulKernel for real types to avoid all |
126 | // but one of the instantiations. |
127 | template <typename Scalar> |
128 | struct ParallelMatMulKernel<Scalar, false> { |
129 | static void Conjugate(const OpKernelContext* context, Tensor* out) {} |
130 | |
131 | static void Run(const OpKernelContext* context, const Tensor& in_x, |
132 | const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, |
133 | bool trans_y, const MatMulBCast& bcast, Tensor* out, |
134 | int batch_size) { |
135 | const bool should_bcast = bcast.IsBroadcastingRequired(); |
136 | const Eigen::ThreadPoolDevice d = context->eigen_cpu_device(); |
137 | Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs; |
138 | contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y); |
139 | if (batch_size == 1 && !should_bcast) { |
140 | auto Tx = in_x.flat_inner_dims<Scalar, 2>(); |
141 | auto Ty = in_y.flat_inner_dims<Scalar, 2>(); |
142 | auto Tz = out->flat_inner_dims<Scalar, 2>(); |
143 | Tz.device(d) = Tx.contract(Ty, contract_pairs); |
144 | } else { |
145 | auto Tx = in_x.tensor<Scalar, 3>(); |
146 | auto Ty = in_y.tensor<Scalar, 3>(); |
147 | auto Tz = out->tensor<Scalar, 3>(); |
148 | const auto& x_batch_indices = bcast.x_batch_indices(); |
149 | const auto& y_batch_indices = bcast.y_batch_indices(); |
150 | // TODO(rmlarsen): Consider launching these contractions asynchronously. |
151 | for (int64_t i = 0; i < batch_size; ++i) { |
152 | const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; |
153 | const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; |
154 | auto x = Tx.template chip<0>(x_batch_index); |
155 | auto y = Ty.template chip<0>(y_batch_index); |
156 | auto z = Tz.template chip<0>(i); |
157 | |
158 | z.device(d) = x.contract(y, contract_pairs); |
159 | } |
160 | } |
161 | } |
162 | }; |
163 | |
164 | // Sequential batch matmul kernel that calls the regular Eigen matmul. |
165 | // We prefer this over the tensor contraction because it performs |
166 | // better on vector-matrix and matrix-vector products. |
167 | template <typename Scalar> |
168 | struct SequentialMatMulKernel { |
169 | using Matrix = |
170 | Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>; |
171 | using ConstMatrixMap = Eigen::Map<const Matrix>; |
172 | using MatrixMap = Eigen::Map<Matrix>; |
173 | |
174 | static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t, |
175 | int slice) { |
176 | return ConstMatrixMap( |
177 | t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2), |
178 | t.dim_size(1), t.dim_size(2)); |
179 | } |
180 | |
181 | static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) { |
182 | return MatrixMap( |
183 | t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2), |
184 | t->dim_size(1), t->dim_size(2)); |
185 | } |
186 | |
187 | static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x, |
188 | bool adj_y, bool trans_x, bool trans_y, |
189 | const MatMulBCast& bcast, Tensor* out, int start, int limit) { |
190 | const bool should_bcast = bcast.IsBroadcastingRequired(); |
191 | const auto& x_batch_indices = bcast.x_batch_indices(); |
192 | const auto& y_batch_indices = bcast.y_batch_indices(); |
193 | for (int64_t i = start; i < limit; ++i) { |
194 | const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i; |
195 | const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i; |
196 | auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index); |
197 | auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index); |
198 | auto z = TensorSliceToEigenMatrix(out, i); |
199 | // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y |
200 | // and trans_y. |
201 | if (!adj_x && !trans_x) { |
202 | if (!adj_y && !trans_y) { |
203 | z.noalias() = x * y; |
204 | } else if (adj_y) { |
205 | z.noalias() = x * y.adjoint(); |
206 | } else { // trans_y == true |
207 | z.noalias() = x * y.transpose(); |
208 | } |
209 | } else if (adj_x) { |
210 | if (!adj_y && !trans_y) { |
211 | z.noalias() = x.adjoint() * y; |
212 | } else if (adj_y) { |
213 | z.noalias() = x.adjoint() * y.adjoint(); |
214 | } else { // trans_y == true |
215 | z.noalias() = x.adjoint() * y.transpose(); |
216 | } |
217 | } else { // trans_x == true |
218 | if (!adj_y && !trans_y) { |
219 | z.noalias() = x.transpose() * y; |
220 | } else if (adj_y) { |
221 | z.noalias() = x.transpose() * y.adjoint(); |
222 | } else { // trans_y == true |
223 | z.noalias() = x.transpose() * y.transpose(); |
224 | } |
225 | } |
226 | } |
227 | } |
228 | }; |
229 | } // namespace |
230 | |
231 | template <typename Device, typename Scalar> |
232 | struct LaunchBatchMatMul; |
233 | |
234 | template <typename Scalar> |
235 | struct LaunchBatchMatMul<CPUDevice, Scalar> { |
236 | static void Launch(OpKernelContext* context, const Tensor& in_x, |
237 | const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, |
238 | bool trans_y, const MatMulBCast& bcast, Tensor* out) { |
239 | typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex> |
240 | ParallelMatMulKernel; |
241 | bool conjugate_result = false; |
242 | |
243 | // Number of matrix multiplies i.e. size of the batch. |
244 | const int64_t batch_size = bcast.output_batch_size(); |
245 | const int64_t cost_per_unit = |
246 | in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2); |
247 | const int64_t small_dim = std::min( |
248 | std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2)); |
249 | // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of |
250 | // Jan 21, 2020. |
251 | const int64_t kMaxCostOuterParallelism = 128 * 128; // heuristic. |
252 | auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads()); |
253 | // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous |
254 | // evaluation in Eigen Tensor. |
255 | if (small_dim > 1 && |
256 | (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) { |
257 | // Parallelize over inner dims. |
258 | // For large matrix products it is counter-productive to parallelize |
259 | // over the batch dimension. |
260 | ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x, |
261 | trans_y, bcast, out, batch_size); |
262 | conjugate_result = adj_x; |
263 | } else { |
264 | // Parallelize over outer dims. For small matrices and large batches, it |
265 | // is counter-productive to parallelize the inner matrix multiplies. |
266 | Shard(worker_threads.num_threads, worker_threads.workers, batch_size, |
267 | cost_per_unit, |
268 | [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out]( |
269 | int start, int limit) { |
270 | SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y, |
271 | trans_x, trans_y, bcast, out, |
272 | start, limit); |
273 | }); |
274 | } |
275 | if (conjugate_result) { |
276 | // We used one of the identities |
277 | // conj(a) * conj(b) = conj(a * b) |
278 | // conj(a) * b = conj(a * conj(b)) |
279 | // above, we need to conjugate the final output. This is a |
280 | // no-op for non-complex types. |
281 | ParallelMatMulKernel::Conjugate(context, out); |
282 | } |
283 | } |
284 | }; |
285 | |
286 | #if GOOGLE_CUDA |
287 | |
288 | namespace { |
289 | // A dummy type to group matmul autotune results together. |
290 | struct BlasLtMatmulAutoTuneGroup { |
291 | static string name() { return "MatmulLt" ; } |
292 | }; |
293 | |
294 | typedef AutotuneSingleton<BlasLtMatmulAutoTuneGroup, BlasLtMatmulPlanParams, |
295 | se::blas::AlgorithmConfig, |
296 | absl::Hash<BlasLtMatmulPlanParams>> |
297 | AutoTuneBatchMatmul; |
298 | |
299 | } // namespace |
300 | |
301 | #endif // GOOGLE_CUDA |
302 | #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
303 | |
304 | class BlasScratchAllocator : public se::ScratchAllocator { |
305 | public: |
306 | using Stream = se::Stream; |
307 | using DeviceMemoryBytes = se::DeviceMemory<uint8>; |
308 | |
309 | BlasScratchAllocator(OpKernelContext* context) |
310 | : memory_limit_(0), total_byte_size_(0), context_(context) {} |
311 | |
312 | BlasScratchAllocator(OpKernelContext* context, int64_t memory_limit) |
313 | : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {} |
314 | |
315 | int64_t GetMemoryLimitInBytes() override { return memory_limit_; } |
316 | |
317 | se::port::StatusOr<DeviceMemoryBytes> AllocateBytes( |
318 | int64_t byte_size) override { |
319 | Tensor temporary_memory; |
320 | |
321 | if (memory_limit_ > 0 && byte_size > memory_limit_) { |
322 | return se::port::Status{ |
323 | se::port::error::UNAVAILABLE, |
324 | absl::StrCat("Requested memory size (" , byte_size, |
325 | ") exceeds the memory limit (" , memory_limit_, ")." )}; |
326 | } |
327 | AllocationAttributes allocation_attr; |
328 | allocation_attr.retry_on_failure = false; |
329 | Status allocation_status(context_->allocate_temp( |
330 | DT_UINT8, TensorShape({byte_size}), &temporary_memory)); |
331 | if (!allocation_status.ok()) { |
332 | return se::port::Status{ |
333 | se::port::error::UNAVAILABLE, |
334 | absl::StrCat("Failed to allocate requested memory of (" , byte_size, |
335 | ")." )}; |
336 | } |
337 | // Hold the reference of the allocated tensors until the end of the |
338 | // allocator. |
339 | allocated_tensors_.push_back(temporary_memory); |
340 | total_byte_size_ += byte_size; |
341 | return se::port::StatusOr<DeviceMemoryBytes>( |
342 | DeviceMemoryBytes::MakeFromByteSize( |
343 | temporary_memory.flat<uint8>().data(), |
344 | temporary_memory.flat<uint8>().size())); |
345 | } |
346 | int64 TotalByteSize() { return total_byte_size_; } |
347 | |
348 | private: |
349 | int64_t memory_limit_; |
350 | int64_t total_byte_size_; |
351 | OpKernelContext* context_; |
352 | std::vector<Tensor> allocated_tensors_; |
353 | }; |
354 | |
355 | template <typename Scalar> |
356 | struct LaunchBatchMatMul<GPUDevice, Scalar> { |
357 | static void Launch(OpKernelContext* context, const Tensor& in_x, |
358 | const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x, |
359 | bool trans_y, const MatMulBCast& bcast, Tensor* out) { |
360 | static const bool use_autotune = MatmulAutotuneEnable(); |
361 | se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose, |
362 | se::blas::Transpose::kTranspose, |
363 | se::blas::Transpose::kConjugateTranspose}; |
364 | const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1); |
365 | const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2); |
366 | const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2); |
367 | const int64_t batch_size = bcast.output_batch_size(); |
368 | auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)]; |
369 | auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)]; |
370 | |
371 | auto* stream = context->op_device_context()->stream(); |
372 | OP_REQUIRES(context, stream, errors::Internal("No GPU stream available." )); |
373 | |
374 | typedef se::DeviceMemory<Scalar> DeviceMemoryType; |
375 | std::vector<DeviceMemoryType> a_device_memory; |
376 | std::vector<DeviceMemoryType> b_device_memory; |
377 | std::vector<DeviceMemoryType> c_device_memory; |
378 | std::vector<DeviceMemoryType*> a_ptrs; |
379 | std::vector<DeviceMemoryType*> b_ptrs; |
380 | std::vector<DeviceMemoryType*> c_ptrs; |
381 | a_device_memory.reserve(bcast.x_batch_size()); |
382 | b_device_memory.reserve(bcast.y_batch_size()); |
383 | c_device_memory.reserve(batch_size); |
384 | a_ptrs.reserve(batch_size); |
385 | b_ptrs.reserve(batch_size); |
386 | c_ptrs.reserve(batch_size); |
387 | auto* a_base_ptr = in_x.template flat<Scalar>().data(); |
388 | auto* b_base_ptr = in_y.template flat<Scalar>().data(); |
389 | auto* c_base_ptr = out->template flat<Scalar>().data(); |
390 | uint64 a_stride; |
391 | uint64 b_stride; |
392 | uint64 c_stride; |
393 | |
394 | bool is_full_broadcast = |
395 | std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1; |
396 | |
397 | // Use float as coefficient type for half precision inputs, otherwise use |
398 | // the input type. |
399 | typedef std::conditional_t<std::is_same_v<Scalar, Eigen::half>, float, |
400 | Scalar> |
401 | Coefficient; |
402 | |
403 | #if GOOGLE_CUDA |
404 | if (EnableCublasLtGemm()) { |
405 | static const int64_t max_scratch_size = |
406 | GetWorkspaceLimit(1LL << 32); // 4GB by default |
407 | |
408 | bool requires_mixed_broadcasting = |
409 | bcast.IsBroadcastingRequired() && !is_full_broadcast; |
410 | |
411 | if (!requires_mixed_broadcasting) { |
412 | a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); |
413 | b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); |
414 | c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); |
415 | a_ptrs.push_back(&a_device_memory.back()); |
416 | b_ptrs.push_back(&b_device_memory.back()); |
417 | c_ptrs.push_back(&c_device_memory.back()); |
418 | |
419 | BlasLtMatmulPlanParams matmul_params{ |
420 | se::blas::ToDataType<Scalar>::value, |
421 | static_cast<size_t>(m), |
422 | static_cast<size_t>(n), |
423 | static_cast<size_t>(k), |
424 | blas_transpose_a, |
425 | blas_transpose_b, |
426 | static_cast<size_t>(batch_size), |
427 | /*broadcast_a=*/bcast.x_batch_size() == 1, |
428 | /*broadcast_b=*/bcast.y_batch_size() == 1}; |
429 | |
430 | std::optional<int> max_algorithm_count; |
431 | if (!use_autotune) max_algorithm_count = 1; |
432 | |
433 | auto plan_and_algorithms_or = |
434 | GetPlanAndAlgorithms(stream, matmul_params, max_algorithm_count); |
435 | OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); |
436 | const auto* plan_and_algorithms = |
437 | std::move(plan_and_algorithms_or).value(); |
438 | const auto& plan = plan_and_algorithms->plan; |
439 | const auto& algorithms = plan_and_algorithms->algorithms; |
440 | |
441 | se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream); |
442 | OP_REQUIRES(context, blas_lt != nullptr, |
443 | errors::Internal("blaslt not supported" )); |
444 | |
445 | se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm); |
446 | if (!use_autotune) { |
447 | algorithm_config.set_algorithm(0); |
448 | } else if (!AutoTuneBatchMatmul::GetInstance()->Find( |
449 | matmul_params, &algorithm_config)) { |
450 | VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size() |
451 | << " algorithms." ; |
452 | se::blas::ProfileResult best_result; |
453 | se::blas::ProfileResult profile_result; |
454 | |
455 | for (size_t i = 0; i != algorithms.size(); ++i) { |
456 | const auto& profile_algorithm = algorithms[i]; |
457 | // Create a new scratch allocator with every autotuning run so that |
458 | // scratch space is deallocated between runs. |
459 | BlasScratchAllocator scratch_allocator(context, max_scratch_size); |
460 | Status cublas_launch_status = |
461 | DoBlasLtMatmul(stream, plan, *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], |
462 | profile_algorithm, scratch_allocator, |
463 | /*bias = */ {}, &profile_result); |
464 | |
465 | VLOG(4) << " Autotune algorithm " << i |
466 | << " result: " << profile_result.elapsed_time_in_ms() |
467 | << " ms, valid=" << profile_result.is_valid() |
468 | << ", workspace_size=" << profile_algorithm.workspace_size; |
469 | |
470 | if (cublas_launch_status.ok() && profile_result.is_valid() && |
471 | profile_result.elapsed_time_in_ms() < |
472 | best_result.elapsed_time_in_ms()) { |
473 | best_result = profile_result; |
474 | // Use index into algorithms array, instead of cublas internal ID. |
475 | best_result.set_algorithm(i); |
476 | } |
477 | } |
478 | |
479 | if (best_result.is_valid()) { |
480 | algorithm_config.set_algorithm(best_result.algorithm()); |
481 | } |
482 | // Each matmul parameter set gets one pass of |
483 | // autotune. If no algorithms works, kNoAlgorithm is added to the |
484 | // autotune map. |
485 | AutoTuneBatchMatmul::GetInstance()->Insert(matmul_params, |
486 | algorithm_config); |
487 | } |
488 | se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm(); |
489 | OP_REQUIRES(context, |
490 | 0 <= algorithm_idx && algorithm_idx < algorithms.size(), |
491 | errors::Internal("Missing/invalid BatchMatmul algorithm" )); |
492 | const auto& algorithm = algorithms[algorithm_idx]; |
493 | BlasScratchAllocator scratch_allocator(context, max_scratch_size); |
494 | VLOG(4) << "Calling BlasLtMatMul: a.shape=(" << bcast.x_batch_size() |
495 | << ", " << in_x.dim_size(1) << ", " << in_x.dim_size(2) |
496 | << "), b.shape=(" << bcast.y_batch_size() << ", " |
497 | << in_y.dim_size(1) << ", " << in_y.dim_size(2) << "), m=" << m |
498 | << ", n=" << n << ", k=" << k << ", batch_size=" << batch_size |
499 | << "trans_x = " << trans_x << "trans_y = " << trans_y |
500 | << "adj_x = " << adj_x << "adj_y = " << adj_y; |
501 | |
502 | OP_REQUIRES_OK( |
503 | context, DoBlasLtMatmul(stream, plan, *a_ptrs[0], *b_ptrs[0], |
504 | *c_ptrs[0], algorithm, scratch_allocator)); |
505 | } else { // requires mixed broadcasting |
506 | const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices(); |
507 | const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices(); |
508 | for (int64_t i = 0; i < bcast.x_batch_size(); ++i) { |
509 | a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); |
510 | } |
511 | for (int64_t i = 0; i < bcast.y_batch_size(); ++i) { |
512 | b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); |
513 | } |
514 | for (int64_t i = 0; i < batch_size; ++i) { |
515 | c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); |
516 | a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); |
517 | b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); |
518 | c_ptrs.push_back(&c_device_memory.back()); |
519 | } |
520 | |
521 | BlasScratchAllocator scratch_allocator(context, max_scratch_size); |
522 | bool blas_launch_status = |
523 | stream |
524 | ->ThenBlasGemmBatchedWithScratch( |
525 | blas_transpose_b, blas_transpose_a, n, m, k, |
526 | static_cast<Coefficient>(1.0), b_ptrs, |
527 | adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, |
528 | static_cast<Coefficient>(0.0), c_ptrs, n, batch_size, |
529 | &scratch_allocator) |
530 | .ok(); |
531 | if (!blas_launch_status) { |
532 | context->SetStatus(errors::Internal( |
533 | "Blas xGEMMBatched launch failed: a.shape=" , |
534 | in_x.shape().DebugString(), |
535 | ", b.shape=" , in_y.shape().DebugString(), ", m=" , m, ", n=" , n, |
536 | ", k=" , k, ", batch_size=" , batch_size)); |
537 | } |
538 | } |
539 | } else { |
540 | #endif // GOOGLE_CUDA |
541 | bool use_strided_batched = |
542 | (!bcast.IsBroadcastingRequired() || is_full_broadcast) && |
543 | batch_size > 1; |
544 | if (use_strided_batched) { |
545 | a_stride = bcast.x_batch_size() != 1 ? m * k : 0; |
546 | b_stride = bcast.y_batch_size() != 1 ? k * n : 0; |
547 | c_stride = m * n; |
548 | a_device_memory.push_back(AsDeviceMemory(a_base_ptr)); |
549 | b_device_memory.push_back(AsDeviceMemory(b_base_ptr)); |
550 | c_device_memory.push_back(AsDeviceMemory(c_base_ptr)); |
551 | a_ptrs.push_back(&a_device_memory.back()); |
552 | b_ptrs.push_back(&b_device_memory.back()); |
553 | c_ptrs.push_back(&c_device_memory.back()); |
554 | } else if (!bcast.IsBroadcastingRequired()) { |
555 | for (int64_t i = 0; i < batch_size; ++i) { |
556 | a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); |
557 | b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); |
558 | c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); |
559 | a_ptrs.push_back(&a_device_memory.back()); |
560 | b_ptrs.push_back(&b_device_memory.back()); |
561 | c_ptrs.push_back(&c_device_memory.back()); |
562 | } |
563 | } else { |
564 | const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices(); |
565 | const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices(); |
566 | for (int64_t i = 0; i < bcast.x_batch_size(); ++i) { |
567 | a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k)); |
568 | } |
569 | for (int64_t i = 0; i < bcast.y_batch_size(); ++i) { |
570 | b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n)); |
571 | } |
572 | for (int64_t i = 0; i < batch_size; ++i) { |
573 | c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n)); |
574 | a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]); |
575 | b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]); |
576 | c_ptrs.push_back(&c_device_memory.back()); |
577 | } |
578 | } |
579 | |
580 | // Blas does |
581 | // C = A x B |
582 | // where A, B and C are assumed to be in column major. |
583 | // We want the output to be in row-major, so we can compute |
584 | // C' = B' x A', where ' stands for transpose (not adjoint). |
585 | // TODO(yangzihao): Choose the best of the three strategies using |
586 | // autotune. |
587 | if (batch_size == 1) { |
588 | // This is a regular matrix*matrix or matrix*vector multiply. Avoid the |
589 | // overhead of the scratch allocator and the batch interface. |
590 | // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS |
591 | if constexpr (!std::is_same_v<Scalar, Eigen::half>) { |
592 | if (n == 1 && |
593 | blas_transpose_b != se::blas::Transpose::kConjugateTranspose && |
594 | blas_transpose_a != se::blas::Transpose::kConjugateTranspose) { |
595 | // This is a matrix*vector multiply so use GEMV to compute A * b. |
596 | // Here we are multiplying in the natural order, so we have to flip |
597 | // the transposition flag to compensate for the tensor being stored |
598 | // row-major. Since GEMV doesn't provide a way to just conjugate an |
599 | // argument, we have to defer those cases to GEMM below. |
600 | auto gemv_trans_a = |
601 | blas_transpose_a == se::blas::Transpose::kTranspose |
602 | ? se::blas::Transpose::kNoTranspose |
603 | : se::blas::Transpose::kTranspose; |
604 | bool blas_launch_status = |
605 | stream |
606 | ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k, |
607 | adj_x || trans_x ? k : m, |
608 | static_cast<Coefficient>(1.0), *(a_ptrs[0]), |
609 | adj_x || trans_x ? m : k, *(b_ptrs[0]), 1, |
610 | static_cast<Coefficient>(0.0), c_ptrs[0], 1) |
611 | .ok(); |
612 | if (!blas_launch_status) { |
613 | context->SetStatus(errors::Internal( |
614 | "Blas xGEMV launch failed : a.shape=" , |
615 | in_x.shape().DebugString(), ", b.shape=" , |
616 | in_y.shape().DebugString(), ", m=" , m, ", n=" , n, ", k=" , k)); |
617 | } |
618 | return; |
619 | } |
620 | } |
621 | |
622 | OP_REQUIRES_OK(context, |
623 | stream->ThenBlasGemm( |
624 | blas_transpose_b, blas_transpose_a, n, m, k, |
625 | *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]), |
626 | adj_x || trans_x ? m : k, c_ptrs[0], n, |
627 | se::blas::kDefaultComputePrecision)); |
628 | } else if (use_strided_batched) { |
629 | OP_REQUIRES_OK( |
630 | context, stream->ThenBlasGemmStridedBatched( |
631 | blas_transpose_b, blas_transpose_a, n, m, k, |
632 | static_cast<Coefficient>(1.0), *b_ptrs[0], |
633 | adj_y || trans_y ? k : n, b_stride, *a_ptrs[0], |
634 | adj_x || trans_x ? m : k, a_stride, |
635 | static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride, |
636 | batch_size, se::blas::kDefaultComputePrecision)); |
637 | } else { |
638 | BlasScratchAllocator scratch_allocator(context); |
639 | bool blas_launch_status = |
640 | stream |
641 | ->ThenBlasGemmBatchedWithScratch( |
642 | blas_transpose_b, blas_transpose_a, n, m, k, |
643 | static_cast<Coefficient>(1.0), b_ptrs, |
644 | adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k, |
645 | static_cast<Coefficient>(0.0), c_ptrs, n, batch_size, |
646 | &scratch_allocator) |
647 | .ok(); |
648 | if (!blas_launch_status) { |
649 | context->SetStatus(errors::Internal( |
650 | "Blas xGEMMBatched launch failed : a.shape=" , |
651 | in_x.shape().DebugString(), |
652 | ", b.shape=" , in_y.shape().DebugString(), ", m=" , m, ", n=" , n, |
653 | ", k=" , k, ", batch_size=" , batch_size)); |
654 | } |
655 | } |
656 | #if GOOGLE_CUDA |
657 | } |
658 | #endif // GOOGLE_CUDA |
659 | } |
660 | }; |
661 | |
662 | #endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM |
663 | |
664 | template <typename Device, typename Ta, typename Tb, typename Tout> |
665 | class BaseBatchMatMulOp : public OpKernel { |
666 | public: |
667 | explicit BaseBatchMatMulOp(OpKernelConstruction* context, |
668 | bool is_legacy_matmul) |
669 | : OpKernel(context) { |
670 | if (is_legacy_matmul) { |
671 | // The old MatMul kernel has "transpose_a/transpose_b" attributes. |
672 | OP_REQUIRES_OK(context, context->GetAttr("transpose_a" , &trans_x_)); |
673 | OP_REQUIRES_OK(context, context->GetAttr("transpose_b" , &trans_y_)); |
674 | adj_x_ = false; |
675 | adj_y_ = false; |
676 | } else { |
677 | OP_REQUIRES_OK(context, context->GetAttr("adj_x" , &adj_x_)); |
678 | OP_REQUIRES_OK(context, context->GetAttr("adj_y" , &adj_y_)); |
679 | trans_x_ = false; |
680 | trans_y_ = false; |
681 | } |
682 | } |
683 | |
684 | ~BaseBatchMatMulOp() override {} |
685 | |
686 | void Compute(OpKernelContext* ctx) override { |
687 | const Tensor& in0 = ctx->input(0); |
688 | const Tensor& in1 = ctx->input(1); |
689 | |
690 | const Status s = ValidateInputTensors(ctx, in0, in1); |
691 | if (!s.ok()) { |
692 | ctx->SetStatus(s); |
693 | return; |
694 | } |
695 | |
696 | MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes()); |
697 | OP_REQUIRES( |
698 | ctx, bcast.IsValid(), |
699 | errors::InvalidArgument( |
700 | "In[0] and In[1] must have compatible batch dimensions: " , |
701 | in0.shape().DebugString(), " vs. " , in1.shape().DebugString())); |
702 | |
703 | TensorShape out_shape = bcast.output_batch_shape(); |
704 | auto batch_size = bcast.output_batch_size(); |
705 | auto d0 = in0.dim_size(in0.dims() - 2); |
706 | auto d1 = in0.dim_size(in0.dims() - 1); |
707 | Tensor in0_reshaped; |
708 | OP_REQUIRES( |
709 | ctx, |
710 | in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})), |
711 | errors::Internal("Failed to reshape In[0] from " , |
712 | in0.shape().DebugString())); |
713 | auto d2 = in1.dim_size(in1.dims() - 2); |
714 | auto d3 = in1.dim_size(in1.dims() - 1); |
715 | Tensor in1_reshaped; |
716 | OP_REQUIRES( |
717 | ctx, |
718 | in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})), |
719 | errors::Internal("Failed to reshape In[1] from " , |
720 | in1.shape().DebugString())); |
721 | if (adj_x_ || trans_x_) std::swap(d0, d1); |
722 | if (adj_y_ || trans_y_) std::swap(d2, d3); |
723 | OP_REQUIRES( |
724 | ctx, d1 == d2, |
725 | errors::InvalidArgument( |
726 | "Matrix size-incompatible: In[0]: " , in0.shape().DebugString(), |
727 | ", In[1]: " , in1.shape().DebugString())); |
728 | out_shape.AddDim(d0); |
729 | out_shape.AddDim(d3); |
730 | Tensor* out = nullptr; |
731 | OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out)); |
732 | if (out->NumElements() == 0) { |
733 | return; |
734 | } |
735 | if (in0.NumElements() == 0 || in1.NumElements() == 0) { |
736 | functor::SetZeroFunctor<Device, Tout> f; |
737 | f(ctx->eigen_device<Device>(), out->flat<Tout>()); |
738 | return; |
739 | } |
740 | Tensor out_reshaped; |
741 | OP_REQUIRES(ctx, |
742 | out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})), |
743 | errors::Internal("Failed to reshape output from " , |
744 | out->shape().DebugString())); |
745 | if (std::is_same<Ta, bfloat16>::value && |
746 | std::is_same<Tb, bfloat16>::value) { |
747 | bool is_cpu = std::is_same<Device, CPUDevice>::value; |
748 | OP_REQUIRES(ctx, is_cpu, |
749 | errors::Internal("bfloat16 matmul is not supported by GPU" )); |
750 | Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float; |
751 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(), |
752 | &in0_reshaped_float)); |
753 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(), |
754 | &in1_reshaped_float)); |
755 | OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(), |
756 | &out_reshaped_float)); |
757 | |
758 | // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU. |
759 | BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(), |
760 | in0_reshaped_float.flat<float>().data(), |
761 | in0_reshaped.NumElements()); |
762 | BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(), |
763 | in1_reshaped_float.flat<float>().data(), |
764 | in1_reshaped.NumElements()); |
765 | |
766 | LaunchBatchMatMul<Device, float>::Launch( |
767 | ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_, |
768 | trans_y_, bcast, &out_reshaped_float); |
769 | FloatToBFloat16(out_reshaped_float.flat<float>().data(), |
770 | out_reshaped.flat<bfloat16>().data(), out->NumElements()); |
771 | } else { |
772 | // Cast tensor to desired type to reuse Eigen. |
773 | // TODO(b/178749687): remove this cast if Eigen supports this natively. |
774 | if (!std::is_same<Ta, Tout>::value) { |
775 | in0_reshaped = CastTensor<Ta, Tout>(in0_reshaped); |
776 | } |
777 | if (!std::is_same<Tb, Tout>::value) { |
778 | in1_reshaped = CastTensor<Tb, Tout>(in1_reshaped); |
779 | } |
780 | LaunchBatchMatMul<Device, Tout>::Launch(ctx, in0_reshaped, in1_reshaped, |
781 | adj_x_, adj_y_, trans_x_, |
782 | trans_y_, bcast, &out_reshaped); |
783 | } |
784 | } |
785 | |
786 | protected: |
787 | virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, |
788 | const Tensor& in1) = 0; |
789 | |
790 | private: |
791 | // TODO(171979567) Make the ops take both adj and transpose attributes. |
792 | bool adj_x_ = false; |
793 | bool adj_y_ = false; |
794 | bool trans_x_ = false; |
795 | bool trans_y_ = false; |
796 | |
797 | // Cast `t` from `SrcT` to `DstT`. |
798 | template <typename SrcT, typename DstT> |
799 | Tensor CastTensor(const Tensor& t) { |
800 | Tensor res = Tensor(DataTypeToEnum<DstT>::v(), t.shape()); |
801 | res.flat<DstT>() = t.flat<SrcT>().template cast<DstT>(); |
802 | return res; |
803 | } |
804 | }; |
805 | |
806 | // BatchMatMul Op implementation which disallows broadcasting. |
807 | template <typename Device, typename Ta, typename Tb, typename Tout, |
808 | bool is_legacy_matmul = false> |
809 | class BatchMatMulOp : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> { |
810 | public: |
811 | explicit BatchMatMulOp(OpKernelConstruction* context) |
812 | : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context, is_legacy_matmul) {} |
813 | |
814 | ~BatchMatMulOp() override {} |
815 | |
816 | private: |
817 | Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, |
818 | const Tensor& in1) override { |
819 | // Disallow broadcasting support. Ensure that all batch dimensions of the |
820 | // input tensors match. |
821 | if (in0.dims() != in1.dims()) { |
822 | return errors::InvalidArgument( |
823 | "In[0] and In[1] has different ndims: " , in0.shape().DebugString(), |
824 | " vs. " , in1.shape().DebugString()); |
825 | } |
826 | const int ndims = in0.dims(); |
827 | if (is_legacy_matmul) { |
828 | if (ndims != 2) { |
829 | return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: " , |
830 | ndims); |
831 | } |
832 | } else { |
833 | if (ndims < 2) { |
834 | return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: " , |
835 | ndims); |
836 | } |
837 | for (int i = 0; i < ndims - 2; ++i) { |
838 | if (in0.dim_size(i) != in1.dim_size(i)) { |
839 | return errors::InvalidArgument( |
840 | "In[0].dim(" , i, ") and In[1].dim(" , i, |
841 | ") must be the same: " , in0.shape().DebugString(), " vs " , |
842 | in1.shape().DebugString()); |
843 | } |
844 | } |
845 | } |
846 | return OkStatus(); |
847 | } |
848 | }; |
849 | |
850 | // BatchMatMul Op implementation with broadcasting support. |
851 | template <typename Device, typename Ta, typename Tb, typename Tout> |
852 | class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> { |
853 | public: |
854 | explicit BatchMatMulV2Op(OpKernelConstruction* context) |
855 | : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context, |
856 | /* is_legacy_matmul= */ false) { |
857 | } |
858 | |
859 | ~BatchMatMulV2Op() override {} |
860 | |
861 | private: |
862 | Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0, |
863 | const Tensor& in1) override { |
864 | // Enable broadcasting support. Validity of broadcasting is checked in |
865 | // BaseBatchMatMulOp. |
866 | if (in0.dims() < 2) { |
867 | return errors::InvalidArgument("In[0] ndims must be >= 2: " , in0.dims()); |
868 | } |
869 | if (in1.dims() < 2) { |
870 | return errors::InvalidArgument("In[1] ndims must be >= 2: " , in1.dims()); |
871 | } |
872 | return OkStatus(); |
873 | } |
874 | }; |
875 | |
876 | // Register for MatMul, BatchMatMul, BatchMatMulv2 where Tin = Tout. |
877 | #define REGISTER_BATCH_MATMUL_CPU(TYPE) \ |
878 | REGISTER_KERNEL_BUILDER( \ |
879 | Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ |
880 | BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE>); \ |
881 | REGISTER_KERNEL_BUILDER( \ |
882 | Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ |
883 | BatchMatMulV2Op<CPUDevice, TYPE, TYPE, TYPE>); \ |
884 | REGISTER_KERNEL_BUILDER( \ |
885 | Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \ |
886 | BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>) |
887 | |
888 | #define REGISTER_BATCH_MATMUL_GPU(TYPE) \ |
889 | REGISTER_KERNEL_BUILDER( \ |
890 | Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ |
891 | BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE>); \ |
892 | REGISTER_KERNEL_BUILDER( \ |
893 | Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ |
894 | BatchMatMulV2Op<GPUDevice, TYPE, TYPE, TYPE>); \ |
895 | REGISTER_KERNEL_BUILDER( \ |
896 | Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \ |
897 | BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>) |
898 | |
899 | // Register for BatchMatMulv3 where Ta, Tb and Tout are not the same. |
900 | #define REGISTER_BATCH_MATMUL_TOUT_CPU(Ta, Tb, Tout) \ |
901 | REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \ |
902 | .Device(DEVICE_CPU) \ |
903 | .TypeConstraint<Ta>("Ta") \ |
904 | .TypeConstraint<Tb>("Tb") \ |
905 | .TypeConstraint<Tout>("Tout"), \ |
906 | BatchMatMulV2Op<CPUDevice, Ta, Tb, Tout>) |
907 | |
908 | #define REGISTER_BATCH_MATMUL_TOUT_GPU(Ta, Tb, Tout) \ |
909 | REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \ |
910 | .Device(DEVICE_GPU) \ |
911 | .TypeConstraint<Ta>("Ta") \ |
912 | .TypeConstraint<Tb>("Tb") \ |
913 | .TypeConstraint<Tout>("Tout"), \ |
914 | BatchMatMulV2Op<GPUDevice, Ta, Tb, Tout>) |
915 | |
916 | } // namespace tensorflow |
917 | |
918 | #endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_ |
919 | |