1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | // Classes and utilities that work with StreamExecutor C API for internal use. |
16 | // This includes functions used for device registration and interfaces needed |
17 | // for testing. |
18 | #ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ |
19 | #define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ |
20 | |
21 | #include "tensorflow/c/experimental/stream_executor/stream_executor.h" |
22 | #include "tensorflow/c/tf_status_helper.h" |
23 | #include "tensorflow/compiler/xla/stream_executor/executor_cache.h" |
24 | #include "tensorflow/compiler/xla/stream_executor/lib/status.h" |
25 | #include "tensorflow/compiler/xla/stream_executor/platform.h" |
26 | |
27 | namespace stream_executor { |
28 | |
29 | // Plugin initialization function that a device plugin |
30 | // must define. |
31 | typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const, |
32 | TF_Status* const); |
33 | |
34 | // Registers StreamExecutor platform. `device_type` and `platform_name` are |
35 | // output parameters. |
36 | port::Status InitStreamExecutorPlugin(void* dso_handle, |
37 | std::string* device_type, |
38 | std::string* platform_name); |
39 | |
40 | // Allow registering a StreamExecutor plugin using a function (used for |
41 | // testing). |
42 | port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn, |
43 | std::string* device_type, |
44 | std::string* platform_name); |
45 | |
46 | // This file implements core stream executor base classes in terms of |
47 | // the C API defined in stream_executor.h. A class "CSomething" represents a |
48 | // "Something" that can be manipulated via calls in the C interface. |
49 | class CPlatform : public Platform { |
50 | public: |
51 | explicit CPlatform(SP_Platform platform, |
52 | void (*destroy_platform)(SP_Platform*), |
53 | SP_PlatformFns platform_fns, |
54 | void (*destroy_platform_fns)(SP_PlatformFns*), |
55 | SP_DeviceFns device_fns, SP_StreamExecutor stream_executor, |
56 | SP_TimerFns timer_fns); |
57 | ~CPlatform() override; |
58 | |
59 | Id id() const override { return const_cast<int*>(&plugin_id_value_); } |
60 | const std::string& Name() const override { return name_; } |
61 | int VisibleDeviceCount() const override { |
62 | int visible_device_count = 0; |
63 | tensorflow::TF_StatusPtr c_status(TF_NewStatus()); |
64 | platform_fns_.get_device_count(&platform_, &visible_device_count, |
65 | c_status.get()); |
66 | if (TF_GetCode(c_status.get()) != TF_OK) { |
67 | LOG(ERROR) << TF_Message(c_status.get()); |
68 | return 0; |
69 | } |
70 | return visible_device_count; |
71 | } |
72 | bool UseBfcAllocator() const { return platform_.use_bfc_allocator; } |
73 | bool ForceMemoryGrowth() const { return platform_.force_memory_growth; } |
74 | port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice( |
75 | int ordinal) const override; |
76 | port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override; |
77 | port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig( |
78 | int ordinal, const PluginConfig& plugin_config) override; |
79 | port::StatusOr<StreamExecutor*> GetExecutor( |
80 | const StreamExecutorConfig& config) override; |
81 | port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor( |
82 | const StreamExecutorConfig& config) override; |
83 | |
84 | // Trace listener is not supported |
85 | void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override { |
86 | LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device" ; |
87 | } |
88 | void UnregisterTraceListener(TraceListener* listener) override {} |
89 | |
90 | void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); } |
91 | |
92 | private: |
93 | SP_Platform platform_; |
94 | void (*destroy_platform_)(SP_Platform*); |
95 | SP_PlatformFns platform_fns_; |
96 | void (*destroy_platform_fns_)(SP_PlatformFns*); |
97 | SP_DeviceFns device_fns_; |
98 | SP_StreamExecutor stream_executor_; |
99 | SP_TimerFns timer_fns_; |
100 | const std::string name_; |
101 | int plugin_id_value_; |
102 | stream_executor::ExecutorCache executor_cache_; |
103 | }; |
104 | |
105 | class CStream : public internal::StreamInterface { |
106 | public: |
107 | CStream(SP_Device* device, SP_StreamExecutor* stream_executor) |
108 | : device_(device), |
109 | stream_executor_(stream_executor), |
110 | stream_handle_(nullptr) {} |
111 | ~CStream() override { Destroy(); } |
112 | |
113 | port::Status Create() { |
114 | tensorflow::TF_StatusPtr c_status(TF_NewStatus()); |
115 | stream_executor_->create_stream(device_, &stream_handle_, c_status.get()); |
116 | port::Status s = tensorflow::StatusFromTF_Status(c_status.get()); |
117 | return s; |
118 | } |
119 | |
120 | void Destroy() { |
121 | if (stream_handle_ != nullptr) { |
122 | stream_executor_->destroy_stream(device_, stream_handle_); |
123 | stream_handle_ = nullptr; |
124 | } |
125 | } |
126 | |
127 | SP_Stream Handle() { return stream_handle_; } |
128 | |
129 | private: |
130 | SP_Device* device_; |
131 | SP_StreamExecutor* stream_executor_; |
132 | SP_Stream stream_handle_; |
133 | }; |
134 | |
135 | class CEvent : public internal::EventInterface { |
136 | public: |
137 | CEvent(SP_Device* device, SP_StreamExecutor* stream_executor) |
138 | : device_(device), |
139 | stream_executor_(stream_executor), |
140 | event_handle_(nullptr) {} |
141 | ~CEvent() override { Destroy(); } |
142 | |
143 | port::Status Create() { |
144 | tensorflow::TF_StatusPtr c_status(TF_NewStatus()); |
145 | stream_executor_->create_event(device_, &event_handle_, c_status.get()); |
146 | return tensorflow::StatusFromTF_Status(c_status.get()); |
147 | } |
148 | |
149 | port::Status Record(SP_Stream stream_handle) { |
150 | tensorflow::TF_StatusPtr c_status(TF_NewStatus()); |
151 | stream_executor_->record_event(device_, stream_handle, event_handle_, |
152 | c_status.get()); |
153 | return tensorflow::StatusFromTF_Status(c_status.get()); |
154 | } |
155 | |
156 | void Destroy() { |
157 | if (event_handle_ != nullptr) { |
158 | stream_executor_->destroy_event(device_, event_handle_); |
159 | event_handle_ = nullptr; |
160 | } |
161 | } |
162 | |
163 | SP_Event Handle() { return event_handle_; } |
164 | |
165 | private: |
166 | SP_Device* device_; |
167 | SP_StreamExecutor* stream_executor_; |
168 | SP_Event event_handle_; |
169 | }; |
170 | |
171 | class CTimer : public internal::TimerInterface { |
172 | public: |
173 | CTimer(SP_Device* device, SP_StreamExecutor* stream_executor, |
174 | SP_TimerFns* timer_fns) |
175 | : device_(device), |
176 | stream_executor_(stream_executor), |
177 | timer_handle_(nullptr), |
178 | timer_fns_(timer_fns) {} |
179 | ~CTimer() override { Destroy(); } |
180 | |
181 | port::Status Create() { |
182 | tensorflow::TF_StatusPtr c_status(TF_NewStatus()); |
183 | stream_executor_->create_timer(device_, &timer_handle_, c_status.get()); |
184 | return tensorflow::StatusFromTF_Status(c_status.get()); |
185 | } |
186 | |
187 | void Destroy() { |
188 | if (timer_handle_ != nullptr) { |
189 | stream_executor_->destroy_timer(device_, timer_handle_); |
190 | timer_handle_ = nullptr; |
191 | } |
192 | } |
193 | |
194 | SP_Timer Handle() { return timer_handle_; } |
195 | |
196 | uint64 Microseconds() const override { |
197 | return timer_fns_->nanoseconds(timer_handle_) / 1000; |
198 | } |
199 | |
200 | uint64 Nanoseconds() const override { |
201 | return timer_fns_->nanoseconds(timer_handle_); |
202 | } |
203 | |
204 | private: |
205 | SP_Device* device_; |
206 | SP_StreamExecutor* stream_executor_; |
207 | SP_Timer timer_handle_; |
208 | SP_TimerFns* timer_fns_; |
209 | }; |
210 | |
211 | } // namespace stream_executor |
212 | #endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_ |
213 | |