1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
16#define TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
17#include <stddef.h>
18#include <stdint.h>
19
20#include "tensorflow/c/c_api_macros.h"
21#include "tensorflow/c/tf_status.h"
22
23// --------------------------------------------------------------------------
24// C API for StreamExecutor. The API is under active development and eventually
25// should allow registering a pluggable device with TensorFlow.
26//
27// Conventions:
28// * Struct prefix indicates whether struct fields should be filled by the
29// plugin or core implementation:
30// * SE_ : set/filled by core unless explicitly marked otherwise.
31// * SP_ : set/filled by plugin unless explicitly marked otherwise.
32// * We use `struct_size` for version checking. It is exempt from the `SE/SP`
33// rule above and should be set both by core and the plugin.
34// * For example, `create_device` function receives `SP_Device*` as input
35// with `struct_size` populated by core. The plugin is responsible for
36// setting `struct_size` as well, along with all other fields.
37// * Refer to "TensorFlow Versioning Strategy" section at
38// https://github.com/tensorflow/community/pull/257/files.
39// * Note that the API is still under active development and doesn't have
40// versioning guarantees yet.
41// * `void* ext` is a free-form field that can be populated by
42// a plugin in `SP_*` structs or potential future extension points in `SE_`
43// structs.
44//
45// Example usage:
46//
47// /* Sample TensorFlow code below, exact implementation might differ. */
48// // Version checking uses `struct_size`. It is exempt from the `SE/SP` rule
49// // above and should be set both by core and the plugin."
50// SP_Device device { SP_DEVICE_STRUCT_SIZE };
51// SE_CreateDeviceParams params { SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE } ;
52// params.device = &device;
53//
54// /* Plugin code below */
55// constexpr char DEVICE_NAME[] = "MY_DEVICE";
56// constexpr char DEVICE_TYPE[] = "GPU";
57//
58// void create_device(const SP_Platform* platform,
59// SE_CreateDeviceParams* params, TF_Status* status) {
60// // Custom actions based on TensorFlow's view of SP_Device.
61// OnTFDeviceView(params->device->struct_size);
62// params->device = { SP_DEVICE_STRUCT_SIZE };
63// params->device->device_handle = get_my_device_handle(device->ordinal);
64// params->device->ordinal = params->ordinal;
65// ...
66// }
67//
68// void destroy_device(const SP_Platform* platform, SP_Device* device) {
69// delete_my_device_handle(device->device_handle);
70// }
71//
72// void SE_InitPlugin(
73// SE_PlatformRegistrationParams* params,
74// TF_Status* status) {
75// params->platform = { SP_PLATFORM_STRUCT_SIZE };
76// // Values such as `name` and `type` must outlive SE_InitPlugin call.
77// params->platform->name = DEVICE_NAME;
78// params->platform->type = DEVICE_TYPE;
79// params->platform_fns->get_device_count = get_device_count;
80// params->platform_fns->create_device = create_device;
81// params->platform_fns->destroy_device = destroy_device;
82// ...
83// }
84
85#define SE_MAJOR 0
86#define SE_MINOR 0
87#define SE_PATCH 1
88
89#ifdef __cplusplus
90extern "C" {
91#endif
92
93typedef struct SP_Stream_st* SP_Stream;
94typedef struct SP_Event_st* SP_Event;
95typedef struct SP_Timer_st* SP_Timer;
96// Takes `callback_arg` passed to `host_callback` as the first argument.
97typedef void (*SE_StatusCallbackFn)(void* const, TF_Status* const);
98
99typedef struct SP_TimerFns {
100 size_t struct_size;
101 void* ext; // reserved for future use
102 uint64_t (*nanoseconds)(SP_Timer timer);
103} SP_TimerFns;
104
105#define SP_TIMER_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_TimerFns, nanoseconds)
106
107typedef struct SP_AllocatorStats {
108 size_t struct_size;
109 int64_t num_allocs;
110 int64_t bytes_in_use;
111 int64_t peak_bytes_in_use;
112 int64_t largest_alloc_size;
113
114 int8_t has_bytes_limit;
115 int64_t bytes_limit;
116
117 int64_t bytes_reserved;
118 int64_t peak_bytes_reserved;
119
120 int8_t has_bytes_reservable_limit;
121 int64_t bytes_reservable_limit;
122
123 int64_t largest_free_block_bytes;
124} SP_AllocatorStats;
125
126#define SP_ALLOCATORSTATS_STRUCT_SIZE \
127 TF_OFFSET_OF_END(SP_AllocatorStats, largest_free_block_bytes)
128
129// Potential states for an SP_Event. If `poll_for_status` returns anything aside
130// from kPending or kComplete, an error has occurred; kUnknown is a bad state.
131typedef enum SE_EventStatus {
132 SE_EVENT_UNKNOWN,
133 SE_EVENT_ERROR,
134 SE_EVENT_PENDING,
135 SE_EVENT_COMPLETE,
136} SE_EventStatus;
137
138// Memory allocation information.
139// This matches DeviceMemoryBase defined here:
140// https://cs.opensource.google/tensorflow/tensorflow/+/refs/tags/v2.3.0:tensorflow/compiler/xla/stream_executor/device_memory.h;l=57
141typedef struct SP_DeviceMemoryBase {
142 size_t struct_size;
143 void* ext; // Reserved for future use
144 // Platform-dependent value representing allocated memory.
145 // Note that the pointer does not have to be to the virtual address itself.
146 void* opaque;
147 uint64_t size; // Size in bytes of this allocation.
148 uint64_t payload; // Value for plugin's use
149} SP_DeviceMemoryBase;
150
151#define SP_DEVICE_MEMORY_BASE_STRUCT_SIZE \
152 TF_OFFSET_OF_END(SP_DeviceMemoryBase, payload)
153
154typedef struct SP_Device {
155 size_t struct_size;
156 void* ext; // free-form data set by plugin
157 int32_t ordinal; // device index
158
159 // Device vendor can store handle to their device representation
160 // here.
161 void* device_handle;
162
163 // [Optional]
164 // Device hardware name. Used for printing.
165 // Must be null-terminated.
166 const char* hardware_name;
167
168 // [Optional]
169 // Device vendor name. Used for printing.
170 // Must be null-terminated.
171 const char* device_vendor;
172
173 // [Optional]
174 // Returns the PCI bus identifier for this device, of the form
175 // [domain]:[bus]:[device].[function]
176 // where domain number is usually 0000.
177 // Example: 0000:00:02.1
178 // For more information see:
179 // https://en.wikipedia.org/wiki/PCI_configuration_space
180 // https://www.oreilly.com/library/view/linux-device-drivers/0596005903/ch12.html
181 // Used for printing. Must be null-terminated.
182 const char* pci_bus_id;
183} SP_Device;
184
185#define SP_DEVICE_STRUCT_SIZE TF_OFFSET_OF_END(SP_Device, pci_bus_id)
186
187typedef struct SE_CreateDeviceParams {
188 size_t struct_size;
189 void* ext; // reserved for future use
190 int32_t ordinal; // device index
191
192 SP_Device* device; // Input/output, struct_size set by TF for plugin to read.
193 // Subsequently plugin fills the entire struct.
194} SE_CreateDeviceParams;
195
196#define SE_CREATE_DEVICE_PARAMS_STRUCT_SIZE \
197 TF_OFFSET_OF_END(SE_CreateDeviceParams, device)
198
199typedef struct SP_DeviceFns {
200 size_t struct_size;
201 void* ext; // reserved for future use
202
203 // [Optional]
204 // Returns the NUMA node associated with this device, for use in
205 // determining socket locality. If the NUMA node could not be determined, -1
206 // is returned.
207 // Negative values are treated as "unset".
208 int32_t (*get_numa_node)(const SP_Device* device);
209
210 // [Optional]
211 // Device's memory bandwidth in bytes/sec. (This is for reads/writes to/from
212 // the device's own memory, not for transfers between the host and device.)
213 // Negative values are treated as "unset".
214 int64_t (*get_memory_bandwidth)(const SP_Device* device);
215
216 // [Optional]
217 // Estimate of average number of floating point operations per second for
218 // this device * 10e-9.
219 // Negative values are treated as "unset".
220 double (*get_gflops)(const SP_Device* device);
221} SP_DeviceFns;
222
223#define SP_DEVICE_FNS_STRUCT_SIZE TF_OFFSET_OF_END(SP_DeviceFns, get_gflops)
224
225typedef struct SE_CreateDeviceFnsParams {
226 size_t struct_size;
227 void* ext; // reserved for future use
228
229 SP_DeviceFns* device_fns; // output, to be filled by plugin
230} SE_CreateDeviceFnsParams;
231
232#define SE_CREATE_DEVICE_FNS_PARAMS_STRUCT_SIZE \
233 TF_OFFSET_OF_END(SE_CreateDeviceFnsParams, device_fns)
234
235typedef struct SP_StreamExecutor {
236 size_t struct_size;
237 void* ext; // reserved for future use
238
239 /*** ALLOCATION CALLBACKS ***/
240 // Synchronously allocates `size` bytes on the underlying platform and returns
241 // `SP_DeviceMemoryBase` representing that allocation. In the case of failure,
242 // nullptr is returned.
243 // `memory_space` is reserved for a potential future usage and should be set
244 // to 0.
245 void (*allocate)(const SP_Device* device, uint64_t size, int64_t memory_space,
246 SP_DeviceMemoryBase* mem);
247
248 // Deallocate the device memory previously allocated via this interface.
249 // Deallocation of a nullptr-representative value is permitted.
250 void (*deallocate)(const SP_Device* device, SP_DeviceMemoryBase* memory);
251
252 // Allocates a region of host memory and registers it with the platform API.
253 // Memory allocated in this manner is required for use in asynchronous memcpy
254 // operations, such as `memcpy_dtoh`.
255 void* (*host_memory_allocate)(const SP_Device* device, uint64_t size);
256
257 // Deallocates a region of host memory allocated by `host_memory_allocate`.
258 void (*host_memory_deallocate)(const SP_Device* device, void* mem);
259
260 // Allocates unified memory space of the given size, if supported. Unified
261 // memory support should be added by setting `supports_unified_memory` field
262 // in `SP_Platform`.
263 void* (*unified_memory_allocate)(const SP_Device* device, uint64_t bytes);
264
265 // Deallocates unified memory space previously allocated with
266 // `unified_memory_allocate`. Unified
267 // memory support should be added by setting `supports_unified_memory` field
268 // in `SP_Platform`.
269 void (*unified_memory_deallocate)(const SP_Device* device, void* location);
270
271 // Fills SP_AllocatorStats with allocator statistics, if it is available.
272 // If it is not available, return false.
273 TF_Bool (*get_allocator_stats)(const SP_Device* device,
274 SP_AllocatorStats* stats);
275 // Fills the underlying device memory usage information, if it is
276 // available. If it is not available (false is returned), free/total need not
277 // be initialized.
278 TF_Bool (*device_memory_usage)(const SP_Device* device, int64_t* free,
279 int64_t* total);
280
281 /*** STREAM CALLBACKS ***/
282 // Creates SP_Stream. This call should also allocate stream
283 // resources on the underlying platform and initializes its
284 // internals.
285 void (*create_stream)(const SP_Device* device, SP_Stream* stream,
286 TF_Status* status);
287
288 // Destroys SP_Stream and deallocates any underlying resources.
289 void (*destroy_stream)(const SP_Device* device, SP_Stream stream);
290
291 // Causes `dependent` to not begin execution until `other` has finished its
292 // last-enqueued work.
293 void (*create_stream_dependency)(const SP_Device* device, SP_Stream dependent,
294 SP_Stream other, TF_Status* status);
295
296 // Without blocking the device, retrieve the current stream status.
297 void (*get_stream_status)(const SP_Device* device, SP_Stream stream,
298 TF_Status* status);
299
300 /*** EVENT CALLBACKS ***/
301 // Create SP_Event. Performs platform-specific allocation and initialization
302 // of an event.
303 void (*create_event)(const SP_Device* device, SP_Event* event,
304 TF_Status* status);
305
306 // Destroy SE_Event and perform any platform-specific deallocation and
307 // cleanup of an event.
308 void (*destroy_event)(const SP_Device* device, SP_Event event);
309
310 // Requests the current status of the event from the underlying platform.
311 SE_EventStatus (*get_event_status)(const SP_Device* device, SP_Event event);
312 // Inserts the specified event at the end of the specified stream.
313 void (*record_event)(const SP_Device* device, SP_Stream stream,
314 SP_Event event, TF_Status* status);
315
316 // Wait for the specified event at the end of the specified stream.
317 void (*wait_for_event)(const SP_Device* const device, SP_Stream stream,
318 SP_Event event, TF_Status* const status);
319
320 /*** TIMER CALLBACKS ***/
321 // Creates SP_Timer. Allocates timer resources on the underlying platform
322 // and initializes its internals, setting `timer` output variable. Sets
323 // values in `timer_fns` struct.
324 void (*create_timer)(const SP_Device* device, SP_Timer* timer,
325 TF_Status* status);
326
327 // Destroy timer and deallocates timer resources on the underlying platform.
328 void (*destroy_timer)(const SP_Device* device, SP_Timer timer);
329
330 // Records a start event for an interval timer.
331 void (*start_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer,
332 TF_Status* status);
333
334 // Records a stop event for an interval timer.
335 void (*stop_timer)(const SP_Device* device, SP_Stream stream, SP_Timer timer,
336 TF_Status* status);
337
338 /*** MEMCPY CALLBACKS ***/
339 // Enqueues a memcpy operation onto stream, with a host destination location
340 // `host_dst` and a device memory source, with target size `size`.
341 void (*memcpy_dtoh)(const SP_Device* device, SP_Stream stream, void* host_dst,
342 const SP_DeviceMemoryBase* device_src, uint64_t size,
343 TF_Status* status);
344
345 // Enqueues a memcpy operation onto stream, with a device destination
346 // location and a host memory source, with target size `size`.
347 void (*memcpy_htod)(const SP_Device* device, SP_Stream stream,
348 SP_DeviceMemoryBase* device_dst, const void* host_src,
349 uint64_t size, TF_Status* status);
350
351 // Enqueues a memcpy operation onto stream, with a device destination
352 // location and a device memory source, with target size `size`.
353 void (*memcpy_dtod)(const SP_Device* device, SP_Stream stream,
354 SP_DeviceMemoryBase* device_dst,
355 const SP_DeviceMemoryBase* device_src, uint64_t size,
356 TF_Status* status);
357
358 // Blocks the caller while a data segment of the given size is
359 // copied from the device source to the host destination.
360 void (*sync_memcpy_dtoh)(const SP_Device* device, void* host_dst,
361 const SP_DeviceMemoryBase* device_src, uint64_t size,
362 TF_Status* status);
363
364 // Blocks the caller while a data segment of the given size is
365 // copied from the host source to the device destination.
366 void (*sync_memcpy_htod)(const SP_Device* device,
367 SP_DeviceMemoryBase* device_dst,
368 const void* host_src, uint64_t size,
369 TF_Status* status);
370
371 // Blocks the caller while a data segment of the given size is copied from the
372 // device source to the device destination.
373 void (*sync_memcpy_dtod)(const SP_Device* device,
374 SP_DeviceMemoryBase* device_dst,
375 const SP_DeviceMemoryBase* device_src, uint64_t size,
376 TF_Status* status);
377
378 // Causes the host code to synchronously wait for the event to complete.
379 void (*block_host_for_event)(const SP_Device* device, SP_Event event,
380 TF_Status* status);
381
382 // [Optional]
383 // Causes the host code to synchronously wait for operations entrained onto
384 // stream to complete. Effectively a join on the asynchronous device
385 // operations enqueued on the stream before this program point.
386 // If not set, then corresponding functionality will be implemented
387 // by registering an event on the `stream` and waiting for it using
388 // `block_host_for_event`.
389 void (*block_host_until_done)(const SP_Device* device, SP_Stream stream,
390 TF_Status* status);
391
392 // Synchronizes all activity occurring in the StreamExecutor's context (most
393 // likely a whole device).
394 void (*synchronize_all_activity)(const SP_Device* device, TF_Status* status);
395
396 // Zero out `size` bytes starting at the location.
397 void (*mem_zero)(const SP_Device* device, SP_Stream stream,
398 SP_DeviceMemoryBase* location, uint64_t size,
399 TF_Status* status);
400
401 // Set the 8-bit patterns starting at the location with `size` bytes.
402 void (*memset)(const SP_Device* device, SP_Stream stream,
403 SP_DeviceMemoryBase* location, uint8_t pattern, uint64_t size,
404 TF_Status* status);
405
406 // Set the 32-bit patterns starting at the location with `size` bytes.
407 void (*memset32)(const SP_Device* device, SP_Stream stream,
408 SP_DeviceMemoryBase* location, uint32_t pattern,
409 uint64_t size, TF_Status* status);
410
411 // Enqueues on a stream a user-specified function to be run on the host.
412 // `callback_arg` should be passed as the first argument to `callback_fn`.
413 TF_Bool (*host_callback)(const SP_Device* device, SP_Stream stream,
414 SE_StatusCallbackFn callback_fn, void* callback_arg);
415} SP_StreamExecutor;
416
417#define SP_STREAMEXECUTOR_STRUCT_SIZE \
418 TF_OFFSET_OF_END(SP_StreamExecutor, host_callback)
419
420typedef struct SE_CreateStreamExecutorParams {
421 size_t struct_size;
422 void* ext; // reserved for future use
423
424 SP_StreamExecutor* stream_executor; // output, to be filled by plugin
425} SE_CreateStreamExecutorParams;
426
427#define SE_CREATE_STREAM_EXECUTOR_PARAMS_STRUCT_SIZE \
428 TF_OFFSET_OF_END(SE_CreateStreamExecutorParams, stream_executor)
429
430typedef struct SP_Platform {
431 size_t struct_size;
432
433 void* ext; // free-form data set by plugin
434
435 // Platform name (also referred to as subtype), for example MY_DEVICE.
436 // The name must start with a capital letter and consist of
437 // capital letters and underscores.
438 // Must be null-terminated.
439 const char* name;
440
441 // Device type name, for example GPU. Must be null-terminated.
442 // The name must start with a capital letter and consist of
443 // capital letters and underscores.
444 const char* type;
445
446 // Whether this platform supports unified memory.
447 // Unified memory is a single memory address space accessible from any device.
448 TF_Bool supports_unified_memory;
449
450 // Whether to wrap allocator for this device with an allocator that uses BFC
451 // (best-fit with coalescing) strategy.
452 TF_Bool use_bfc_allocator;
453
454 // Whether to force the memory allocations to grow over time instead of
455 // allocating it all at once. When this is set to true, the value of
456 // allow_growth is ignored.
457 TF_Bool force_memory_growth;
458} SP_Platform;
459
460#define SP_PLATFORM_STRUCT_SIZE \
461 TF_OFFSET_OF_END(SP_Platform, force_memory_growth)
462
463typedef struct SP_PlatformFns {
464 size_t struct_size;
465
466 void* ext; // reserved for future use
467
468 // Callbacks for getting device count
469 void (*get_device_count)(const SP_Platform* platform, int* device_count,
470 TF_Status* status);
471 // Callbacks for creating/destroying SP_Device.
472 void (*create_device)(const SP_Platform* platform,
473 SE_CreateDeviceParams* params, TF_Status* status);
474
475 // Clean up fields inside SP_Device that were allocated
476 // by the plugin. `device` itself should not be deleted here.
477 void (*destroy_device)(const SP_Platform* platform, SP_Device* device);
478
479 // Callbacks for creating/destroying SP_DeviceFns.
480 void (*create_device_fns)(const SP_Platform* platform,
481 SE_CreateDeviceFnsParams* params,
482 TF_Status* status);
483
484 // Clean up fields inside SP_DeviceFns that were allocated
485 // by the plugin. `device_fns` itself should not be deleted here.
486 void (*destroy_device_fns)(const SP_Platform* platform,
487 SP_DeviceFns* device_fns);
488
489 // Callbacks for creating/destroying SP_StreamExecutor.
490 void (*create_stream_executor)(const SP_Platform* platform,
491 SE_CreateStreamExecutorParams* params,
492 TF_Status* status);
493 // Clean up fields inside SP_StreamExecutor that were allocated
494 // by the plugin. `stream_executor` itself should not be deleted here.
495 void (*destroy_stream_executor)(const SP_Platform* platform,
496 SP_StreamExecutor* stream_executor);
497
498 // Callbacks for creating/destroying SP_TimerFns.
499 void (*create_timer_fns)(const SP_Platform* platform, SP_TimerFns* timer,
500 TF_Status* status);
501
502 void (*destroy_timer_fns)(const SP_Platform* platform,
503 SP_TimerFns* timer_fns);
504} SP_PlatformFns;
505
506#define SP_PLATFORM_FNS_STRUCT_SIZE \
507 TF_OFFSET_OF_END(SP_PlatformFns, destroy_timer_fns)
508
509typedef struct SE_PlatformRegistrationParams {
510 size_t struct_size;
511 void* ext; // reserved for future use
512
513 // StreamExecutor C API version.
514 int32_t major_version;
515 int32_t minor_version;
516 int32_t patch_version;
517
518 SP_Platform* platform; // output, set by plugin
519 SP_PlatformFns* platform_fns; // output, set by plugin
520 // Clean up fields inside SP_Platform that were allocated
521 // by the plugin. `platform` itself should not be deleted here.
522 void (*destroy_platform)(SP_Platform* platform); // out, set by plugin
523 void (*destroy_platform_fns)(
524 SP_PlatformFns* platform_fns); // out, set by plugin
525} SE_PlatformRegistrationParams;
526
527#define SE_PLATFORM_REGISTRATION_PARAMS_STRUCT_SIZE \
528 TF_OFFSET_OF_END(SE_PlatformRegistrationParams, destroy_platform_fns)
529
530void SE_InitPlugin(SE_PlatformRegistrationParams* params, TF_Status* status);
531
532#ifdef __cplusplus
533} // extern "C"
534#endif
535
536#endif // TENSORFLOW_C_EXPERIMENTAL_STREAM_EXECUTOR_STREAM_EXECUTOR_H_
537