1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | #ifndef TENSORFLOW_CORE_TPU_TPU_OPS_C_API_H_ |
16 | #define TENSORFLOW_CORE_TPU_TPU_OPS_C_API_H_ |
17 | |
18 | #include <stddef.h> |
19 | |
20 | #include <cstdint> |
21 | |
22 | #include "absl/types/optional.h" |
23 | #include "tensorflow/c/tf_tensor.h" |
24 | #include "tensorflow/compiler/xla/stream_executor/tpu/c_api_decl.h" |
25 | #include "tensorflow/compiler/xla/stream_executor/tpu/proto_helper.h" |
26 | #include "tensorflow/core/tpu/libtftpu.h" |
27 | |
28 | typedef struct TpuSerializedProto TpuSerializedProto; |
29 | |
30 | namespace tensorflow { |
31 | |
32 | class TpuMeshCommonState; |
33 | class TpuEmbeddingEngineState; |
34 | class ResourceMgr; |
35 | |
36 | } // namespace tensorflow |
37 | |
38 | extern "C" { |
39 | |
40 | typedef struct XLA_TpuProgram XLA_TpuProgram; |
41 | |
42 | // Enum for choosing sharding/unsharding program from a `XLA_TpuProgram` obj. |
43 | enum TpuProgramShardingType { kInvalid = 0, kMain, kSharding, kUnsharding }; |
44 | |
45 | struct TpuProgramFingerprint { |
46 | const char* bytes; |
47 | size_t size; |
48 | }; |
49 | |
50 | struct TpuExecutableSerializedProto { |
51 | const char* bytes; |
52 | size_t size; |
53 | }; |
54 | |
55 | struct CompilerMetadataSerializedProto { |
56 | const char* bytes; |
57 | size_t size; |
58 | }; |
59 | |
60 | struct HostComputeMetadataSerializedProto { |
61 | const char* bytes; |
62 | size_t size; |
63 | }; |
64 | |
65 | typedef struct XLA_TpuMeshState XLA_TpuMeshState; |
66 | |
67 | typedef struct XLA_TpuEmbeddingEngineState XLA_TpuEmbeddingEngineState; |
68 | |
69 | typedef struct TpuEmbedding_TensorBatchFixedState |
70 | TpuEmbedding_TensorBatchFixedState; |
71 | |
72 | typedef struct TpuProfiler TpuProfiler; |
73 | |
74 | typedef struct XLA_DeviceAssignment { |
75 | const char* bytes; |
76 | size_t size; |
77 | } XLA_DeviceAssignment; |
78 | |
79 | // Property for creating compilation cache key. |
80 | struct CompilationCacheKeyProperty { |
81 | const char* config_prefix; |
82 | const char* shapes_prefix; |
83 | const char* function_name; |
84 | uint64_t mlir_module_fingerprint; |
85 | const int32_t* device_ids; |
86 | size_t device_ids_size; |
87 | int32_t guaranteed_constants_size; |
88 | uint64_t function_library_fingerprint; |
89 | int32_t num_cores_per_replica; |
90 | int32_t num_replicas; |
91 | const XLA_TpuMeshState* mesh_state; |
92 | uint64_t session_id; |
93 | tensorflow::ResourceMgr* resource_mgr; |
94 | }; |
95 | |
96 | // Compilation cache key result returning both the key and a more verbose debug |
97 | // version. |
98 | struct CompilationCacheKeyResult { |
99 | const char* key; |
100 | const char* debug_string; |
101 | }; |
102 | |
103 | typedef struct XLA_TpuNodeContext XLA_TpuNodeContext; |
104 | |
105 | typedef struct TfTpu_OrdinalSelector TfTpuOrdinalSelector; |
106 | |
107 | struct TpuPartitionedCall_Params { |
108 | bool input_shape_opt; |
109 | bool group_tensors_for_packing; |
110 | int32_t minimum_input_tensors_packing; |
111 | int32_t minimum_output_tensors_packing; |
112 | |
113 | // Whether to attempt to automatically shard inputs by adding an |
114 | // XlaSharding op after each input. |
115 | bool enable_auto_xla_input_sharding; |
116 | |
117 | // The dimension of each input to shard if |
118 | // enable_auto_xla_input_sharding is set to true. Negative numbers are |
119 | // allowed and refers to dimensions starting from the end. |
120 | int32_t auto_xla_input_sharding_dim; |
121 | |
122 | // If true, only create one variable on the TPU for each variable on the CPU. |
123 | bool enable_variable_deduplication; |
124 | }; |
125 | |
126 | // Compiles Mlir or TF function computation by lowering into HLO IR and returns |
127 | // `count` number of TPU programs ready for execution. |
128 | // The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and creates |
129 | // `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller is |
130 | // responsible to deallocate both the `XLA_TpuProgram*[]` array and the |
131 | // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` |
132 | // API respectively. |
133 | TFTPU_CAPI_EXPORT void TpuCompile_CompileAndBuild( |
134 | TpuSerializedProto compilation_request, const XLA_TpuMeshState* mesh_state, |
135 | XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status); |
136 | |
137 | // Compiles a HLO IR and returns `count` number of TPU programs ready for |
138 | // execution. The API allocates the `XLA_TpuProgram*[]` array `tpu_programs` and |
139 | // creates `XLA_TpuProgram` object(s) using the `TpuProgram_New` API. The caller |
140 | // is responsible to deallocate both the `XLA_TpuProgram*[]` array and the |
141 | // `XLA_TpuProgram` object(s) using `TpuProgram_FreeArray` and `TpuProgram_Free` |
142 | // API respectively. |
143 | TFTPU_CAPI_EXPORT void TpuCompile_XrtCompileAndBuild( |
144 | TpuSerializedProto xrt_computation, const XLA_TpuMeshState* mesh_state, |
145 | XLA_TpuProgram** tpu_programs[], size_t* count, TF_Status* status); |
146 | |
147 | // Creates a TPU profiler that is ready to start profiling. |
148 | TFTPU_CAPI_EXPORT void TpuProfiler_Create(TpuProfiler** tpu_profiler, |
149 | TF_Status* status); |
150 | // Destroys the given TPU profiler. |
151 | TFTPU_CAPI_EXPORT void TpuProfiler_Destroy(TpuProfiler* tpu_profiler); |
152 | // Starts profiling if not already started, returns an error otherwise. |
153 | TFTPU_CAPI_EXPORT void TpuProfiler_Start(TpuProfiler* tpu_profiler, |
154 | TF_Status* status); |
155 | // Stops profiling if not already stopped, returns an error otherwise. |
156 | TFTPU_CAPI_EXPORT void TpuProfiler_Stop(TpuProfiler* tpu_profiler, |
157 | TF_Status* status); |
158 | // Serializes profiled data into `buffer` and returns the size of `buffer`. The |
159 | // profile data held by the TPU driver will be cleared after retrieval. |
160 | // |
161 | // Step 1. Query the size of buffer required into `size_in_bytes`. |
162 | // |
163 | // size_t size_in_bytes; |
164 | // TpuProfiler_CollectData(profiler, status, nullptr, &size_in_bytes); |
165 | // |
166 | // Step 2. Retrieve the data into a `buffer` of size `size_in_bytes`. |
167 | // Subsequently,The TPU driver clears its copy of the profile data. |
168 | // |
169 | // uint8_t buffer = new uint8_t[size_in_bytes]; |
170 | // TpuProfiler_CollectData(profiler, status, buffer, size_in_bytes); |
171 | // |
172 | // Step 3. Unpack the data into an XSpace. |
173 | // |
174 | // tensorflow::profiler::XSpace space; |
175 | // space.ParseFromArray(buffer, size_in_bytes); |
176 | // |
177 | TFTPU_CAPI_EXPORT void TpuProfiler_CollectData(TpuProfiler* tpu_profiler, |
178 | TF_Status* status, |
179 | uint8_t* buffer, |
180 | size_t* size_in_bytes); |
181 | |
182 | // Creates a new TPU mesh state object. |
183 | TFTPU_CAPI_EXPORT XLA_TpuMeshState* TpuMeshState_Create(); |
184 | |
185 | // Deletes the given TPU `mesh_state` object. Once deleted the object is |
186 | // unusable. |
187 | TFTPU_CAPI_EXPORT void TpuMeshState_Free(XLA_TpuMeshState* mesh_state); |
188 | |
189 | // Returns a pointer to an opaque mesh data structure used internally. |
190 | TFTPU_CAPI_EXPORT void* TpuMeshState_MeshCommonState( |
191 | XLA_TpuMeshState* mesh_state); |
192 | |
193 | // Creates a new TPU embedding engine state object. |
194 | TFTPU_CAPI_EXPORT XLA_TpuEmbeddingEngineState* TpuEmbeddingEngineState_Create(); |
195 | |
196 | // Delete the given TPU embedding engine state object. Once deleted the object |
197 | // is unusable. |
198 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngineState_Free( |
199 | XLA_TpuEmbeddingEngineState* engine_state); |
200 | |
201 | // Returns a pointer to an opaque embedding engine state data structure used |
202 | // internally. |
203 | TFTPU_CAPI_EXPORT void* TpuEmbeddingEngineState_GetState( |
204 | XLA_TpuEmbeddingEngineState* engine_state); |
205 | |
206 | TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Create( |
207 | TfTpuOrdinalSelector** ordinal_selector, int num_cores_per_replica); |
208 | |
209 | TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_Destroy( |
210 | TfTpuOrdinalSelector* ordinal_selector); |
211 | |
212 | TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_GetOrdinal( |
213 | TfTpuOrdinalSelector* ordinal_selector, std::optional<uint64_t> key, |
214 | int64_t* req_id, int64_t* ordinal); |
215 | |
216 | TFTPU_CAPI_EXPORT void TfTpuOrdinalSelector_DequeueFromCoreSelector( |
217 | TfTpuOrdinalSelector* ordinal_selector, int32_t device_ordinal, |
218 | int64_t req_id); |
219 | |
220 | TFTPU_CAPI_EXPORT void TfTpu_GetTpuPartitionedCallParams( |
221 | TpuPartitionedCall_Params* params); |
222 | |
223 | typedef struct TpuExecutable_LoadProgramAndEnqueueToStream_Params { |
224 | int32_t struct_size; |
225 | void* priv; |
226 | |
227 | const XLA_TpuProgram* program; |
228 | SE_DeviceMemoryBase* arguments; |
229 | size_t arguments_len; |
230 | SE_DeviceMemoryBase* result; |
231 | bool has_cross_program_prefetch_addr; |
232 | SE_DeviceMemoryBase* cross_program_prefetch_addr; |
233 | int32_t rng_seed; |
234 | XLA_DeviceAssignment* device_assignment; |
235 | SE_Stream* stream; |
236 | |
237 | TF_Status* status; // out |
238 | } TpuExecutable_LoadProgramAndEnqueueToStream_Params; |
239 | |
240 | #define TpuExecutable_LoadProgramAndEnqueueToStream_Params_SIZE \ |
241 | (sizeof(struct TpuExecutable_LoadProgramAndEnqueueToStream_Params)) |
242 | |
243 | TFTPU_CAPI_EXPORT void TpuExecutable_LoadProgramAndEnqueueToStream( |
244 | TpuExecutable_LoadProgramAndEnqueueToStream_Params* params); |
245 | |
246 | TFTPU_CAPI_EXPORT void HardwareLayout_HostShapeToDeviceShape( |
247 | XLA_Shape* host_shape, XLA_Shape* device_shape); |
248 | TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSize(XLA_Shape* shape); |
249 | TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompact(XLA_Shape* shape); |
250 | TFTPU_CAPI_EXPORT int64_t HardwareLayout_ShapeSizeCompactRaw(XLA_Shape* shape); |
251 | |
252 | typedef struct TpuExecute_RuntimeInputToPaddedData_Params { |
253 | int32_t struct_size; |
254 | void* priv; |
255 | |
256 | uint32_t* runtime_input_ptr; |
257 | size_t runtime_input_size; |
258 | int8_t* padded_data_ptr; |
259 | size_t padded_data_size; |
260 | XLA_Shape* runtime_shape; |
261 | XLA_Shape* compile_time_shape; |
262 | |
263 | TF_Status* status; // out |
264 | } TpuExecute_RuntimeInputToPaddedData_Params; |
265 | |
266 | #define TpuExecute_RuntimeInputToPaddedData_Params_SIZE \ |
267 | (sizeof(struct TpuExecute_RuntimeInputToPaddedData_Params)) |
268 | |
269 | TFTPU_CAPI_EXPORT void TpuExecute_RuntimeInputToPaddedData( |
270 | TpuExecute_RuntimeInputToPaddedData_Params* params); |
271 | |
272 | typedef struct ConfigureDistributedTpuOp_DoWork_Params { |
273 | int32_t struct_size; |
274 | void* priv; |
275 | |
276 | size_t num_cores_per_host_size; |
277 | const int32_t* num_cores_per_host; |
278 | size_t server_address_size; |
279 | const char* server_address; |
280 | |
281 | size_t* host_config_output_size; // out |
282 | char** host_config_output; // out |
283 | TF_Status* status; // out |
284 | } ConfigureDistributedTpuOp_DoWork_Params; |
285 | |
286 | #define ConfigureDistributedTpuOp_DoWork_Params_SIZE \ |
287 | (sizeof(struct ConfigureDistributedTpuOp_DoWork_Params)) |
288 | |
289 | TFTPU_CAPI_EXPORT void ConfigureDistributedTpuOp_DoWork( |
290 | ConfigureDistributedTpuOp_DoWork_Params* params); |
291 | |
292 | typedef struct WaitForDistributedTpuOp_DoWork_Params { |
293 | int32_t struct_size; |
294 | void* priv; |
295 | |
296 | size_t num_hosts; |
297 | size_t num_cores_per_host; |
298 | const int32_t** host_ordinal_to_global_core_id_map; |
299 | tensorflow::TpuMeshCommonState* tpu_mesh_common_state; |
300 | |
301 | size_t* tpu_topology_output_size; // out |
302 | char** tpu_topology_output; // out |
303 | TF_Status* status; // out |
304 | } WaitForDistributedTpuOp_DoWork_Params; |
305 | |
306 | #define WaitForDistributedTpuOp_DoWork_Params_SIZE \ |
307 | (sizeof(struct WaitForDistributedTpuOp_DoWork_Params)) |
308 | |
309 | TFTPU_CAPI_EXPORT void WaitForDistributedTpuOp_DoWork( |
310 | WaitForDistributedTpuOp_DoWork_Params* params); |
311 | |
312 | typedef struct InitializeHostForDistributedTpuOp_DoWork_Params { |
313 | int32_t struct_size; |
314 | void* priv; |
315 | |
316 | size_t tpu_host_config_size; |
317 | const char* tpu_host_config; |
318 | bool enable_whole_mesh_compilations; |
319 | bool is_master_worker; |
320 | |
321 | size_t* core_id_output_size; // out |
322 | int32_t** core_id_output; // out |
323 | TF_Status* status; // out |
324 | } InitializeHostForDistributedTpuOp_DoWork_Params; |
325 | |
326 | #define InitializeHostForDistributedTpuOp_DoWork_Params_SIZE \ |
327 | (sizeof(struct InitializeHostForDistributedTpuOp_DoWork_Params)) |
328 | |
329 | TFTPU_CAPI_EXPORT void InitializeHostForDistributedTpuOp_DoWork( |
330 | InitializeHostForDistributedTpuOp_DoWork_Params* params); |
331 | |
332 | TFTPU_CAPI_EXPORT void SetGlobalTPUArrayOp_DoWork( |
333 | const size_t tpu_topology_size, const char* tpu_topology, |
334 | TF_Status* status); |
335 | |
336 | TFTPU_CAPI_EXPORT void DisconnectDistributedTpuChipsOp_DoWork( |
337 | int32_t* number_of_chips_output, TF_Status* status); |
338 | |
339 | TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeCharArray(char* output); |
340 | TFTPU_CAPI_EXPORT void TpuConfigurationApi_FreeInt32Array(int32_t* output); |
341 | |
342 | TFTPU_CAPI_EXPORT bool TpuConfigurationApi_HasTPUPodState(); |
343 | |
344 | TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpusPerHost(int32_t* tpus, |
345 | TF_Status* status); |
346 | TFTPU_CAPI_EXPORT void TpuConfigurationApi_TpuMemoryLimit(int64_t* memory_limit, |
347 | TF_Status* status); |
348 | TFTPU_CAPI_EXPORT void TpuConfigurationApi_RemoteCompilationCacheSizeInBytes( |
349 | int64_t* cache_size_in_bytes); |
350 | |
351 | typedef struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params { |
352 | int32_t struct_size; |
353 | void* priv; |
354 | |
355 | size_t tpu_host_config_size; |
356 | const char* tpu_host_config; |
357 | |
358 | size_t* server_address_output_size; // out |
359 | char** server_address_output; // out |
360 | TF_Status* status; // out |
361 | } TpuConfigurationApi_CompilationCacheServerAddressFromConfig_Params; |
362 | |
363 | #define TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params_SIZE \ |
364 | (sizeof( \ |
365 | struct TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params)) |
366 | |
367 | TFTPU_CAPI_EXPORT |
368 | void TpuConfigurationApi_CompilationCacheServerAddressFromConfig( |
369 | TpuConfigurationApi_CompilationCacheServerAddrFromConfig_Params* params); |
370 | |
371 | typedef struct TpuConfigurationApi_GetServerAddressAndPort_Params { |
372 | int32_t struct_size; |
373 | void* priv; |
374 | |
375 | size_t* server_address_output_size; // out |
376 | char** server_address_output; // out |
377 | int* port_output; // out |
378 | TF_Status* status; // out |
379 | } TpuConfigurationApi_GetServerAddressAndPort_Params; |
380 | |
381 | #define TpuConfigurationApi_GetServerAddressAndPort_Params_SIZE \ |
382 | (sizeof(struct TpuConfigurationApi_GetServerAddressAndPort_Params)) |
383 | |
384 | TFTPU_CAPI_EXPORT void TpuConfigurationApi_GetServerAddressAndPort( |
385 | TpuConfigurationApi_GetServerAddressAndPort_Params* params); |
386 | |
387 | // Creates a new TPU program. |
388 | TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_New(); |
389 | |
390 | // Destroys the `tpu_program`. |
391 | TFTPU_CAPI_EXPORT void TpuProgram_Free(XLA_TpuProgram* tpu_program); |
392 | |
393 | // Creates an array of `XLA_TpuProgram*`. |
394 | TFTPU_CAPI_EXPORT XLA_TpuProgram** TpuProgram_NewArray(size_t count); |
395 | |
396 | // Destroys an array of `XLA_TpuProgram*`. |
397 | TFTPU_CAPI_EXPORT void TpuProgram_FreeArray(XLA_TpuProgram* tpu_program[]); |
398 | |
399 | // Unloads and destroys the `tpu_program`. Once the TPU program is unloaded and |
400 | // destroyed, it is in an unusable state. |
401 | TFTPU_CAPI_EXPORT void TpuProgram_UnloadAndDestroy(XLA_TpuProgram* tpu_program, |
402 | TF_Status* status); |
403 | |
404 | // Gets TPU program size in bytes from the `tpu_program`. |
405 | TFTPU_CAPI_EXPORT int64_t |
406 | TpuProgram_GetProgramSize(const XLA_TpuProgram* tpu_program); |
407 | |
408 | // Logs the summary of current memory state snapshot of the `tpu_program`. |
409 | TFTPU_CAPI_EXPORT bool TpuProgram_LogProgramMemorySummary( |
410 | const XLA_TpuProgram* tpu_program); |
411 | |
412 | // Gets TPU program executable info from the `tpu_program`. |
413 | TFTPU_CAPI_EXPORT void TpuProgram_GetExecutableInfo( |
414 | const XLA_TpuProgram* tpu_program, TpuSerializedProto* executable_info, |
415 | TF_Status* status); |
416 | |
417 | // Gets host transfer info proto. |
418 | TFTPU_CAPI_EXPORT void TpuProgram_GetHostTransferInfo( |
419 | const XLA_TpuProgram* tpu_program, TpuSerializedProto* host_transfer_info, |
420 | TF_Status* status); |
421 | |
422 | // Gets HLO metadata proto. |
423 | TFTPU_CAPI_EXPORT void TpuProgram_GetHloMetadata( |
424 | const XLA_TpuProgram* tpu_program, TpuSerializedProto* hlo_metadata, |
425 | TF_Status* status); |
426 | |
427 | // Gets may modify variables boolean value. |
428 | TFTPU_CAPI_EXPORT void TpuProgram_GetMayModifyVariables( |
429 | const XLA_TpuProgram* tpu_program, bool* may_modify_variables); |
430 | |
431 | // Checks if TPU program has sharding. |
432 | TFTPU_CAPI_EXPORT bool TpuProgram_HasSharding( |
433 | const XLA_TpuProgram* tpu_program); |
434 | |
435 | // Gets TPU program by sharding type. Return value is valid only when the |
436 | // `status.status()` returns `OK`. |
437 | TFTPU_CAPI_EXPORT XLA_TpuProgram* TpuProgram_GetTpuProgram( |
438 | XLA_TpuProgram* tpu_program, TpuProgramShardingType type); |
439 | |
440 | // Gets TPU executable proto from a `tpu_program`. |
441 | TFTPU_CAPI_EXPORT void TpuProgram_SerializeTpuExecutable( |
442 | const XLA_TpuProgram* tpu_program, TpuExecutableSerializedProto* executable, |
443 | TF_Status* status); |
444 | |
445 | // Gets compilation metadata proto from a `tpu_program`. |
446 | TFTPU_CAPI_EXPORT void TpuProgram_SerializeCompilerMetadata( |
447 | const XLA_TpuProgram* tpu_program, |
448 | CompilerMetadataSerializedProto* compiler_metadata, TF_Status* status); |
449 | |
450 | // Deserializes the `GetTpuProgramResponse` proto into an `XLA_TpuProgram`. |
451 | TFTPU_CAPI_EXPORT void TpuProgram_DeserializeFromGetTpuProgramResponseProto( |
452 | TpuSerializedProto get_tpu_program_response, XLA_TpuProgram* tpu_program, |
453 | TF_Status* status); |
454 | |
455 | TFTPU_CAPI_EXPORT TpuProgramFingerprint |
456 | TpuProgram_GetFingerprint(const XLA_TpuProgram* tpu_program); |
457 | |
458 | TFTPU_CAPI_EXPORT void TpuProgram_DestroyFingerprint( |
459 | TpuProgramFingerprint fingerprint); |
460 | |
461 | // Checks if whether a TPU compilation is enabled. |
462 | TFTPU_CAPI_EXPORT bool TpuCompile_IsTpuCompilationEnabled(); |
463 | |
464 | // XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit |
465 | // when cancellation is requested for an XLA compile op. Some tests require this |
466 | // behavior to be disabled, and we test for this condition with the following |
467 | // flag function. |
468 | TFTPU_CAPI_EXPORT bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation(); |
469 | |
470 | // Returns the number of available TPU core count. |
471 | TFTPU_CAPI_EXPORT int TpuTopology_AvailableCoreCount( |
472 | const XLA_TpuMeshState* mesh_state, TpuCoreTypeEnum tpu_core_type); |
473 | |
474 | // Recycle unused service port. |
475 | TFTPU_CAPI_EXPORT void TpuNetUtil_RecycleUnusedPort(int port); |
476 | |
477 | // Creates a unique compilation cache `key` used for `put` and `get` operations. |
478 | // Returned buffers are heap-allocated and must be owned. |
479 | TFTPU_CAPI_EXPORT CompilationCacheKeyResult |
480 | TpuCompile_CreateCompilationCacheKey(CompilationCacheKeyProperty property); |
481 | |
482 | // Destroys the CompilationCacheKeyResult returned by calling the |
483 | // `TpuCompile_CreateCompilationCacheKey` API. |
484 | TFTPU_CAPI_EXPORT void TpuCompile_DestroyCompilationCacheKey( |
485 | CompilationCacheKeyResult result); |
486 | |
487 | // Creates a guaranteed const fingerprint. Guarantee const is normally used in |
488 | // TPU inference to avoid re-copying unchanged variables onto the TPU device. |
489 | // It promises the value is identical for every execution in the same session |
490 | // even if the actual value changes in later executions. |
491 | TFTPU_CAPI_EXPORT uint64_t TpuCompile_CreateGuaranteedConstFingerprint( |
492 | uint64_t fingerprint, const char* data, size_t size); |
493 | |
494 | XLA_TpuNodeContext* TpuNodeContext_Create(int device_ordinal, |
495 | TF_Status* status); |
496 | void TpuNodeContext_Free(XLA_TpuNodeContext* node_context); |
497 | |
498 | void TpuNodeContext_StopChipHeartbeats(TF_Status* status); |
499 | |
500 | void TpuNodeContext_CloseTpuHost(TF_Status* status); |
501 | |
502 | void TpuNodeContext_Initialize(int device_ordinal, TF_Status* status); |
503 | |
504 | bool TpuNodeContext_CompactionSupported(int device_ordinal); |
505 | |
506 | // Globally initialize the TPU system for inference. |
507 | TFTPU_CAPI_EXPORT void TfTpu_InitializeTpuModelServer(); |
508 | |
509 | typedef struct TpuEmbeddingEngine_ExecutePartitioner_Params { |
510 | int32_t struct_size; |
511 | void* priv; |
512 | TpuSerializedProto tpu_embedding_config; |
513 | |
514 | // out |
515 | size_t* common_config_size; |
516 | char** common_config; |
517 | TF_Status* status; |
518 | } TpuEmbeddingEngine_ExecutePartitioner_Params; |
519 | |
520 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ExecutePartitioner( |
521 | TpuEmbeddingEngine_ExecutePartitioner_Params* params); |
522 | |
523 | typedef struct TpuEmbeddingEngine_ConfigureMemory_Params { |
524 | int32_t struct_size; |
525 | void* priv; |
526 | |
527 | int num_inputs; |
528 | size_t common_config_size; |
529 | const char* common_config; |
530 | |
531 | // out |
532 | size_t* memory_config_size; |
533 | char** memory_config; |
534 | TF_Status* status; |
535 | } TpuEmbeddingEngine_ConfigureMemory_Params; |
536 | |
537 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConfigureMemory( |
538 | TpuEmbeddingEngine_ConfigureMemory_Params* params); |
539 | |
540 | typedef struct TpuEmbeddingEngine_CollateMemory_Params { |
541 | int32_t struct_size; |
542 | void* priv; |
543 | |
544 | size_t memory_configs_size; |
545 | const TpuSerializedProto* memory_configs; |
546 | |
547 | // out |
548 | size_t* merged_memory_config_size; |
549 | char** merged_memory_config; |
550 | TF_Status* status; |
551 | } TpuEmbeddingEngine_CollateMemory_Params; |
552 | |
553 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_CollateMemory( |
554 | TpuEmbeddingEngine_CollateMemory_Params* params); |
555 | |
556 | typedef struct TpuEmbeddingEngine_ConfigureHost_Params { |
557 | int32_t struct_size; |
558 | void* priv; |
559 | |
560 | int num_inputs; |
561 | size_t common_config_size; |
562 | const char* common_config; |
563 | size_t memory_config_size; |
564 | const char* memory_config; |
565 | TpuSerializedProto tpu_embedding_config; |
566 | |
567 | // out |
568 | size_t* network_config_size; |
569 | char** network_config; |
570 | TF_Status* status; |
571 | } TpuEmbeddingEngine_ConfigureHost_Params; |
572 | |
573 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConfigureHost( |
574 | TpuEmbeddingEngine_ConfigureHost_Params* params); |
575 | |
576 | typedef struct TpuEmbeddingEngine_ConnectHosts_Params { |
577 | int32_t struct_size; |
578 | void* priv; |
579 | |
580 | size_t network_configs_size; |
581 | const TpuSerializedProto* network_configs; |
582 | |
583 | // out |
584 | TF_Status* status; |
585 | } TpuEmbeddingEngine_ConnectHosts_Params; |
586 | |
587 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ConnectHosts( |
588 | TpuEmbeddingEngine_ConnectHosts_Params* params); |
589 | |
590 | typedef struct TpuEmbeddingEngine_Finalize_Params { |
591 | int32_t struct_size; |
592 | void* priv; |
593 | const XLA_TpuMeshState* tpu_mesh_state; |
594 | |
595 | size_t common_config_size; |
596 | const char* common_config; |
597 | size_t memory_config_size; |
598 | const char* memory_config; |
599 | |
600 | // out |
601 | TF_Status* status; |
602 | } TpuEmbeddingEngine_Finalize_Params; |
603 | |
604 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_Finalize( |
605 | TpuEmbeddingEngine_Finalize_Params* params); |
606 | |
607 | typedef struct TpuEmbeddingEngine_IsInitialized_Params { |
608 | int32_t struct_size; |
609 | void* priv; |
610 | |
611 | size_t config_string_size; |
612 | const char* config_string; |
613 | |
614 | // out |
615 | bool* is_tpu_embedding_initialized; |
616 | TF_Status* status; |
617 | } TpuEmbeddingEngine_IsInitialized_Params; |
618 | |
619 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_IsInitialized( |
620 | TpuEmbeddingEngine_IsInitialized_Params* params); |
621 | |
622 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_WriteParameters( |
623 | TpuEmbeddingEngineParameters* params, TF_Status* status); |
624 | |
625 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_ReadParameters( |
626 | TpuEmbeddingEngineParameters* params, TF_Status* status); |
627 | |
628 | typedef struct TpuEmbeddingEngine_EnqueueTensorBatch_Params { |
629 | int32_t struct_size; |
630 | void* priv; |
631 | |
632 | int32_t mode; |
633 | int32_t local_device_ordinal; |
634 | TpuEmbedding_TensorBatchFixedState* fixed_state; |
635 | |
636 | TF_Tensor** sample_indices_tensors; |
637 | size_t sample_indices_tensors_size; |
638 | TF_Tensor** embedding_indices_tensors; |
639 | size_t embedding_indices_tensors_size; |
640 | TF_Tensor** aggregation_weights_tensors; |
641 | size_t aggregation_weights_tensors_size; |
642 | TF_Status* status; |
643 | } TpuEmbeddingEngine_EnqueueTensorBatch_Params; |
644 | |
645 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_EnqueueTensorBatch( |
646 | TpuEmbeddingEngine_EnqueueTensorBatch_Params* params); |
647 | |
648 | typedef struct TpuEmbedding_TensorBatchFixedState_Create_Params { |
649 | int32_t struct_size; |
650 | void* priv; |
651 | |
652 | size_t combiners_size; |
653 | char** combiners; |
654 | |
655 | // out |
656 | TF_Status* status; |
657 | } TpuEmbedding_TensorBatchFixedState_Create_Params; |
658 | |
659 | TFTPU_CAPI_EXPORT TpuEmbedding_TensorBatchFixedState* |
660 | TpuEmbeddingTensorBatchFixedState_Create( |
661 | TpuEmbedding_TensorBatchFixedState_Create_Params* params); |
662 | TFTPU_CAPI_EXPORT void TpuEmbeddingTensorBatchFixedState_Destroy( |
663 | TpuEmbedding_TensorBatchFixedState* fixed_state); |
664 | |
665 | typedef struct TpuEmbeddingEngine_RecvActivationsComputation_Params { |
666 | int32_t struct_size; |
667 | void* priv; |
668 | |
669 | size_t config_string_size; |
670 | XLA_Shape* deduplication_data_shape; |
671 | const XLA_TpuMeshState* tpu_mesh_state; |
672 | |
673 | // out |
674 | TpuSerializedProto* xla_computation; |
675 | TF_Status* status; |
676 | } TpuEmbeddingEngine_RecvActivationsComputation_Params; |
677 | |
678 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_RecvActivationsComputation( |
679 | TpuEmbeddingEngine_RecvActivationsComputation_Params* params); |
680 | |
681 | typedef struct |
682 | TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params { |
683 | int32_t struct_size; |
684 | void* priv; |
685 | |
686 | const XLA_TpuMeshState* tpu_mesh_state; |
687 | // out |
688 | TpuSerializedProto* xla_computation; |
689 | TF_Status* status; |
690 | } TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params; |
691 | |
692 | TFTPU_CAPI_EXPORT void |
693 | TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation( |
694 | TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation_Params* |
695 | params); |
696 | |
697 | typedef struct TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params { |
698 | int32_t struct_size; |
699 | void* priv; |
700 | |
701 | int32_t num_inputs; |
702 | const XLA_TpuMeshState* tpu_mesh_state; |
703 | XLA_Shape* learning_rate_tuple_shape; |
704 | XLA_Shape* deduplication_data_shape; |
705 | XLA_Shape* gradient_tuple_shape; |
706 | // out |
707 | TpuSerializedProto* xla_computation; |
708 | TF_Status* status; |
709 | } TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params; |
710 | |
711 | TFTPU_CAPI_EXPORT void TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation( |
712 | TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation_Params* params); |
713 | |
714 | struct TfTpu_OpsApiFn { |
715 | TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CompileAndBuild); |
716 | TFTPU_ADD_FN_IN_STRUCT(TpuCompile_XrtCompileAndBuild); |
717 | |
718 | TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Create); |
719 | TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_Free); |
720 | TFTPU_ADD_FN_IN_STRUCT(TpuMeshState_MeshCommonState); |
721 | |
722 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_Create); |
723 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_Free); |
724 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngineState_GetState); |
725 | |
726 | TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Create); |
727 | TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Destroy); |
728 | TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Start); |
729 | TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_Stop); |
730 | TFTPU_ADD_FN_IN_STRUCT(TpuProfiler_CollectData); |
731 | |
732 | TFTPU_ADD_FN_IN_STRUCT(TpuExecutable_LoadProgramAndEnqueueToStream); |
733 | TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_HostShapeToDeviceShape); |
734 | TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSize); |
735 | TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompact); |
736 | TFTPU_ADD_FN_IN_STRUCT(HardwareLayout_ShapeSizeCompactRaw); |
737 | |
738 | TFTPU_ADD_FN_IN_STRUCT(TpuExecute_RuntimeInputToPaddedData); |
739 | |
740 | TFTPU_ADD_FN_IN_STRUCT(ConfigureDistributedTpuOp_DoWork); |
741 | TFTPU_ADD_FN_IN_STRUCT(WaitForDistributedTpuOp_DoWork); |
742 | TFTPU_ADD_FN_IN_STRUCT(InitializeHostForDistributedTpuOp_DoWork); |
743 | TFTPU_ADD_FN_IN_STRUCT(SetGlobalTPUArrayOp_DoWork); |
744 | TFTPU_ADD_FN_IN_STRUCT(DisconnectDistributedTpuChipsOp_DoWork); |
745 | TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeCharArray); |
746 | TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_FreeInt32Array); |
747 | TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_HasTPUPodState); |
748 | TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpusPerHost); |
749 | TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_TpuMemoryLimit); |
750 | TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_RemoteCompilationCacheSizeInBytes); |
751 | TFTPU_ADD_FN_IN_STRUCT( |
752 | TpuConfigurationApi_CompilationCacheServerAddressFromConfig); |
753 | TFTPU_ADD_FN_IN_STRUCT(TpuConfigurationApi_GetServerAddressAndPort); |
754 | |
755 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_New); |
756 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_Free); |
757 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_NewArray); |
758 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_FreeArray); |
759 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_UnloadAndDestroy); |
760 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetProgramSize); |
761 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_LogProgramMemorySummary); |
762 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetExecutableInfo); |
763 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHostTransferInfo); |
764 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetHloMetadata); |
765 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetMayModifyVariables); |
766 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_HasSharding); |
767 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetTpuProgram); |
768 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeTpuExecutable); |
769 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_SerializeCompilerMetadata); |
770 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DeserializeFromGetTpuProgramResponseProto); |
771 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_GetFingerprint); |
772 | TFTPU_ADD_FN_IN_STRUCT(TpuProgram_DestroyFingerprint); |
773 | |
774 | TFTPU_ADD_FN_IN_STRUCT(TpuCompile_IsTpuCompilationEnabled); |
775 | TFTPU_ADD_FN_IN_STRUCT(TpuCompile_ShouldTpuCompileOpIgnoreCancellation); |
776 | TFTPU_ADD_FN_IN_STRUCT(TpuTopology_AvailableCoreCount); |
777 | TFTPU_ADD_FN_IN_STRUCT(TpuNetUtil_RecycleUnusedPort); |
778 | TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateCompilationCacheKey); |
779 | TFTPU_ADD_FN_IN_STRUCT(TpuCompile_DestroyCompilationCacheKey); |
780 | TFTPU_ADD_FN_IN_STRUCT(TpuCompile_CreateGuaranteedConstFingerprint); |
781 | |
782 | TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Create); |
783 | TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Free); |
784 | TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_StopChipHeartbeats); |
785 | TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CloseTpuHost); |
786 | TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_Initialize); |
787 | TFTPU_ADD_FN_IN_STRUCT(TpuNodeContext_CompactionSupported); |
788 | |
789 | TFTPU_ADD_FN_IN_STRUCT(TfTpu_InitializeTpuModelServer); |
790 | |
791 | TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Create); |
792 | TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_Destroy); |
793 | TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_GetOrdinal); |
794 | TFTPU_ADD_FN_IN_STRUCT(TfTpuOrdinalSelector_DequeueFromCoreSelector); |
795 | TFTPU_ADD_FN_IN_STRUCT(TfTpu_GetTpuPartitionedCallParams); |
796 | |
797 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ExecutePartitioner); |
798 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConfigureMemory); |
799 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_CollateMemory); |
800 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConfigureHost); |
801 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ConnectHosts); |
802 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_Finalize); |
803 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_IsInitialized); |
804 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_WriteParameters); |
805 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_ReadParameters); |
806 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Create); |
807 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingTensorBatchFixedState_Destroy); |
808 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_EnqueueTensorBatch); |
809 | TFTPU_ADD_FN_IN_STRUCT(TpuEmbeddingEngine_RecvActivationsComputation); |
810 | TFTPU_ADD_FN_IN_STRUCT( |
811 | TpuEmbeddingEngine_RecvTPUEmbeddingDeduplicationDataComputation); |
812 | TFTPU_ADD_FN_IN_STRUCT( |
813 | TpuEmbeddingEngine_SendTPUEmbeddingGradientsComputation); |
814 | }; |
815 | |
816 | } // extern "C" |
817 | |
818 | #endif // TENSORFLOW_CORE_TPU_TPU_OPS_C_API_H_ |
819 | |