1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// See docs in ../ops/math_ops.cc.
17
18#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
19#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
20
21#define EIGEN_USE_THREADS
22
23#include <type_traits>
24#include <utility>
25#include <vector>
26
27#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
28#include "tensorflow/core/framework/op_kernel.h"
29#include "tensorflow/core/framework/register_types.h"
30#include "tensorflow/core/framework/tensor.h"
31#include "tensorflow/core/framework/tensor_shape.h"
32#include "tensorflow/core/framework/type_traits.h"
33#include "tensorflow/core/framework/types.h"
34#include "tensorflow/core/kernels/fill_functor.h"
35#include "tensorflow/core/lib/core/errors.h"
36#include "tensorflow/core/lib/gtl/inlined_vector.h"
37#include "tensorflow/core/platform/logging.h"
38#include "tensorflow/core/platform/types.h"
39#include "tensorflow/core/util/matmul_autotune.h"
40#include "tensorflow/core/util/matmul_bcast.h"
41#include "tensorflow/core/util/work_sharder.h"
42
43#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
44#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
45#endif
46
47#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
48#include "tensorflow/core/kernels/gpu_utils.h"
49#include "tensorflow/core/platform/stream_executor.h"
50#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
51#if GOOGLE_CUDA
52#include "third_party/gpus/cuda/include/cuda.h"
53#include "tensorflow/compiler/xla/stream_executor/cuda/cuda_blas_lt.h"
54#include "tensorflow/compiler/xla/stream_executor/host_or_device_scalar.h"
55#include "tensorflow/core/kernels/matmul_util.h"
56#endif // GOOGLE_CUDA
57
58namespace tensorflow {
59
60typedef Eigen::ThreadPoolDevice CPUDevice;
61typedef Eigen::GpuDevice GPUDevice;
62
63namespace {
64
65// Returns the pair of dimensions along which to perform Tensor contraction to
66// emulate matrix multiplication.
67// For matrix multiplication of 2D Tensors X and Y, X is contracted along
68// second dimension and Y is contracted along the first dimension (if neither X
69// nor Y is adjointed). The dimension to contract along is switched when any
70// operand is adjointed.
71// See http://en.wikipedia.org/wiki/Tensor_contraction
72Eigen::IndexPair<Eigen::DenseIndex> ContractionDims(bool adj_x, bool adj_y) {
73 return Eigen::IndexPair<Eigen::DenseIndex>(adj_x ? 0 : 1, adj_y ? 1 : 0);
74}
75
76// Parallel batch matmul kernel based on the multi-threaded tensor contraction
77// in Eigen.
78template <typename Scalar, bool IsComplex = true>
79struct ParallelMatMulKernel {
80 static void Conjugate(const OpKernelContext* context, Tensor* out) {
81 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
82 auto z = out->tensor<Scalar, 3>();
83 z.device(d) = z.conjugate();
84 }
85
86 static void Run(const OpKernelContext* context, const Tensor& in_x,
87 const Tensor in_y, bool adj_x, bool adj_y, bool trans_x,
88 bool trans_y, const MatMulBCast& bcast, Tensor* out,
89 int batch_size) {
90 static_assert(IsComplex, "Complex type expected.");
91 auto Tx = in_x.tensor<Scalar, 3>();
92 auto Ty = in_y.tensor<Scalar, 3>();
93 auto Tz = out->tensor<Scalar, 3>();
94 // We use the identities
95 // conj(a) * conj(b) = conj(a * b)
96 // conj(a) * b = conj(a * conj(b))
97 // to halve the number of cases. The final conjugation of the result is
98 // done at the end of LaunchBatchMatMul<CPUDevice, Scalar>::Launch().
99 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
100 contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
101 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
102
103 const bool should_bcast = bcast.IsBroadcastingRequired();
104 const auto& x_batch_indices = bcast.x_batch_indices();
105 const auto& y_batch_indices = bcast.y_batch_indices();
106 // TODO(rmlarsen): Consider launching these contractions asynchronously.
107 for (int64_t i = 0; i < batch_size; ++i) {
108 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
109 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
110
111 auto x = Tx.template chip<0>(x_batch_index);
112 auto z = Tz.template chip<0>(i);
113 if (adj_x != adj_y) {
114 auto y = Ty.template chip<0>(y_batch_index).conjugate();
115 z.device(d) = x.contract(y, contract_pairs);
116 } else {
117 auto y = Ty.template chip<0>(y_batch_index);
118 z.device(d) = x.contract(y, contract_pairs);
119 }
120 }
121 }
122};
123
124// The Eigen contraction kernel used here is very large and slow to compile,
125// so we partially specialize ParallelMatMulKernel for real types to avoid all
126// but one of the instantiations.
127template <typename Scalar>
128struct ParallelMatMulKernel<Scalar, false> {
129 static void Conjugate(const OpKernelContext* context, Tensor* out) {}
130
131 static void Run(const OpKernelContext* context, const Tensor& in_x,
132 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
133 bool trans_y, const MatMulBCast& bcast, Tensor* out,
134 int batch_size) {
135 const bool should_bcast = bcast.IsBroadcastingRequired();
136 const Eigen::ThreadPoolDevice d = context->eigen_cpu_device();
137 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> contract_pairs;
138 contract_pairs[0] = ContractionDims(adj_x || trans_x, adj_y || trans_y);
139 if (batch_size == 1 && !should_bcast) {
140 auto Tx = in_x.flat_inner_dims<Scalar, 2>();
141 auto Ty = in_y.flat_inner_dims<Scalar, 2>();
142 auto Tz = out->flat_inner_dims<Scalar, 2>();
143 Tz.device(d) = Tx.contract(Ty, contract_pairs);
144 } else {
145 auto Tx = in_x.tensor<Scalar, 3>();
146 auto Ty = in_y.tensor<Scalar, 3>();
147 auto Tz = out->tensor<Scalar, 3>();
148 const auto& x_batch_indices = bcast.x_batch_indices();
149 const auto& y_batch_indices = bcast.y_batch_indices();
150 // TODO(rmlarsen): Consider launching these contractions asynchronously.
151 for (int64_t i = 0; i < batch_size; ++i) {
152 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
153 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
154 auto x = Tx.template chip<0>(x_batch_index);
155 auto y = Ty.template chip<0>(y_batch_index);
156 auto z = Tz.template chip<0>(i);
157
158 z.device(d) = x.contract(y, contract_pairs);
159 }
160 }
161 }
162};
163
164// Sequential batch matmul kernel that calls the regular Eigen matmul.
165// We prefer this over the tensor contraction because it performs
166// better on vector-matrix and matrix-vector products.
167template <typename Scalar>
168struct SequentialMatMulKernel {
169 using Matrix =
170 Eigen::Matrix<Scalar, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>;
171 using ConstMatrixMap = Eigen::Map<const Matrix>;
172 using MatrixMap = Eigen::Map<Matrix>;
173
174 static ConstMatrixMap ConstTensorSliceToEigenMatrix(const Tensor& t,
175 int slice) {
176 return ConstMatrixMap(
177 t.flat<Scalar>().data() + slice * t.dim_size(1) * t.dim_size(2),
178 t.dim_size(1), t.dim_size(2));
179 }
180
181 static MatrixMap TensorSliceToEigenMatrix(Tensor* t, int slice) {
182 return MatrixMap(
183 t->flat<Scalar>().data() + slice * t->dim_size(1) * t->dim_size(2),
184 t->dim_size(1), t->dim_size(2));
185 }
186
187 static void Run(const Tensor& in_x, const Tensor& in_y, bool adj_x,
188 bool adj_y, bool trans_x, bool trans_y,
189 const MatMulBCast& bcast, Tensor* out, int start, int limit) {
190 const bool should_bcast = bcast.IsBroadcastingRequired();
191 const auto& x_batch_indices = bcast.x_batch_indices();
192 const auto& y_batch_indices = bcast.y_batch_indices();
193 for (int64_t i = start; i < limit; ++i) {
194 const int64_t x_batch_index = should_bcast ? x_batch_indices[i] : i;
195 const int64_t y_batch_index = should_bcast ? y_batch_indices[i] : i;
196 auto x = ConstTensorSliceToEigenMatrix(in_x, x_batch_index);
197 auto y = ConstTensorSliceToEigenMatrix(in_y, y_batch_index);
198 auto z = TensorSliceToEigenMatrix(out, i);
199 // Assume at most one of adj_x or trans_x is true. Similarly, for adj_y
200 // and trans_y.
201 if (!adj_x && !trans_x) {
202 if (!adj_y && !trans_y) {
203 z.noalias() = x * y;
204 } else if (adj_y) {
205 z.noalias() = x * y.adjoint();
206 } else { // trans_y == true
207 z.noalias() = x * y.transpose();
208 }
209 } else if (adj_x) {
210 if (!adj_y && !trans_y) {
211 z.noalias() = x.adjoint() * y;
212 } else if (adj_y) {
213 z.noalias() = x.adjoint() * y.adjoint();
214 } else { // trans_y == true
215 z.noalias() = x.adjoint() * y.transpose();
216 }
217 } else { // trans_x == true
218 if (!adj_y && !trans_y) {
219 z.noalias() = x.transpose() * y;
220 } else if (adj_y) {
221 z.noalias() = x.transpose() * y.adjoint();
222 } else { // trans_y == true
223 z.noalias() = x.transpose() * y.transpose();
224 }
225 }
226 }
227 }
228};
229} // namespace
230
231template <typename Device, typename Scalar>
232struct LaunchBatchMatMul;
233
234template <typename Scalar>
235struct LaunchBatchMatMul<CPUDevice, Scalar> {
236 static void Launch(OpKernelContext* context, const Tensor& in_x,
237 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
238 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
239 typedef ParallelMatMulKernel<Scalar, Eigen::NumTraits<Scalar>::IsComplex>
240 ParallelMatMulKernel;
241 bool conjugate_result = false;
242
243 // Number of matrix multiplies i.e. size of the batch.
244 const int64_t batch_size = bcast.output_batch_size();
245 const int64_t cost_per_unit =
246 in_x.dim_size(1) * in_x.dim_size(2) * out->dim_size(2);
247 const int64_t small_dim = std::min(
248 std::min(in_x.dim_size(1), in_x.dim_size(2)), out->dim_size(2));
249 // NOTE(nikhilsarda): This heuristic is optimal in benchmarks as of
250 // Jan 21, 2020.
251 const int64_t kMaxCostOuterParallelism = 128 * 128; // heuristic.
252 auto worker_threads = *(context->device()->tensorflow_cpu_worker_threads());
253 // TODO(rmlarsen): Reconsider the heuristics now that we have asynchronous
254 // evaluation in Eigen Tensor.
255 if (small_dim > 1 &&
256 (batch_size == 1 || cost_per_unit > kMaxCostOuterParallelism)) {
257 // Parallelize over inner dims.
258 // For large matrix products it is counter-productive to parallelize
259 // over the batch dimension.
260 ParallelMatMulKernel::Run(context, in_x, in_y, adj_x, adj_y, trans_x,
261 trans_y, bcast, out, batch_size);
262 conjugate_result = adj_x;
263 } else {
264 // Parallelize over outer dims. For small matrices and large batches, it
265 // is counter-productive to parallelize the inner matrix multiplies.
266 Shard(worker_threads.num_threads, worker_threads.workers, batch_size,
267 cost_per_unit,
268 [&in_x, &in_y, adj_x, adj_y, trans_x, trans_y, &bcast, out](
269 int start, int limit) {
270 SequentialMatMulKernel<Scalar>::Run(in_x, in_y, adj_x, adj_y,
271 trans_x, trans_y, bcast, out,
272 start, limit);
273 });
274 }
275 if (conjugate_result) {
276 // We used one of the identities
277 // conj(a) * conj(b) = conj(a * b)
278 // conj(a) * b = conj(a * conj(b))
279 // above, we need to conjugate the final output. This is a
280 // no-op for non-complex types.
281 ParallelMatMulKernel::Conjugate(context, out);
282 }
283 }
284};
285
286#if GOOGLE_CUDA
287
288namespace {
289// A dummy type to group matmul autotune results together.
290struct BlasLtMatmulAutoTuneGroup {
291 static string name() { return "MatmulLt"; }
292};
293
294typedef AutotuneSingleton<BlasLtMatmulAutoTuneGroup, BlasLtMatmulPlanParams,
295 se::blas::AlgorithmConfig,
296 absl::Hash<BlasLtMatmulPlanParams>>
297 AutoTuneBatchMatmul;
298
299} // namespace
300
301#endif // GOOGLE_CUDA
302#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
303
304class BlasScratchAllocator : public se::ScratchAllocator {
305 public:
306 using Stream = se::Stream;
307 using DeviceMemoryBytes = se::DeviceMemory<uint8>;
308
309 BlasScratchAllocator(OpKernelContext* context)
310 : memory_limit_(0), total_byte_size_(0), context_(context) {}
311
312 BlasScratchAllocator(OpKernelContext* context, int64_t memory_limit)
313 : memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
314
315 int64_t GetMemoryLimitInBytes() override { return memory_limit_; }
316
317 se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
318 int64_t byte_size) override {
319 Tensor temporary_memory;
320
321 if (memory_limit_ > 0 && byte_size > memory_limit_) {
322 return se::port::Status{
323 se::port::error::UNAVAILABLE,
324 absl::StrCat("Requested memory size (", byte_size,
325 ") exceeds the memory limit (", memory_limit_, ").")};
326 }
327 AllocationAttributes allocation_attr;
328 allocation_attr.retry_on_failure = false;
329 Status allocation_status(context_->allocate_temp(
330 DT_UINT8, TensorShape({byte_size}), &temporary_memory));
331 if (!allocation_status.ok()) {
332 return se::port::Status{
333 se::port::error::UNAVAILABLE,
334 absl::StrCat("Failed to allocate requested memory of (", byte_size,
335 ").")};
336 }
337 // Hold the reference of the allocated tensors until the end of the
338 // allocator.
339 allocated_tensors_.push_back(temporary_memory);
340 total_byte_size_ += byte_size;
341 return se::port::StatusOr<DeviceMemoryBytes>(
342 DeviceMemoryBytes::MakeFromByteSize(
343 temporary_memory.flat<uint8>().data(),
344 temporary_memory.flat<uint8>().size()));
345 }
346 int64 TotalByteSize() { return total_byte_size_; }
347
348 private:
349 int64_t memory_limit_;
350 int64_t total_byte_size_;
351 OpKernelContext* context_;
352 std::vector<Tensor> allocated_tensors_;
353};
354
355template <typename Scalar>
356struct LaunchBatchMatMul<GPUDevice, Scalar> {
357 static void Launch(OpKernelContext* context, const Tensor& in_x,
358 const Tensor& in_y, bool adj_x, bool adj_y, bool trans_x,
359 bool trans_y, const MatMulBCast& bcast, Tensor* out) {
360 static const bool use_autotune = MatmulAutotuneEnable();
361 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
362 se::blas::Transpose::kTranspose,
363 se::blas::Transpose::kConjugateTranspose};
364 const uint64 m = in_x.dim_size(adj_x || trans_x ? 2 : 1);
365 const uint64 k = in_x.dim_size(adj_x || trans_x ? 1 : 2);
366 const uint64 n = in_y.dim_size(adj_y || trans_y ? 1 : 2);
367 const int64_t batch_size = bcast.output_batch_size();
368 auto blas_transpose_a = trans[adj_x ? 2 : (trans_x ? 1 : 0)];
369 auto blas_transpose_b = trans[adj_y ? 2 : (trans_y ? 1 : 0)];
370
371 auto* stream = context->op_device_context()->stream();
372 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
373
374 typedef se::DeviceMemory<Scalar> DeviceMemoryType;
375 std::vector<DeviceMemoryType> a_device_memory;
376 std::vector<DeviceMemoryType> b_device_memory;
377 std::vector<DeviceMemoryType> c_device_memory;
378 std::vector<DeviceMemoryType*> a_ptrs;
379 std::vector<DeviceMemoryType*> b_ptrs;
380 std::vector<DeviceMemoryType*> c_ptrs;
381 a_device_memory.reserve(bcast.x_batch_size());
382 b_device_memory.reserve(bcast.y_batch_size());
383 c_device_memory.reserve(batch_size);
384 a_ptrs.reserve(batch_size);
385 b_ptrs.reserve(batch_size);
386 c_ptrs.reserve(batch_size);
387 auto* a_base_ptr = in_x.template flat<Scalar>().data();
388 auto* b_base_ptr = in_y.template flat<Scalar>().data();
389 auto* c_base_ptr = out->template flat<Scalar>().data();
390 uint64 a_stride;
391 uint64 b_stride;
392 uint64 c_stride;
393
394 bool is_full_broadcast =
395 std::min(bcast.x_batch_size(), bcast.y_batch_size()) == 1;
396
397 // Use float as coefficient type for half precision inputs, otherwise use
398 // the input type.
399 typedef std::conditional_t<std::is_same_v<Scalar, Eigen::half>, float,
400 Scalar>
401 Coefficient;
402
403#if GOOGLE_CUDA
404 if (EnableCublasLtGemm()) {
405 static const int64_t max_scratch_size =
406 GetWorkspaceLimit(1LL << 32); // 4GB by default
407
408 bool requires_mixed_broadcasting =
409 bcast.IsBroadcastingRequired() && !is_full_broadcast;
410
411 if (!requires_mixed_broadcasting) {
412 a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
413 b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
414 c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
415 a_ptrs.push_back(&a_device_memory.back());
416 b_ptrs.push_back(&b_device_memory.back());
417 c_ptrs.push_back(&c_device_memory.back());
418
419 BlasLtMatmulPlanParams matmul_params{
420 se::blas::ToDataType<Scalar>::value,
421 static_cast<size_t>(m),
422 static_cast<size_t>(n),
423 static_cast<size_t>(k),
424 blas_transpose_a,
425 blas_transpose_b,
426 static_cast<size_t>(batch_size),
427 /*broadcast_a=*/bcast.x_batch_size() == 1,
428 /*broadcast_b=*/bcast.y_batch_size() == 1};
429
430 std::optional<int> max_algorithm_count;
431 if (!use_autotune) max_algorithm_count = 1;
432
433 auto plan_and_algorithms_or =
434 GetPlanAndAlgorithms(stream, matmul_params, max_algorithm_count);
435 OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
436 const auto* plan_and_algorithms =
437 std::move(plan_and_algorithms_or).value();
438 const auto& plan = plan_and_algorithms->plan;
439 const auto& algorithms = plan_and_algorithms->algorithms;
440
441 se::cuda::BlasLt* blas_lt = se::cuda::GetBlasLt(stream);
442 OP_REQUIRES(context, blas_lt != nullptr,
443 errors::Internal("blaslt not supported"));
444
445 se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
446 if (!use_autotune) {
447 algorithm_config.set_algorithm(0);
448 } else if (!AutoTuneBatchMatmul::GetInstance()->Find(
449 matmul_params, &algorithm_config)) {
450 VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
451 << " algorithms.";
452 se::blas::ProfileResult best_result;
453 se::blas::ProfileResult profile_result;
454
455 for (size_t i = 0; i != algorithms.size(); ++i) {
456 const auto& profile_algorithm = algorithms[i];
457 // Create a new scratch allocator with every autotuning run so that
458 // scratch space is deallocated between runs.
459 BlasScratchAllocator scratch_allocator(context, max_scratch_size);
460 Status cublas_launch_status =
461 DoBlasLtMatmul(stream, plan, *a_ptrs[0], *b_ptrs[0], *c_ptrs[0],
462 profile_algorithm, scratch_allocator,
463 /*bias = */ {}, &profile_result);
464
465 VLOG(4) << " Autotune algorithm " << i
466 << " result: " << profile_result.elapsed_time_in_ms()
467 << " ms, valid=" << profile_result.is_valid()
468 << ", workspace_size=" << profile_algorithm.workspace_size;
469
470 if (cublas_launch_status.ok() && profile_result.is_valid() &&
471 profile_result.elapsed_time_in_ms() <
472 best_result.elapsed_time_in_ms()) {
473 best_result = profile_result;
474 // Use index into algorithms array, instead of cublas internal ID.
475 best_result.set_algorithm(i);
476 }
477 }
478
479 if (best_result.is_valid()) {
480 algorithm_config.set_algorithm(best_result.algorithm());
481 }
482 // Each matmul parameter set gets one pass of
483 // autotune. If no algorithms works, kNoAlgorithm is added to the
484 // autotune map.
485 AutoTuneBatchMatmul::GetInstance()->Insert(matmul_params,
486 algorithm_config);
487 }
488 se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
489 OP_REQUIRES(context,
490 0 <= algorithm_idx && algorithm_idx < algorithms.size(),
491 errors::Internal("Missing/invalid BatchMatmul algorithm"));
492 const auto& algorithm = algorithms[algorithm_idx];
493 BlasScratchAllocator scratch_allocator(context, max_scratch_size);
494 VLOG(4) << "Calling BlasLtMatMul: a.shape=(" << bcast.x_batch_size()
495 << ", " << in_x.dim_size(1) << ", " << in_x.dim_size(2)
496 << "), b.shape=(" << bcast.y_batch_size() << ", "
497 << in_y.dim_size(1) << ", " << in_y.dim_size(2) << "), m=" << m
498 << ", n=" << n << ", k=" << k << ", batch_size=" << batch_size
499 << "trans_x = " << trans_x << "trans_y = " << trans_y
500 << "adj_x = " << adj_x << "adj_y = " << adj_y;
501
502 OP_REQUIRES_OK(
503 context, DoBlasLtMatmul(stream, plan, *a_ptrs[0], *b_ptrs[0],
504 *c_ptrs[0], algorithm, scratch_allocator));
505 } else { // requires mixed broadcasting
506 const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
507 const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
508 for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
509 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
510 }
511 for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
512 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
513 }
514 for (int64_t i = 0; i < batch_size; ++i) {
515 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
516 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
517 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
518 c_ptrs.push_back(&c_device_memory.back());
519 }
520
521 BlasScratchAllocator scratch_allocator(context, max_scratch_size);
522 bool blas_launch_status =
523 stream
524 ->ThenBlasGemmBatchedWithScratch(
525 blas_transpose_b, blas_transpose_a, n, m, k,
526 static_cast<Coefficient>(1.0), b_ptrs,
527 adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
528 static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
529 &scratch_allocator)
530 .ok();
531 if (!blas_launch_status) {
532 context->SetStatus(errors::Internal(
533 "Blas xGEMMBatched launch failed: a.shape=",
534 in_x.shape().DebugString(),
535 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
536 ", k=", k, ", batch_size=", batch_size));
537 }
538 }
539 } else {
540#endif // GOOGLE_CUDA
541 bool use_strided_batched =
542 (!bcast.IsBroadcastingRequired() || is_full_broadcast) &&
543 batch_size > 1;
544 if (use_strided_batched) {
545 a_stride = bcast.x_batch_size() != 1 ? m * k : 0;
546 b_stride = bcast.y_batch_size() != 1 ? k * n : 0;
547 c_stride = m * n;
548 a_device_memory.push_back(AsDeviceMemory(a_base_ptr));
549 b_device_memory.push_back(AsDeviceMemory(b_base_ptr));
550 c_device_memory.push_back(AsDeviceMemory(c_base_ptr));
551 a_ptrs.push_back(&a_device_memory.back());
552 b_ptrs.push_back(&b_device_memory.back());
553 c_ptrs.push_back(&c_device_memory.back());
554 } else if (!bcast.IsBroadcastingRequired()) {
555 for (int64_t i = 0; i < batch_size; ++i) {
556 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
557 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
558 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
559 a_ptrs.push_back(&a_device_memory.back());
560 b_ptrs.push_back(&b_device_memory.back());
561 c_ptrs.push_back(&c_device_memory.back());
562 }
563 } else {
564 const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
565 const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
566 for (int64_t i = 0; i < bcast.x_batch_size(); ++i) {
567 a_device_memory.push_back(AsDeviceMemory(a_base_ptr + i * m * k));
568 }
569 for (int64_t i = 0; i < bcast.y_batch_size(); ++i) {
570 b_device_memory.push_back(AsDeviceMemory(b_base_ptr + i * k * n));
571 }
572 for (int64_t i = 0; i < batch_size; ++i) {
573 c_device_memory.push_back(AsDeviceMemory(c_base_ptr + i * m * n));
574 a_ptrs.push_back(&a_device_memory[a_batch_indices[i]]);
575 b_ptrs.push_back(&b_device_memory[b_batch_indices[i]]);
576 c_ptrs.push_back(&c_device_memory.back());
577 }
578 }
579
580 // Blas does
581 // C = A x B
582 // where A, B and C are assumed to be in column major.
583 // We want the output to be in row-major, so we can compute
584 // C' = B' x A', where ' stands for transpose (not adjoint).
585 // TODO(yangzihao): Choose the best of the three strategies using
586 // autotune.
587 if (batch_size == 1) {
588 // This is a regular matrix*matrix or matrix*vector multiply. Avoid the
589 // overhead of the scratch allocator and the batch interface.
590 // TODO(benbarsdell): Use fp16 Gemv if it becomes supported by CUBLAS
591 if constexpr (!std::is_same_v<Scalar, Eigen::half>) {
592 if (n == 1 &&
593 blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
594 blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
595 // This is a matrix*vector multiply so use GEMV to compute A * b.
596 // Here we are multiplying in the natural order, so we have to flip
597 // the transposition flag to compensate for the tensor being stored
598 // row-major. Since GEMV doesn't provide a way to just conjugate an
599 // argument, we have to defer those cases to GEMM below.
600 auto gemv_trans_a =
601 blas_transpose_a == se::blas::Transpose::kTranspose
602 ? se::blas::Transpose::kNoTranspose
603 : se::blas::Transpose::kTranspose;
604 bool blas_launch_status =
605 stream
606 ->ThenBlasGemv(gemv_trans_a, adj_x || trans_x ? m : k,
607 adj_x || trans_x ? k : m,
608 static_cast<Coefficient>(1.0), *(a_ptrs[0]),
609 adj_x || trans_x ? m : k, *(b_ptrs[0]), 1,
610 static_cast<Coefficient>(0.0), c_ptrs[0], 1)
611 .ok();
612 if (!blas_launch_status) {
613 context->SetStatus(errors::Internal(
614 "Blas xGEMV launch failed : a.shape=",
615 in_x.shape().DebugString(), ", b.shape=",
616 in_y.shape().DebugString(), ", m=", m, ", n=", n, ", k=", k));
617 }
618 return;
619 }
620 }
621
622 OP_REQUIRES_OK(context,
623 stream->ThenBlasGemm(
624 blas_transpose_b, blas_transpose_a, n, m, k,
625 *(b_ptrs[0]), adj_y || trans_y ? k : n, *(a_ptrs[0]),
626 adj_x || trans_x ? m : k, c_ptrs[0], n,
627 se::blas::kDefaultComputePrecision));
628 } else if (use_strided_batched) {
629 OP_REQUIRES_OK(
630 context, stream->ThenBlasGemmStridedBatched(
631 blas_transpose_b, blas_transpose_a, n, m, k,
632 static_cast<Coefficient>(1.0), *b_ptrs[0],
633 adj_y || trans_y ? k : n, b_stride, *a_ptrs[0],
634 adj_x || trans_x ? m : k, a_stride,
635 static_cast<Coefficient>(0.0), c_ptrs[0], n, c_stride,
636 batch_size, se::blas::kDefaultComputePrecision));
637 } else {
638 BlasScratchAllocator scratch_allocator(context);
639 bool blas_launch_status =
640 stream
641 ->ThenBlasGemmBatchedWithScratch(
642 blas_transpose_b, blas_transpose_a, n, m, k,
643 static_cast<Coefficient>(1.0), b_ptrs,
644 adj_y || trans_y ? k : n, a_ptrs, adj_x || trans_x ? m : k,
645 static_cast<Coefficient>(0.0), c_ptrs, n, batch_size,
646 &scratch_allocator)
647 .ok();
648 if (!blas_launch_status) {
649 context->SetStatus(errors::Internal(
650 "Blas xGEMMBatched launch failed : a.shape=",
651 in_x.shape().DebugString(),
652 ", b.shape=", in_y.shape().DebugString(), ", m=", m, ", n=", n,
653 ", k=", k, ", batch_size=", batch_size));
654 }
655 }
656#if GOOGLE_CUDA
657 }
658#endif // GOOGLE_CUDA
659 }
660};
661
662#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
663
664template <typename Device, typename Ta, typename Tb, typename Tout>
665class BaseBatchMatMulOp : public OpKernel {
666 public:
667 explicit BaseBatchMatMulOp(OpKernelConstruction* context,
668 bool is_legacy_matmul)
669 : OpKernel(context) {
670 if (is_legacy_matmul) {
671 // The old MatMul kernel has "transpose_a/transpose_b" attributes.
672 OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &trans_x_));
673 OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &trans_y_));
674 adj_x_ = false;
675 adj_y_ = false;
676 } else {
677 OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
678 OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
679 trans_x_ = false;
680 trans_y_ = false;
681 }
682 }
683
684 ~BaseBatchMatMulOp() override {}
685
686 void Compute(OpKernelContext* ctx) override {
687 const Tensor& in0 = ctx->input(0);
688 const Tensor& in1 = ctx->input(1);
689
690 const Status s = ValidateInputTensors(ctx, in0, in1);
691 if (!s.ok()) {
692 ctx->SetStatus(s);
693 return;
694 }
695
696 MatMulBCast bcast(in0.shape().dim_sizes(), in1.shape().dim_sizes());
697 OP_REQUIRES(
698 ctx, bcast.IsValid(),
699 errors::InvalidArgument(
700 "In[0] and In[1] must have compatible batch dimensions: ",
701 in0.shape().DebugString(), " vs. ", in1.shape().DebugString()));
702
703 TensorShape out_shape = bcast.output_batch_shape();
704 auto batch_size = bcast.output_batch_size();
705 auto d0 = in0.dim_size(in0.dims() - 2);
706 auto d1 = in0.dim_size(in0.dims() - 1);
707 Tensor in0_reshaped;
708 OP_REQUIRES(
709 ctx,
710 in0_reshaped.CopyFrom(in0, TensorShape({bcast.x_batch_size(), d0, d1})),
711 errors::Internal("Failed to reshape In[0] from ",
712 in0.shape().DebugString()));
713 auto d2 = in1.dim_size(in1.dims() - 2);
714 auto d3 = in1.dim_size(in1.dims() - 1);
715 Tensor in1_reshaped;
716 OP_REQUIRES(
717 ctx,
718 in1_reshaped.CopyFrom(in1, TensorShape({bcast.y_batch_size(), d2, d3})),
719 errors::Internal("Failed to reshape In[1] from ",
720 in1.shape().DebugString()));
721 if (adj_x_ || trans_x_) std::swap(d0, d1);
722 if (adj_y_ || trans_y_) std::swap(d2, d3);
723 OP_REQUIRES(
724 ctx, d1 == d2,
725 errors::InvalidArgument(
726 "Matrix size-incompatible: In[0]: ", in0.shape().DebugString(),
727 ", In[1]: ", in1.shape().DebugString()));
728 out_shape.AddDim(d0);
729 out_shape.AddDim(d3);
730 Tensor* out = nullptr;
731 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
732 if (out->NumElements() == 0) {
733 return;
734 }
735 if (in0.NumElements() == 0 || in1.NumElements() == 0) {
736 functor::SetZeroFunctor<Device, Tout> f;
737 f(ctx->eigen_device<Device>(), out->flat<Tout>());
738 return;
739 }
740 Tensor out_reshaped;
741 OP_REQUIRES(ctx,
742 out_reshaped.CopyFrom(*out, TensorShape({batch_size, d0, d3})),
743 errors::Internal("Failed to reshape output from ",
744 out->shape().DebugString()));
745 if (std::is_same<Ta, bfloat16>::value &&
746 std::is_same<Tb, bfloat16>::value) {
747 bool is_cpu = std::is_same<Device, CPUDevice>::value;
748 OP_REQUIRES(ctx, is_cpu,
749 errors::Internal("bfloat16 matmul is not supported by GPU"));
750 Tensor in0_reshaped_float, in1_reshaped_float, out_reshaped_float;
751 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in0_reshaped.shape(),
752 &in0_reshaped_float));
753 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, in1_reshaped.shape(),
754 &in1_reshaped_float));
755 OP_REQUIRES_OK(ctx, ctx->allocate_temp(DT_FLOAT, out_reshaped.shape(),
756 &out_reshaped_float));
757
758 // TODO: Avoid extra copy to make bfloat16 matmul efficient on CPU.
759 BFloat16ToFloat(in0_reshaped.flat<bfloat16>().data(),
760 in0_reshaped_float.flat<float>().data(),
761 in0_reshaped.NumElements());
762 BFloat16ToFloat(in1_reshaped.flat<bfloat16>().data(),
763 in1_reshaped_float.flat<float>().data(),
764 in1_reshaped.NumElements());
765
766 LaunchBatchMatMul<Device, float>::Launch(
767 ctx, in0_reshaped_float, in1_reshaped_float, adj_x_, adj_y_, trans_x_,
768 trans_y_, bcast, &out_reshaped_float);
769 FloatToBFloat16(out_reshaped_float.flat<float>().data(),
770 out_reshaped.flat<bfloat16>().data(), out->NumElements());
771 } else {
772 // Cast tensor to desired type to reuse Eigen.
773 // TODO(b/178749687): remove this cast if Eigen supports this natively.
774 if (!std::is_same<Ta, Tout>::value) {
775 in0_reshaped = CastTensor<Ta, Tout>(in0_reshaped);
776 }
777 if (!std::is_same<Tb, Tout>::value) {
778 in1_reshaped = CastTensor<Tb, Tout>(in1_reshaped);
779 }
780 LaunchBatchMatMul<Device, Tout>::Launch(ctx, in0_reshaped, in1_reshaped,
781 adj_x_, adj_y_, trans_x_,
782 trans_y_, bcast, &out_reshaped);
783 }
784 }
785
786 protected:
787 virtual Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
788 const Tensor& in1) = 0;
789
790 private:
791 // TODO(171979567) Make the ops take both adj and transpose attributes.
792 bool adj_x_ = false;
793 bool adj_y_ = false;
794 bool trans_x_ = false;
795 bool trans_y_ = false;
796
797 // Cast `t` from `SrcT` to `DstT`.
798 template <typename SrcT, typename DstT>
799 Tensor CastTensor(const Tensor& t) {
800 Tensor res = Tensor(DataTypeToEnum<DstT>::v(), t.shape());
801 res.flat<DstT>() = t.flat<SrcT>().template cast<DstT>();
802 return res;
803 }
804};
805
806// BatchMatMul Op implementation which disallows broadcasting.
807template <typename Device, typename Ta, typename Tb, typename Tout,
808 bool is_legacy_matmul = false>
809class BatchMatMulOp : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
810 public:
811 explicit BatchMatMulOp(OpKernelConstruction* context)
812 : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context, is_legacy_matmul) {}
813
814 ~BatchMatMulOp() override {}
815
816 private:
817 Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
818 const Tensor& in1) override {
819 // Disallow broadcasting support. Ensure that all batch dimensions of the
820 // input tensors match.
821 if (in0.dims() != in1.dims()) {
822 return errors::InvalidArgument(
823 "In[0] and In[1] has different ndims: ", in0.shape().DebugString(),
824 " vs. ", in1.shape().DebugString());
825 }
826 const int ndims = in0.dims();
827 if (is_legacy_matmul) {
828 if (ndims != 2) {
829 return errors::InvalidArgument("In[0] and In[1] ndims must be == 2: ",
830 ndims);
831 }
832 } else {
833 if (ndims < 2) {
834 return errors::InvalidArgument("In[0] and In[1] ndims must be >= 2: ",
835 ndims);
836 }
837 for (int i = 0; i < ndims - 2; ++i) {
838 if (in0.dim_size(i) != in1.dim_size(i)) {
839 return errors::InvalidArgument(
840 "In[0].dim(", i, ") and In[1].dim(", i,
841 ") must be the same: ", in0.shape().DebugString(), " vs ",
842 in1.shape().DebugString());
843 }
844 }
845 }
846 return OkStatus();
847 }
848};
849
850// BatchMatMul Op implementation with broadcasting support.
851template <typename Device, typename Ta, typename Tb, typename Tout>
852class BatchMatMulV2Op : public BaseBatchMatMulOp<Device, Ta, Tb, Tout> {
853 public:
854 explicit BatchMatMulV2Op(OpKernelConstruction* context)
855 : BaseBatchMatMulOp<Device, Ta, Tb, Tout>(context,
856 /* is_legacy_matmul= */ false) {
857 }
858
859 ~BatchMatMulV2Op() override {}
860
861 private:
862 Status ValidateInputTensors(OpKernelContext* ctx, const Tensor& in0,
863 const Tensor& in1) override {
864 // Enable broadcasting support. Validity of broadcasting is checked in
865 // BaseBatchMatMulOp.
866 if (in0.dims() < 2) {
867 return errors::InvalidArgument("In[0] ndims must be >= 2: ", in0.dims());
868 }
869 if (in1.dims() < 2) {
870 return errors::InvalidArgument("In[1] ndims must be >= 2: ", in1.dims());
871 }
872 return OkStatus();
873 }
874};
875
876// Register for MatMul, BatchMatMul, BatchMatMulv2 where Tin = Tout.
877#define REGISTER_BATCH_MATMUL_CPU(TYPE) \
878 REGISTER_KERNEL_BUILDER( \
879 Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
880 BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE>); \
881 REGISTER_KERNEL_BUILDER( \
882 Name("BatchMatMulV2").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
883 BatchMatMulV2Op<CPUDevice, TYPE, TYPE, TYPE>); \
884 REGISTER_KERNEL_BUILDER( \
885 Name("MatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
886 BatchMatMulOp<CPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
887
888#define REGISTER_BATCH_MATMUL_GPU(TYPE) \
889 REGISTER_KERNEL_BUILDER( \
890 Name("BatchMatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
891 BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE>); \
892 REGISTER_KERNEL_BUILDER( \
893 Name("BatchMatMulV2").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
894 BatchMatMulV2Op<GPUDevice, TYPE, TYPE, TYPE>); \
895 REGISTER_KERNEL_BUILDER( \
896 Name("MatMul").Device(DEVICE_GPU).TypeConstraint<TYPE>("T"), \
897 BatchMatMulOp<GPUDevice, TYPE, TYPE, TYPE, /* is_legacy_matmul=*/true>)
898
899// Register for BatchMatMulv3 where Ta, Tb and Tout are not the same.
900#define REGISTER_BATCH_MATMUL_TOUT_CPU(Ta, Tb, Tout) \
901 REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \
902 .Device(DEVICE_CPU) \
903 .TypeConstraint<Ta>("Ta") \
904 .TypeConstraint<Tb>("Tb") \
905 .TypeConstraint<Tout>("Tout"), \
906 BatchMatMulV2Op<CPUDevice, Ta, Tb, Tout>)
907
908#define REGISTER_BATCH_MATMUL_TOUT_GPU(Ta, Tb, Tout) \
909 REGISTER_KERNEL_BUILDER(Name("BatchMatMulV3") \
910 .Device(DEVICE_GPU) \
911 .TypeConstraint<Ta>("Ta") \
912 .TypeConstraint<Tb>("Tb") \
913 .TypeConstraint<Tout>("Tout"), \
914 BatchMatMulV2Op<GPUDevice, Ta, Tb, Tout>)
915
916} // namespace tensorflow
917
918#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_IMPL_H_
919