1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// Classes and utilities that work with StreamExecutor C API for internal use.
16// This includes functions used for device registration and interfaces needed
17// for testing.
18#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
19#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
20
21#include "tensorflow/c/experimental/stream_executor/stream_executor.h"
22#include "tensorflow/c/tf_status_helper.h"
23#include "tensorflow/compiler/xla/stream_executor/executor_cache.h"
24#include "tensorflow/compiler/xla/stream_executor/lib/status.h"
25#include "tensorflow/compiler/xla/stream_executor/platform.h"
26
27namespace stream_executor {
28
29// Plugin initialization function that a device plugin
30// must define.
31typedef void (*SEInitPluginFn)(SE_PlatformRegistrationParams* const,
32 TF_Status* const);
33
34// Registers StreamExecutor platform. `device_type` and `platform_name` are
35// output parameters.
36port::Status InitStreamExecutorPlugin(void* dso_handle,
37 std::string* device_type,
38 std::string* platform_name);
39
40// Allow registering a StreamExecutor plugin using a function (used for
41// testing).
42port::Status InitStreamExecutorPlugin(SEInitPluginFn init_fn,
43 std::string* device_type,
44 std::string* platform_name);
45
46// This file implements core stream executor base classes in terms of
47// the C API defined in stream_executor.h. A class "CSomething" represents a
48// "Something" that can be manipulated via calls in the C interface.
49class CPlatform : public Platform {
50 public:
51 explicit CPlatform(SP_Platform platform,
52 void (*destroy_platform)(SP_Platform*),
53 SP_PlatformFns platform_fns,
54 void (*destroy_platform_fns)(SP_PlatformFns*),
55 SP_DeviceFns device_fns, SP_StreamExecutor stream_executor,
56 SP_TimerFns timer_fns);
57 ~CPlatform() override;
58
59 Id id() const override { return const_cast<int*>(&plugin_id_value_); }
60 const std::string& Name() const override { return name_; }
61 int VisibleDeviceCount() const override {
62 int visible_device_count = 0;
63 tensorflow::TF_StatusPtr c_status(TF_NewStatus());
64 platform_fns_.get_device_count(&platform_, &visible_device_count,
65 c_status.get());
66 if (TF_GetCode(c_status.get()) != TF_OK) {
67 LOG(ERROR) << TF_Message(c_status.get());
68 return 0;
69 }
70 return visible_device_count;
71 }
72 bool UseBfcAllocator() const { return platform_.use_bfc_allocator; }
73 bool ForceMemoryGrowth() const { return platform_.force_memory_growth; }
74 port::StatusOr<std::unique_ptr<DeviceDescription>> DescriptionForDevice(
75 int ordinal) const override;
76 port::StatusOr<StreamExecutor*> ExecutorForDevice(int ordinal) override;
77 port::StatusOr<StreamExecutor*> ExecutorForDeviceWithPluginConfig(
78 int ordinal, const PluginConfig& plugin_config) override;
79 port::StatusOr<StreamExecutor*> GetExecutor(
80 const StreamExecutorConfig& config) override;
81 port::StatusOr<std::unique_ptr<StreamExecutor>> GetUncachedExecutor(
82 const StreamExecutorConfig& config) override;
83
84 // Trace listener is not supported
85 void RegisterTraceListener(std::unique_ptr<TraceListener> listener) override {
86 LOG(FATAL) << "RegisterTraceListener is not supported by pluggable device";
87 }
88 void UnregisterTraceListener(TraceListener* listener) override {}
89
90 void DestroyAllExecutors() { executor_cache_.DestroyAllExecutors(); }
91
92 private:
93 SP_Platform platform_;
94 void (*destroy_platform_)(SP_Platform*);
95 SP_PlatformFns platform_fns_;
96 void (*destroy_platform_fns_)(SP_PlatformFns*);
97 SP_DeviceFns device_fns_;
98 SP_StreamExecutor stream_executor_;
99 SP_TimerFns timer_fns_;
100 const std::string name_;
101 int plugin_id_value_;
102 stream_executor::ExecutorCache executor_cache_;
103};
104
105class CStream : public internal::StreamInterface {
106 public:
107 CStream(SP_Device* device, SP_StreamExecutor* stream_executor)
108 : device_(device),
109 stream_executor_(stream_executor),
110 stream_handle_(nullptr) {}
111 ~CStream() override { Destroy(); }
112
113 port::Status Create() {
114 tensorflow::TF_StatusPtr c_status(TF_NewStatus());
115 stream_executor_->create_stream(device_, &stream_handle_, c_status.get());
116 port::Status s = tensorflow::StatusFromTF_Status(c_status.get());
117 return s;
118 }
119
120 void Destroy() {
121 if (stream_handle_ != nullptr) {
122 stream_executor_->destroy_stream(device_, stream_handle_);
123 stream_handle_ = nullptr;
124 }
125 }
126
127 SP_Stream Handle() { return stream_handle_; }
128
129 private:
130 SP_Device* device_;
131 SP_StreamExecutor* stream_executor_;
132 SP_Stream stream_handle_;
133};
134
135class CEvent : public internal::EventInterface {
136 public:
137 CEvent(SP_Device* device, SP_StreamExecutor* stream_executor)
138 : device_(device),
139 stream_executor_(stream_executor),
140 event_handle_(nullptr) {}
141 ~CEvent() override { Destroy(); }
142
143 port::Status Create() {
144 tensorflow::TF_StatusPtr c_status(TF_NewStatus());
145 stream_executor_->create_event(device_, &event_handle_, c_status.get());
146 return tensorflow::StatusFromTF_Status(c_status.get());
147 }
148
149 port::Status Record(SP_Stream stream_handle) {
150 tensorflow::TF_StatusPtr c_status(TF_NewStatus());
151 stream_executor_->record_event(device_, stream_handle, event_handle_,
152 c_status.get());
153 return tensorflow::StatusFromTF_Status(c_status.get());
154 }
155
156 void Destroy() {
157 if (event_handle_ != nullptr) {
158 stream_executor_->destroy_event(device_, event_handle_);
159 event_handle_ = nullptr;
160 }
161 }
162
163 SP_Event Handle() { return event_handle_; }
164
165 private:
166 SP_Device* device_;
167 SP_StreamExecutor* stream_executor_;
168 SP_Event event_handle_;
169};
170
171class CTimer : public internal::TimerInterface {
172 public:
173 CTimer(SP_Device* device, SP_StreamExecutor* stream_executor,
174 SP_TimerFns* timer_fns)
175 : device_(device),
176 stream_executor_(stream_executor),
177 timer_handle_(nullptr),
178 timer_fns_(timer_fns) {}
179 ~CTimer() override { Destroy(); }
180
181 port::Status Create() {
182 tensorflow::TF_StatusPtr c_status(TF_NewStatus());
183 stream_executor_->create_timer(device_, &timer_handle_, c_status.get());
184 return tensorflow::StatusFromTF_Status(c_status.get());
185 }
186
187 void Destroy() {
188 if (timer_handle_ != nullptr) {
189 stream_executor_->destroy_timer(device_, timer_handle_);
190 timer_handle_ = nullptr;
191 }
192 }
193
194 SP_Timer Handle() { return timer_handle_; }
195
196 uint64 Microseconds() const override {
197 return timer_fns_->nanoseconds(timer_handle_) / 1000;
198 }
199
200 uint64 Nanoseconds() const override {
201 return timer_fns_->nanoseconds(timer_handle_);
202 }
203
204 private:
205 SP_Device* device_;
206 SP_StreamExecutor* stream_executor_;
207 SP_Timer timer_handle_;
208 SP_TimerFns* timer_fns_;
209};
210
211} // namespace stream_executor
212#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_INTERNAL_H_
213