1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ------------------------------------------------------------------------------*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_KERNELS_AUTOTUNE_CONV_IMPL_H_ |
17 | #define TENSORFLOW_CORE_KERNELS_AUTOTUNE_CONV_IMPL_H_ |
18 | |
19 | #if GOOGLE_CUDA |
20 | #define EIGEN_USE_THREADS |
21 | |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/kernels/conv_ops_gpu.h" |
24 | #include "tensorflow/core/util/proto/proto_utils.h" |
25 | |
26 | #include "tensorflow/compiler/xla/stream_executor/gpu/redzone_allocator.h" |
27 | #include "tensorflow/compiler/xla/stream_executor/tf_allocator_adapter.h" |
28 | |
29 | namespace tensorflow::internal { |
30 | |
31 | template <typename LaunchFunc, typename Sig> |
32 | StatusOr<std::vector<tensorflow::AutotuneResult>> AutotuneConvImpl( |
33 | OpKernelContext* ctx, |
34 | std::vector<std::unique_ptr<const se::dnn::OpRunner<Sig>>>& runners, |
35 | bool actually_do_autotune, const LaunchFunc& launch_func, |
36 | size_t scratch_size_limit, const se::RedzoneAllocator& rz_allocator) { |
37 | auto* stream = ctx->op_device_context()->stream(); |
38 | |
39 | se::TfAllocatorAdapter tf_allocator_adapter(ctx->device()->GetAllocator({}), |
40 | stream); |
41 | |
42 | std::vector<tensorflow::AutotuneResult> results; |
43 | // TODO(reedwm): Warn if determinism is enabled after autotune is run |
44 | for (auto& runner : runners) { |
45 | // TODO(zhengxq): profile each algorithm multiple times to better |
46 | // accuracy. |
47 | se::RedzoneAllocator rz_scratch_allocator( |
48 | stream, &tf_allocator_adapter, se::GpuAsmOpts(), |
49 | /*memory_limit=*/scratch_size_limit); |
50 | DnnScratchAllocator scratch_allocator(scratch_size_limit, ctx); |
51 | se::ScratchAllocator* allocator_used = |
52 | !RedzoneCheckDisabled() |
53 | ? static_cast<se::ScratchAllocator*>(&rz_scratch_allocator) |
54 | : static_cast<se::ScratchAllocator*>(&scratch_allocator); |
55 | |
56 | TF_ASSIGN_OR_RETURN(auto desc, runner->ToAlgorithmDesc()); |
57 | se::dnn::ProfileResult profile_result; |
58 | Status cudnn_launch_status = |
59 | actually_do_autotune |
60 | ? launch_func(allocator_used, runner, &profile_result) |
61 | : OkStatus(); |
62 | if (!actually_do_autotune) { |
63 | // Make the result valid according to `is_valid`. |
64 | profile_result.set_algorithm(desc); |
65 | profile_result.set_elapsed_time_in_ms(0); |
66 | } |
67 | |
68 | // We need to make sure the profiling results are one-to-one with the |
69 | // "runners". So, we insert dummy results when the execution fails. |
70 | results.emplace_back(); |
71 | auto& result = results.back(); |
72 | *result.mutable_algorithm() = desc.ToProto(); |
73 | if (cudnn_launch_status.ok() && profile_result.is_valid()) { |
74 | result.set_scratch_bytes( |
75 | !RedzoneCheckDisabled() |
76 | ? rz_scratch_allocator.TotalAllocatedBytesExcludingRedzones() |
77 | : scratch_allocator.TotalByteSize()); |
78 | *result.mutable_run_time() = proto_utils::ToDurationProto( |
79 | absl::Milliseconds(profile_result.elapsed_time_in_ms())); |
80 | |
81 | CheckRedzones(rz_scratch_allocator, &result); |
82 | CheckRedzones(rz_allocator, &result); |
83 | } else { |
84 | result.mutable_failure()->set_kind(AutotuneResult::UNKNOWN); |
85 | result.mutable_failure()->set_msg( |
86 | absl::StrCat("Profiling failure on CUDNN engine " , desc.ToString(), |
87 | ": " , cudnn_launch_status.ToString())); |
88 | } |
89 | } |
90 | |
91 | return results; |
92 | } |
93 | |
94 | } // namespace tensorflow::internal |
95 | |
96 | #endif // GOOGLE_CUDA |
97 | |
98 | #endif // TENSORFLOW_CORE_KERNELS_AUTOTUNE_CONV_IMPL_H_ |
99 | |