1/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#include <vector>
16
17#include "tensorflow/c/c_api_macros.h"
18#include "tensorflow/c/c_api_macros_internal.h"
19#include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h"
20#include "tensorflow/c/tf_status_helper.h"
21#include "tensorflow/core/common_runtime/device/device_utils.h"
22#include "tensorflow/core/platform/errors.h"
23#include "tensorflow/core/platform/logging.h"
24#include "tensorflow/core/platform/mutex.h"
25#include "tensorflow/core/platform/status.h"
26#include "tensorflow/core/profiler/lib/profiler_factory.h"
27#include "tensorflow/core/profiler/lib/profiler_interface.h"
28#include "tensorflow/core/profiler/profiler_options.pb.h"
29
30namespace tensorflow {
31namespace profiler {
32
33namespace {
34
35Status ValidateTPProfilerRegistrationParams(
36 const TF_ProfilerRegistrationParams& params) {
37 TF_VALIDATE_STRUCT_SIZE(TF_ProfilerRegistrationParams, params,
38 TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE);
39 TF_VALIDATE_NOT_NULL(TF_ProfilerRegistrationParams, params, destroy_profiler);
40 TF_VALIDATE_NOT_NULL(TF_ProfilerRegistrationParams, params,
41 destroy_profiler_fns);
42 return OkStatus();
43}
44
45Status ValidateTPProfiler(const TP_Profiler& profiler) {
46 TF_VALIDATE_STRUCT_SIZE(TP_Profiler, profiler, TP_PROFILER_STRUCT_SIZE);
47 TF_VALIDATE_NOT_NULL(TP_Profiler, profiler, device_type);
48 TF_RETURN_IF_ERROR(
49 tensorflow::device_utils::ValidateDeviceType(profiler.device_type));
50 return OkStatus();
51}
52
53Status ValidateTPProfilerFns(const TP_ProfilerFns& profiler_fns) {
54 TF_VALIDATE_STRUCT_SIZE(TP_ProfilerFns, profiler_fns,
55 TF_PROFILER_FNS_STRUCT_SIZE);
56 TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, start);
57 TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, stop);
58 TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, collect_data_xspace);
59 return OkStatus();
60}
61
62class PluggableProfiler : public tensorflow::profiler::ProfilerInterface {
63 public:
64 // The caller must have validated profiler_fns and profiler.
65 static std::unique_ptr<tensorflow::profiler::ProfilerInterface>
66 CreatePluggableProfiler(const ProfileOptions& options, TP_Profiler profiler,
67 TP_ProfilerFns profiler_fns) {
68 if (options.device_tracer_level() == 0) {
69 return nullptr;
70 }
71 if (options.device_type() != ProfileOptions::PLUGGABLE_DEVICE &&
72 options.device_type() != ProfileOptions::UNSPECIFIED) {
73 return nullptr;
74 }
75 return absl::WrapUnique(new PluggableProfiler(profiler_fns, profiler));
76 }
77
78 Status Start() override {
79 tensorflow::TF_StatusPtr status(TF_NewStatus());
80 profiler_fns_.start(&profiler_, status.get());
81 return tensorflow::StatusFromTF_Status(status.get());
82 }
83
84 Status Stop() override {
85 tensorflow::TF_StatusPtr status(TF_NewStatus());
86 profiler_fns_.stop(&profiler_, status.get());
87 return tensorflow::StatusFromTF_Status(status.get());
88 }
89
90 Status CollectData(XSpace* space) override {
91 tensorflow::TF_StatusPtr status(TF_NewStatus());
92 // Get size of buffer required for Plugin to serialize XSpace into it.
93 size_t size_in_bytes;
94 profiler_fns_.collect_data_xspace(&profiler_, /*buffer=*/nullptr,
95 &size_in_bytes, status.get());
96
97 if (size_in_bytes <= 0)
98 return tensorflow::StatusFromTF_Status(status.get());
99
100 // Prepare an appropriately sized buffer.
101 std::vector<uint8_t> buffer(size_in_bytes);
102 profiler_fns_.collect_data_xspace(&profiler_, buffer.data(), &size_in_bytes,
103 status.get());
104 // Deserialize XSpace from the buffer and return it.
105 XSpace plugin_space;
106 plugin_space.ParseFromArray(buffer.data(), buffer.size());
107 for (XPlane& plugin_plane : *plugin_space.mutable_planes()) {
108 XPlane* plane = space->add_planes();
109 plane->Swap(&plugin_plane);
110 }
111 return tensorflow::StatusFromTF_Status(status.get());
112 }
113
114 private:
115 PluggableProfiler(TP_ProfilerFns profiler_fns, TP_Profiler profiler)
116 : profiler_fns_(profiler_fns), profiler_(profiler) {}
117 TP_ProfilerFns profiler_fns_;
118 TP_Profiler profiler_;
119};
120
121class PluggableProfilerFactory {
122 public:
123 PluggableProfilerFactory(TP_Profiler profiler,
124 void (*destroy_profiler)(TP_Profiler*),
125 TP_ProfilerFns profiler_fns,
126 void (*destroy_profiler_fns)(TP_ProfilerFns*))
127 : profiler_(std::move(profiler)),
128 destroy_profiler_(destroy_profiler),
129 profiler_fns_(std::move(profiler_fns)),
130 destroy_profiler_fns_(destroy_profiler_fns) {}
131
132 ~PluggableProfilerFactory() {
133 destroy_profiler_(&profiler_);
134 destroy_profiler_fns_(&profiler_fns_);
135 }
136
137 std::unique_ptr<tensorflow::profiler::ProfilerInterface>
138 CreatePluggableProfiler(const ProfileOptions& options) {
139 return PluggableProfiler::CreatePluggableProfiler(options, profiler_,
140 profiler_fns_);
141 }
142
143 private:
144 TP_Profiler profiler_{TP_PROFILER_STRUCT_SIZE};
145 void (*destroy_profiler_)(TP_Profiler*);
146 TP_ProfilerFns profiler_fns_{TP_PROFILER_FNS_STRUCT_SIZE};
147 void (*destroy_profiler_fns_)(TP_ProfilerFns*);
148};
149
150} // namespace
151
152Status InitPluginProfiler(TFInitProfilerFn init_fn) {
153 TF_ProfilerRegistrationParams params{
154 TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE};
155 TP_Profiler profiler{TP_PROFILER_STRUCT_SIZE};
156 TP_ProfilerFns profiler_fns{TP_PROFILER_FNS_STRUCT_SIZE};
157 params.major_version = TP_MAJOR;
158 params.minor_version = TP_MINOR;
159 params.patch_version = TP_PATCH;
160 params.profiler = &profiler;
161 params.profiler_fns = &profiler_fns;
162 tensorflow::TF_StatusPtr status(TF_NewStatus());
163 init_fn(&params, status.get());
164 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(status.get()));
165 TF_RETURN_IF_ERROR(ValidateTPProfilerRegistrationParams(params));
166 TF_RETURN_IF_ERROR(ValidateTPProfiler(profiler));
167 TF_RETURN_IF_ERROR(ValidateTPProfilerFns(profiler_fns));
168
169 PluggableProfilerFactory factory(std::move(profiler), params.destroy_profiler,
170 std::move(profiler_fns),
171 params.destroy_profiler_fns);
172 std::function<std::unique_ptr<ProfilerInterface>(const ProfileOptions&)>
173 create_func = [factory = std::move(factory)](
174 const ProfileOptions& options) mutable {
175 return factory.CreatePluggableProfiler(options);
176 };
177
178 tensorflow::profiler::RegisterProfilerFactory(std::move(create_func));
179 return OkStatus();
180}
181
182} // namespace profiler
183} // namespace tensorflow
184