1/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/c/eager/c_api_experimental.h"
17
18#include <vector>
19
20#include "absl/strings/match.h"
21#include "tensorflow/c/c_api.h"
22#include "tensorflow/c/eager/c_api_internal.h"
23#include "tensorflow/c/eager/tfe_context_internal.h"
24#include "tensorflow/c/eager/tfe_op_internal.h"
25#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
26#include "tensorflow/c/tf_status_helper.h"
27#include "tensorflow/core/common_runtime/composite_device.h"
28#include "tensorflow/core/common_runtime/device.h"
29#include "tensorflow/core/common_runtime/eager/eager_operation.h"
30#include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h"
31#include "tensorflow/core/lib/monitoring/counter.h"
32#include "tensorflow/core/lib/monitoring/gauge.h"
33#include "tensorflow/core/lib/monitoring/sampler.h"
34#include "tensorflow/core/platform/casts.h"
35#include "tensorflow/core/platform/errors.h"
36#include "tensorflow/core/platform/mutex.h"
37#include "tensorflow/core/platform/strcat.h"
38
39using tensorflow::string;
40
41void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name,
42 const char* raw_device_name, TF_Status* status) {
43 if (op_to_reset) {
44 tensorflow::ImmediateExecutionOperation* op =
45 tensorflow::unwrap(op_to_reset);
46 op->Clear();
47 status->status = op->Reset(op_or_function_name, raw_device_name);
48 } else {
49 TF_SetStatus(status, TF_INVALID_ARGUMENT,
50 "op_to_reset should not be nullptr");
51 }
52}
53
54void TFE_ContextEnableGraphCollection(TFE_Context* ctx) {
55 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true);
56}
57
58void TFE_ContextDisableGraphCollection(TFE_Context* ctx) {
59 tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false);
60}
61
62uint64_t TFE_GetContextId(TFE_Context* ctx) {
63 tensorflow::EagerContext* context =
64 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
65 return context->GetContextId();
66}
67
68void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell,
69 int64_t value) {
70 cell->cell.IncrementBy(value);
71}
72
73int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) {
74 return cell->cell.value();
75}
76
77TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name,
78 TF_Status* status,
79 const char* description) {
80 auto* result = new TFE_MonitoringCounter0({name, description});
81 Set_TF_Status_from_Status(status, result->counter->GetStatus());
82 if (!result->counter->GetStatus().ok()) {
83 delete result;
84 return nullptr;
85 }
86 return result;
87}
88
89void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) {
90 delete counter;
91}
92
93TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0(
94 TFE_MonitoringCounter0* counter) {
95 return static_cast<TFE_MonitoringCounterCell*>(
96 static_cast<void*>(counter->counter->GetCell()));
97}
98
99TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name,
100 TF_Status* status,
101 const char* description,
102 const char* label1) {
103 auto* result = new TFE_MonitoringCounter1({name, description, label1});
104 Set_TF_Status_from_Status(status, result->counter->GetStatus());
105 if (!result->counter->GetStatus().ok()) {
106 delete result;
107 return nullptr;
108 }
109 return result;
110}
111
112void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) {
113 delete counter;
114}
115
116TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1(
117 TFE_MonitoringCounter1* counter, const char* label1) {
118 return static_cast<TFE_MonitoringCounterCell*>(
119 static_cast<void*>(counter->counter->GetCell(label1)));
120}
121
122TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name,
123 TF_Status* status,
124 const char* description,
125 const char* label1,
126 const char* label2) {
127 auto* result =
128 new TFE_MonitoringCounter2({name, description, label1, label2});
129 Set_TF_Status_from_Status(status, result->counter->GetStatus());
130 if (!result->counter->GetStatus().ok()) {
131 delete result;
132 return nullptr;
133 }
134 return result;
135}
136
137void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) {
138 delete counter;
139}
140
141TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2(
142 TFE_MonitoringCounter2* counter, const char* label1, const char* label2) {
143 return static_cast<TFE_MonitoringCounterCell*>(
144 static_cast<void*>(counter->counter->GetCell(label1, label2)));
145}
146
147void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell,
148 int64_t value) {
149 cell->cell.Set(value);
150}
151
152int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) {
153 return cell->cell.value();
154}
155
156TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name,
157 TF_Status* status,
158 const char* description) {
159 auto* result = new TFE_MonitoringIntGauge0({name, description});
160 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
161 if (!result->gauge->GetStatus().ok()) {
162 delete result;
163 return nullptr;
164 }
165 return result;
166}
167
168void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) {
169 delete gauge;
170}
171
172TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0(
173 TFE_MonitoringIntGauge0* gauge) {
174 return static_cast<TFE_MonitoringIntGaugeCell*>(
175 static_cast<void*>(gauge->gauge->GetCell()));
176}
177
178TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name,
179 TF_Status* status,
180 const char* description,
181 const char* label1) {
182 auto* result = new TFE_MonitoringIntGauge1({name, description, label1});
183 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
184 if (!result->gauge->GetStatus().ok()) {
185 delete result;
186 return nullptr;
187 }
188 return result;
189}
190
191void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) {
192 delete gauge;
193}
194
195TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1(
196 TFE_MonitoringIntGauge1* gauge, const char* label1) {
197 return static_cast<TFE_MonitoringIntGaugeCell*>(
198 static_cast<void*>(gauge->gauge->GetCell(label1)));
199}
200
201TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name,
202 TF_Status* status,
203 const char* description,
204 const char* label1,
205 const char* label2) {
206 auto* result =
207 new TFE_MonitoringIntGauge2({name, description, label1, label2});
208 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
209 if (!result->gauge->GetStatus().ok()) {
210 delete result;
211 return nullptr;
212 }
213 return result;
214}
215
216void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) {
217 delete gauge;
218}
219
220TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2(
221 TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) {
222 return static_cast<TFE_MonitoringIntGaugeCell*>(
223 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
224}
225
226void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell,
227 const char* value) {
228 cell->cell.Set({value});
229}
230
231const void TFE_MonitoringStringGaugeCellValue(
232 TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) {
233 tensorflow::string value = cell->cell.value();
234 void* data = tensorflow::port::Malloc(value.length());
235 value.copy(static_cast<char*>(data), value.length(), 0);
236 buf->data = data;
237 buf->length = value.length();
238 buf->data_deallocator = [](void* data, size_t length) {
239 tensorflow::port::Free(data);
240 };
241}
242
243TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0(
244 const char* name, TF_Status* status, const char* description) {
245 auto* result = new TFE_MonitoringStringGauge0({name, description});
246 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
247 if (!result->gauge->GetStatus().ok()) {
248 delete result;
249 return nullptr;
250 }
251 return result;
252}
253
254void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) {
255 delete gauge;
256}
257
258TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0(
259 TFE_MonitoringStringGauge0* gauge) {
260 return static_cast<TFE_MonitoringStringGaugeCell*>(
261 static_cast<void*>(gauge->gauge->GetCell()));
262}
263
264TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1(
265 const char* name, TF_Status* status, const char* description,
266 const char* label1) {
267 auto* result = new TFE_MonitoringStringGauge1({name, description, label1});
268 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
269 if (!result->gauge->GetStatus().ok()) {
270 delete result;
271 return nullptr;
272 }
273 return result;
274}
275
276void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) {
277 delete gauge;
278}
279
280TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1(
281 TFE_MonitoringStringGauge1* gauge, const char* label1) {
282 return static_cast<TFE_MonitoringStringGaugeCell*>(
283 static_cast<void*>(gauge->gauge->GetCell(label1)));
284}
285
286TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2(
287 const char* name, TF_Status* status, const char* description,
288 const char* label1, const char* label2) {
289 auto* result =
290 new TFE_MonitoringStringGauge2({name, description, label1, label2});
291 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
292 if (!result->gauge->GetStatus().ok()) {
293 delete result;
294 return nullptr;
295 }
296 return result;
297}
298
299void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) {
300 delete gauge;
301}
302
303TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2(
304 TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) {
305 return static_cast<TFE_MonitoringStringGaugeCell*>(
306 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
307}
308
309TFE_MonitoringStringGauge3* TFE_MonitoringNewStringGauge3(
310 const char* name, TF_Status* status, const char* description,
311 const char* label1, const char* label2, const char* label3) {
312 auto* result = new TFE_MonitoringStringGauge3(
313 {name, description, label1, label2, label3});
314 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
315 if (!result->gauge->GetStatus().ok()) {
316 delete result;
317 return nullptr;
318 }
319 return result;
320}
321
322void TFE_MonitoringDeleteStringGauge3(TFE_MonitoringStringGauge3* gauge) {
323 delete gauge;
324}
325
326TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge3(
327 TFE_MonitoringStringGauge3* gauge, const char* label1, const char* label2,
328 const char* label3) {
329 return static_cast<TFE_MonitoringStringGaugeCell*>(
330 static_cast<void*>(gauge->gauge->GetCell(label1, label2, label3)));
331}
332
333TFE_MonitoringStringGauge4* TFE_MonitoringNewStringGauge4(
334 const char* name, TF_Status* status, const char* description,
335 const char* label1, const char* label2, const char* label3,
336 const char* label4) {
337 auto* result = new TFE_MonitoringStringGauge4(
338 {name, description, label1, label2, label3, label4});
339 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
340 if (!result->gauge->GetStatus().ok()) {
341 delete result;
342 return nullptr;
343 }
344 return result;
345}
346
347void TFE_MonitoringDeleteStringGauge4(TFE_MonitoringStringGauge4* gauge) {
348 delete gauge;
349}
350
351TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge4(
352 TFE_MonitoringStringGauge4* gauge, const char* label1, const char* label2,
353 const char* label3, const char* label4) {
354 return static_cast<TFE_MonitoringStringGaugeCell*>(static_cast<void*>(
355 gauge->gauge->GetCell(label1, label2, label3, label4)));
356}
357
358void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell,
359 bool value) {
360 cell->cell.Set(value);
361}
362
363bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) {
364 return cell->cell.value();
365}
366
367TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name,
368 TF_Status* status,
369 const char* description) {
370 auto* result = new TFE_MonitoringBoolGauge0({name, description});
371 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
372 if (!result->gauge->GetStatus().ok()) {
373 delete result;
374 return nullptr;
375 }
376 return result;
377}
378
379void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) {
380 delete gauge;
381}
382
383TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0(
384 TFE_MonitoringBoolGauge0* gauge) {
385 return static_cast<TFE_MonitoringBoolGaugeCell*>(
386 static_cast<void*>(gauge->gauge->GetCell()));
387}
388
389TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name,
390 TF_Status* status,
391 const char* description,
392 const char* label1) {
393 auto* result = new TFE_MonitoringBoolGauge1({name, description, label1});
394 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
395 if (!result->gauge->GetStatus().ok()) {
396 delete result;
397 return nullptr;
398 }
399 return result;
400}
401
402void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) {
403 delete gauge;
404}
405
406TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1(
407 TFE_MonitoringBoolGauge1* gauge, const char* label1) {
408 return static_cast<TFE_MonitoringBoolGaugeCell*>(
409 static_cast<void*>(gauge->gauge->GetCell(label1)));
410}
411
412TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name,
413 TF_Status* status,
414 const char* description,
415 const char* label1,
416 const char* label2) {
417 auto* result =
418 new TFE_MonitoringBoolGauge2({name, description, label1, label2});
419 Set_TF_Status_from_Status(status, result->gauge->GetStatus());
420 if (!result->gauge->GetStatus().ok()) {
421 delete result;
422 return nullptr;
423 }
424 return result;
425}
426
427void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) {
428 delete gauge;
429}
430
431TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2(
432 TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) {
433 return static_cast<TFE_MonitoringBoolGaugeCell*>(
434 static_cast<void*>(gauge->gauge->GetCell(label1, label2)));
435}
436
437void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell,
438 double value) {
439 cell->cell.Add(value);
440}
441
442void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell,
443 TF_Buffer* buf) {
444 string content;
445 cell->cell.value().SerializeToString(&content);
446 void* data = tensorflow::port::Malloc(content.length());
447 content.copy(static_cast<char*>(data), content.length(), 0);
448 buf->data = data;
449 buf->length = content.length();
450 buf->data_deallocator = [](void* data, size_t length) {
451 tensorflow::port::Free(data);
452 };
453}
454
455TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale,
456 double growth_factor,
457 int bucket_count) {
458 return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() {
459 return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor,
460 bucket_count);
461 });
462}
463
464void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) {
465 delete buckets;
466}
467
468TFE_MonitoringSampler0* TFE_MonitoringNewSampler0(
469 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
470 const char* description) {
471 auto* result = new TFE_MonitoringSampler0(
472 {name, buckets->create_buckets(), description});
473 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
474 if (!result->sampler->GetStatus().ok()) {
475 delete result;
476 return nullptr;
477 }
478 return result;
479}
480
481void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) {
482 delete sampler;
483}
484
485TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0(
486 TFE_MonitoringSampler0* sampler) {
487 return static_cast<TFE_MonitoringSamplerCell*>(
488 static_cast<void*>(sampler->sampler->GetCell()));
489}
490
491TFE_MonitoringSampler1* TFE_MonitoringNewSampler1(
492 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
493 const char* description, const char* label1) {
494 auto* result = new TFE_MonitoringSampler1(
495 {name, buckets->create_buckets(), description, label1});
496 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
497 if (!result->sampler->GetStatus().ok()) {
498 delete result;
499 return nullptr;
500 }
501 return result;
502}
503
504void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) {
505 delete sampler;
506}
507
508TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1(
509 TFE_MonitoringSampler1* sampler, const char* label1) {
510 return static_cast<TFE_MonitoringSamplerCell*>(
511 static_cast<void*>(sampler->sampler->GetCell(label1)));
512}
513
514TFE_MonitoringSampler2* TFE_MonitoringNewSampler2(
515 const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status,
516 const char* description, const char* label1, const char* label2) {
517 auto* result = new TFE_MonitoringSampler2(
518 {name, buckets->create_buckets(), description, label1, label2});
519 Set_TF_Status_from_Status(status, result->sampler->GetStatus());
520 if (!result->sampler->GetStatus().ok()) {
521 delete result;
522 return nullptr;
523 }
524 return result;
525}
526
527void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) {
528 delete sampler;
529}
530
531TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2(
532 TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) {
533 return static_cast<TFE_MonitoringSamplerCell*>(
534 static_cast<void*>(sampler->sampler->GetCell(label1, label2)));
535}
536
537void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) {
538 options->use_tfrt = use_tfrt;
539}
540
541void TFE_ContextOptionsSetTfrtDistributedRuntime(
542 TFE_ContextOptions* options, bool use_tfrt_distributed_runtime) {
543 options->use_tfrt_distributed_runtime = use_tfrt_distributed_runtime;
544}
545
546TFE_CancellationManager* TFE_NewCancellationManager() {
547 return tensorflow::wrap(new tensorflow::CancellationManager);
548}
549
550void TFE_CancellationManagerStartCancel(
551 TFE_CancellationManager* cancellation_manager) {
552 tensorflow::unwrap(cancellation_manager)->StartCancel();
553}
554
555bool TFE_CancellationManagerIsCancelled(
556 TFE_CancellationManager* cancellation_manager) {
557 return tensorflow::unwrap(cancellation_manager)->IsCancelled();
558}
559
560void TFE_DeleteCancellationManager(
561 TFE_CancellationManager* cancellation_manager) {
562 delete tensorflow::unwrap(cancellation_manager);
563}
564
565void TFE_OpSetCancellationManager(TFE_Op* op,
566 TFE_CancellationManager* cancellation_manager,
567 TF_Status* status) {
568 tensorflow::unwrap(op)->SetCancellationManager(
569 tensorflow::unwrap(cancellation_manager));
570 status->status = ::tensorflow::OkStatus();
571}
572
573TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue) {
574 return new TFE_Executor(is_async, enable_streaming_enqueue);
575}
576
577void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; }
578
579bool TFE_ExecutorIsAsync(TFE_Executor* executor) {
580 return executor->executor()->Async();
581}
582
583void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor,
584 TF_Status* status) {
585 status->status = executor->executor()->WaitForAllPendingNodes();
586}
587
588void TFE_ExecutorClearError(TFE_Executor* executor) {
589 executor->executor()->ClearError();
590}
591
592void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) {
593 tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor());
594}
595
596TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) {
597 return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor());
598}
599
600void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) {
601 auto address_space = tensorflow::DeviceNameUtils::AddressSpace(
602 tensorflow::unwrap(ctx)->HostCPUParsedName());
603 auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space);
604 void* data = tensorflow::port::Malloc(str.length());
605 str.copy(static_cast<char*>(data), str.length(), 0);
606 buf->data = data;
607 buf->length = str.length();
608 buf->data_deallocator = [](void* data, size_t length) {
609 tensorflow::port::Free(data);
610 };
611}
612
613void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name,
614 TF_Buffer* buf, TF_Status* status) {
615 auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name);
616 if (function_def == nullptr) {
617 status->status = tensorflow::errors::NotFound(
618 "Unable to find FunctionDef with name: ", function_name);
619 return;
620 }
621 string str = function_def->SerializeAsString();
622 void* data = tensorflow::port::Malloc(str.length());
623 str.copy(static_cast<char*>(data), str.length(), 0);
624 buf->data = data;
625 buf->length = str.length();
626 buf->data_deallocator = [](void* data, size_t length) {
627 tensorflow::port::Free(data);
628 };
629 status->status = ::tensorflow::OkStatus();
630}
631
632TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype,
633 const int64_t* dims, int num_dims,
634 TF_Status* status) {
635 std::vector<int64_t> dimvec(num_dims);
636 for (int i = 0; i < num_dims; ++i) {
637 dimvec[i] = static_cast<int64_t>(dims[i]);
638 }
639
640 if (ctx == nullptr) {
641 status->status = tensorflow::errors::InvalidArgument("Invalid Context");
642 return nullptr;
643 }
644
645 tensorflow::AbstractTensorInterface* t =
646 tensorflow::unwrap(ctx)->CreateTensor(
647 static_cast<tensorflow::DataType>(dtype), dimvec);
648
649 if (t == nullptr) {
650 status->status =
651 tensorflow::errors::InvalidArgument("Unsupported dtype: ", dtype);
652 return nullptr;
653 }
654
655 return new TF_Tensor{t};
656}
657
658TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t,
659 TF_Status* status) {
660 return tensorflow::wrap(
661 tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor));
662}
663
664TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx,
665 TFE_TensorHandle** handles,
666 int* num_handles,
667 TF_Status* status) {
668 std::vector<tensorflow::TensorHandle*> tensor_handles;
669 tensor_handles.reserve(*num_handles);
670 for (int i = 0; i < *num_handles; ++i) {
671 tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle =
672 tensorflow::unwrap(handles[i]);
673 if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) {
674 // One of the inputs we're trying to pack is on a custom device. We'll let
675 // the first custom device we see handle all of the packing.
676 auto* custom_device_handle =
677 tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>(
678 unwrapped_handle);
679 tensorflow::ImmediateExecutionTensorHandle* result;
680 status->status = custom_device_handle->device()->Pack(
681 absl::Span<tensorflow::ImmediateExecutionTensorHandle*>(
682 tensorflow::unwrap(handles), *num_handles),
683 &result);
684 return tensorflow::wrap(result);
685 }
686 tensor_handles.push_back(
687 tensorflow::TensorHandleFromInterface(unwrapped_handle));
688 }
689 tensorflow::EagerContext* context =
690 tensorflow::ContextFromInterface(tensorflow::unwrap(ctx));
691 tensorflow::TensorHandle* handle = nullptr;
692 status->status = tensorflow::TensorHandle::CreatePackedHandle(
693 std::move(tensor_handles), context, &handle);
694 return tensorflow::wrap(handle);
695}
696
697void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable,
698 TF_Status* status) {
699 tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable);
700}
701
702void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable,
703 TF_Status* status) {
704 tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable);
705}
706
707void TFE_ContextSetRunEagerOpAsFunction(TFE_Context* ctx, unsigned char enable,
708 TF_Status* status) {
709 tensorflow::unwrap(ctx)->SetRunEagerOpAsFunction(enable);
710}
711
712void TFE_ContextSetJitCompileRewrite(TFE_Context* ctx, unsigned char enable,
713 TF_Status* status) {
714 tensorflow::unwrap(ctx)->SetJitCompileRewrite(enable);
715}
716
717const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) {
718 if (h == nullptr) {
719 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
720 return nullptr;
721 }
722 return tensorflow::unwrap(h)->DeviceType(&status->status);
723}
724
725int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) {
726 if (h == nullptr) {
727 status->status = tensorflow::errors::InvalidArgument("Invalid handle");
728 return -1;
729 }
730 return tensorflow::unwrap(h)->DeviceId(&status->status);
731}
732
733TF_CAPI_EXPORT extern void TFE_TensorHandleGetStatus(TFE_TensorHandle* h,
734 TF_Status* status) {
735 status->status = tensorflow::unwrap(h)->TensorHandleStatus();
736}
737
738void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf,
739 TF_Status* status) {
740 const std::vector<std::string>& op_names =
741 tensorflow::unwrap(ctx)->GetLoggedOpsTestonly();
742
743 std::ostringstream op_names_oss;
744 for (const auto& op : op_names) {
745 op_names_oss << op << ", ";
746 }
747 const std::string& op_names_str = op_names_oss.str();
748 void* data = tensorflow::port::Malloc(op_names_str.length());
749 op_names_str.copy(static_cast<char*>(data), op_names_str.length(), 0);
750 buf->data = data;
751 buf->length = op_names_str.length();
752 buf->data_deallocator = [](void* data, size_t length) {
753 tensorflow::port::Free(data);
754 };
755 status->status = ::tensorflow::OkStatus();
756}
757
758void TFE_SetLogicalCpuDevices(TFE_Context* ctx, int num_cpus,
759 const char* prefix, TF_Status* status) {
760 std::vector<std::unique_ptr<tensorflow::Device>> devices;
761
762 if (prefix == nullptr || strlen(prefix) == 0)
763 prefix = "/job:localhost/replica:0/task:0";
764
765 tensorflow::SessionOptions sess_options;
766 (*sess_options.config.mutable_device_count())["CPU"] = num_cpus;
767 status->status =
768 tensorflow::DeviceFactory::AddCpuDevices(sess_options, prefix, &devices);
769
770 // Remove the device that has the host device name since host device is alreay
771 // in an initialized context.
772 for (auto d = devices.begin(); d != devices.end();) {
773 if (absl::StrContains(d->get()->name(), "CPU:0")) {
774 d = devices.erase(d);
775 } else {
776 ++d;
777 }
778 }
779
780 status->status = tensorflow::unwrap(ctx)->AddDevices(std::move(devices));
781}
782
783void TFE_InsertConfigKeyValue(TFE_Context* ctx, const char* key,
784 const char* value, TF_Status* status) {
785 tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
786 tensorflow::unwrap(ctx)->GetDistributedManager();
787 tensorflow::CoordinationServiceAgent* coord_agent =
788 dist_mgr->GetCoordinationServiceAgent();
789 if (coord_agent == nullptr) {
790 status->status = tensorflow::errors::FailedPrecondition(
791 "Coordination service agent is not enabled.");
792 return;
793 }
794 status->status = coord_agent->InsertKeyValue(key, value);
795}
796
797void TFE_GetConfigKeyValue(TFE_Context* ctx, const char* key,
798 TF_Buffer* value_buf, TF_Status* status) {
799 tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
800 tensorflow::unwrap(ctx)->GetDistributedManager();
801 tensorflow::CoordinationServiceAgent* coord_agent =
802 dist_mgr->GetCoordinationServiceAgent();
803 if (coord_agent == nullptr) {
804 status->status = tensorflow::errors::FailedPrecondition(
805 "Coordination service is not enabled.");
806 return;
807 }
808 auto status_or_value = coord_agent->GetKeyValue(key);
809 status->status = status_or_value.status();
810 if (!status_or_value.ok()) return;
811
812 const std::string& value_string = status_or_value.value();
813 void* data = tensorflow::port::Malloc(value_string.length());
814 value_string.copy(static_cast<char*>(data), value_string.length(), 0);
815 value_buf->data = data;
816 value_buf->length = value_string.length();
817 value_buf->data_deallocator = [](void* data, size_t length) {
818 tensorflow::port::Free(data);
819 };
820}
821
822void TFE_DeleteConfigKeyValue(TFE_Context* ctx, const char* key,
823 TF_Status* status) {
824 tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
825 tensorflow::unwrap(ctx)->GetDistributedManager();
826 tensorflow::CoordinationServiceAgent* coord_agent =
827 dist_mgr->GetCoordinationServiceAgent();
828 if (coord_agent == nullptr) {
829 status->status = tensorflow::errors::FailedPrecondition(
830 "Coordination service is not enabled.");
831 return;
832 }
833 status->status = coord_agent->DeleteKeyValue(key);
834}
835
836void TFE_ReportErrorToCluster(TFE_Context* ctx, int error_code,
837 const char* error_message, TF_Status* status) {
838 tensorflow::ImmediateExecutionDistributedManager* dist_mgr =
839 tensorflow::unwrap(ctx)->GetDistributedManager();
840 tensorflow::CoordinationServiceAgent* coord_agent =
841 dist_mgr->GetCoordinationServiceAgent();
842 if (coord_agent == nullptr) {
843 status->status = tensorflow::errors::FailedPrecondition(
844 "Coordination service is not enabled.");
845 return;
846 }
847 tensorflow::Status s(static_cast<tensorflow::error::Code>(error_code),
848 error_message);
849 status->status = coord_agent->ReportError(s);
850}
851