1 | /* Copyright 2021 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #include <vector> |
16 | |
17 | #include "tensorflow/c/c_api_macros.h" |
18 | #include "tensorflow/c/c_api_macros_internal.h" |
19 | #include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h" |
20 | #include "tensorflow/c/tf_status_helper.h" |
21 | #include "tensorflow/core/common_runtime/device/device_utils.h" |
22 | #include "tensorflow/core/platform/errors.h" |
23 | #include "tensorflow/core/platform/logging.h" |
24 | #include "tensorflow/core/platform/mutex.h" |
25 | #include "tensorflow/core/platform/status.h" |
26 | #include "tensorflow/core/profiler/lib/profiler_factory.h" |
27 | #include "tensorflow/core/profiler/lib/profiler_interface.h" |
28 | #include "tensorflow/core/profiler/profiler_options.pb.h" |
29 | |
30 | namespace tensorflow { |
31 | namespace profiler { |
32 | |
33 | namespace { |
34 | |
35 | Status ValidateTPProfilerRegistrationParams( |
36 | const TF_ProfilerRegistrationParams& params) { |
37 | TF_VALIDATE_STRUCT_SIZE(TF_ProfilerRegistrationParams, params, |
38 | TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE); |
39 | TF_VALIDATE_NOT_NULL(TF_ProfilerRegistrationParams, params, destroy_profiler); |
40 | TF_VALIDATE_NOT_NULL(TF_ProfilerRegistrationParams, params, |
41 | destroy_profiler_fns); |
42 | return OkStatus(); |
43 | } |
44 | |
45 | Status ValidateTPProfiler(const TP_Profiler& profiler) { |
46 | TF_VALIDATE_STRUCT_SIZE(TP_Profiler, profiler, TP_PROFILER_STRUCT_SIZE); |
47 | TF_VALIDATE_NOT_NULL(TP_Profiler, profiler, device_type); |
48 | TF_RETURN_IF_ERROR( |
49 | tensorflow::device_utils::ValidateDeviceType(profiler.device_type)); |
50 | return OkStatus(); |
51 | } |
52 | |
53 | Status ValidateTPProfilerFns(const TP_ProfilerFns& profiler_fns) { |
54 | TF_VALIDATE_STRUCT_SIZE(TP_ProfilerFns, profiler_fns, |
55 | TF_PROFILER_FNS_STRUCT_SIZE); |
56 | TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, start); |
57 | TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, stop); |
58 | TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, collect_data_xspace); |
59 | return OkStatus(); |
60 | } |
61 | |
62 | class PluggableProfiler : public tensorflow::profiler::ProfilerInterface { |
63 | public: |
64 | // The caller must have validated profiler_fns and profiler. |
65 | static std::unique_ptr<tensorflow::profiler::ProfilerInterface> |
66 | CreatePluggableProfiler(const ProfileOptions& options, TP_Profiler profiler, |
67 | TP_ProfilerFns profiler_fns) { |
68 | if (options.device_tracer_level() == 0) { |
69 | return nullptr; |
70 | } |
71 | if (options.device_type() != ProfileOptions::PLUGGABLE_DEVICE && |
72 | options.device_type() != ProfileOptions::UNSPECIFIED) { |
73 | return nullptr; |
74 | } |
75 | return absl::WrapUnique(new PluggableProfiler(profiler_fns, profiler)); |
76 | } |
77 | |
78 | Status Start() override { |
79 | tensorflow::TF_StatusPtr status(TF_NewStatus()); |
80 | profiler_fns_.start(&profiler_, status.get()); |
81 | return tensorflow::StatusFromTF_Status(status.get()); |
82 | } |
83 | |
84 | Status Stop() override { |
85 | tensorflow::TF_StatusPtr status(TF_NewStatus()); |
86 | profiler_fns_.stop(&profiler_, status.get()); |
87 | return tensorflow::StatusFromTF_Status(status.get()); |
88 | } |
89 | |
90 | Status CollectData(XSpace* space) override { |
91 | tensorflow::TF_StatusPtr status(TF_NewStatus()); |
92 | // Get size of buffer required for Plugin to serialize XSpace into it. |
93 | size_t size_in_bytes; |
94 | profiler_fns_.collect_data_xspace(&profiler_, /*buffer=*/nullptr, |
95 | &size_in_bytes, status.get()); |
96 | |
97 | if (size_in_bytes <= 0) |
98 | return tensorflow::StatusFromTF_Status(status.get()); |
99 | |
100 | // Prepare an appropriately sized buffer. |
101 | std::vector<uint8_t> buffer(size_in_bytes); |
102 | profiler_fns_.collect_data_xspace(&profiler_, buffer.data(), &size_in_bytes, |
103 | status.get()); |
104 | // Deserialize XSpace from the buffer and return it. |
105 | XSpace plugin_space; |
106 | plugin_space.ParseFromArray(buffer.data(), buffer.size()); |
107 | for (XPlane& plugin_plane : *plugin_space.mutable_planes()) { |
108 | XPlane* plane = space->add_planes(); |
109 | plane->Swap(&plugin_plane); |
110 | } |
111 | return tensorflow::StatusFromTF_Status(status.get()); |
112 | } |
113 | |
114 | private: |
115 | PluggableProfiler(TP_ProfilerFns profiler_fns, TP_Profiler profiler) |
116 | : profiler_fns_(profiler_fns), profiler_(profiler) {} |
117 | TP_ProfilerFns profiler_fns_; |
118 | TP_Profiler profiler_; |
119 | }; |
120 | |
121 | class PluggableProfilerFactory { |
122 | public: |
123 | PluggableProfilerFactory(TP_Profiler profiler, |
124 | void (*destroy_profiler)(TP_Profiler*), |
125 | TP_ProfilerFns profiler_fns, |
126 | void (*destroy_profiler_fns)(TP_ProfilerFns*)) |
127 | : profiler_(std::move(profiler)), |
128 | destroy_profiler_(destroy_profiler), |
129 | profiler_fns_(std::move(profiler_fns)), |
130 | destroy_profiler_fns_(destroy_profiler_fns) {} |
131 | |
132 | ~PluggableProfilerFactory() { |
133 | destroy_profiler_(&profiler_); |
134 | destroy_profiler_fns_(&profiler_fns_); |
135 | } |
136 | |
137 | std::unique_ptr<tensorflow::profiler::ProfilerInterface> |
138 | CreatePluggableProfiler(const ProfileOptions& options) { |
139 | return PluggableProfiler::CreatePluggableProfiler(options, profiler_, |
140 | profiler_fns_); |
141 | } |
142 | |
143 | private: |
144 | TP_Profiler profiler_{TP_PROFILER_STRUCT_SIZE}; |
145 | void (*destroy_profiler_)(TP_Profiler*); |
146 | TP_ProfilerFns profiler_fns_{TP_PROFILER_FNS_STRUCT_SIZE}; |
147 | void (*destroy_profiler_fns_)(TP_ProfilerFns*); |
148 | }; |
149 | |
150 | } // namespace |
151 | |
152 | Status InitPluginProfiler(TFInitProfilerFn init_fn) { |
153 | TF_ProfilerRegistrationParams params{ |
154 | TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE}; |
155 | TP_Profiler profiler{TP_PROFILER_STRUCT_SIZE}; |
156 | TP_ProfilerFns profiler_fns{TP_PROFILER_FNS_STRUCT_SIZE}; |
157 | params.major_version = TP_MAJOR; |
158 | params.minor_version = TP_MINOR; |
159 | params.patch_version = TP_PATCH; |
160 | params.profiler = &profiler; |
161 | params.profiler_fns = &profiler_fns; |
162 | tensorflow::TF_StatusPtr status(TF_NewStatus()); |
163 | init_fn(¶ms, status.get()); |
164 | TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(status.get())); |
165 | TF_RETURN_IF_ERROR(ValidateTPProfilerRegistrationParams(params)); |
166 | TF_RETURN_IF_ERROR(ValidateTPProfiler(profiler)); |
167 | TF_RETURN_IF_ERROR(ValidateTPProfilerFns(profiler_fns)); |
168 | |
169 | PluggableProfilerFactory factory(std::move(profiler), params.destroy_profiler, |
170 | std::move(profiler_fns), |
171 | params.destroy_profiler_fns); |
172 | std::function<std::unique_ptr<ProfilerInterface>(const ProfileOptions&)> |
173 | create_func = [factory = std::move(factory)]( |
174 | const ProfileOptions& options) mutable { |
175 | return factory.CreatePluggableProfiler(options); |
176 | }; |
177 | |
178 | tensorflow::profiler::RegisterProfilerFactory(std::move(create_func)); |
179 | return OkStatus(); |
180 | } |
181 | |
182 | } // namespace profiler |
183 | } // namespace tensorflow |
184 | |