1 | /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/eager/c_api_experimental.h" |
17 | |
18 | #include <vector> |
19 | |
20 | #include "absl/strings/match.h" |
21 | #include "tensorflow/c/c_api.h" |
22 | #include "tensorflow/c/eager/c_api_internal.h" |
23 | #include "tensorflow/c/eager/tfe_context_internal.h" |
24 | #include "tensorflow/c/eager/tfe_op_internal.h" |
25 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
26 | #include "tensorflow/c/tf_status_helper.h" |
27 | #include "tensorflow/core/common_runtime/composite_device.h" |
28 | #include "tensorflow/core/common_runtime/device.h" |
29 | #include "tensorflow/core/common_runtime/eager/eager_operation.h" |
30 | #include "tensorflow/core/distributed_runtime/coordination/coordination_service_agent.h" |
31 | #include "tensorflow/core/lib/monitoring/counter.h" |
32 | #include "tensorflow/core/lib/monitoring/gauge.h" |
33 | #include "tensorflow/core/lib/monitoring/sampler.h" |
34 | #include "tensorflow/core/platform/casts.h" |
35 | #include "tensorflow/core/platform/errors.h" |
36 | #include "tensorflow/core/platform/mutex.h" |
37 | #include "tensorflow/core/platform/strcat.h" |
38 | |
39 | using tensorflow::string; |
40 | |
41 | void TFE_OpReset(TFE_Op* op_to_reset, const char* op_or_function_name, |
42 | const char* raw_device_name, TF_Status* status) { |
43 | if (op_to_reset) { |
44 | tensorflow::ImmediateExecutionOperation* op = |
45 | tensorflow::unwrap(op_to_reset); |
46 | op->Clear(); |
47 | status->status = op->Reset(op_or_function_name, raw_device_name); |
48 | } else { |
49 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
50 | "op_to_reset should not be nullptr" ); |
51 | } |
52 | } |
53 | |
54 | void TFE_ContextEnableGraphCollection(TFE_Context* ctx) { |
55 | tensorflow::unwrap(ctx)->SetShouldStoreGraphs(true); |
56 | } |
57 | |
58 | void TFE_ContextDisableGraphCollection(TFE_Context* ctx) { |
59 | tensorflow::unwrap(ctx)->SetShouldStoreGraphs(false); |
60 | } |
61 | |
62 | uint64_t TFE_GetContextId(TFE_Context* ctx) { |
63 | tensorflow::EagerContext* context = |
64 | tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); |
65 | return context->GetContextId(); |
66 | } |
67 | |
68 | void TFE_MonitoringCounterCellIncrementBy(TFE_MonitoringCounterCell* cell, |
69 | int64_t value) { |
70 | cell->cell.IncrementBy(value); |
71 | } |
72 | |
73 | int64_t TFE_MonitoringCounterCellValue(TFE_MonitoringCounterCell* cell) { |
74 | return cell->cell.value(); |
75 | } |
76 | |
77 | TFE_MonitoringCounter0* TFE_MonitoringNewCounter0(const char* name, |
78 | TF_Status* status, |
79 | const char* description) { |
80 | auto* result = new TFE_MonitoringCounter0({name, description}); |
81 | Set_TF_Status_from_Status(status, result->counter->GetStatus()); |
82 | if (!result->counter->GetStatus().ok()) { |
83 | delete result; |
84 | return nullptr; |
85 | } |
86 | return result; |
87 | } |
88 | |
89 | void TFE_MonitoringDeleteCounter0(TFE_MonitoringCounter0* counter) { |
90 | delete counter; |
91 | } |
92 | |
93 | TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter0( |
94 | TFE_MonitoringCounter0* counter) { |
95 | return static_cast<TFE_MonitoringCounterCell*>( |
96 | static_cast<void*>(counter->counter->GetCell())); |
97 | } |
98 | |
99 | TFE_MonitoringCounter1* TFE_MonitoringNewCounter1(const char* name, |
100 | TF_Status* status, |
101 | const char* description, |
102 | const char* label1) { |
103 | auto* result = new TFE_MonitoringCounter1({name, description, label1}); |
104 | Set_TF_Status_from_Status(status, result->counter->GetStatus()); |
105 | if (!result->counter->GetStatus().ok()) { |
106 | delete result; |
107 | return nullptr; |
108 | } |
109 | return result; |
110 | } |
111 | |
112 | void TFE_MonitoringDeleteCounter1(TFE_MonitoringCounter1* counter) { |
113 | delete counter; |
114 | } |
115 | |
116 | TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter1( |
117 | TFE_MonitoringCounter1* counter, const char* label1) { |
118 | return static_cast<TFE_MonitoringCounterCell*>( |
119 | static_cast<void*>(counter->counter->GetCell(label1))); |
120 | } |
121 | |
122 | TFE_MonitoringCounter2* TFE_MonitoringNewCounter2(const char* name, |
123 | TF_Status* status, |
124 | const char* description, |
125 | const char* label1, |
126 | const char* label2) { |
127 | auto* result = |
128 | new TFE_MonitoringCounter2({name, description, label1, label2}); |
129 | Set_TF_Status_from_Status(status, result->counter->GetStatus()); |
130 | if (!result->counter->GetStatus().ok()) { |
131 | delete result; |
132 | return nullptr; |
133 | } |
134 | return result; |
135 | } |
136 | |
137 | void TFE_MonitoringDeleteCounter2(TFE_MonitoringCounter2* counter) { |
138 | delete counter; |
139 | } |
140 | |
141 | TFE_MonitoringCounterCell* TFE_MonitoringGetCellCounter2( |
142 | TFE_MonitoringCounter2* counter, const char* label1, const char* label2) { |
143 | return static_cast<TFE_MonitoringCounterCell*>( |
144 | static_cast<void*>(counter->counter->GetCell(label1, label2))); |
145 | } |
146 | |
147 | void TFE_MonitoringIntGaugeCellSet(TFE_MonitoringIntGaugeCell* cell, |
148 | int64_t value) { |
149 | cell->cell.Set(value); |
150 | } |
151 | |
152 | int64_t TFE_MonitoringIntGaugeCellValue(TFE_MonitoringIntGaugeCell* cell) { |
153 | return cell->cell.value(); |
154 | } |
155 | |
156 | TFE_MonitoringIntGauge0* TFE_MonitoringNewIntGauge0(const char* name, |
157 | TF_Status* status, |
158 | const char* description) { |
159 | auto* result = new TFE_MonitoringIntGauge0({name, description}); |
160 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
161 | if (!result->gauge->GetStatus().ok()) { |
162 | delete result; |
163 | return nullptr; |
164 | } |
165 | return result; |
166 | } |
167 | |
168 | void TFE_MonitoringDeleteIntGauge0(TFE_MonitoringIntGauge0* gauge) { |
169 | delete gauge; |
170 | } |
171 | |
172 | TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge0( |
173 | TFE_MonitoringIntGauge0* gauge) { |
174 | return static_cast<TFE_MonitoringIntGaugeCell*>( |
175 | static_cast<void*>(gauge->gauge->GetCell())); |
176 | } |
177 | |
178 | TFE_MonitoringIntGauge1* TFE_MonitoringNewIntGauge1(const char* name, |
179 | TF_Status* status, |
180 | const char* description, |
181 | const char* label1) { |
182 | auto* result = new TFE_MonitoringIntGauge1({name, description, label1}); |
183 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
184 | if (!result->gauge->GetStatus().ok()) { |
185 | delete result; |
186 | return nullptr; |
187 | } |
188 | return result; |
189 | } |
190 | |
191 | void TFE_MonitoringDeleteIntGauge1(TFE_MonitoringIntGauge1* gauge) { |
192 | delete gauge; |
193 | } |
194 | |
195 | TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge1( |
196 | TFE_MonitoringIntGauge1* gauge, const char* label1) { |
197 | return static_cast<TFE_MonitoringIntGaugeCell*>( |
198 | static_cast<void*>(gauge->gauge->GetCell(label1))); |
199 | } |
200 | |
201 | TFE_MonitoringIntGauge2* TFE_MonitoringNewIntGauge2(const char* name, |
202 | TF_Status* status, |
203 | const char* description, |
204 | const char* label1, |
205 | const char* label2) { |
206 | auto* result = |
207 | new TFE_MonitoringIntGauge2({name, description, label1, label2}); |
208 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
209 | if (!result->gauge->GetStatus().ok()) { |
210 | delete result; |
211 | return nullptr; |
212 | } |
213 | return result; |
214 | } |
215 | |
216 | void TFE_MonitoringDeleteIntGauge2(TFE_MonitoringIntGauge2* gauge) { |
217 | delete gauge; |
218 | } |
219 | |
220 | TFE_MonitoringIntGaugeCell* TFE_MonitoringGetCellIntGauge2( |
221 | TFE_MonitoringIntGauge2* gauge, const char* label1, const char* label2) { |
222 | return static_cast<TFE_MonitoringIntGaugeCell*>( |
223 | static_cast<void*>(gauge->gauge->GetCell(label1, label2))); |
224 | } |
225 | |
226 | void TFE_MonitoringStringGaugeCellSet(TFE_MonitoringStringGaugeCell* cell, |
227 | const char* value) { |
228 | cell->cell.Set({value}); |
229 | } |
230 | |
231 | const void TFE_MonitoringStringGaugeCellValue( |
232 | TFE_MonitoringStringGaugeCell* cell, TF_Buffer* buf) { |
233 | tensorflow::string value = cell->cell.value(); |
234 | void* data = tensorflow::port::Malloc(value.length()); |
235 | value.copy(static_cast<char*>(data), value.length(), 0); |
236 | buf->data = data; |
237 | buf->length = value.length(); |
238 | buf->data_deallocator = [](void* data, size_t length) { |
239 | tensorflow::port::Free(data); |
240 | }; |
241 | } |
242 | |
243 | TFE_MonitoringStringGauge0* TFE_MonitoringNewStringGauge0( |
244 | const char* name, TF_Status* status, const char* description) { |
245 | auto* result = new TFE_MonitoringStringGauge0({name, description}); |
246 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
247 | if (!result->gauge->GetStatus().ok()) { |
248 | delete result; |
249 | return nullptr; |
250 | } |
251 | return result; |
252 | } |
253 | |
254 | void TFE_MonitoringDeleteStringGauge0(TFE_MonitoringStringGauge0* gauge) { |
255 | delete gauge; |
256 | } |
257 | |
258 | TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge0( |
259 | TFE_MonitoringStringGauge0* gauge) { |
260 | return static_cast<TFE_MonitoringStringGaugeCell*>( |
261 | static_cast<void*>(gauge->gauge->GetCell())); |
262 | } |
263 | |
264 | TFE_MonitoringStringGauge1* TFE_MonitoringNewStringGauge1( |
265 | const char* name, TF_Status* status, const char* description, |
266 | const char* label1) { |
267 | auto* result = new TFE_MonitoringStringGauge1({name, description, label1}); |
268 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
269 | if (!result->gauge->GetStatus().ok()) { |
270 | delete result; |
271 | return nullptr; |
272 | } |
273 | return result; |
274 | } |
275 | |
276 | void TFE_MonitoringDeleteStringGauge1(TFE_MonitoringStringGauge1* gauge) { |
277 | delete gauge; |
278 | } |
279 | |
280 | TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge1( |
281 | TFE_MonitoringStringGauge1* gauge, const char* label1) { |
282 | return static_cast<TFE_MonitoringStringGaugeCell*>( |
283 | static_cast<void*>(gauge->gauge->GetCell(label1))); |
284 | } |
285 | |
286 | TFE_MonitoringStringGauge2* TFE_MonitoringNewStringGauge2( |
287 | const char* name, TF_Status* status, const char* description, |
288 | const char* label1, const char* label2) { |
289 | auto* result = |
290 | new TFE_MonitoringStringGauge2({name, description, label1, label2}); |
291 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
292 | if (!result->gauge->GetStatus().ok()) { |
293 | delete result; |
294 | return nullptr; |
295 | } |
296 | return result; |
297 | } |
298 | |
299 | void TFE_MonitoringDeleteStringGauge2(TFE_MonitoringStringGauge2* gauge) { |
300 | delete gauge; |
301 | } |
302 | |
303 | TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge2( |
304 | TFE_MonitoringStringGauge2* gauge, const char* label1, const char* label2) { |
305 | return static_cast<TFE_MonitoringStringGaugeCell*>( |
306 | static_cast<void*>(gauge->gauge->GetCell(label1, label2))); |
307 | } |
308 | |
309 | TFE_MonitoringStringGauge3* TFE_MonitoringNewStringGauge3( |
310 | const char* name, TF_Status* status, const char* description, |
311 | const char* label1, const char* label2, const char* label3) { |
312 | auto* result = new TFE_MonitoringStringGauge3( |
313 | {name, description, label1, label2, label3}); |
314 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
315 | if (!result->gauge->GetStatus().ok()) { |
316 | delete result; |
317 | return nullptr; |
318 | } |
319 | return result; |
320 | } |
321 | |
322 | void TFE_MonitoringDeleteStringGauge3(TFE_MonitoringStringGauge3* gauge) { |
323 | delete gauge; |
324 | } |
325 | |
326 | TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge3( |
327 | TFE_MonitoringStringGauge3* gauge, const char* label1, const char* label2, |
328 | const char* label3) { |
329 | return static_cast<TFE_MonitoringStringGaugeCell*>( |
330 | static_cast<void*>(gauge->gauge->GetCell(label1, label2, label3))); |
331 | } |
332 | |
333 | TFE_MonitoringStringGauge4* TFE_MonitoringNewStringGauge4( |
334 | const char* name, TF_Status* status, const char* description, |
335 | const char* label1, const char* label2, const char* label3, |
336 | const char* label4) { |
337 | auto* result = new TFE_MonitoringStringGauge4( |
338 | {name, description, label1, label2, label3, label4}); |
339 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
340 | if (!result->gauge->GetStatus().ok()) { |
341 | delete result; |
342 | return nullptr; |
343 | } |
344 | return result; |
345 | } |
346 | |
347 | void TFE_MonitoringDeleteStringGauge4(TFE_MonitoringStringGauge4* gauge) { |
348 | delete gauge; |
349 | } |
350 | |
351 | TFE_MonitoringStringGaugeCell* TFE_MonitoringGetCellStringGauge4( |
352 | TFE_MonitoringStringGauge4* gauge, const char* label1, const char* label2, |
353 | const char* label3, const char* label4) { |
354 | return static_cast<TFE_MonitoringStringGaugeCell*>(static_cast<void*>( |
355 | gauge->gauge->GetCell(label1, label2, label3, label4))); |
356 | } |
357 | |
358 | void TFE_MonitoringBoolGaugeCellSet(TFE_MonitoringBoolGaugeCell* cell, |
359 | bool value) { |
360 | cell->cell.Set(value); |
361 | } |
362 | |
363 | bool TFE_MonitoringBoolGaugeCellValue(TFE_MonitoringBoolGaugeCell* cell) { |
364 | return cell->cell.value(); |
365 | } |
366 | |
367 | TFE_MonitoringBoolGauge0* TFE_MonitoringNewBoolGauge0(const char* name, |
368 | TF_Status* status, |
369 | const char* description) { |
370 | auto* result = new TFE_MonitoringBoolGauge0({name, description}); |
371 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
372 | if (!result->gauge->GetStatus().ok()) { |
373 | delete result; |
374 | return nullptr; |
375 | } |
376 | return result; |
377 | } |
378 | |
379 | void TFE_MonitoringDeleteBoolGauge0(TFE_MonitoringBoolGauge0* gauge) { |
380 | delete gauge; |
381 | } |
382 | |
383 | TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge0( |
384 | TFE_MonitoringBoolGauge0* gauge) { |
385 | return static_cast<TFE_MonitoringBoolGaugeCell*>( |
386 | static_cast<void*>(gauge->gauge->GetCell())); |
387 | } |
388 | |
389 | TFE_MonitoringBoolGauge1* TFE_MonitoringNewBoolGauge1(const char* name, |
390 | TF_Status* status, |
391 | const char* description, |
392 | const char* label1) { |
393 | auto* result = new TFE_MonitoringBoolGauge1({name, description, label1}); |
394 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
395 | if (!result->gauge->GetStatus().ok()) { |
396 | delete result; |
397 | return nullptr; |
398 | } |
399 | return result; |
400 | } |
401 | |
402 | void TFE_MonitoringDeleteBoolGauge1(TFE_MonitoringBoolGauge1* gauge) { |
403 | delete gauge; |
404 | } |
405 | |
406 | TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge1( |
407 | TFE_MonitoringBoolGauge1* gauge, const char* label1) { |
408 | return static_cast<TFE_MonitoringBoolGaugeCell*>( |
409 | static_cast<void*>(gauge->gauge->GetCell(label1))); |
410 | } |
411 | |
412 | TFE_MonitoringBoolGauge2* TFE_MonitoringNewBoolGauge2(const char* name, |
413 | TF_Status* status, |
414 | const char* description, |
415 | const char* label1, |
416 | const char* label2) { |
417 | auto* result = |
418 | new TFE_MonitoringBoolGauge2({name, description, label1, label2}); |
419 | Set_TF_Status_from_Status(status, result->gauge->GetStatus()); |
420 | if (!result->gauge->GetStatus().ok()) { |
421 | delete result; |
422 | return nullptr; |
423 | } |
424 | return result; |
425 | } |
426 | |
427 | void TFE_MonitoringDeleteBoolGauge2(TFE_MonitoringBoolGauge2* gauge) { |
428 | delete gauge; |
429 | } |
430 | |
431 | TFE_MonitoringBoolGaugeCell* TFE_MonitoringGetCellBoolGauge2( |
432 | TFE_MonitoringBoolGauge2* gauge, const char* label1, const char* label2) { |
433 | return static_cast<TFE_MonitoringBoolGaugeCell*>( |
434 | static_cast<void*>(gauge->gauge->GetCell(label1, label2))); |
435 | } |
436 | |
437 | void TFE_MonitoringSamplerCellAdd(TFE_MonitoringSamplerCell* cell, |
438 | double value) { |
439 | cell->cell.Add(value); |
440 | } |
441 | |
442 | void TFE_MonitoringSamplerCellValue(TFE_MonitoringSamplerCell* cell, |
443 | TF_Buffer* buf) { |
444 | string content; |
445 | cell->cell.value().SerializeToString(&content); |
446 | void* data = tensorflow::port::Malloc(content.length()); |
447 | content.copy(static_cast<char*>(data), content.length(), 0); |
448 | buf->data = data; |
449 | buf->length = content.length(); |
450 | buf->data_deallocator = [](void* data, size_t length) { |
451 | tensorflow::port::Free(data); |
452 | }; |
453 | } |
454 | |
455 | TFE_MonitoringBuckets* TFE_MonitoringNewExponentialBuckets(double scale, |
456 | double growth_factor, |
457 | int bucket_count) { |
458 | return new TFE_MonitoringBuckets([scale, growth_factor, bucket_count]() { |
459 | return tensorflow::monitoring::Buckets::Exponential(scale, growth_factor, |
460 | bucket_count); |
461 | }); |
462 | } |
463 | |
464 | void TFE_MonitoringDeleteBuckets(TFE_MonitoringBuckets* buckets) { |
465 | delete buckets; |
466 | } |
467 | |
468 | TFE_MonitoringSampler0* TFE_MonitoringNewSampler0( |
469 | const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status, |
470 | const char* description) { |
471 | auto* result = new TFE_MonitoringSampler0( |
472 | {name, buckets->create_buckets(), description}); |
473 | Set_TF_Status_from_Status(status, result->sampler->GetStatus()); |
474 | if (!result->sampler->GetStatus().ok()) { |
475 | delete result; |
476 | return nullptr; |
477 | } |
478 | return result; |
479 | } |
480 | |
481 | void TFE_MonitoringDeleteSampler0(TFE_MonitoringSampler0* sampler) { |
482 | delete sampler; |
483 | } |
484 | |
485 | TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler0( |
486 | TFE_MonitoringSampler0* sampler) { |
487 | return static_cast<TFE_MonitoringSamplerCell*>( |
488 | static_cast<void*>(sampler->sampler->GetCell())); |
489 | } |
490 | |
491 | TFE_MonitoringSampler1* TFE_MonitoringNewSampler1( |
492 | const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status, |
493 | const char* description, const char* label1) { |
494 | auto* result = new TFE_MonitoringSampler1( |
495 | {name, buckets->create_buckets(), description, label1}); |
496 | Set_TF_Status_from_Status(status, result->sampler->GetStatus()); |
497 | if (!result->sampler->GetStatus().ok()) { |
498 | delete result; |
499 | return nullptr; |
500 | } |
501 | return result; |
502 | } |
503 | |
504 | void TFE_MonitoringDeleteSampler1(TFE_MonitoringSampler1* sampler) { |
505 | delete sampler; |
506 | } |
507 | |
508 | TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler1( |
509 | TFE_MonitoringSampler1* sampler, const char* label1) { |
510 | return static_cast<TFE_MonitoringSamplerCell*>( |
511 | static_cast<void*>(sampler->sampler->GetCell(label1))); |
512 | } |
513 | |
514 | TFE_MonitoringSampler2* TFE_MonitoringNewSampler2( |
515 | const char* name, TFE_MonitoringBuckets* buckets, TF_Status* status, |
516 | const char* description, const char* label1, const char* label2) { |
517 | auto* result = new TFE_MonitoringSampler2( |
518 | {name, buckets->create_buckets(), description, label1, label2}); |
519 | Set_TF_Status_from_Status(status, result->sampler->GetStatus()); |
520 | if (!result->sampler->GetStatus().ok()) { |
521 | delete result; |
522 | return nullptr; |
523 | } |
524 | return result; |
525 | } |
526 | |
527 | void TFE_MonitoringDeleteSampler2(TFE_MonitoringSampler2* sampler) { |
528 | delete sampler; |
529 | } |
530 | |
531 | TFE_MonitoringSamplerCell* TFE_MonitoringGetCellSampler2( |
532 | TFE_MonitoringSampler2* sampler, const char* label1, const char* label2) { |
533 | return static_cast<TFE_MonitoringSamplerCell*>( |
534 | static_cast<void*>(sampler->sampler->GetCell(label1, label2))); |
535 | } |
536 | |
537 | void TFE_ContextOptionsSetTfrt(TFE_ContextOptions* options, bool use_tfrt) { |
538 | options->use_tfrt = use_tfrt; |
539 | } |
540 | |
541 | void TFE_ContextOptionsSetTfrtDistributedRuntime( |
542 | TFE_ContextOptions* options, bool use_tfrt_distributed_runtime) { |
543 | options->use_tfrt_distributed_runtime = use_tfrt_distributed_runtime; |
544 | } |
545 | |
546 | TFE_CancellationManager* TFE_NewCancellationManager() { |
547 | return tensorflow::wrap(new tensorflow::CancellationManager); |
548 | } |
549 | |
550 | void TFE_CancellationManagerStartCancel( |
551 | TFE_CancellationManager* cancellation_manager) { |
552 | tensorflow::unwrap(cancellation_manager)->StartCancel(); |
553 | } |
554 | |
555 | bool TFE_CancellationManagerIsCancelled( |
556 | TFE_CancellationManager* cancellation_manager) { |
557 | return tensorflow::unwrap(cancellation_manager)->IsCancelled(); |
558 | } |
559 | |
560 | void TFE_DeleteCancellationManager( |
561 | TFE_CancellationManager* cancellation_manager) { |
562 | delete tensorflow::unwrap(cancellation_manager); |
563 | } |
564 | |
565 | void TFE_OpSetCancellationManager(TFE_Op* op, |
566 | TFE_CancellationManager* cancellation_manager, |
567 | TF_Status* status) { |
568 | tensorflow::unwrap(op)->SetCancellationManager( |
569 | tensorflow::unwrap(cancellation_manager)); |
570 | status->status = ::tensorflow::OkStatus(); |
571 | } |
572 | |
573 | TFE_Executor* TFE_NewExecutor(bool is_async, bool enable_streaming_enqueue) { |
574 | return new TFE_Executor(is_async, enable_streaming_enqueue); |
575 | } |
576 | |
577 | void TFE_DeleteExecutor(TFE_Executor* executor) { delete executor; } |
578 | |
579 | bool TFE_ExecutorIsAsync(TFE_Executor* executor) { |
580 | return executor->executor()->Async(); |
581 | } |
582 | |
583 | void TFE_ExecutorWaitForAllPendingNodes(TFE_Executor* executor, |
584 | TF_Status* status) { |
585 | status->status = executor->executor()->WaitForAllPendingNodes(); |
586 | } |
587 | |
588 | void TFE_ExecutorClearError(TFE_Executor* executor) { |
589 | executor->executor()->ClearError(); |
590 | } |
591 | |
592 | void TFE_ContextSetExecutorForThread(TFE_Context* ctx, TFE_Executor* executor) { |
593 | tensorflow::unwrap(ctx)->SetExecutorForThread(executor->executor()); |
594 | } |
595 | |
596 | TFE_Executor* TFE_ContextGetExecutorForThread(TFE_Context* ctx) { |
597 | return new TFE_Executor(&tensorflow::unwrap(ctx)->Executor()); |
598 | } |
599 | |
600 | void TFE_HostAddressSpace(TFE_Context* ctx, TF_Buffer* buf) { |
601 | auto address_space = tensorflow::DeviceNameUtils::AddressSpace( |
602 | tensorflow::unwrap(ctx)->HostCPUParsedName()); |
603 | auto str = tensorflow::DeviceNameUtils::ParsedNameToString(address_space); |
604 | void* data = tensorflow::port::Malloc(str.length()); |
605 | str.copy(static_cast<char*>(data), str.length(), 0); |
606 | buf->data = data; |
607 | buf->length = str.length(); |
608 | buf->data_deallocator = [](void* data, size_t length) { |
609 | tensorflow::port::Free(data); |
610 | }; |
611 | } |
612 | |
613 | void TFE_ContextGetFunctionDef(TFE_Context* ctx, const char* function_name, |
614 | TF_Buffer* buf, TF_Status* status) { |
615 | auto* function_def = tensorflow::unwrap(ctx)->FindFunctionDef(function_name); |
616 | if (function_def == nullptr) { |
617 | status->status = tensorflow::errors::NotFound( |
618 | "Unable to find FunctionDef with name: " , function_name); |
619 | return; |
620 | } |
621 | string str = function_def->SerializeAsString(); |
622 | void* data = tensorflow::port::Malloc(str.length()); |
623 | str.copy(static_cast<char*>(data), str.length(), 0); |
624 | buf->data = data; |
625 | buf->length = str.length(); |
626 | buf->data_deallocator = [](void* data, size_t length) { |
627 | tensorflow::port::Free(data); |
628 | }; |
629 | status->status = ::tensorflow::OkStatus(); |
630 | } |
631 | |
632 | TF_Tensor* TFE_AllocateHostTensor(TFE_Context* ctx, TF_DataType dtype, |
633 | const int64_t* dims, int num_dims, |
634 | TF_Status* status) { |
635 | std::vector<int64_t> dimvec(num_dims); |
636 | for (int i = 0; i < num_dims; ++i) { |
637 | dimvec[i] = static_cast<int64_t>(dims[i]); |
638 | } |
639 | |
640 | if (ctx == nullptr) { |
641 | status->status = tensorflow::errors::InvalidArgument("Invalid Context" ); |
642 | return nullptr; |
643 | } |
644 | |
645 | tensorflow::AbstractTensorInterface* t = |
646 | tensorflow::unwrap(ctx)->CreateTensor( |
647 | static_cast<tensorflow::DataType>(dtype), dimvec); |
648 | |
649 | if (t == nullptr) { |
650 | status->status = |
651 | tensorflow::errors::InvalidArgument("Unsupported dtype: " , dtype); |
652 | return nullptr; |
653 | } |
654 | |
655 | return new TF_Tensor{t}; |
656 | } |
657 | |
658 | TFE_TensorHandle* TFE_NewTensorHandleFromTensor(TFE_Context* ctx, TF_Tensor* t, |
659 | TF_Status* status) { |
660 | return tensorflow::wrap( |
661 | tensorflow::unwrap(ctx)->CreateLocalHandle(t->tensor)); |
662 | } |
663 | |
664 | TFE_TensorHandle* TFE_CreatePackedTensorHandle(TFE_Context* ctx, |
665 | TFE_TensorHandle** handles, |
666 | int* num_handles, |
667 | TF_Status* status) { |
668 | std::vector<tensorflow::TensorHandle*> tensor_handles; |
669 | tensor_handles.reserve(*num_handles); |
670 | for (int i = 0; i < *num_handles; ++i) { |
671 | tensorflow::ImmediateExecutionTensorHandle* unwrapped_handle = |
672 | tensorflow::unwrap(handles[i]); |
673 | if (tensorflow::CustomDeviceTensorHandle::classof(unwrapped_handle)) { |
674 | // One of the inputs we're trying to pack is on a custom device. We'll let |
675 | // the first custom device we see handle all of the packing. |
676 | auto* custom_device_handle = |
677 | tensorflow::down_cast<tensorflow::CustomDeviceTensorHandle*>( |
678 | unwrapped_handle); |
679 | tensorflow::ImmediateExecutionTensorHandle* result; |
680 | status->status = custom_device_handle->device()->Pack( |
681 | absl::Span<tensorflow::ImmediateExecutionTensorHandle*>( |
682 | tensorflow::unwrap(handles), *num_handles), |
683 | &result); |
684 | return tensorflow::wrap(result); |
685 | } |
686 | tensor_handles.push_back( |
687 | tensorflow::TensorHandleFromInterface(unwrapped_handle)); |
688 | } |
689 | tensorflow::EagerContext* context = |
690 | tensorflow::ContextFromInterface(tensorflow::unwrap(ctx)); |
691 | tensorflow::TensorHandle* handle = nullptr; |
692 | status->status = tensorflow::TensorHandle::CreatePackedHandle( |
693 | std::move(tensor_handles), context, &handle); |
694 | return tensorflow::wrap(handle); |
695 | } |
696 | |
697 | void TFE_ContextSetSoftDevicePlacement(TFE_Context* ctx, unsigned char enable, |
698 | TF_Status* status) { |
699 | tensorflow::unwrap(ctx)->SetAllowSoftPlacement(enable); |
700 | } |
701 | |
702 | void TFE_ContextSetLogDevicePlacement(TFE_Context* ctx, unsigned char enable, |
703 | TF_Status* status) { |
704 | tensorflow::unwrap(ctx)->SetLogDevicePlacement(enable); |
705 | } |
706 | |
707 | void TFE_ContextSetRunEagerOpAsFunction(TFE_Context* ctx, unsigned char enable, |
708 | TF_Status* status) { |
709 | tensorflow::unwrap(ctx)->SetRunEagerOpAsFunction(enable); |
710 | } |
711 | |
712 | void TFE_ContextSetJitCompileRewrite(TFE_Context* ctx, unsigned char enable, |
713 | TF_Status* status) { |
714 | tensorflow::unwrap(ctx)->SetJitCompileRewrite(enable); |
715 | } |
716 | |
717 | const char* TFE_TensorHandleDeviceType(TFE_TensorHandle* h, TF_Status* status) { |
718 | if (h == nullptr) { |
719 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
720 | return nullptr; |
721 | } |
722 | return tensorflow::unwrap(h)->DeviceType(&status->status); |
723 | } |
724 | |
725 | int TFE_TensorHandleDeviceID(TFE_TensorHandle* h, TF_Status* status) { |
726 | if (h == nullptr) { |
727 | status->status = tensorflow::errors::InvalidArgument("Invalid handle" ); |
728 | return -1; |
729 | } |
730 | return tensorflow::unwrap(h)->DeviceId(&status->status); |
731 | } |
732 | |
733 | TF_CAPI_EXPORT extern void TFE_TensorHandleGetStatus(TFE_TensorHandle* h, |
734 | TF_Status* status) { |
735 | status->status = tensorflow::unwrap(h)->TensorHandleStatus(); |
736 | } |
737 | |
738 | void TFE_GetExecutedOpNames(TFE_Context* ctx, TF_Buffer* buf, |
739 | TF_Status* status) { |
740 | const std::vector<std::string>& op_names = |
741 | tensorflow::unwrap(ctx)->GetLoggedOpsTestonly(); |
742 | |
743 | std::ostringstream op_names_oss; |
744 | for (const auto& op : op_names) { |
745 | op_names_oss << op << ", " ; |
746 | } |
747 | const std::string& op_names_str = op_names_oss.str(); |
748 | void* data = tensorflow::port::Malloc(op_names_str.length()); |
749 | op_names_str.copy(static_cast<char*>(data), op_names_str.length(), 0); |
750 | buf->data = data; |
751 | buf->length = op_names_str.length(); |
752 | buf->data_deallocator = [](void* data, size_t length) { |
753 | tensorflow::port::Free(data); |
754 | }; |
755 | status->status = ::tensorflow::OkStatus(); |
756 | } |
757 | |
758 | void TFE_SetLogicalCpuDevices(TFE_Context* ctx, int num_cpus, |
759 | const char* prefix, TF_Status* status) { |
760 | std::vector<std::unique_ptr<tensorflow::Device>> devices; |
761 | |
762 | if (prefix == nullptr || strlen(prefix) == 0) |
763 | prefix = "/job:localhost/replica:0/task:0" ; |
764 | |
765 | tensorflow::SessionOptions sess_options; |
766 | (*sess_options.config.mutable_device_count())["CPU" ] = num_cpus; |
767 | status->status = |
768 | tensorflow::DeviceFactory::AddCpuDevices(sess_options, prefix, &devices); |
769 | |
770 | // Remove the device that has the host device name since host device is alreay |
771 | // in an initialized context. |
772 | for (auto d = devices.begin(); d != devices.end();) { |
773 | if (absl::StrContains(d->get()->name(), "CPU:0" )) { |
774 | d = devices.erase(d); |
775 | } else { |
776 | ++d; |
777 | } |
778 | } |
779 | |
780 | status->status = tensorflow::unwrap(ctx)->AddDevices(std::move(devices)); |
781 | } |
782 | |
783 | void TFE_InsertConfigKeyValue(TFE_Context* ctx, const char* key, |
784 | const char* value, TF_Status* status) { |
785 | tensorflow::ImmediateExecutionDistributedManager* dist_mgr = |
786 | tensorflow::unwrap(ctx)->GetDistributedManager(); |
787 | tensorflow::CoordinationServiceAgent* coord_agent = |
788 | dist_mgr->GetCoordinationServiceAgent(); |
789 | if (coord_agent == nullptr) { |
790 | status->status = tensorflow::errors::FailedPrecondition( |
791 | "Coordination service agent is not enabled." ); |
792 | return; |
793 | } |
794 | status->status = coord_agent->InsertKeyValue(key, value); |
795 | } |
796 | |
797 | void TFE_GetConfigKeyValue(TFE_Context* ctx, const char* key, |
798 | TF_Buffer* value_buf, TF_Status* status) { |
799 | tensorflow::ImmediateExecutionDistributedManager* dist_mgr = |
800 | tensorflow::unwrap(ctx)->GetDistributedManager(); |
801 | tensorflow::CoordinationServiceAgent* coord_agent = |
802 | dist_mgr->GetCoordinationServiceAgent(); |
803 | if (coord_agent == nullptr) { |
804 | status->status = tensorflow::errors::FailedPrecondition( |
805 | "Coordination service is not enabled." ); |
806 | return; |
807 | } |
808 | auto status_or_value = coord_agent->GetKeyValue(key); |
809 | status->status = status_or_value.status(); |
810 | if (!status_or_value.ok()) return; |
811 | |
812 | const std::string& value_string = status_or_value.value(); |
813 | void* data = tensorflow::port::Malloc(value_string.length()); |
814 | value_string.copy(static_cast<char*>(data), value_string.length(), 0); |
815 | value_buf->data = data; |
816 | value_buf->length = value_string.length(); |
817 | value_buf->data_deallocator = [](void* data, size_t length) { |
818 | tensorflow::port::Free(data); |
819 | }; |
820 | } |
821 | |
822 | void TFE_DeleteConfigKeyValue(TFE_Context* ctx, const char* key, |
823 | TF_Status* status) { |
824 | tensorflow::ImmediateExecutionDistributedManager* dist_mgr = |
825 | tensorflow::unwrap(ctx)->GetDistributedManager(); |
826 | tensorflow::CoordinationServiceAgent* coord_agent = |
827 | dist_mgr->GetCoordinationServiceAgent(); |
828 | if (coord_agent == nullptr) { |
829 | status->status = tensorflow::errors::FailedPrecondition( |
830 | "Coordination service is not enabled." ); |
831 | return; |
832 | } |
833 | status->status = coord_agent->DeleteKeyValue(key); |
834 | } |
835 | |
836 | void TFE_ReportErrorToCluster(TFE_Context* ctx, int error_code, |
837 | const char* error_message, TF_Status* status) { |
838 | tensorflow::ImmediateExecutionDistributedManager* dist_mgr = |
839 | tensorflow::unwrap(ctx)->GetDistributedManager(); |
840 | tensorflow::CoordinationServiceAgent* coord_agent = |
841 | dist_mgr->GetCoordinationServiceAgent(); |
842 | if (coord_agent == nullptr) { |
843 | status->status = tensorflow::errors::FailedPrecondition( |
844 | "Coordination service is not enabled." ); |
845 | return; |
846 | } |
847 | tensorflow::Status s(static_cast<tensorflow::error::Code>(error_code), |
848 | error_message); |
849 | status->status = coord_agent->ReportError(s); |
850 | } |
851 | |