1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15#ifndef TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
16#define TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
17
18#include "tensorflow/c/c_api.h"
19#include "tensorflow/c/eager/c_api.h"
20
21#ifdef __cplusplus
22extern "C" {
23#endif
24
25// Resets `op_to_reset` with `op_or_function_name` and `raw_device_name`. This
26// is for performance optimization by reusing an exiting unused op rather than
27// creating a new op every time. If `raw_device_name` is `NULL` or empty, it
28// does not set the device name. If it's not `NULL`, then it attempts to parse
29// and set the device name. It's effectively `TFE_OpSetDevice`, but it is faster
30// than separately calling it because if the existing op has the same
31// `raw_device_name`, it skips parsing and just leave as it is.
32TF_CAPI_EXPORT extern void TFE_OpReset(TFE_Op* op_to_reset,
33 const char* op_or_function_name,
34 const char* raw_device_name,
35 TF_Status* status);
36
37// Enables only graph collection in RunMetadata on the functions executed from
38// this context.
39TF_CAPI_EXPORT extern void TFE_ContextEnableGraphCollection(TFE_Context* ctx);
40
41// Disables only graph collection in RunMetadata on the functions executed from
42// this context.
43TF_CAPI_EXPORT extern void TFE_ContextDisableGraphCollection(TFE_Context* ctx);
44
45// TODO(fishx): Move these monitoring APIs into a separate file.
46// -----------------------------------------------------------------------------
47// Monitoring Counter APIs.
48// These APIs de-templated monitoring Counter for swig.
49
50typedef struct TFE_MonitoringCounterCell TFE_MonitoringCounterCell;
51
52// Atomically increments the value of the cell. The value must be non-negative.
53TF_CAPI_EXPORT extern void TFE_MonitoringCounterCellIncrementBy(
54 TFE_MonitoringCounterCell* cell, int64_t value);
55
56// Retrieves the current value of the cell.
57TF_CAPI_EXPORT extern int64_t TFE_MonitoringCounterCellValue(
58 TFE_MonitoringCounterCell* cell);
59
60// APIs for Counter without label.
61typedef struct TFE_MonitoringCounter0 TFE_MonitoringCounter0;
62// Returns a new Counter metric object. The caller should manage lifetime of
63// the object. Using duplicate metric name will crash the program with fatal
64// error.
65TF_CAPI_EXPORT extern TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(
66 const char* name, TF_Status* status, const char* description);
67// Deletes the Counter object.
68TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter0(
69 TFE_MonitoringCounter0* counter);
70// Retrieves the cell from the Counter object. The Counter object will manage
71// lifetime of the cell.
72TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
73 TFE_MonitoringCounter0* counter);
74
75// APIs for Counter with 1 label.
76typedef struct TFE_MonitoringCounter1 TFE_MonitoringCounter1;
77TF_CAPI_EXPORT extern TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(
78 const char* name, TF_Status* status, const char* description,
79 const char* label1);
80TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter1(
81 TFE_MonitoringCounter1* counter);
82TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
83 TFE_MonitoringCounter1* counter, const char* label1);
84
85// APIs for Counter with 2 labels.
86typedef struct TFE_MonitoringCounter2 TFE_MonitoringCounter2;
87TF_CAPI_EXPORT extern TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(
88 const char* name, TF_Status* status, const char* description,
89 const char* label1, const char* label2);
90TF_CAPI_EXPORT extern void TFE_MonitoringDeleteCounter2(
91 TFE_MonitoringCounter2* counter);
92TF_CAPI_EXPORT extern TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
93 TFE_MonitoringCounter2* counter, const char* label1, const char* label2);
94
95// -----------------------------------------------------------------------------
96// Monitoring Gauge APIs.
97// These APIs de-templated monitoring Gauge for swig.
98
99typedef struct TFE_MonitoringIntGaugeCell TFE_MonitoringIntGaugeCell;
100
101// Atomically set the value of the cell.
102TF_CAPI_EXPORT extern void TFE_MonitoringIntGaugeCellSet(
103 TFE_MonitoringIntGaugeCell* cell, int64_t value);
104
105// Retrieves the current value of the cell.
106TF_CAPI_EXPORT extern int64_t TFE_MonitoringIntGaugeCellValue(
107 TFE_MonitoringIntGaugeCell* cell);
108
109// APIs for Int Gauge without label.
110typedef struct TFE_MonitoringIntGauge0 TFE_MonitoringIntGauge0;
111TF_CAPI_EXPORT extern TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(
112 const char* name, TF_Status* out_status, const char* description);
113TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge0(
114 TFE_MonitoringIntGauge0* gauge);
115TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
116TFE_MonitoringGetCellIntGauge0(TFE_MonitoringIntGauge0* gauge);
117
118// APIs for Int Gauge with 1 label.
119typedef struct TFE_MonitoringIntGauge1 TFE_MonitoringIntGauge1;
120TF_CAPI_EXPORT extern TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(
121 const char* name, TF_Status* out_status, const char* description,
122 const char* label1);
123TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge1(
124 TFE_MonitoringIntGauge1* gauge);
125TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
126TFE_MonitoringGetCellIntGauge1(TFE_MonitoringIntGauge1* gauge,
127 const char* label1);
128
129// APIs for Int Gauge with 2 label.
130typedef struct TFE_MonitoringIntGauge2 TFE_MonitoringIntGauge2;
131TF_CAPI_EXPORT extern TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(
132 const char* name, TF_Status* out_status, const char* description,
133 const char* label1, const char* label2);
134TF_CAPI_EXPORT extern void TFE_MonitoringDeleteIntGauge2(
135 TFE_MonitoringIntGauge2* gauge);
136TF_CAPI_EXPORT extern TFE_MonitoringIntGaugeCell*
137TFE_MonitoringGetCellIntGauge2(TFE_MonitoringIntGauge2* gauge,
138 const char* label1, const char* label2);
139
140typedef struct TFE_MonitoringStringGaugeCell TFE_MonitoringStringGaugeCell;
141TF_CAPI_EXPORT extern void TFE_MonitoringStringGaugeCellSet(
142 TFE_MonitoringStringGaugeCell* cell, const char* value);
143// Retrieves the string value and saves it in the buffer.
144TF_CAPI_EXPORT extern const void TFE_MonitoringStringGaugeCellValue(
145 TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf);
146
147// APIs for String Gauge without label.
148typedef struct TFE_MonitoringStringGauge0 TFE_MonitoringStringGauge0;
149TF_CAPI_EXPORT extern TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
150 const char* name, TF_Status* out_status, const char* description);
151TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge0(
152 TFE_MonitoringStringGauge0* gauge);
153TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
154TFE_MonitoringGetCellStringGauge0(TFE_MonitoringStringGauge0* gauge);
155
156// APIs for String Gauge with 1 label.
157typedef struct TFE_MonitoringStringGauge1 TFE_MonitoringStringGauge1;
158TF_CAPI_EXPORT extern TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
159 const char* name, TF_Status* out_status, const char* description,
160 const char* label1);
161TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge1(
162 TFE_MonitoringStringGauge1* gauge);
163TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
164TFE_MonitoringGetCellStringGauge1(TFE_MonitoringStringGauge1* gauge,
165 const char* label1);
166
167// APIs for String Gauge with 2 label.
168typedef struct TFE_MonitoringStringGauge2 TFE_MonitoringStringGauge2;
169TF_CAPI_EXPORT extern TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
170 const char* name, TF_Status* out_status, const char* description,
171 const char* label1, const char* label2);
172TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge2(
173 TFE_MonitoringStringGauge2* gauge);
174TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
175TFE_MonitoringGetCellStringGauge2(TFE_MonitoringStringGauge2* gauge,
176 const char* label1, const char* label2);
177
178// APIs for String Gauge with 3 labels.
179typedef struct TFE_MonitoringStringGauge3 TFE_MonitoringStringGauge3;
180TF_CAPI_EXPORT extern TFE_MonitoringStringGauge3* TFE_MonitoringNewStringGauge3(
181 const char* name, TF_Status* out_status, const char* description,
182 const char* label1, const char* label2, const char* label3);
183TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge3(
184 TFE_MonitoringStringGauge3* gauge);
185TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
186TFE_MonitoringGetCellStringGauge3(TFE_MonitoringStringGauge3* gauge,
187 const char* label1, const char* label2,
188 const char* label3);
189
190// APIs for String Gauge with 4 labels.
191typedef struct TFE_MonitoringStringGauge4 TFE_MonitoringStringGauge4;
192TF_CAPI_EXPORT extern TFE_MonitoringStringGauge4* TFE_MonitoringNewStringGauge4(
193 const char* name, TF_Status* out_status, const char* description,
194 const char* label1, const char* label2, const char* label3,
195 const char* label4);
196TF_CAPI_EXPORT extern void TFE_MonitoringDeleteStringGauge4(
197 TFE_MonitoringStringGauge4* gauge);
198TF_CAPI_EXPORT extern TFE_MonitoringStringGaugeCell*
199TFE_MonitoringGetCellStringGauge4(TFE_MonitoringStringGauge4* gauge,
200 const char* label1, const char* label2,
201 const char* label3, const char* label4);
202
203typedef struct TFE_MonitoringBoolGaugeCell TFE_MonitoringBoolGaugeCell;
204TF_CAPI_EXPORT extern void TFE_MonitoringBoolGaugeCellSet(
205 TFE_MonitoringBoolGaugeCell* cell, bool value);
206TF_CAPI_EXPORT extern bool TFE_MonitoringBoolGaugeCellValue(
207 TFE_MonitoringBoolGaugeCell* cell);
208
209// APIs for Bool Gauge without label.
210typedef struct TFE_MonitoringBoolGauge0 TFE_MonitoringBoolGauge0;
211TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(
212 const char* name, TF_Status* out_status, const char* description);
213TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge0(
214 TFE_MonitoringBoolGauge0* gauge);
215TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
216TFE_MonitoringGetCellBoolGauge0(TFE_MonitoringBoolGauge0* gauge);
217
218// APIs for Bool Gauge with 1 label.
219typedef struct TFE_MonitoringBoolGauge1 TFE_MonitoringBoolGauge1;
220TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(
221 const char* name, TF_Status* out_status, const char* description,
222 const char* label1);
223TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge1(
224 TFE_MonitoringBoolGauge1* gauge);
225TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
226TFE_MonitoringGetCellBoolGauge1(TFE_MonitoringBoolGauge1* gauge,
227 const char* label1);
228
229// APIs for Bool Gauge with 2 label.
230typedef struct TFE_MonitoringBoolGauge2 TFE_MonitoringBoolGauge2;
231TF_CAPI_EXPORT extern TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(
232 const char* name, TF_Status* out_status, const char* description,
233 const char* label1, const char* label2);
234TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBoolGauge2(
235 TFE_MonitoringBoolGauge2* gauge);
236TF_CAPI_EXPORT extern TFE_MonitoringBoolGaugeCell*
237TFE_MonitoringGetCellBoolGauge2(TFE_MonitoringBoolGauge2* gauge,
238 const char* label1, const char* label2);
239
240// -----------------------------------------------------------------------------
241// Monitoring Sampler APIs.
242// These APIs de-templated monitoring Sampler for swig.
243
244typedef struct TFE_MonitoringSamplerCell TFE_MonitoringSamplerCell;
245
246// Atomically add the value of the cell.
247TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellAdd(
248 TFE_MonitoringSamplerCell* cell, double value);
249
250// Retrieves the current value of the cell. The return value is a HistogramProto
251// saved in the buffer.
252TF_CAPI_EXPORT extern void TFE_MonitoringSamplerCellValue(
253 TFE_MonitoringSamplerCell* cell, TF_Buffer* buf);
254
255// APIs for sampler buckets
256typedef struct TFE_MonitoringBuckets TFE_MonitoringBuckets;
257TF_CAPI_EXPORT extern TFE_MonitoringBuckets*
258TFE_MonitoringNewExponentialBuckets(double scale, double growth_factor,
259 int bucket_count);
260TF_CAPI_EXPORT extern void TFE_MonitoringDeleteBuckets(
261 TFE_MonitoringBuckets* buckets);
262
263// APIs for Sampler without label.
264typedef struct TFE_MonitoringSampler0 TFE_MonitoringSampler0;
265TF_CAPI_EXPORT extern TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
266 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
267 const char* description);
268TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler0(
269 TFE_MonitoringSampler0* sampler);
270TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
271 TFE_MonitoringSampler0* sampler);
272
273// APIs for Sampler with 1 label.
274typedef struct TFE_MonitoringSampler1 TFE_MonitoringSampler1;
275TF_CAPI_EXPORT extern TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
276 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
277 const char* description, const char* label1);
278TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler1(
279 TFE_MonitoringSampler1* sampler);
280TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
281 TFE_MonitoringSampler1* sampler, const char* label1);
282
283// APIs for Sampler with 2 label.
284typedef struct TFE_MonitoringSampler2 TFE_MonitoringSampler2;
285TF_CAPI_EXPORT extern TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
286 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* out_status,
287 const char* description, const char* label1, const char* label2);
288TF_CAPI_EXPORT extern void TFE_MonitoringDeleteSampler2(
289 TFE_MonitoringSampler2* sampler);
290TF_CAPI_EXPORT extern TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
291 TFE_MonitoringSampler2* sampler, const char* label1, const char* label2);
292
293// Sets whether to use TFRT
294TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrt(TFE_ContextOptions*,
295 bool use_tfrt);
296
297// Sets whether to use TFRT distributed runtime
298TF_CAPI_EXPORT extern void TFE_ContextOptionsSetTfrtDistributedRuntime(
299 TFE_ContextOptions* options, bool use_tfrt_distributed_runtime);
300
301// Returns the context_id from the EagerContext which is used by the
302// EagerService to maintain consistency between client and worker. The
303// context_id is initialized with a dummy value and is later set when the worker
304// is initialized (either locally or remotely). The context_id can change during
305// the process lifetime although this should cause the worker to be
306// reinitialized (e.g. cleared caches) as well.
307TF_CAPI_EXPORT extern uint64_t TFE_GetContextId(TFE_Context* ctx);
308
309// -----------------------------------------------------------------------------
310// Cancellation APIs.
311
312typedef struct TFE_CancellationManager TFE_CancellationManager;
313TF_CAPI_EXPORT extern TFE_CancellationManager* TFE_NewCancellationManager();
314TF_CAPI_EXPORT extern bool TFE_CancellationManagerIsCancelled(
315 TFE_CancellationManager*);
316TF_CAPI_EXPORT extern void TFE_CancellationManagerStartCancel(
317 TFE_CancellationManager*);
318TF_CAPI_EXPORT extern void TFE_DeleteCancellationManager(
319 TFE_CancellationManager*);
320
321// Associates the given `cancellation_manager` with `op`, so that invoking
322// `TFE_CancellationManagerStartCancel(cancellation_manager)` will cancel the
323// execution of `op`.
324typedef struct TFE_CancellationManager TFE_CancellationManager;
325TF_CAPI_EXPORT extern void TFE_OpSetCancellationManager(
326 TFE_Op* op, TFE_CancellationManager* cancellation_manager,
327 TF_Status* status);
328
329// -----------------------------------------------------------------------------
330// Eager Executor APIs.
331typedef struct TFE_Executor TFE_Executor;
332
333// Creates a new eager Executor. Nodes in one executor are guaranteed to be
334// executed in sequence. Assigning nodes to different executors allows executing
335// nodes in parallel.
336TF_CAPI_EXPORT extern TFE_Executor* TFE_NewExecutor(
337 bool is_async, bool enable_streaming_enqueue);
338
339// Deletes the eager Executor without waiting for enqueued nodes. Please call
340// TFE_ExecutorWaitForAllPendingNodes before calling this API if you want to
341// make sure all nodes are finished.
342TF_CAPI_EXPORT extern void TFE_DeleteExecutor(TFE_Executor*);
343
344// Returns true if the executor is in async mode.
345TF_CAPI_EXPORT extern bool TFE_ExecutorIsAsync(TFE_Executor*);
346
347// Causes the calling thread to block till all ops dispatched in this executor
348// have been executed. Note that "execution" here refers to kernel execution /
349// scheduling of copies, etc. Similar to sync execution, it doesn't guarantee
350// that lower level device queues (like GPU streams) have been flushed.
351//
352// This call may not block for execution of ops enqueued concurrently with this
353// call.
354TF_CAPI_EXPORT extern void TFE_ExecutorWaitForAllPendingNodes(
355 TFE_Executor*, TF_Status* status);
356
357// When an error happens, any pending operations are discarded, and newly issued
358// ops return an error. This call clears the error state and re-enables
359// execution of newly issued ops.
360//
361// Note that outputs of discarded ops remain in a corrupt state and should not
362// be used for future calls.
363// TODO(agarwal): mark the affected handles and raise errors if they are used.
364TF_CAPI_EXPORT extern void TFE_ExecutorClearError(TFE_Executor*);
365
366// Sets a custom Executor for the current thread. All nodes created by this
367// thread will be added to this Executor. It will override the current executor.
368TF_CAPI_EXPORT extern void TFE_ContextSetExecutorForThread(TFE_Context*,
369 TFE_Executor*);
370
371// Returns the Executor for the current thread.
372TF_CAPI_EXPORT extern TFE_Executor* TFE_ContextGetExecutorForThread(
373 TFE_Context*);
374
375// -----------------------------------------------------------------------------
376// Dynamic cluster API.
377
378// Update an existing context with a new set of servers defined in a ServerDef
379// proto. Servers can be added to and removed from the list of remote workers
380// in the context. A New set of servers identified by the ServerDef must be up
381// when the context is updated.
382//
383// This API is for experimental usage and may be subject to change.
384TF_CAPI_EXPORT extern void TFE_ContextUpdateServerDef(TFE_Context* ctx,
385 int keep_alive_secs,
386 const void* proto,
387 size_t proto_len,
388 TF_Status* status);
389
390// Checks whether a remote worker is alive or not. This will return true even if
391// the context doesn't exist on the remote worker.
392TF_CAPI_EXPORT extern bool TFE_ContextCheckAlive(TFE_Context* ctx,
393 const char* worker_name,
394 TF_Status* status);
395
396// Sync pending nodes in local executors (including the context default executor
397// and thread executors) and streaming requests to remote executors, and get the
398// combined status.
399TF_CAPI_EXPORT extern void TFE_ContextAsyncWait(TFE_Context* ctx,
400 TF_Status* status);
401
402// This function will block till the operation that produces `h` has
403// completed. This is only valid on local TFE_TensorHandles. The pointer
404// returned will be on the device in which the TFE_TensorHandle resides (so e.g.
405// for a GPU tensor this will return a pointer to GPU memory). The pointer is
406// only guaranteed to be valid until TFE_DeleteTensorHandle is called on this
407// TensorHandle. Only supports POD data types.
408TF_CAPI_EXPORT extern void* TFE_TensorHandleDevicePointer(TFE_TensorHandle*,
409 TF_Status*);
410
411// This function will block till the operation that produces `h` has
412// completed. This is only valid on local TFE_TensorHandles. Returns the size in
413// bytes of the memory pointed to by the device pointer returned above.
414TF_CAPI_EXPORT extern size_t TFE_TensorHandleDeviceMemorySize(TFE_TensorHandle*,
415 TF_Status*);
416
417// Creates a new TensorHandle from memory residing in the physical device
418// device_name. Takes ownership of the memory, and will call deleter to release
419// it after TF no longer needs it or in case of error.
420//
421// Custom devices must use TFE_NewCustomDeviceTensorHandle instead.
422TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewTensorHandleFromDeviceMemory(
423 TFE_Context* ctx, const char* device_name, TF_DataType, const int64_t* dims,
424 int num_dims, void* data, size_t len,
425 void (*deallocator)(void* data, size_t len, void* arg),
426 void* deallocator_arg, TF_Status* status);
427
428// Retrieves the address space (i.e. job, replia, task) of the local host and
429// saves it in the buffer.
430TF_CAPI_EXPORT extern void TFE_HostAddressSpace(TFE_Context* ctx,
431 TF_Buffer* buf);
432
433// APIs for generically dealing with op attributes (e.g. when forwarding them
434// through custom device implementations).
435//
436// TODO(allenl): Currently these are black boxes, but we should have some way to
437// inspect values. This would let people e.g. copy over most attributes and then
438// modify some based on their values.
439
440// A reference to an op's name -> attribute mapping
441typedef struct TFE_OpAttrs TFE_OpAttrs;
442
443// Fetch a reference to `op`'s attributes. The returned reference is only valid
444// while `op` is alive.
445TF_CAPI_EXPORT extern const TFE_OpAttrs* TFE_OpGetAttrs(const TFE_Op* op);
446// Add attributes in `attrs` to `op`.
447//
448// Does not overwrite or update existing attributes, but adds new ones.
449TF_CAPI_EXPORT extern void TFE_OpAddAttrs(TFE_Op* op, const TFE_OpAttrs* attrs);
450
451// Serialize `attrs` as a tensorflow::NameAttrList protocol buffer (into `buf`),
452// containing the op name and a map of its attributes.
453TF_CAPI_EXPORT extern void TFE_OpAttrsSerialize(const TFE_OpAttrs* attrs,
454 TF_Buffer* buf,
455 TF_Status* status);
456
457// Set an op's attribute from a serialized AttrValue protocol buffer.
458//
459// Analogous to TF_SetAttrValueProto for building graph operations.
460TF_CAPI_EXPORT extern void TFE_OpSetAttrValueProto(const TFE_Op* op,
461 const char* attr_name,
462 const void* proto,
463 size_t proto_len,
464 TF_Status* status);
465
466// TODO(b/166642410): It would be nice, for custom devices and for other users,
467// to have a non-string representation of devices (TF_Device) extracted from
468// tensors/ops/etc. and usable in APIs like OpSetDevice/ResetOp/etc.
469
470#define TFE_CUSTOM_DEVICE_VERSION 4
471
472// Struct to be filled in. Functions are required except where indicated.
473typedef struct TFE_CustomDevice {
474 int version = TFE_CUSTOM_DEVICE_VERSION;
475 // Method to copy a tensor to the custom device.
476 TFE_TensorHandle* (*copy_tensor_to_device)(TFE_Context* context,
477 TFE_TensorHandle* tensor,
478 TF_Status* status,
479 void* device_info);
480
481 // Method to copy a tensor from the custom device to a target device.
482 TFE_TensorHandle* (*copy_tensor_from_device)(TFE_Context* context,
483 TFE_TensorHandle* tensor,
484 const char* target_device_name,
485 TF_Status* status,
486 void* device_info);
487
488 // Method to execute an operation.
489 //
490 // Arguments provide enough information to reconstruct the original `TFE_Op`,
491 // or construct a transformed version, by inspecting the passed `op`.
492 //
493 // TFE_OpGetDevice(op) records the original placement of the operation. It may
494 // be an empty string if no device was explicitly requested, but will
495 // otherwise be the name of this custom device. Ops are placed onto a custom
496 // device if any of their inputs are on that custom device, but custom devices
497 // are free to set a bad status in order to require explicit placement.
498 void (*execute)(const TFE_Op* op, int* num_outputs,
499 TFE_TensorHandle** outputs, TF_Status* s, void* device_info);
500
501 // Method to delete a device.
502 void (*delete_device)(void* device_info);
503
504 // Implements TFE_CreatePackedTensorHandle when one of `handles` is on this
505 // custom device.
506 //
507 // Many devices will want to simply return an "unimplemented" status
508 // here. This is the default behavior if `pack` is null when passed to
509 // TFE_RegisterCustomDevice.
510 TFE_TensorHandle* (*pack)(TFE_Context* context, TFE_TensorHandle** handles,
511 int num_handles, TF_Status* s,
512 void* device_info) = nullptr;
513
514 // Pins the op to `device` based on inputs to `op`. Returns true
515 // signifying to pin to the current custom device. Returns false
516 // to pin to the physical device.
517 //
518 // This function is guaranteed to be called only when all of the custom-device
519 // inputs are on this device.
520 bool (*shall_pin_to_this_device)(const TFE_Op* op, TF_Status* s) = nullptr;
521} TFE_CustomDevice;
522
523// Registers a custom device for use with eager execution.
524//
525// Eager operations may be placed on this device, e.g. `with
526// tf.device("CUSTOM"):` from Python if `device_name` for this call is
527// "/job:localhost/replica:0/task:0/device:CUSTOM:0".
528//
529// The custom device defines copy operations for moving TensorHandles on and
530// off, and an execution operation for named operations. Often execution will
531// simply wrap op execution on one or more physical devices.
532//
533// device_info is an opaque caller-defined type stored with the custom device
534// which is passed to the functions referenced in the TFE_CustomDevice struct
535// `device` (execute, delete_device, etc.). It can for example contain the
536// names of wrapped devices.
537//
538// There are currently no graph semantics implemented for registered custom
539// devices, so executing tf.functions which contain operations placed on the
540// custom devices will fail.
541//
542// `device_name` must not name an existing physical or custom device. It must
543// follow the format:
544//
545// /job:<name>/replica:<replica>/task:<task>/device:<type>:<device_num>
546//
547// If the device is successfully registered, `status` is set to TF_OK. Otherwise
548// the device is not usable. In case of a bad status, `device.delete_device` is
549// still called on `device_info` (i.e. the caller does not retain ownership).
550//
551// This API is highly experimental, and in particular is expected to change when
552// it starts supporting operations with attributes and when tf.function support
553// is added.
554TF_CAPI_EXPORT extern void TFE_RegisterCustomDevice(TFE_Context* ctx,
555 TFE_CustomDevice device,
556 const char* device_name,
557 void* device_info,
558 TF_Status* status);
559
560// Returns whether `device_name` maps to a registered custom device.
561TF_CAPI_EXPORT extern bool TFE_IsCustomDevice(TFE_Context* ctx,
562 const char* device_name);
563
564// Struct to be filled in to define a custom device tensor handle. Fields are
565// required except where indicated.
566typedef struct TFE_CustomDeviceTensorHandleMethods {
567 int version = TFE_CUSTOM_DEVICE_VERSION;
568
569 // Computes the rank of the tensor handle.
570 //
571 // Shapes are specified via callbacks because retrieving the shape of a tensor
572 // is a blocking operation for async eager; custom devices should avoid
573 // retrieving shapes of tensors they wrap until the custom device tensor's
574 // shape is explicitly requested where possible.
575 int (*num_dims)(void* data, TF_Status* status);
576
577 // Computes the axis length at `dim_index`.
578 int64_t (*dim)(void* data, int dim_index, TF_Status* status);
579
580 void (*deallocator)(void* data);
581
582 // Summarizes the value of this tensor. The caller takes ownership of the
583 // returned buffer. If `status` is not TF_OK, instead returns a null pointer.
584 //
585 // Does not include the shape and dtype of the tensor (which is generally
586 // appended later), but should include any information specific to this custom
587 // device which would be useful for debugging.
588 //
589 // Optional. If null, defaults to resolving the TFE_TensorHandle into a
590 // TF_Tensor and summarizing that.
591 TF_Buffer* (*summarize)(void* data, TF_Status* status) = nullptr;
592} TFE_CustomDeviceTensorHandle;
593
594// Creates a new TensorHandle from memory residing in a custom device. Takes
595// ownership of the memory pointed to by `tensor_handle_data`, and calls
596// `methods.deallocator` to release it after TF no longer needs it or in case of
597// an error.
598//
599// This call is similar to `TFE_NewTensorHandleFromDeviceMemory`, but supports
600// custom devices instead of physical devices and does not require blocking
601// waiting for exact shapes.
602TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_NewCustomDeviceTensorHandle(
603 TFE_Context*, const char* device_name, TF_DataType, void* data,
604 TFE_CustomDeviceTensorHandle methods, TF_Status* status);
605
606TF_CAPI_EXPORT extern void TFE_ContextGetFunctionDef(TFE_Context* ctx,
607 const char* function_name,
608 TF_Buffer* buf,
609 TF_Status* status);
610
611// Allocate and return a new Tensor on the host.
612//
613// The caller must set the Tensor values by writing them to the pointer returned
614// by TF_TensorData with length TF_TensorByteSize.
615TF_CAPI_EXPORT extern TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx,
616 TF_DataType dtype,
617 const int64_t* dims,
618 int num_dims,
619 TF_Status* status);
620
621// Given a Tensor, wrap it with a TensorHandle
622//
623// Similar to TFE_NewTensorHandle, but includes a pointer to the TFE_Context.
624// The context should be identical to that of the Tensor.
625TF_CAPI_EXPORT TFE_TensorHandle* TFE_NewTensorHandleFromTensor(
626 TFE_Context* ctx, TF_Tensor* t, TF_Status* status);
627
628// Create a packed TensorHandle with the given list of TensorHandles.
629// If `handles` are on the same device, assign the same device to the packed
630// handle; if `handles` are on different deivces, assign a CompositeDevice to
631// it.
632TF_CAPI_EXPORT extern TFE_TensorHandle* TFE_CreatePackedTensorHandle(
633 TFE_Context* ctx, TFE_TensorHandle** handles, int* num_handles,
634 TF_Status* status);
635
636// Configure soft device placement policy for the eager executor. Note this
637// policy is applied to any subsequent op executions.
638TF_CAPI_EXPORT void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx,
639 unsigned char enable,
640 TF_Status* status);
641
642// Configure device placement policy logging for the eager executor. Note this
643// policy is applied to any subsequent op executions.
644TF_CAPI_EXPORT void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx,
645 unsigned char enable,
646 TF_Status* status);
647
648// Enables running eager ops as function.
649TF_CAPI_EXPORT void TFE_ContextSetRunEagerOpAsFunction(TFE_Context* ctx,
650 unsigned char enable,
651 TF_Status* status);
652
653// Enables rewrite jit_compile functions.
654TF_CAPI_EXPORT void TFE_ContextSetJitCompileRewrite(TFE_Context* ctx,
655 unsigned char enable,
656 TF_Status* status);
657
658// Returns the device type of the operation that produced `h`.
659TF_CAPI_EXPORT extern const char* TFE_TensorHandleDeviceType(
660 TFE_TensorHandle* h, TF_Status* status);
661
662// Returns the device ID of the operation that produced `h`.
663TF_CAPI_EXPORT extern int TFE_TensorHandleDeviceID(TFE_TensorHandle* h,
664 TF_Status* status);
665
666// Returns the status for the tensor handle. In TFRT, a tensor handle can carry
667// error info if error happens. If so, the status will be set with the error
668// info. If not, status will be set as OK.
669TF_CAPI_EXPORT extern void TFE_TensorHandleGetStatus(TFE_TensorHandle* h,
670 TF_Status* status);
671
672// Get a comma-separated list of op names executed in graph functions dispatched
673// to `ctx`. This feature is currently only enabled for TFRT debug builds, for
674// performance and simplicity reasons.
675TF_CAPI_EXPORT extern void TFE_GetExecutedOpNames(TFE_Context* ctx,
676 TF_Buffer* buf,
677 TF_Status* status);
678
679// Set logical devices to the context's device manager.
680// If logical devices are already configured at context initialization
681// through TFE_ContextOptions, this method should not be called.
682TF_CAPI_EXPORT extern void TFE_SetLogicalCpuDevices(TFE_Context* ctx,
683 int num_cpus,
684 const char* prefix,
685 TF_Status* status);
686
687// Set configuration key and value using coordination service.
688// If coordination service is enabled, the key-value will be stored on the
689// leader and become accessible to all workers in the cluster.
690// Currently, a config key can only be set with one value, and subsequently
691// setting the same key will lead to errors.
692//
693// Note that the key-values are only expected to be used for cluster
694// configuration data, and should not be used for storing a large amount of data
695// or being accessed very frequently.
696TF_CAPI_EXPORT extern void TFE_InsertConfigKeyValue(TFE_Context* ctx,
697 const char* key,
698 const char* value,
699 TF_Status* status);
700
701// Get configuration key and value using coordination service.
702// The config key must be set before getting its value. Getting value of
703// non-existing config keys will result in errors.
704TF_CAPI_EXPORT extern void TFE_GetConfigKeyValue(TFE_Context* ctx,
705 const char* key,
706 TF_Buffer* value_buf,
707 TF_Status* status);
708
709// Delete configuration key-value. If `key` is a directory, recursively clean up
710// all key-values under the path specified by `key`.
711TF_CAPI_EXPORT extern void TFE_DeleteConfigKeyValue(TFE_Context* ctx,
712 const char* key,
713 TF_Status* status);
714
715// Report error (specified by error_code and error_message) to other tasks in
716// the cluster.
717TF_CAPI_EXPORT extern void TFE_ReportErrorToCluster(TFE_Context* ctx,
718 int error_code,
719 const char* error_message,
720 TF_Status* status);
721
722#ifdef __cplusplus
723} /* end extern "C" */
724#endif
725
726#endif // TENSORFLOW_C_EAGER_C_API_EXPERIMENTAL_H_
727