1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_CORE_TPU_TPU_OPS_C_API_H_
16#define TENSORFLOW_CORE_TPU_TPU_OPS_C_API_H_
17
18#include <stddef.h>
19
20#include <cstdint>
21
22#include "absl/types/optional.h"
23#include "tensorflow/c/tf_tensor.h"
24#include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h"
25#include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h"
26#include "tensorflow/core/tpu/libtftpu.h"
27
28typedef struct TpuSerializedProto TpuSerializedProto;
29
30namespace tensorflow {
31
32class TpuMeshCommonState;
33class TpuEmbeddingEngineState;
34class ResourceMgr;
35
36} // namespace tensorflow
37
38extern "C" {
39
40typedef struct XLA_TpuProgram XLA_TpuProgram;
41
42// Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj.
43enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding };
44
45struct TpuProgramFingerprint {
46 const char* bytes;
47 size_t size;
48};
49
50struct TpuExecutableSerializedProto {
51 const char* bytes;
52 size_t size;
53};
54
55struct CompilerMetadataSerializedProto {
56 const char* bytes;
57 size_t size;
58};
59
60struct HostComputeMetadataSerializedProto {
61 const char* bytes;
62 size_t size;
63};
64
65typedef struct XLA_TpuMeshState XLA_TpuMeshState;
66
67typedef struct XLA_TpuEmbeddingEngineState XLA_TpuEmbeddingEngineState;
68
69typedef struct TpuEmbedding_TensorBatchFixedState
70 TpuEmbedding_TensorBatchFixedState;
71
72typedef struct TpuProfiler TpuProfiler;
73
74typedef struct XLA_DeviceAssignment {
75 const char* bytes;
76 size_t size;
77} XLA_DeviceAssignment;
78
79// Property for creating compilation cache key.
80struct CompilationCacheKeyProperty {
81 const char* config_prefix;
82 const char* shapes_prefix;
83 const char* function_name;
84 uint64_t mlir_module_fingerprint;
85 const int32_t* device_ids;
86 size_t device_ids_size;
87 int32_t guaranteed_constants_size;
88 uint64_t function_library_fingerprint;
89 int32_t num_cores_per_replica;
90 int32_t num_replicas;
91 const XLA_TpuMeshState* mesh_state;
92 uint64_t session_id;
93 tensorflow::ResourceMgr* resource_mgr;
94};
95
96// Compilation cache key result returning both the key and a more verbose debug
97// version.
98struct CompilationCacheKeyResult {
99 const char* key;
100 const char* debug_string;
101};
102
103typedef struct XLA_TpuNodeContext XLA_TpuNodeContext;
104
105typedef struct TfTpu_OrdinalSelector TfTpuOrdinalSelector;
106
107struct TpuPartitionedCall_Params {
108 bool input_shape_opt;
109 bool group_tensors_for_packing;
110 int32_t minimum_input_tensors_packing;
111 int32_t minimum_output_tensors_packing;
112
113 // Whether to attempt to automatically shard inputs by adding an
114 // XlaSharding op after each input.
115 bool enable_auto_xla_input_sharding;
116
117 // The dimension of each input to shard if
118 // enable_auto_xla_input_sharding is set to true. Negative numbers are
119 // allowed and refers to dimensions starting from the end.
120 int32_t auto_xla_input_sharding_dim;
121
122 // If true, only create one variable on the TPU for each variable on the CPU.
123 bool enable_variable_deduplication;
124};
125
126// Compiles Mlir or TF function computation by lowering into HLO IR and returns
127// `count` number of TPU programs ready for execution.
128// The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates
129// `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is
130// responsible to deallocate both the `XLA_TpuProgram*[]` array and the
131// `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
132// API respectively.
133TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild(
134 TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state,
135 XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
136
137// Compiles a HLO IR and returns `count` number of TPU programs ready for
138// execution. The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and
139// creates `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller
140// is responsible to deallocate both the `XLA_TpuProgram*[]` array and the
141// `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free`
142// API respectively.
143TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild(
144 TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state,
145 XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status);
146
147// Creates a TPU profiler that is ready to start profiling.
148TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler,
149 TF_Status* status);
150// Destroys the given TPU profiler.
151TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler);
152// Starts profiling if not already started, returns an error otherwise.
153TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler,
154 TF_Status* status);
155// Stops profiling if not already stopped, returns an error otherwise.
156TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler,
157 TF_Status* status);
158// Serializes profiled data into `buffer` and returns the size of `buffer`. The
159// profile data held by the TPU driver will be cleared after retrieval.
160//
161// Step 1. Query the size of buffer required into `size_in_bytes`.
162//
163// size_t size_in_bytes;
164// TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes);
165//
166// Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`.
167// Subsequently,The TPU driver clears its copy of the profile data.
168//
169// uint8_t buffer = new uint8_t[size_in_bytes];
170// TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes);
171//
172// Step 3. Unpack the data into an XSpace.
173//
174// tensorflow::profiler::XSpace space;
175// space.ParseFromArray(buffer, size_in_bytes);
176//
177TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler,
178 TF_Status* status,
179 uint8_t* buffer,
180 size_t* size_in_bytes);
181
182// Creates a new TPU mesh state object.
183TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create();
184
185// Deletes the given TPU `mesh_state` object. Once deleted the object is
186// unusable.
187TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state);
188
189// Returns a pointer to an opaque mesh data structure used internally.
190TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState(
191 XLA_TpuMeshState* mesh_state);
192
193// Creates a new TPU embedding engine state object.
194TFTPU_CAPI_EXPORT XLA_TpuEmbeddingEngineState* TpuEmbeddingEngineState_Create();
195
196// Delete the given TPU embedding engine state object. Once deleted the object
197// is unusable.
198TFTPU_CAPI_EXPORT void TpuEmbeddingEngineState_Free(
199 XLA_TpuEmbeddingEngineState* engine_state);
200
201// Returns a pointer to an opaque embedding engine state data structure used
202// internally.
203TFTPU_CAPI_EXPORT void* TpuEmbeddingEngineState_GetState(
204 XLA_TpuEmbeddingEngineState* engine_state);
205
206TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Create(
207 TfTpuOrdinalSelector** ordinal_selector, int num_cores_per_replica);
208
209TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Destroy(
210 TfTpuOrdinalSelector* ordinal_selector);
211
212TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_GetOrdinal(
213 TfTpuOrdinalSelector* ordinal_selector, std::optional<uint64_t> key,
214 int64_t* req_id, int64_t* ordinal);
215
216TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_DequeueFromCoreSelector(
217 TfTpuOrdinalSelector* ordinal_selector, int32_t device_ordinal,
218 int64_t req_id);
219
220TFTPU_CAPI_EXPORT void TfTpu_GetTpuPartitionedCallParams(
221 TpuPartitionedCall_Params* params);
222
223typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params {
224 int32_t struct_size;
225 void* priv;
226
227 const XLA_TpuProgram* program;
228 SE_DeviceMemoryBase* arguments;
229 size_t arguments_len;
230 SE_DeviceMemoryBase* result;
231 bool has_cross_program_prefetch_addr;
232 SE_DeviceMemoryBase* cross_program_prefetch_addr;
233 int32_t rng_seed;
234 XLA_DeviceAssignment* device_assignment;
235 SE_Stream* stream;
236
237 TF_Status* status; // out
238} TpuExecutable_LoadProgramAndEnqueueToStream_Params;
239
240#define TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE \
241 (sizeof(struct TpuExecutable_LoadProgramAndEnqueueToStream_Params))
242
243TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream(
244 TpuExecutable_LoadProgramAndEnqueueToStream_Params* params);
245
246TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape(
247 XLA_Shape* host_shape, XLA_Shape* device_shape);
248TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape);
249TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape);
250TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape);
251
252typedef struct TpuExecute_RuntimeInputToPaddedData_Params {
253 int32_t struct_size;
254 void* priv;
255
256 uint32_t* runtime_input_ptr;
257 size_t runtime_input_size;
258 int8_t* padded_data_ptr;
259 size_t padded_data_size;
260 XLA_Shape* runtime_shape;
261 XLA_Shape* compile_time_shape;
262
263 TF_Status* status; // out
264} TpuExecute_RuntimeInputToPaddedData_Params;
265
266#define TpuExecute_RuntimeInputToPaddedData_Params_SIZE \
267 (sizeof(struct TpuExecute_RuntimeInputToPaddedData_Params))
268
269TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData(
270 TpuExecute_RuntimeInputToPaddedData_Params* params);
271
272typedef struct ConfigureDistributedTpuOp_DoWork_Params {
273 int32_t struct_size;
274 void* priv;
275
276 size_t num_cores_per_host_size;
277 const int32_t* num_cores_per_host;
278 size_t server_address_size;
279 const char* server_address;
280
281 size_t* host_config_output_size; // out
282 char** host_config_output; // out
283 TF_Status* status; // out
284} ConfigureDistributedTpuOp_DoWork_Params;
285
286#define ConfigureDistributedTpuOp_DoWork_Params_SIZE \
287 (sizeof(struct ConfigureDistributedTpuOp_DoWork_Params))
288
289TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork(
290 ConfigureDistributedTpuOp_DoWork_Params* params);
291
292typedef struct WaitForDistributedTpuOp_DoWork_Params {
293 int32_t struct_size;
294 void* priv;
295
296 size_t num_hosts;
297 size_t num_cores_per_host;
298 const int32_t** host_ordinal_to_global_core_id_map;
299 tensorflow::TpuMeshCommonState* tpu_mesh_common_state;
300
301 size_t* tpu_topology_output_size; // out
302 char** tpu_topology_output; // out
303 TF_Status* status; // out
304} WaitForDistributedTpuOp_DoWork_Params;
305
306#define WaitForDistributedTpuOp_DoWork_Params_SIZE \
307 (sizeof(struct WaitForDistributedTpuOp_DoWork_Params))
308
309TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork(
310 WaitForDistributedTpuOp_DoWork_Params* params);
311
312typedef struct InitializeHostForDistributedTpuOp_DoWork_Params {
313 int32_t struct_size;
314 void* priv;
315
316 size_t tpu_host_config_size;
317 const char* tpu_host_config;
318 bool enable_whole_mesh_compilations;
319 bool is_master_worker;
320
321 size_t* core_id_output_size; // out
322 int32_t** core_id_output; // out
323 TF_Status* status; // out
324} InitializeHostForDistributedTpuOp_DoWork_Params;
325
326#define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \
327 (sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params))
328
329TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork(
330 InitializeHostForDistributedTpuOp_DoWork_Params* params);
331
332TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork(
333 const size_t tpu_topology_size, const char* tpu_topology,
334 TF_Status* status);
335
336TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork(
337 int32_t* number_of_chips_output, TF_Status* status);
338
339TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output);
340TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output);
341
342TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState();
343
344TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus,
345 TF_Status* status);
346TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit,
347 TF_Status* status);
348TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes(
349 int64_t* cache_size_in_bytes);
350
351typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params {
352 int32_t struct_size;
353 void* priv;
354
355 size_t tpu_host_config_size;
356 const char* tpu_host_config;
357
358 size_t* server_address_output_size; // out
359 char** server_address_output; // out
360 TF_Status* status; // out
361} TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params;
362
363#define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \
364 (sizeof( \
365 struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params))
366
367TFTPU_CAPI_EXPORT
368void TpuConfigurationApi_CompilationCacheServerAddressFromConfig(
369 TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params);
370
371typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params {
372 int32_t struct_size;
373 void* priv;
374
375 size_t* server_address_output_size; // out
376 char** server_address_output; // out
377 int* port_output; // out
378 TF_Status* status; // out
379} TpuConfigurationApi_GetServerAddressAndPort_Params;
380
381#define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \
382 (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params))
383
384TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort(
385 TpuConfigurationApi_GetServerAddressAndPort_Params* params);
386
387// Creates a new TPU program.
388TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New();
389
390// Destroys the `tpu_program`.
391TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program);
392
393// Creates an array of `XLA_TpuProgram*`.
394TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count);
395
396// Destroys an array of `XLA_TpuProgram*`.
397TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]);
398
399// Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and
400// destroyed, it is in an unusable state.
401TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program,
402 TF_Status* status);
403
404// Gets TPU program size in bytes from the `tpu_program`.
405TFTPU_CAPI_EXPORT int64_t
406TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program);
407
408// Logs the summary of current memory state snapshot of the `tpu_program`.
409TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary(
410 const XLA_TpuProgram* tpu_program);
411
412// Gets TPU program executable info from the `tpu_program`.
413TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo(
414 const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info,
415 TF_Status* status);
416
417// Gets host transfer info proto.
418TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo(
419 const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info,
420 TF_Status* status);
421
422// Gets HLO metadata proto.
423TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata(
424 const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata,
425 TF_Status* status);
426
427// Gets may modify variables boolean value.
428TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables(
429 const XLA_TpuProgram* tpu_program, bool* may_modify_variables);
430
431// Checks if TPU program has sharding.
432TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding(
433 const XLA_TpuProgram* tpu_program);
434
435// Gets TPU program by sharding type. Return value is valid only when the
436// `status.status()` returns `OK`.
437TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram(
438 XLA_TpuProgram* tpu_program, TpuProgramShardingType type);
439
440// Gets TPU executable proto from a `tpu_program`.
441TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable(
442 const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable,
443 TF_Status* status);
444
445// Gets compilation metadata proto from a `tpu_program`.
446TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata(
447 const XLA_TpuProgram* tpu_program,
448 CompilerMetadataSerializedProto* compiler_metadata, TF_Status* status);
449
450// Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`.
451TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto(
452 TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program,
453 TF_Status* status);
454
455TFTPU_CAPI_EXPORT TpuProgramFingerprint
456TpuProgram_GetFingerprint(const XLA_TpuProgram* tpu_program);
457
458TFTPU_CAPI_EXPORT void TpuProgram_DestroyFingerprint(
459 TpuProgramFingerprint fingerprint);
460
461// Checks if whether a TPU compilation is enabled.
462TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled();
463
464// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
465// when cancellation is requested for an XLA compile op. Some tests require this
466// behavior to be disabled, and we test for this condition with the following
467// flag function.
468TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
469
470// Returns the number of available TPU core count.
471TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount(
472 const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type);
473
474// Recycle unused service port.
475TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port);
476
477// Creates a unique compilation cache `key` used for `put` and `get` operations.
478// Returned buffers are heap-allocated and must be owned.
479TFTPU_CAPI_EXPORT CompilationCacheKeyResult
480TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property);
481
482// Destroys the CompilationCacheKeyResult returned by calling the
483// `TpuCompile_CreateCompilationCacheKey` API.
484TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey(
485 CompilationCacheKeyResult result);
486
487// Creates a guaranteed const fingerprint. Guarantee const is normally used in
488// TPU inference to avoid re-copying unchanged variables onto the TPU device.
489// It promises the value is identical for every execution in the same session
490// even if the actual value changes in later executions.
491TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint(
492 uint64_t fingerprint, const char* data, size_t size);
493
494XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal,
495 TF_Status* status);
496void TpuNodeContext_Free(XLA_TpuNodeContext* node_context);
497
498void TpuNodeContext_StopChipHeartbeats(TF_Status* status);
499
500void TpuNodeContext_CloseTpuHost(TF_Status* status);
501
502void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status);
503
504bool TpuNodeContext_CompactionSupported(int device_ordinal);
505
506// Globally initialize the TPU system for inference.
507TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer();
508
509typedef struct TpuEmbeddingEngine_ExecutePartitioner_Params {
510 int32_t struct_size;
511 void* priv;
512 TpuSerializedProto tpu_embedding_config;
513
514 // out
515 size_t* common_config_size;
516 char** common_config;
517 TF_Status* status;
518} TpuEmbeddingEngine_ExecutePartitioner_Params;
519
520TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ExecutePartitioner(
521 TpuEmbeddingEngine_ExecutePartitioner_Params* params);
522
523typedef struct TpuEmbeddingEngine_ConfigureMemory_Params {
524 int32_t struct_size;
525 void* priv;
526
527 int num_inputs;
528 size_t common_config_size;
529 const char* common_config;
530
531 // out
532 size_t* memory_config_size;
533 char** memory_config;
534 TF_Status* status;
535} TpuEmbeddingEngine_ConfigureMemory_Params;
536
537TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConfigureMemory(
538 TpuEmbeddingEngine_ConfigureMemory_Params* params);
539
540typedef struct TpuEmbeddingEngine_CollateMemory_Params {
541 int32_t struct_size;
542 void* priv;
543
544 size_t memory_configs_size;
545 const TpuSerializedProto* memory_configs;
546
547 // out
548 size_t* merged_memory_config_size;
549 char** merged_memory_config;
550 TF_Status* status;
551} TpuEmbeddingEngine_CollateMemory_Params;
552
553TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_CollateMemory(
554 TpuEmbeddingEngine_CollateMemory_Params* params);
555
556typedef struct TpuEmbeddingEngine_ConfigureHost_Params {
557 int32_t struct_size;
558 void* priv;
559
560 int num_inputs;
561 size_t common_config_size;
562 const char* common_config;
563 size_t memory_config_size;
564 const char* memory_config;
565 TpuSerializedProto tpu_embedding_config;
566
567 // out
568 size_t* network_config_size;
569 char** network_config;
570 TF_Status* status;
571} TpuEmbeddingEngine_ConfigureHost_Params;
572
573TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConfigureHost(
574 TpuEmbeddingEngine_ConfigureHost_Params* params);
575
576typedef struct TpuEmbeddingEngine_ConnectHosts_Params {
577 int32_t struct_size;
578 void* priv;
579
580 size_t network_configs_size;
581 const TpuSerializedProto* network_configs;
582
583 // out
584 TF_Status* status;
585} TpuEmbeddingEngine_ConnectHosts_Params;
586
587TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConnectHosts(
588 TpuEmbeddingEngine_ConnectHosts_Params* params);
589
590typedef struct TpuEmbeddingEngine_Finalize_Params {
591 int32_t struct_size;
592 void* priv;
593 const XLA_TpuMeshState* tpu_mesh_state;
594
595 size_t common_config_size;
596 const char* common_config;
597 size_t memory_config_size;
598 const char* memory_config;
599
600 // out
601 TF_Status* status;
602} TpuEmbeddingEngine_Finalize_Params;
603
604TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_Finalize(
605 TpuEmbeddingEngine_Finalize_Params* params);
606
607typedef struct TpuEmbeddingEngine_IsInitialized_Params {
608 int32_t struct_size;
609 void* priv;
610
611 size_t config_string_size;
612 const char* config_string;
613
614 // out
615 bool* is_tpu_embedding_initialized;
616 TF_Status* status;
617} TpuEmbeddingEngine_IsInitialized_Params;
618
619TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_IsInitialized(
620 TpuEmbeddingEngine_IsInitialized_Params* params);
621
622TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_WriteParameters(
623 TpuEmbeddingEngineParameters* params, TF_Status* status);
624
625TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ReadParameters(
626 TpuEmbeddingEngineParameters* params, TF_Status* status);
627
628typedef struct TpuEmbeddingEngine_EnqueueTensorBatch_Params {
629 int32_t struct_size;
630 void* priv;
631
632 int32_t mode;
633 int32_t local_device_ordinal;
634 TpuEmbedding_TensorBatchFixedState* fixed_state;
635
636 TF_Tensor** sample_indices_tensors;
637 size_t sample_indices_tensors_size;
638 TF_Tensor** embedding_indices_tensors;
639 size_t embedding_indices_tensors_size;
640 TF_Tensor** aggregation_weights_tensors;
641 size_t aggregation_weights_tensors_size;
642 TF_Status* status;
643} TpuEmbeddingEngine_EnqueueTensorBatch_Params;
644
645TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_EnqueueTensorBatch(
646 TpuEmbeddingEngine_EnqueueTensorBatch_Params* params);
647
648typedef struct TpuEmbedding_TensorBatchFixedState_Create_Params {
649 int32_t struct_size;
650 void* priv;
651
652 size_t combiners_size;
653 char** combiners;
654
655 // out
656 TF_Status* status;
657} TpuEmbedding_TensorBatchFixedState_Create_Params;
658
659TFTPU_CAPI_EXPORT TpuEmbedding_TensorBatchFixedState*
660TpuEmbeddingTensorBatchFixedState_Create(
661 TpuEmbedding_TensorBatchFixedState_Create_Params* params);
662TFTPU_CAPI_EXPORT void TpuEmbeddingTensorBatchFixedState_Destroy(
663 TpuEmbedding_TensorBatchFixedState* fixed_state);
664
665typedef struct TpuEmbeddingEngine_RecvActivationsComputation_Params {
666 int32_t struct_size;
667 void* priv;
668
669 size_t config_string_size;
670 XLA_Shape* deduplication_data_shape;
671 const XLA_TpuMeshState* tpu_mesh_state;
672
673 // out
674 TpuSerializedProto* xla_computation;
675 TF_Status* status;
676} TpuEmbeddingEngine_RecvActivationsComputation_Params;
677
678TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_RecvActivationsComputation(
679 TpuEmbeddingEngine_RecvActivationsComputation_Params* params);
680
681typedef struct
682 TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params {
683 int32_t struct_size;
684 void* priv;
685
686 const XLA_TpuMeshState* tpu_mesh_state;
687 // out
688 TpuSerializedProto* xla_computation;
689 TF_Status* status;
690} TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params;
691
692TFTPU_CAPI_EXPORT void
693TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation(
694 TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params*
695 params);
696
697typedef struct TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params {
698 int32_t struct_size;
699 void* priv;
700
701 int32_t num_inputs;
702 const XLA_TpuMeshState* tpu_mesh_state;
703 XLA_Shape* learning_rate_tuple_shape;
704 XLA_Shape* deduplication_data_shape;
705 XLA_Shape* gradient_tuple_shape;
706 // out
707 TpuSerializedProto* xla_computation;
708 TF_Status* status;
709} TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params;
710
711TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation(
712 TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params* params);
713
714struct TfTpu_OpsApiFn {
715 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild);
716 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild);
717
718 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create);
719 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free);
720 TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState);
721
722 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_Create);
723 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_Free);
724 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_GetState);
725
726 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create);
727 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy);
728 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start);
729 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop);
730 TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData);
731
732 TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream);
733 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape);
734 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize);
735 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact);
736 TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw);
737
738 TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData);
739
740 TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork);
741 TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork);
742 TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork);
743 TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork);
744 TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork);
745 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray);
746 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array);
747 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState);
748 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost);
749 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit);
750 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes);
751 TFTPU_ADD_FN_IN_STRUCT(
752 TpuConfigurationApi_CompilationCacheServerAddressFromConfig);
753 TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort);
754
755 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New);
756 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free);
757 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray);
758 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray);
759 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy);
760 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize);
761 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary);
762 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo);
763 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo);
764 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata);
765 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables);
766 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding);
767 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram);
768 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable);
769 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata);
770 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto);
771 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetFingerprint);
772 TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DestroyFingerprint);
773
774 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled);
775 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation);
776 TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount);
777 TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort);
778 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey);
779 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey);
780 TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint);
781
782 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create);
783 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free);
784 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats);
785 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost);
786 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize);
787 TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported);
788
789 TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer);
790
791 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Create);
792 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Destroy);
793 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_GetOrdinal);
794 TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_DequeueFromCoreSelector);
795 TFTPU_ADD_FN_IN_STRUCT(TfTpu_GetTpuPartitionedCallParams);
796
797 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ExecutePartitioner);
798 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConfigureMemory);
799 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_CollateMemory);
800 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConfigureHost);
801 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConnectHosts);
802 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_Finalize);
803 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_IsInitialized);
804 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_WriteParameters);
805 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ReadParameters);
806 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Create);
807 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Destroy);
808 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_EnqueueTensorBatch);
809 TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_RecvActivationsComputation);
810 TFTPU_ADD_FN_IN_STRUCT(
811 TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation);
812 TFTPU_ADD_FN_IN_STRUCT(
813 TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation);
814};
815
816} // extern "C"
817
818#endif // TENSORFLOW_CORE_TPU_TPU_OPS_C_API_H_
819