1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16// Implements matmul operations with other kernels baked into the
17// processing, to optimize latency and memory usage:
18// - MatMul + BiasAdd + <Activation>
19// - MatMul + FusedBatchNorm + <Activation>
20//
21// Activation: Relu, Relu6, Elu, etc...
22//
23// Currently supported only on CPU device.
24
25#ifndef TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
26#define TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
27
28#define USE_EIGEN_TENSOR
29#define EIGEN_USE_THREADS
30
31#if GOOGLE_CUDA
32#define EIGEN_USE_GPU
33#endif // GOOGLE_CUDA
34
35#include <string>
36#include <utility>
37#include <vector>
38
39#include "tensorflow/core/framework/bounds_check.h"
40#include "tensorflow/core/framework/op_kernel.h"
41#include "tensorflow/core/framework/register_types.h"
42#include "tensorflow/core/framework/tensor.h"
43#include "tensorflow/core/framework/tensor_shape.h"
44#include "tensorflow/core/kernels/fill_functor.h"
45#include "tensorflow/core/kernels/fused_eigen_output_kernels.h"
46#include "tensorflow/core/platform/errors.h"
47#include "tensorflow/core/util/matmul_autotune.h"
48#include "tensorflow/core/util/tensor_format.h"
49
50#if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
51#include "tensorflow/core/kernels/eigen_contraction_kernel.h"
52#endif
53
54#if GOOGLE_CUDA
55#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_asm_opts.h"
56#include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h"
57#include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h"
58#include "tensorflow/core/kernels/conv_ops_gpu.h"
59#include "tensorflow/core/kernels/gpu_utils.h"
60#include "tensorflow/core/kernels/matmul_op_impl.h"
61#include "tensorflow/core/kernels/matmul_util.h"
62#include "tensorflow/core/platform/stream_executor.h"
63#include "tensorflow/core/platform/tensor_float_32_utils.h"
64#include "tensorflow/core/profiler/lib/scoped_annotation.h"
65#include "tensorflow/core/protobuf/autotuning.pb.h"
66#include "tensorflow/core/util/autotune_maps/conv_parameters.h"
67#include "tensorflow/core/util/proto/proto_utils.h"
68#include "tensorflow/core/util/use_cudnn.h"
69#endif // GOOGLE_CUDA
70
71namespace tensorflow {
72
73typedef Eigen::ThreadPoolDevice CPUDevice;
74typedef Eigen::GpuDevice GPUDevice;
75
76template <typename Device, typename T>
77struct LaunchFusedMatMulOp {
78 void operator()(
79 OpKernelContext* context, const Tensor& a, const Tensor& b,
80 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
81 FusedComputationType fusion, const FusedComputationArgs& fusion_args,
82 Tensor* output, bool use_autotune);
83};
84
85template <typename T>
86struct LaunchFusedMatMulOp<CPUDevice, T> {
87 void operator()(
88 OpKernelContext* context, const Tensor& a, const Tensor& b,
89 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
90 FusedComputationType fusion, const FusedComputationArgs& fusion_args,
91 Tensor* output, bool use_autotune) {
92 OP_REQUIRES(context, DataTypeToEnum<T>::value != DT_HALF,
93 errors::InvalidArgument("_FusedMatMul doesn't support DT_HALF "
94 "data type on CPU devices."));
95 auto lhs = a.matrix<T>();
96 auto rhs = b.matrix<T>();
97 auto out = output->matrix<T>();
98
99 auto& d = context->eigen_device<CPUDevice>();
100
101 // Executes Eigen contraction with output kernel wrapped into type erased
102 // wrapper to reduce the number of unique template instantiations.
103 auto executeWithOutputKernel = [&](auto output_kernel) {
104 OutputKernelWrapper output_kernel_wrapper(
105 [&output_kernel](
106 const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
107 const Eigen::TensorContractionParams& params, Eigen::Index i,
108 Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) {
109 output_kernel(output_mapper, params, i, j, num_rows, num_cols);
110 });
111
112 out.device(d) = lhs.contract(rhs, dim_pair, output_kernel_wrapper);
113 };
114
115 BiasAddArgs<T> bias_add_args;
116 if (BiasAddArgs<T>::IsSupported(fusion)) {
117 if (fusion == FusedComputationType::kBiasAddWithLeakyRelu) {
118 OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args,
119 &fusion_args.leakyrelu_alpha));
120 } else {
121 OP_REQUIRES_OK(context, InitBiasAddArgs(context, &bias_add_args));
122 }
123 }
124
125 switch (fusion) {
126 case FusedComputationType::kBiasAdd:
127 executeWithOutputKernel(WithBiasAdd<T>(bias_add_args));
128 break;
129 case FusedComputationType::kBiasAddWithRelu:
130 executeWithOutputKernel(WithBiasAddAndRelu<T>(bias_add_args));
131 break;
132 case FusedComputationType::kBiasAddWithRelu6:
133 executeWithOutputKernel(WithBiasAddAndRelu6<T>(bias_add_args));
134 break;
135 case FusedComputationType::kBiasAddWithElu:
136 executeWithOutputKernel(WithBiasAddAndElu<T>(bias_add_args));
137 break;
138 case FusedComputationType::kBiasAddWithLeakyRelu:
139 executeWithOutputKernel(WithBiasAddAndLeakyRelu<T>(bias_add_args));
140 break;
141 case FusedComputationType::kUndefined:
142 OP_REQUIRES_OK(context, errors::Internal("Fusion type is undefined"));
143 break;
144 default:
145 OP_REQUIRES_OK(context,
146 errors::Internal("Fusion type is not supported"));
147 }
148 }
149
150 private:
151 // Wrap output_kernel into type erased struct to reduce the number of unique
152 // template instantiations for Eigen Tensor contraction expressions.
153 //
154 // We do not pass std::function directly as an output kernel because it blows
155 // up the binary size in debug mode with super long symbol names.
156 struct OutputKernelWrapper {
157 using OutputKernelFn =
158 std::function<void(const ContractionOutputMapper<T, Eigen::Index>&,
159 const Eigen::TensorContractionParams&, Eigen::Index,
160 Eigen::Index, Eigen::Index, Eigen::Index)>;
161
162 explicit OutputKernelWrapper(OutputKernelFn fn)
163 : output_kernel_fn(std::move(fn)) {}
164
165 void operator()(
166 const ContractionOutputMapper<T, Eigen::Index>& output_mapper,
167 const Eigen::TensorContractionParams& params, Eigen::Index i,
168 Eigen::Index j, Eigen::Index num_rows, Eigen::Index num_cols) const {
169 output_kernel_fn(output_mapper, params, i, j, num_rows, num_cols);
170 }
171
172 OutputKernelFn output_kernel_fn;
173 };
174};
175
176#if GOOGLE_CUDA
177namespace {
178
179StatusOr<se::cuda::BlasLt::Epilogue> GetBlasLtEpilogOp(
180 FusedComputationType fusion) {
181 if (fusion == FusedComputationType::kBiasAdd) {
182 return se::cuda::BlasLt::Epilogue::kBias;
183 } else if (fusion == FusedComputationType::kBiasAddWithRelu) {
184 return se::cuda::BlasLt::Epilogue::kBiasThenReLU;
185 } else if (fusion == FusedComputationType::kBiasAddWithGeluApproximate) {
186 return se::cuda::BlasLt::Epilogue::kBiasThenGeLUApproximate;
187 } else {
188 return errors::Internal("Unsupported fusion for BlasLt Matmul");
189 }
190}
191
192template <typename LaunchFunc>
193se::blas::AlgorithmConfig AutotuneMatmul(
194 const std::vector<se::cuda::BlasLt::MatmulAlgorithm>& algorithms,
195 BlasLtMatmulPlanParams& matmul_params, OpKernelContext* context,
196 const LaunchFunc& launch_func) {
197 // Note that algorithm_config.algorithm() here is used to refer
198 // to the index within the algorithms vector, not the algorithm
199 // itself.
200 se::blas::AlgorithmConfig algorithm_config(se::blas::kNoAlgorithm);
201 if (!AutoTuneBatchMatmul::GetInstance()->Find(matmul_params,
202 &algorithm_config)) {
203 VLOG(4) << "Autotuning BlasLtMatmul over " << algorithms.size()
204 << " algorithms.";
205 se::blas::ProfileResult best_result;
206 se::blas::ProfileResult profile_result;
207
208 for (size_t i = 0; i != algorithms.size(); ++i) {
209 const auto& profile_algorithm = algorithms[i];
210
211 // Create a new scratch allocator with every autotuning run so that
212 // scratch space is deallocated between runs.
213 BlasScratchAllocator scratch_allocator(context);
214
215 Status cublaslt_launch =
216 launch_func(scratch_allocator, profile_algorithm, &profile_result);
217
218 VLOG(4) << " Autotune algorithm " << i
219 << " result: " << profile_result.elapsed_time_in_ms()
220 << " ms, valid=" << profile_result.is_valid()
221 << ", workspace_size=" << profile_algorithm.workspace_size;
222
223 if (cublaslt_launch.ok() && profile_result.is_valid() &&
224 profile_result.elapsed_time_in_ms() <
225 best_result.elapsed_time_in_ms()) {
226 best_result = profile_result;
227 // Use index into algorithms array, instead of cublas internal ID.
228 best_result.set_algorithm(i);
229 }
230 }
231
232 if (best_result.is_valid()) {
233 algorithm_config.set_algorithm(best_result.algorithm());
234 }
235 // We make sure that each matmul parameter set only gets one pass of
236 // autotune. If no algorithms works, we add kNoAlgorithm to the autotune
237 // map.
238 AutoTuneBatchMatmul::GetInstance()->Insert(matmul_params, algorithm_config);
239 }
240 return algorithm_config;
241}
242
243template <typename LaunchFunc, typename Sig>
244StatusOr<std::vector<tensorflow::AutotuneResult>> AutotuneMatMulImpl(
245 OpKernelContext* ctx,
246 std::vector<std::unique_ptr<const se::dnn::OpRunner<Sig>>>& runners,
247 bool actually_do_autotune, const LaunchFunc& launch_func,
248 size_t scratch_size_limit, const se::RedzoneAllocator& rz_allocator) {
249 auto* stream = ctx->op_device_context()->stream();
250
251 se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
252 stream);
253
254 std::vector<tensorflow::AutotuneResult> results;
255 results.reserve(runners.size());
256 // TODO(reedwm): Warn if determinism is enabled after autotune is run
257 for (auto& runner : runners) {
258 // TODO(zhengxq): profile each algorithm multiple times to better
259 // accuracy.
260 se::RedzoneAllocator rz_scratch_allocator(
261 stream, &tf_allocator_adapter, se::GpuAsmOpts(),
262 /*memory_limit=*/scratch_size_limit);
263 BlasScratchAllocator scratch_allocator(ctx, scratch_size_limit);
264 se::ScratchAllocator* allocator_used =
265 !RedzoneCheckDisabled()
266 ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator)
267 : static_cast<se::ScratchAllocator*>(&scratch_allocator);
268
269 TF_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc());
270 se::dnn::ProfileResult profile_result;
271 Status cudnn_launch_status =
272 actually_do_autotune
273 ? launch_func(allocator_used, runner, &profile_result)
274 : OkStatus();
275 if (!actually_do_autotune) {
276 // Make the result valid according to `is_valid`.
277 profile_result.set_algorithm(desc);
278 profile_result.set_elapsed_time_in_ms(0);
279 }
280
281 // We need to make sure the profiling results are one-to-one with the
282 // "runners". So, we insert dummy results when the execution fails.
283 results.emplace_back();
284 auto& result = results.back();
285 *result.mutable_algorithm() = desc.ToProto();
286 if (cudnn_launch_status.ok() && profile_result.is_valid()) {
287 result.set_scratch_bytes(
288 !RedzoneCheckDisabled()
289 ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones()
290 : scratch_allocator.TotalByteSize());
291 *result.mutable_run_time() = proto_utils::ToDurationProto(
292 absl::Milliseconds(profile_result.elapsed_time_in_ms()));
293
294 CheckRedzones(rz_scratch_allocator, &result);
295 CheckRedzones(rz_allocator, &result);
296 } else {
297 result.mutable_failure()->set_kind(AutotuneResult::UNKNOWN);
298 result.mutable_failure()->set_msg(
299 absl::StrCat("Profiling failure on CUDNN engine ", desc.ToString(),
300 ": ", cudnn_launch_status.ToString()));
301 }
302 }
303
304 return results;
305}
306
307struct FusedMatmulAutotuneGroup {
308 static string name() { return "FusedMatmul"; }
309};
310
311typedef AutotuneSingleton<FusedMatmulAutotuneGroup, MatmulParameters,
312 AutotuneEntry<se::dnn::FusedMatmulOp>>
313 FusedMatmulAutotuneMap;
314
315template <typename T>
316StatusOr<AutotuneEntry<se::dnn::FusedMatmulOp>> AutotuneFusedMatmul(
317 bool cudnn_use_autotune,
318 AutotuneMap<MatmulParameters, AutotuneEntry<se::dnn::FusedMatmulOp>>*
319 autotune_map,
320 const MatmulParameters& params, OpKernelContext* ctx, bool trans_a,
321 bool trans_b, uint64_t m, uint64_t n, uint64_t k, int64_t lda, int64_t ldb,
322 int64_t ldc, se::dnn::ActivationMode activation_mode,
323 se::DeviceMemory<T> a_ptr, se::DeviceMemory<T> b_ptr,
324 se::DeviceMemory<T> c_ptr, se::DeviceMemory<T> bias_ptr,
325 int64_t scratch_size_limit) {
326 AutotuneEntry<se::dnn::FusedMatmulOp> autotune_entry;
327 auto* stream = ctx->op_device_context()->stream();
328
329 if (!autotune_map->Find(params, &autotune_entry)) {
330 profiler::ScopedAnnotation trace("cudnn_autotuning");
331
332 se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}),
333 stream);
334 se::RedzoneAllocator rz_allocator(stream, &tf_allocator_adapter,
335 se::GpuAsmOpts());
336 se::DeviceMemory<T> c_ptr_rz(WrapRedzoneBestEffort(&rz_allocator, c_ptr));
337
338 std::vector<std::unique_ptr<const se::dnn::FusedMatmulRunner>> runners;
339 auto element_type = se::dnn::ToDataType<T>::value;
340 TF_RETURN_IF_ERROR(stream->parent()->GetFusedMatmulRunners(
341 CudnnUseFrontend(), element_type, element_type, element_type, stream,
342 trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode,
343 /*use_fallback=*/false, &runners));
344
345 auto launch_func =
346 [&](se::ScratchAllocator* allocator_used,
347 const std::unique_ptr<const se::dnn::FusedMatmulRunner>& runner,
348 se::dnn::ProfileResult* profile_result) -> Status {
349 TF_ASSIGN_OR_RETURN(auto scratch, allocator_used->AllocateBytes(
350 runner->GetWorkspaceSize()));
351 return (*runner)(stream, profile_result, scratch, a_ptr, b_ptr, bias_ptr,
352 c_ptr_rz);
353 };
354
355 TF_ASSIGN_OR_RETURN(
356 auto results,
357 AutotuneMatMulImpl(ctx, runners, cudnn_use_autotune, launch_func,
358 scratch_size_limit, rz_allocator));
359 // Only log on an AutotuneConv cache miss.
360 LogFusedMatmulAutotuneResults(element_type, element_type, a_ptr, b_ptr,
361 c_ptr, bias_ptr, trans_a, trans_b, m, n, k,
362 lda, ldb, ldc, activation_mode,
363 stream->parent(), results);
364
365 // Two-level autotuning: Cudnn frontend supports two engine lists:
366 // heuristics and fallback. Heuristics engines are normally faster.
367 // To reduce autotuning time, we evaluate the fallback engines only when
368 // none of the heuristics engines work.
369 const bool found_working_engine =
370 std::any_of(results.cbegin(), results.cend(),
371 [](const auto& result) { return !result.has_failure(); });
372
373 if (found_working_engine) {
374 TF_ASSIGN_OR_RETURN(autotune_entry,
375 BestCudnnConvAlgorithm<se::dnn::FusedMatmulOp>(
376 results, std::move(runners)));
377 } else {
378 LOG(WARNING)
379 << "None of the algorithms provided by cuDNN frontend heuristics "
380 "worked; trying fallback algorithms. Matmul: "
381 << params.ToString();
382 std::vector<std::unique_ptr<const se::dnn::FusedMatmulRunner>>
383 fallback_runners;
384 TF_RETURN_IF_ERROR(stream->parent()->GetFusedMatmulRunners(
385 CudnnUseFrontend(), element_type, element_type, element_type, stream,
386 trans_a, trans_b, m, n, k, lda, ldb, ldc, activation_mode,
387 /*use_fallback=*/true, &fallback_runners));
388
389 TF_ASSIGN_OR_RETURN(
390 auto fallback_results,
391 AutotuneMatMulImpl(ctx, fallback_runners, cudnn_use_autotune,
392 launch_func, scratch_size_limit, rz_allocator));
393
394 LogFusedMatmulAutotuneResults(element_type, element_type, a_ptr, b_ptr,
395 c_ptr, bias_ptr, trans_a, trans_b, m, n, k,
396 lda, ldb, ldc, activation_mode,
397 stream->parent(), fallback_results);
398
399 TF_ASSIGN_OR_RETURN(autotune_entry,
400 BestCudnnConvAlgorithm<se::dnn::FusedMatmulOp>(
401 fallback_results, std::move(fallback_runners)));
402 }
403
404 autotune_map->Insert(params, autotune_entry);
405 }
406 return autotune_entry;
407}
408
409} // namespace
410
411template <typename T>
412struct LaunchFusedMatMulOp<GPUDevice, T> {
413 void operator()(
414 OpKernelContext* context, const Tensor& a, const Tensor& b,
415 const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
416 FusedComputationType fusion, const FusedComputationArgs& fusion_args,
417 Tensor* output, bool use_autotune) {
418 OP_REQUIRES(
419 context, DataTypeToEnum<T>::value != DT_BFLOAT16,
420 errors::InvalidArgument("_FusedMatMul doesn't support "
421 "DT_BFLOAT16 data type on CPU devices."));
422 auto* stream = context->op_device_context()->stream();
423 OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
424
425 // All fusion patterns supported by GPU are in the form of MatMul + BiasAdd
426 // + <other pointwise operations>. Therefore, the bias tensor is required.
427 const Tensor& bias = context->input(2);
428
429 if (bias.dims() != 1) {
430 OP_REQUIRES_OK(context,
431 errors::InvalidArgument("bias must be 1-dimensional",
432 bias.shape().DebugString()));
433 }
434
435 auto a_ptr = AsDeviceMemory(a.template flat<T>().data(),
436 a.template flat<T>().size());
437 auto b_ptr = AsDeviceMemory(b.template flat<T>().data(),
438 b.template flat<T>().size());
439 auto bias_ptr = AsDeviceMemory(bias.template flat<T>().data(),
440 bias.template flat<T>().size());
441 auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
442 output->template flat<T>().size());
443
444 bool trans_a = dim_pair[0].first == 0 ? true : false;
445 bool trans_b = dim_pair[0].second == 1 ? true : false;
446
447 const int64_t m = a.dim_size(trans_a ? 1 : 0);
448 const int64_t k = a.dim_size(trans_a ? 0 : 1);
449 const int64_t n = b.dim_size(trans_b ? 0 : 1);
450
451 bool use_cudnn = false;
452 se::dnn::ActivationMode matmul_activation_mode;
453 switch (fusion) {
454 case FusedComputationType::kBiasAddWithGeluExact:
455 matmul_activation_mode = se::dnn::ActivationMode::kGeluExact;
456 use_cudnn = true;
457 break;
458 case FusedComputationType::kBiasAddWithTanh:
459 matmul_activation_mode = se::dnn::ActivationMode::kTanh;
460 use_cudnn = true;
461 break;
462 case FusedComputationType::kBiasAddWithSigmoid:
463 matmul_activation_mode = se::dnn::ActivationMode::kSigmoid;
464 use_cudnn = true;
465 break;
466 default:
467 use_cudnn = false;
468 }
469
470 BlasScratchAllocator scratch_allocator(context);
471
472 // The Gelu exact fusion is supported by the cuDNN.
473 if (use_cudnn) {
474 int device_id = stream->parent()->device_ordinal();
475 DataType ab_dtype = a.dtype();
476 DataType c_dtype = output->dtype();
477 MatmulParameters cudnn_matmul_params = {/*ab_type=*/ab_dtype,
478 /*c_type=*/c_dtype,
479 trans_a,
480 trans_b,
481 static_cast<uint64_t>(m),
482 static_cast<uint64_t>(n),
483 static_cast<uint64_t>(k),
484 a.dim_size(1),
485 b.dim_size(1),
486 output->dim_size(1),
487 matmul_activation_mode,
488 device_id};
489
490 auto entry_or = AutotuneFusedMatmul<T>(
491 use_autotune, FusedMatmulAutotuneMap::GetInstance(),
492 cudnn_matmul_params, context, trans_a, trans_b, m, n, k,
493 a.dim_size(1), b.dim_size(1), output->dim_size(1),
494 matmul_activation_mode, a_ptr, b_ptr, c_ptr, bias_ptr,
495 GetDnnWorkspaceLimitOrDefault());
496 OP_REQUIRES_OK(context, entry_or.status());
497 auto autotune_entry = std::move(entry_or).value();
498
499 auto& runners = autotune_entry.GetOpRunners();
500 se::dnn::FusedMatmulOp::Config config;
501 auto primary_or = runners.primary->GetOrCreateRunner(config, stream);
502 OP_REQUIRES_OK(context, primary_or.status());
503 auto* primary = primary_or.value();
504
505 const se::dnn::FusedMatmulRunner* no_scratch_fallback = nullptr;
506 if (runners.no_scratch_fallback) {
507 auto no_scratch_fallback_or =
508 runners.no_scratch_fallback->GetOrCreateRunner(config, stream);
509 OP_REQUIRES_OK(context, no_scratch_fallback_or.status());
510 no_scratch_fallback = no_scratch_fallback_or.value();
511 }
512
513 auto runner_and_scratch_or =
514 AllocateScratchOrFallback<se::dnn::FusedMatmulOp::Signature>(
515 &scratch_allocator, primary, no_scratch_fallback);
516 OP_REQUIRES_OK(context, runner_and_scratch_or.status());
517 auto runner_and_scratch = std::move(runner_and_scratch_or).value();
518 auto& runner =
519 *std::get<const se::dnn::FusedMatmulRunner*>(runner_and_scratch);
520 Status cudnn_launch_status = runner(
521 stream, nullptr, std::get<se::DeviceMemoryBase>(runner_and_scratch),
522 a_ptr, b_ptr, bias_ptr, c_ptr);
523 OP_REQUIRES_OK(context, cudnn_launch_status);
524 return;
525 }
526
527 auto epilog_op_or = GetBlasLtEpilogOp(fusion);
528 OP_REQUIRES_OK(context, epilog_op_or.status());
529 se::cuda::BlasLt::Epilogue epilog_op = epilog_op_or.value();
530
531 se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
532 se::blas::Transpose::kTranspose};
533
534 BlasLtMatmulPlanParams matmul_params{se::blas::ToDataType<T>::value,
535 static_cast<size_t>(m),
536 static_cast<size_t>(n),
537 static_cast<size_t>(k),
538 trans[trans_a ? 1 : 0],
539 trans[trans_b ? 1 : 0],
540 /*batch_size=*/1,
541 /*broadcast_a=*/false,
542 /*broadcast_b=*/false,
543 epilog_op};
544
545 auto plan_and_algorithms_or = GetPlanAndAlgorithms(stream, matmul_params);
546 OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
547 const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value();
548 const auto& plan = plan_and_algorithms->plan;
549 const auto& algorithms = plan_and_algorithms->algorithms;
550 OP_REQUIRES(context, algorithms.size() > 0,
551 errors::InvalidArgument("No matmul algorithm returned!"));
552
553 auto launch_func = [&](BlasScratchAllocator& scratch_allocator,
554 const se::cuda::BlasLt::MatmulAlgorithm& algorithm,
555 se::blas::ProfileResult* profile_result) {
556 return DoBlasLtMatmul(stream, plan, a_ptr, b_ptr, c_ptr, algorithm,
557 scratch_allocator, bias_ptr, profile_result);
558 };
559
560 se::cuda::BlasLt::MatmulAlgorithm algorithm = algorithms[0];
561 if (use_autotune) {
562 se::blas::AlgorithmConfig algorithm_config =
563 AutotuneMatmul(algorithms, matmul_params, context, launch_func);
564
565 se::blas::AlgorithmType algorithm_idx = algorithm_config.algorithm();
566 algorithm = algorithms[algorithm_idx];
567 }
568
569 OP_REQUIRES_OK(context, launch_func(scratch_allocator, algorithm, nullptr));
570 }
571};
572
573#endif // GOOGLE_CUDA
574
575template <typename Device, typename T>
576class FusedMatMulOp : public OpKernel {
577 public:
578 explicit FusedMatMulOp(OpKernelConstruction* context) : OpKernel(context) {
579 OP_REQUIRES_OK(context, context->GetAttr("transpose_a", &transpose_a_));
580 OP_REQUIRES_OK(context, context->GetAttr("transpose_b", &transpose_b_));
581
582 std::vector<FusedComputationPattern> patterns;
583
584 using FCT = FusedComputationType;
585 if (std::is_same<Device, CPUDevice>::value) {
586 patterns = {
587 {FCT::kBiasAdd, {"BiasAdd"}},
588 {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
589 {FCT::kBiasAddWithRelu6, {"BiasAdd", "Relu6"}},
590 {FCT::kBiasAddWithElu, {"BiasAdd", "Elu"}},
591 {FCT::kBiasAddWithLeakyRelu, {"BiasAdd", "LeakyRelu"}},
592 };
593 } else if (std::is_same<Device, GPUDevice>::value) {
594 patterns = {
595 {FCT::kBiasAdd, {"BiasAdd"}},
596 {FCT::kBiasAddWithRelu, {"BiasAdd", "Relu"}},
597 {FCT::kBiasAddWithTanh, {"BiasAdd", "Tanh"}},
598 {FCT::kBiasAddWithSigmoid, {"BiasAdd", "Sigmoid"}},
599 {FCT::kBiasAddWithGeluApproximate, {"BiasAdd", "GeluApproximate"}},
600 {FCT::kBiasAddWithGeluExact, {"BiasAdd", "GeluExact"}}};
601 }
602
603 OP_REQUIRES_OK(context, InitializeFusedComputation(
604 context, "MatMul", patterns,
605 &fused_computation_, &fused_computation_args_));
606 if (std::is_same<Device, GPUDevice>::value &&
607 (fused_computation_ == FCT::kBiasAddWithGeluExact ||
608 fused_computation_ == FCT::kBiasAddWithTanh ||
609 fused_computation_ == FCT::kBiasAddWithSigmoid)) {
610 OP_REQUIRES(context, DataTypeToEnum<T>::value == DT_HALF,
611 errors::InvalidArgument(
612 "Matmul with BiasAdd+GeluExact|Tanh|Sigmoid supports "
613 "only DT_HALF data type."));
614 }
615 use_autotune_ = MatmulAutotuneEnable();
616 }
617
618 void Compute(OpKernelContext* ctx) override {
619 const Tensor& a = ctx->input(0);
620 const Tensor& b = ctx->input(1);
621
622 // Check that the dimensions of the two matrices are valid.
623 OP_REQUIRES(ctx, a.dims() == b.dims(),
624 errors::InvalidArgument("In[0] and In[1] has different ndims: ",
625 a.shape().DebugString(), " vs. ",
626 b.shape().DebugString()));
627 OP_REQUIRES(
628 ctx, TensorShapeUtils::IsMatrix(a.shape()),
629 errors::InvalidArgument("In[0] is not a matrix. Instead it has shape ",
630 a.shape().DebugString()));
631 OP_REQUIRES(
632 ctx, TensorShapeUtils::IsMatrix(b.shape()),
633 errors::InvalidArgument("In[1] is not a matrix. Instead it has shape ",
634 b.shape().DebugString()));
635 Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair;
636 dim_pair[0].first = transpose_a_ ? 0 : 1;
637 dim_pair[0].second = transpose_b_ ? 1 : 0;
638
639 OP_REQUIRES(
640 ctx, a.dim_size(dim_pair[0].first) == b.dim_size(dim_pair[0].second),
641 errors::InvalidArgument(
642 "Matrix size-incompatible: In[0]: ", a.shape().DebugString(),
643 ", In[1]: ", b.shape().DebugString()));
644 int a_dim_remaining = 1 - dim_pair[0].first;
645 int b_dim_remaining = 1 - dim_pair[0].second;
646 TensorShape out_shape(
647 {a.dim_size(a_dim_remaining), b.dim_size(b_dim_remaining)});
648 Tensor* out = nullptr;
649 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
650
651 if (out->NumElements() == 0) {
652 // If a has shape [0, x] or b has shape [x, 0], the output shape
653 // is a 0-element matrix, so there is nothing to do.
654 return;
655 }
656
657 if (a.NumElements() == 0 && b.NumElements() == 0) {
658 // If a has shape [x, 0] and b has shape [0, y], the
659 // output shape is [x, y] where x and y are non-zero, so we fill
660 // the output with zeros.
661 functor::SetZeroFunctor<Device, T> f;
662 f(ctx->eigen_device<Device>(), out->flat<T>());
663 return;
664 }
665
666 auto launch = LaunchFusedMatMulOp<Device, T>();
667 launch(ctx, a, b, dim_pair, fused_computation_, fused_computation_args_,
668 out, use_autotune_);
669 }
670
671 private:
672 bool transpose_a_;
673 bool transpose_b_;
674 bool use_autotune_;
675
676 FusedComputationType fused_computation_ = FusedComputationType::kUndefined;
677 FusedComputationArgs fused_computation_args_;
678
679 TF_DISALLOW_COPY_AND_ASSIGN(FusedMatMulOp);
680};
681
682// Registration of the CPU implementations.
683#define REGISTER_FUSED_CPU_MATMUL(T) \
684 REGISTER_KERNEL_BUILDER( \
685 Name("_FusedMatMul").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
686 FusedMatMulOp<CPUDevice, T>);
687
688TF_CALL_float(REGISTER_FUSED_CPU_MATMUL);
689
690#undef REGISTER_FUSED_CPU_MATMUL
691
692#if GOOGLE_CUDA
693
694// Registration of the GPU implementations.
695#define REGISTER_FUSED_GPU_MATMUL(T) \
696 REGISTER_KERNEL_BUILDER( \
697 Name("_FusedMatMul").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
698 FusedMatMulOp<GPUDevice, T>);
699
700TF_CALL_float(REGISTER_FUSED_GPU_MATMUL);
701TF_CALL_half(REGISTER_FUSED_GPU_MATMUL);
702
703#undef REGISTER_FUSED_GPU_MATMUL
704
705#endif // GOOGLE_CUDA
706
707} // namespace tensorflow
708#endif // TENSORFLOW_CORE_KERNELS_MATMUL_OP_FUSED_H_
709