1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/python/eager/pywrap_tensor.h"
17
18#include <stdlib.h>
19#include <string.h>
20
21#include <cmath>
22
23#include "structmember.h" // NOLINT // For PyMemberDef
24#include "pybind11/pybind11.h"
25#include "tensorflow/c/c_api.h"
26#include "tensorflow/c/eager/c_api.h"
27#include "tensorflow/c/eager/c_api_internal.h"
28#include "tensorflow/c/eager/tfe_context_internal.h"
29#include "tensorflow/c/eager/tfe_tensorhandle_internal.h"
30#include "tensorflow/c/tf_status.h"
31#include "tensorflow/core/framework/types.h"
32#include "tensorflow/core/framework/types.pb.h"
33#include "tensorflow/core/lib/strings/strcat.h"
34#include "tensorflow/python/eager/pywrap_tensor_conversion.h"
35#include "tensorflow/python/eager/pywrap_tfe.h"
36#include "tensorflow/python/lib/core/ndarray_tensor.h"
37#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
38#include "tensorflow/python/lib/core/numpy.h"
39#include "tensorflow/python/lib/core/py_exception_registry.h"
40#include "tensorflow/python/lib/core/py_seq_tensor.h"
41#include "tensorflow/python/lib/core/pybind11_status.h"
42#include "tensorflow/python/lib/core/safe_ptr.h"
43
44// forward declare
45struct EagerTensor;
46namespace tensorflow {
47
48// Convert a TFE_TensorHandle to a Python numpy.ndarray object.
49// The two may share underlying storage so changes to one may reflect in the
50// other.
51PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) {
52 if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) {
53 TF_SetStatus(status, TF_INVALID_ARGUMENT,
54 "Cannot convert a Tensor of dtype resource to a NumPy array.");
55 return nullptr;
56 }
57
58 if (TFE_TensorHandleDataType(handle) == TF_VARIANT) {
59 TF_SetStatus(status, TF_INVALID_ARGUMENT,
60 "Cannot convert a Tensor of dtype variant to a NumPy array.");
61 return nullptr;
62 }
63 tensorflow::Safe_TF_TensorPtr tensor = nullptr;
64 Py_BEGIN_ALLOW_THREADS;
65 tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status));
66 Py_END_ALLOW_THREADS;
67 if (!status->status.ok()) {
68 return nullptr;
69 }
70
71 PyObject* ret = nullptr;
72 auto cppstatus =
73 tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret);
74 tensorflow::Set_TF_Status_from_Status(status, cppstatus);
75 if (!status->status.ok()) {
76 Py_XDECREF(ret);
77 return nullptr;
78 }
79 CHECK_NE(ret, nullptr);
80 return ret;
81}
82} // namespace tensorflow
83namespace {
84
85using tensorflow::TFE_TensorHandleToNumpy;
86
87// An instance of _EagerTensorProfiler that will receive callbacks about
88// events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all.
89PyObject* eager_tensor_profiler = nullptr;
90
91// Read-only dict. Please don't use this in any setting where the dict might
92// actually get mutated. This is only used to pass empty kwargs when creating a
93// new EagerTensor.
94PyObject* EmptyDict() {
95 static PyObject* empty_dict = PyDict_New();
96 return empty_dict;
97}
98
99PyObject* EmptyTuple() {
100 static PyObject* empty_tuple = PyTuple_New(0);
101 return empty_tuple;
102}
103
104TFE_Context* GetContextHandle(PyObject* py_context) {
105 tensorflow::Safe_PyObjectPtr py_context_handle(
106 PyObject_GetAttrString(py_context, "_handle"));
107 if (py_context_handle == nullptr) {
108 // Current Python code makes sure this never happens. If it does, or
109 // becomes hard to maintain, we can call the ensure_initialized() method
110 // here.
111 PyErr_SetString(
112 PyExc_TypeError,
113 "Expected `context` argument in EagerTensor constructor to have a "
114 "`_handle` attribute but it did not. Was eager Context initialized?");
115 return nullptr;
116 }
117
118 auto* ctx = reinterpret_cast<TFE_Context*>(
119 PyCapsule_GetPointer(py_context_handle.get(), nullptr));
120 if (ctx == nullptr) {
121 PyErr_SetString(PyExc_TypeError,
122 tensorflow::strings::StrCat(
123 "Expected context._handle to contain a PyCapsule "
124 "encoded pointer to TFE_Context. Got ",
125 Py_TYPE(py_context_handle.get())->tp_name)
126 .c_str());
127 }
128 return ctx;
129}
130
131
132// Helper function to convert `v` to a tensorflow::DataType and store it in
133// `*out`. Returns true on success, false otherwise.
134// Note that we assume that v is a python int (not long) representing a
135// TF_DataType/tensorflow::DataType value.
136bool PyIntToDataType(PyObject* v, tensorflow::DataType* out) {
137#if PY_MAJOR_VERSION < 3
138 if (PyInt_Check(v)) {
139 *out = static_cast<tensorflow::DataType>(PyInt_AS_LONG(v));
140 return true;
141 }
142#else
143 if (PyLong_Check(v)) {
144 *out = static_cast<tensorflow::DataType>(PyLong_AsLong(v));
145 return true;
146 }
147#endif
148 return false;
149}
150
151// Helper function to create a python integer from TF_DataType.
152PyObject* PyIntFromDataType(TF_DataType l) {
153#if PY_MAJOR_VERSION < 3
154 return PyInt_FromLong(l);
155#else
156 return PyLong_FromLong(l);
157#endif
158}
159
160// PyObject->tensorflow::DataType conversion function to be used with
161// PyArg_Parse* APIs.
162int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) {
163 if (obj == Py_None) {
164 *dst = tensorflow::DataType::DT_INVALID;
165 } else if (!PyIntToDataType(obj, dst)) {
166 PyErr_SetString(
167 PyExc_TypeError,
168 tensorflow::strings::StrCat(
169 "Expecting a DataType value for dtype. Got ", Py_TYPE(obj)->tp_name)
170 .c_str());
171 return 0;
172 }
173
174 return 1;
175}
176
177// Conversion function extracting a const char** device name from a PyObject.
178// The function should be used with PyArg_Parse* APIs.
179int ConvertDeviceName(PyObject* obj, const char** dst) {
180 if (obj == Py_None) {
181 *dst = nullptr;
182 } else {
183 auto device_name = TFE_GetPythonString(obj);
184 if (device_name == nullptr) {
185 PyErr_Clear();
186 PyErr_SetString(PyExc_TypeError, "Error parsing device argument.");
187 return 0;
188 }
189 *dst = device_name;
190 }
191
192 return 1;
193}
194
195void RaiseExceptionTypeFromTFStatus(TF_Status* tf_status) {
196 auto status = tensorflow::StatusFromTF_Status(tf_status);
197 SetRegisteredErrFromStatus(status);
198}
199
200} // namespace
201
202namespace tensorflow {
203// This function checks whether the desired type is "compatible" with the
204// inferred type. At a high level, compatibility means that all integral types
205// are compatible with each other, and all floating types are compatible with
206// each other.
207//
208// Type compatibility doesn't consider overflows (i.e. int64 is *always*
209// compatible with int32). This is intended to match graph behavior.
210bool IsCompatible(DataType desired, DataType returned) {
211 if (desired == returned) return true;
212
213 if (DataTypeIsInteger(desired) && DataTypeIsInteger(returned)) {
214 return true;
215 } else if (DataTypeIsFloating(desired) &&
216 (DataTypeIsFloating(returned) || DataTypeIsInteger(returned))) {
217 return true;
218 } else if (DataTypeIsComplex(desired) &&
219 (DataTypeIsComplex(returned) || DataTypeIsInteger(returned) ||
220 DataTypeIsFloating(returned))) {
221 return true;
222 } else if (DataTypeIsQuantized(desired) && DataTypeIsInteger(returned)) {
223 return true;
224 }
225 return false;
226}
227
228// TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to
229// execute TFE Ops) to a separate common library.
230// Casts data referred to by `handle` from type `src_type_enum` to type
231// `dst_type_enum`.
232TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle,
233 TF_DataType src_type_enum,
234 TF_DataType dst_type_enum, TF_Status* out_status) {
235 if (ctx == nullptr) return nullptr;
236 const char* op_name = "Cast";
237 const char* device_name = "/device:CPU:0";
238 TFE_Op* op = TFE_NewOp(ctx, op_name, out_status);
239#define RETURN_ERROR \
240 { \
241 TFE_DeleteOp(op); \
242 return nullptr; \
243 }
244 if (!out_status->status.ok()) RETURN_ERROR
245 TFE_OpSetDevice(op, device_name, out_status);
246 if (!out_status->status.ok()) RETURN_ERROR
247 TFE_OpAddInput(op, handle, out_status);
248 if (!out_status->status.ok()) RETURN_ERROR
249 TFE_OpSetAttrType(op, "SrcT", src_type_enum);
250 TFE_OpSetAttrType(op, "DstT", dst_type_enum);
251 TFE_OpSetAttrBool(op, "Truncate", false);
252 TFE_TensorHandle* output = nullptr;
253 int num_outputs = 1;
254 TFE_Execute(op, &output, &num_outputs, out_status);
255 if (!out_status->status.ok() || num_outputs != 1 || output == nullptr) {
256 if (output != nullptr) {
257 TFE_DeleteTensorHandle(output);
258 }
259 RETURN_ERROR
260 }
261 TFE_DeleteOp(op);
262 return output;
263#undef RETURN_ERROR
264}
265
266Safe_TFE_TensorHandlePtr EagerConst(TFE_Context* ctx, TFE_TensorHandle* handle,
267 const char* device_name,
268 TF_Status* out_status) {
269 const char* op_name = "_EagerConst";
270 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
271 TFE_NewOp(ctx, op_name, out_status), TFE_DeleteOp);
272 if (!out_status->status.ok()) return nullptr;
273 TFE_OpSetDevice(op.get(), device_name, out_status);
274 if (!out_status->status.ok()) return nullptr;
275 TFE_OpAddInput(op.get(), handle, out_status);
276 if (!out_status->status.ok()) return nullptr;
277 TFE_OpSetAttrType(op.get(), "T", TFE_TensorHandleDataType(handle));
278 TFE_TensorHandle* output = nullptr;
279 int num_outputs = 1;
280 TFE_Execute(op.get(), &output, &num_outputs, out_status);
281 Safe_TFE_TensorHandlePtr result(output);
282 if (!out_status->status.ok() || num_outputs != 1) {
283 return nullptr;
284 }
285 return result;
286}
287
288TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx,
289 PyObject* value,
290 tensorflow::DataType dtype,
291 const char* device_name) {
292 tensorflow::Safe_PyObjectPtr value_decrefer;
293 if (PyArray_IsScalar(value, Generic)) {
294 // Convert numpy scalars to numpy arrays.
295 value = PyArray_FromScalar(value, nullptr);
296 // The returned value needs to be DECREF'd, but the original value was
297 // created in python code, and doesn't need to be DECREF'd.
298 value_decrefer.reset(value);
299 }
300
301 Safe_TFE_TensorHandlePtr handle =
302 make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype));
303
304 if (handle == nullptr) return nullptr;
305
306 Safe_TF_StatusPtr status = make_safe(TF_NewStatus());
307 TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get());
308 if (dtype != tensorflow::DT_INVALID &&
309 dtype != static_cast<DataType>(handle_dtype)) {
310 if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) {
311 handle = tensorflow::make_safe(
312 tensorflow::EagerCast(ctx, handle.get(), handle_dtype,
313 static_cast<TF_DataType>(dtype), status.get()));
314 if (!status->status.ok()) {
315 PyErr_SetString(PyExc_TypeError,
316 absl::StrCat("Error while casting from dtype ",
317 tensorflow::DataTypeString(
318 static_cast<DataType>(handle_dtype)),
319 " to ", tensorflow::DataTypeString(dtype),
320 ". ", TF_Message(status.get()))
321 .c_str());
322 return nullptr;
323 }
324 } else {
325 tensorflow::Safe_PyObjectPtr value_str(PyObject_Repr(value));
326 PyErr_SetString(
327 PyExc_TypeError,
328 absl::StrCat("Cannot convert ", TFE_GetPythonString(value_str.get()),
329 " to EagerTensor of dtype ",
330 tensorflow::DataTypeString(dtype))
331 .c_str());
332 return nullptr;
333 }
334 }
335
336 // We always initially generate CPU:0 tensors. Copy to the current device.
337 if (device_name != nullptr) {
338 if (strstr(device_name, "/device:CPU:0") != nullptr) {
339 // We always generate CPU:0 tensors, but we may need to change the device
340 // slightly, as for example from /job:localhost/... to /job:worker/...
341 //
342 // Note that this is a shallow copy and will share the underlying buffer,
343 // because we are copying to the same device.
344 handle = make_safe(TFE_TensorHandleCopyToDevice(
345 handle.get(), ctx, device_name, status.get()));
346 const TF_Code code = TF_GetCode(status.get());
347 if (code != TF_OK) {
348 RaiseExceptionTypeFromTFStatus(status.get());
349 return nullptr;
350 }
351 } else {
352 /*Copy the constant to the current device. Identity is sometimes
353 overloaded to allow copies like this, but using a different op allows
354 devices to support constant creation without allowing copies via
355 identity ops.
356
357 Note that running this _EagerConst op limits mirroring of cached Python
358 literals somewhat. Mirroring of constants themselves works:
359
360 with tf.device("GPU:0"):
361 tf.constant(1.) # Cached on CPU:0, mirrored to GPU:0
362 with tf.device("GPU:1"):
363 tf.constant(1.) # Cache hit for the CPU version, new mirror to GPU:1.
364 with tf.device("GPU:1"):
365 tf.constant(1.) # Cache hit for the CPU version, cached mirror
366
367 But mirrors for the output of `tf.constant` are not shared just because
368 there was a cache hit for the input literal, because of _EagerConst:
369
370 x = tf.constant(2.) # Cached on CPU:0
371 with tf.device("GPU:1"):
372 tf.identity(x) # `x` now mirrored to GPU:1
373 y = tf.constant(2.) # Cache hit for CPU version
374 with tf.device("GPU:1"):
375 tf.identity(y) # `y` now mirrored on GPU:1 (new copy!)*/
376 handle =
377 tensorflow::EagerConst(ctx, handle.get(), device_name, status.get());
378 const TF_Code code = TF_GetCode(status.get());
379 if (code != TF_OK) {
380 RaiseExceptionTypeFromTFStatus(status.get());
381 return nullptr;
382 }
383 }
384 }
385
386 return handle.release();
387}
388
389TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value,
390 DataType dtype,
391 const char* device_name) {
392 // Reduce the overhead of allocation/transfer-to-device for scalars by
393 // caching the corresponding handles. Note that currently only Python
394 // scalars are cached.
395 // TODO(slebedev): also cache singleton NumPy arrays and scalars?
396 if (PyArray_IsPythonNumber(value)) {
397 auto* cache = TFE_TensorHandleCache::Get();
398 TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name);
399 if (handle != nullptr) return handle;
400 handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
401 if (handle == nullptr) return nullptr;
402 if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) {
403 cache->Insert(value, dtype, ctx, device_name, handle);
404 }
405 return handle;
406 } else {
407 return ConvertToEagerTensorUncached(ctx, value, dtype, device_name);
408 }
409}
410
411} // namespace tensorflow
412
413extern "C" {
414
415static const int kMaxEagerTensorParentSize = 64;
416
417// TODO(agarwal): store context handle in EagerTensor.
418typedef struct EagerTensor {
419 PyObject_HEAD;
420 // Note that we leave kMaxEagerTensorParentSize bytes here for use by the
421 // parent class. The parent class is set at runtime, so we don't know the
422 // exact size at compile time.
423 char unused[kMaxEagerTensorParentSize];
424 TFE_TensorHandle* handle;
425 int64_t id;
426 // Indicates whether it's a packed tensor or not.
427 bool is_packed;
428 // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will
429 // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE
430 // tensors, this will contain a serialized HandleData proto with shape
431 // inference metadata about shapes and dtypes of resources accessible from
432 // this handle.
433 // Note that we assume that handle_data cannot participate in reference
434 // cycles, and hence don't provide GC support for it.
435 PyObject* handle_data;
436
437 // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the
438 // first time that `_EagerTensorBase`'s `shape` property is called.
439 PyObject* tensor_shape;
440
441 // We store a status object here as an optimization to avoid allocating a new
442 // Status objects on different functions that operate on EagerTensor and need
443 // to use a TF_Status object. However note that accesses to `status` are not
444 // thread-safe.
445 TF_Status status;
446
447 // The eager Context (from eager/context.py) used by this Tensor.
448 // This is currently used only to make sure context outlives TensorHandles.
449 PyObject* context;
450
451 PyObject* weakreflist; /* List of weak references */
452
453 // Per-instance attribute dictionary, to support monkey patching
454 // (e.g. EagerTensor.assign when slicing variables). This dictionary is
455 // created by CPython the first time an attribute is assigned, pointed to by
456 // tp_dictoffset. Note that garbage collection is not enabled for
457 // EagerTensors, so assigning objects to EagerTensor attributes which require
458 // garbage collection is likely to cause issues.
459 PyObject* dict;
460} EagerTensor;
461
462namespace {
463
464// Returns true on success - successfully invoked or no profiler registered.
465// Returns false if some error occurred.
466bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) {
467 if (eager_tensor_profiler != nullptr) {
468#if PY_MAJOR_VERSION < 3
469 PyObject* created_method_name = PyString_InternFromString("created");
470#else
471 PyObject* created_method_name = PyUnicode_InternFromString("created");
472#endif
473 if (created_method_name == nullptr) {
474 return false;
475 }
476 PyObject* result = PyObject_CallMethodObjArgs(
477 eager_tensor_profiler, created_method_name, created_tensor, NULL);
478 if (result == nullptr) {
479 LOG(ERROR) << "Invoking created() on EagerTensor profiler failed";
480 // While we can potentially continue because the error is related to
481 // profiling, we choose to return an error because:
482 // - If profiling is used, the user likely wants to stop execution on
483 // profiling errors.
484 // - Error in profiling code might have left some state in an invalid
485 // form that can lead to an error later on. Better to fail fast.
486 Py_DECREF(created_method_name);
487 return false;
488 }
489 Py_DECREF(created_method_name);
490 Py_DECREF(result);
491 }
492 return true;
493}
494
495} // namespace
496
497// tp_init for EagerTensor.
498int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) {
499 self->id = get_uid();
500 self->handle = nullptr;
501 self->is_packed = false;
502 Py_INCREF(Py_None);
503 self->handle_data = Py_None;
504 Py_INCREF(Py_None);
505 self->tensor_shape = Py_None;
506 self->status.status = ::tensorflow::OkStatus();
507 self->dict = nullptr;
508 self->weakreflist = nullptr;
509 self->context = nullptr;
510 PyObject* value;
511 const char* device_name = nullptr;
512 tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID;
513 const char* kwlist[] = {"value", "device", "dtype", nullptr};
514 if (!PyArg_ParseTupleAndKeywords(
515 args, kwds, "OO&|O&", const_cast<char**>(kwlist), &value,
516 ConvertDeviceName, &device_name, ConvertDataType, &dtype)) {
517 return -1;
518 }
519
520 PyObject* py_context = GetPyEagerContext();
521 if (py_context == nullptr) return -1;
522 self->context = py_context;
523
524 auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context),
525 value, dtype, device_name);
526 if (handle == nullptr) return -1;
527 self->handle = handle;
528
529 if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) {
530 return -1;
531 }
532
533 return 0;
534}
535
536// tp_dealloc for EagerTensor.
537void EagerTensor_dealloc(EagerTensor* self) {
538 // Unhook the object from python's GC so that the weakref deleter doesn't
539 // try to re-delete this.
540 PyObject_GC_UnTrack((PyObject*)self);
541
542 // Clear weak references to self.
543 // Needs to happen before any actual destruction.
544 PyObject_ClearWeakRefs((PyObject*)self);
545
546 Py_DECREF(self->handle_data);
547 Py_DECREF(self->tensor_shape);
548 // If an attribute dictionary has been created, release it. Note that this
549 // is only ever created by CPython's attribute setting methods; we don't
550 // create it ourselves.
551 Py_CLEAR(self->dict);
552 if (self->handle != nullptr) {
553 // Destructor may call arbitrary functions that end up calling into
554 // Python from another thread.
555 Py_BEGIN_ALLOW_THREADS;
556 TFE_DeleteTensorHandle(self->handle);
557 Py_END_ALLOW_THREADS;
558 self->handle = nullptr;
559 }
560
561 // Decref context after deleting the tensor handle.
562 Py_XDECREF(self->context);
563
564 // We have the global interpreter lock, so use this chance to perform delayed
565 // refcount decrements.
566 tensorflow::ClearDecrefCache();
567 auto id = self->id;
568 Py_TYPE(self)->tp_free(self);
569 TFE_Py_TapeSetDeleteTrace(id);
570}
571
572// Getter for `_id`.
573static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) {
574 return PyLong_FromLongLong(self->id);
575}
576
577// Getter for `_datatype_enum`.
578static PyObject* EagerTensor_datatype_enum(EagerTensor* self) {
579 return PyIntFromDataType(TFE_TensorHandleDataType(self->handle));
580}
581
582// Getter for `_shape_tuple`.
583static PyObject* EagerTensor_shape_tuple(EagerTensor* self) {
584 auto handle = self->handle;
585 int n = TFE_TensorHandleNumDims(handle, &self->status);
586 TF_Code code = TF_GetCode(&self->status);
587 if (code != TF_OK) {
588 RaiseExceptionTypeFromTFStatus(&self->status);
589 // Cleanup self->status before returning.
590 self->status.status = ::tensorflow::OkStatus();
591 return nullptr;
592 }
593 PyObject* shape = PyTuple_New(n);
594 if (PyErr_Occurred()) return nullptr;
595 for (int i = 0; i < n; ++i) {
596 int64_t dim_c_value = TFE_TensorHandleDim(handle, i, &self->status);
597 PyObject* dim;
598 // The C++ convention is -1 for unknown/variable axis lengths. Translate
599 // that to the Python "None" convention. Unknown axis lengths are unusual
600 // for eager tensors.
601 if (dim_c_value < 0) {
602 Py_IncRef(Py_None);
603 dim = Py_None;
604 } else {
605 dim = PyLong_FromLongLong(dim_c_value);
606 }
607 code = TF_GetCode(&self->status);
608 if (code != TF_OK || dim == nullptr ||
609 PyTuple_SetItem(shape, i, dim) != 0) {
610 if (code != TF_OK) {
611 RaiseExceptionTypeFromTFStatus(&self->status);
612 } else {
613 PyErr_SetString(PyExc_RuntimeError, "Error while creating shape");
614 }
615 // Cleanup self->status before returning.
616 self->status.status = ::tensorflow::OkStatus();
617 Py_DECREF(shape);
618 if (dim != nullptr) Py_DECREF(dim);
619 return nullptr;
620 }
621 }
622 return shape;
623}
624
625// Getter for `_rank`.
626static PyObject* EagerTensor_rank(EagerTensor* self) {
627 int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status);
628 if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
629 // Cleanup self->status before returning.
630 self->status.status = ::tensorflow::OkStatus();
631 return nullptr;
632 }
633#if PY_MAJOR_VERSION < 3
634 return PyInt_FromLong(num_dims);
635#else
636 return PyLong_FromLong(num_dims);
637#endif
638}
639
640// Getter for `_num_elements`.
641static PyObject* EagerTensor_num_elements(EagerTensor* self) {
642 auto handle = self->handle;
643 int n = TFE_TensorHandleNumElements(handle, &self->status);
644 if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
645 // Cleanup self->status before returning.
646 self->status.status = ::tensorflow::OkStatus();
647 return nullptr;
648 }
649 return PyLong_FromLongLong(n);
650}
651
652static PyObject* EagerTensor_handle_data(EagerTensor* self, void* unused) {
653 Py_INCREF(self->handle_data);
654 return self->handle_data;
655}
656
657static int EagerTensor_sethandle_data(EagerTensor* self, PyObject* value,
658 void* unused) {
659 Py_DECREF(self->handle_data);
660 Py_INCREF(value);
661 self->handle_data = value;
662 return 0;
663}
664
665static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) {
666 Py_INCREF(self->tensor_shape);
667 return self->tensor_shape;
668}
669
670static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value,
671 void* unused) {
672 Py_DECREF(self->tensor_shape);
673 Py_INCREF(value);
674 self->tensor_shape = value;
675 return 0;
676}
677
678// Function `_copy_to_device`.
679static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args,
680 PyObject* kwds) {
681 if (!_PyArg_NoKeywords("copy_to_device", kwds)) return nullptr;
682
683 const char* device_name = nullptr;
684 if (!PyArg_ParseTuple(args, "O&:copy_to_device", ConvertDeviceName,
685 &device_name)) {
686 return nullptr;
687 }
688
689 // Note that this is a shallow copy and will share the underlying buffer
690 // if copying to the same device.
691 TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice(
692 self->handle, GetContextHandle(self->context), device_name,
693 &self->status);
694 if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status,
695 PyExc_RuntimeError)) {
696 // Cleanup self->status before returning.
697 self->status.status = ::tensorflow::OkStatus();
698 return nullptr;
699 }
700
701 return EagerTensorFromHandle(handle);
702}
703
704// Function `_numpy_internal`.
705// Convert an EagerTensor to a Python numpy.ndarray object.
706// The two may share underlying storage so changes to one may reflect in the
707// other.
708// Note that if `self` is not on CPU, we raise an Exception.
709static PyObject* EagerTensor_numpy_internal(EagerTensor* self) {
710 auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status);
711 if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) {
712 Py_XDECREF(py_array);
713 // Cleanup self->status before returning.
714 self->status.status = ::tensorflow::OkStatus();
715 return nullptr;
716 } else {
717 return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array));
718 }
719}
720
721// Function `_prefer_custom_summarizer`.
722//
723// A hint that callers should prefer `SummarizeValue` to resolving this handle
724// and formatting the tensor.
725static PyObject* EagerTensor_prefer_custom_summarizer(EagerTensor* self) {
726 if (tensorflow::unwrap(self->handle)->PreferCustomSummarizer()) {
727 Py_RETURN_TRUE;
728 } else {
729 Py_RETURN_FALSE;
730 }
731}
732
733// Function `_summarize_value`.
734//
735// Returns a string PyObject which summarizes the value of this tensor. It does
736// not include a shape or dtype.
737static PyObject* EagerTensor_summarize_value(EagerTensor* self) {
738 std::string summary;
739 tensorflow::Status status =
740 tensorflow::unwrap(self->handle)->SummarizeValue(summary);
741 if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
742 return nullptr;
743 }
744 return PyUnicode_FromString(summary.c_str());
745}
746
747// Getter `device`.
748static PyObject* EagerTensor_device(EagerTensor* self) {
749 const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status);
750 if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status,
751 PyExc_ValueError)) {
752 // Cleanup self->status before returning.
753 self->status.status = ::tensorflow::OkStatus();
754 return nullptr;
755 }
756#if PY_MAJOR_VERSION >= 3
757 return PyUnicode_FromString(device);
758#else
759 return PyBytes_FromString(device);
760#endif
761}
762
763// Getter `backing_device`.
764static PyObject* EagerTensor_backing_device(EagerTensor* self) {
765 const char* device =
766 TFE_TensorHandleBackingDeviceName(self->handle, &self->status);
767 if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status,
768 PyExc_ValueError)) {
769 // Cleanup self->status before returning.
770 self->status.status = ::tensorflow::OkStatus();
771 return nullptr;
772 }
773#if PY_MAJOR_VERSION >= 3
774 return PyUnicode_FromString(device);
775#else
776 return PyBytes_FromString(device);
777#endif
778}
779
780// Getter `is_packed`.
781static PyObject* EagerTensor_is_packed(EagerTensor* self) {
782 return PyBool_FromLong(self->is_packed);
783}
784
785static PyGetSetDef EagerTensor_getsetters[] = {
786 {const_cast<char*>("_id"), (getter)EagerTensor_getid, nullptr,
787 const_cast<char*>("Tensor ID."), nullptr},
788 {const_cast<char*>("device"), (getter)EagerTensor_device, nullptr,
789 const_cast<char*>("Device of op that produced the tensor."), nullptr},
790 {const_cast<char*>("backing_device"), (getter)EagerTensor_backing_device,
791 nullptr, const_cast<char*>("Device on which tensor's memory is resident."),
792 nullptr},
793 {const_cast<char*>("is_packed"), (getter)EagerTensor_is_packed, nullptr,
794 const_cast<char*>("Whether the EagerTensor is a packed tensor or not."),
795 nullptr},
796 {const_cast<char*>("_handle_data"), (getter)EagerTensor_handle_data,
797 (setter)EagerTensor_sethandle_data,
798 const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE"),
799 nullptr},
800 {const_cast<char*>("_tensor_shape"), (getter)EagerTensor_tensor_shape,
801 (setter)EagerTensor_settensor_shape,
802 const_cast<char*>("Shape of the tensor."), nullptr},
803 {nullptr} /* Sentinel */
804};
805
806#if PY_MAJOR_VERSION < 3
807// Only used for Python2 since Python3 seems to set the __dict__ correctly.
808static PyMemberDef EagerTensor_members[] = {
809 {const_cast<char*>("__dict__"), T_OBJECT, offsetof(EagerTensor, dict),
810 READONLY},
811 {nullptr},
812};
813#endif
814
815static PyMethodDef EagerTensor_methods[] = {
816 {"_numpy_internal", (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS,
817 PyDoc_STR("Internal method to get a NumPy array for the tensor.")},
818 {"_datatype_enum", (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS,
819 PyDoc_STR("The DType of the tensor as an enum.")},
820 {"_shape_tuple", (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS,
821 PyDoc_STR("The shape of the tensor as a python tuple.")},
822 {"_rank", (PyCFunction)EagerTensor_rank, METH_NOARGS,
823 PyDoc_STR("The rank of the tensor.")},
824 {"_copy_to_device", (PyCFunction)EagerTensor_copy_to_device,
825 METH_VARARGS | METH_KEYWORDS,
826 PyDoc_STR("Copies the tensor to the desired device.")},
827 {"_num_elements", (PyCFunction)EagerTensor_num_elements, METH_NOARGS,
828 PyDoc_STR("Number of elements in the tensor.")},
829 {"_prefer_custom_summarizer",
830 (PyCFunction)EagerTensor_prefer_custom_summarizer, METH_NOARGS,
831 PyDoc_STR("Indicates whether _numpy_internal loses information.")},
832 {"_summarize_value", (PyCFunction)EagerTensor_summarize_value, METH_NOARGS,
833 PyDoc_STR("A string which summarizes the value of this tensor.")},
834 {nullptr, nullptr},
835};
836
837static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view,
838 int flags) {
839 if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE) {
840 PyErr_SetString(PyExc_BufferError, "EagerTensor is not writable.");
841 return -1;
842 }
843
844 // TensorHandleToNumpy is zero-copy for everything but DT_RESOURCE and
845 // DT_STRING so the following is only slightly slower than a NumPy-free
846 // implementation.
847 auto py_array = tensorflow::make_safe(
848 TFE_TensorHandleToNumpy(self->handle, &self->status));
849 if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status,
850 PyExc_BufferError)) {
851 // Cleanup self->status before returning.
852 self->status.status = ::tensorflow::OkStatus();
853 return -1;
854 }
855 if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) {
856 return -1;
857 }
858 view->readonly = 1;
859 return 0;
860}
861
862static PyBufferProcs EagerTensor_as_buffer = {
863#if PY_MAJOR_VERSION < 3
864 nullptr, nullptr, nullptr, nullptr,
865#endif
866 (getbufferproc)EagerTensor_getbuffer,
867 // Never called because getbufferproc delegates to NumPy.
868 (releasebufferproc) nullptr};
869
870// Note that here we are trying to dynamically create a new class as a subclass
871// of a "HEAPTYPE" class that is itself created in python code and passed in at
872// runtime. This is fairly atypical and undocumented.
873//
874// We use the following strategy for this. Unfortunately, we have to use
875// different approaches for python2.x vs python3.x
876// For python2.x, we create the class as a static type and set its tp_base to
877// the passed in type. Unfortunately setting tp_flags to include
878// Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more
879// initialization of the underlying PyHeapTypeObject and not doing that leads to
880// some random crashes especially during garbage collection.
881// python3.x explicitly disables a static subclass of a HEAPTYPE base class.
882// However it provides a new function, PyType_FromSpecWithBases, to create
883// types dynamically.
884
885// Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor.
886PyTypeObject* EagerTensorType = nullptr;
887
888#if PY_MAJOR_VERSION >= 3
889static PyType_Slot EagerTensor_Type_slots[] = {
890 {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)},
891 {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)},
892 {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getsetters)},
893 {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)},
894 {0, nullptr},
895};
896#else
897
898#define EAGER_TENSOR_TPFLAGS (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_NEWBUFFER)
899
900// TODO(agarwal): support active_trace.
901static PyTypeObject _EagerTensorType = {
902 // clang-format off
903 PyVarObject_HEAD_INIT(nullptr, 0)
904 // clang-format on
905 "EagerTensor", /* tp_name */
906 sizeof(EagerTensor), /* tp_basicsize */
907 0, /* tp_itemsize */
908 (destructor)EagerTensor_dealloc, /* tp_dealloc */
909#if PY_VERSION_HEX < 0x03080000
910 nullptr, /* tp_print */
911#else
912 0, /* tp_vectorcall_offset */
913#endif
914 nullptr, /* tp_getattr */
915 nullptr, /* tp_setattr */
916 nullptr, /* tp_compare */
917 nullptr, /* tp_repr */
918 nullptr, /* tp_as_number */
919 nullptr, /* tp_as_sequence */
920 nullptr, /* tp_as_mapping */
921 nullptr, /* tp_hash */
922 nullptr, /* tp_call */
923 nullptr, /* tp_str */
924 nullptr, /* tp_getattro */
925 nullptr, /* tp_setattro */
926 &EagerTensor_as_buffer, /* tp_as_buffer */
927 EAGER_TENSOR_TPFLAGS, /* tp_flags */
928 nullptr, /* tp_doc */
929 nullptr, /* tp_traverse */
930 nullptr, /* tp_clear */
931 nullptr, /* tp_richcompare */
932 offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */
933 nullptr, /* tp_iter */
934 nullptr, /* tp_iternext */
935 EagerTensor_methods, /* tp_methods */
936 EagerTensor_members, /* tp_members */
937 EagerTensor_getsetters, /* tp_getset */
938 nullptr, /* tp_base */
939 nullptr, /* tp_dict */
940 nullptr, /* tp_descr_get */
941 nullptr, /* tp_descr_set */
942 offsetof(EagerTensor, dict), /* tp_dictoffset */
943 (initproc)EagerTensor_init, /* tp_init */
944 nullptr, /* tp_alloc */
945 nullptr, /* tp_new */
946};
947
948#endif
949
950} // extern "C"
951
952bool EagerTensor_CheckExact(const PyObject* o) {
953 return Py_TYPE(o) == EagerTensorType;
954}
955
956TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) {
957 return reinterpret_cast<const EagerTensor*>(o)->handle;
958}
959
960PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle,
961 const bool is_packed) {
962 if (handle == nullptr) {
963 return nullptr;
964 }
965 EagerTensor* t = reinterpret_cast<EagerTensor*>(
966 EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict()));
967 if (t != nullptr) {
968 t->id = get_uid();
969 t->is_packed = is_packed;
970 Py_INCREF(Py_None);
971 t->handle_data = Py_None;
972 Py_INCREF(Py_None);
973 t->tensor_shape = Py_None;
974 t->handle = handle;
975 t->status.status = ::tensorflow::OkStatus();
976 t->weakreflist = nullptr;
977 PyObject* py_context = GetPyEagerContext();
978 if (py_context == nullptr) {
979 LOG(ERROR) << "Cannot create an eager tensor before eager context has "
980 "been set or after it has been deleted";
981 return nullptr;
982 }
983 t->context = py_context;
984
985 if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) {
986 return nullptr;
987 }
988 }
989 return reinterpret_cast<PyObject*>(t);
990}
991
992int64_t PyEagerTensor_ID(const PyObject* tensor) {
993 DCHECK(EagerTensor_CheckExact(tensor));
994 return reinterpret_cast<const EagerTensor*>(tensor)->id;
995}
996
997tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) {
998 DCHECK(EagerTensor_CheckExact(tensor));
999 return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType(
1000 reinterpret_cast<const EagerTensor*>(tensor)->handle));
1001}
1002
1003int64_t PyEagerTensor_NumElements(PyObject* tensor) {
1004 DCHECK(EagerTensor_CheckExact(tensor));
1005 EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor);
1006 int64_t result = TFE_TensorHandleNumElements(as_c_eager_tensor->handle,
1007 &as_c_eager_tensor->status);
1008
1009 if (tensorflow::MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status,
1010 PyExc_ValueError)) {
1011 // Cleanup status before returning.
1012 as_c_eager_tensor->status.status = ::tensorflow::OkStatus();
1013 return -1;
1014 }
1015
1016 return result;
1017}
1018
1019PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) {
1020 if (!PyType_Check(base_class)) {
1021 PyErr_SetString(
1022 PyExc_TypeError,
1023 tensorflow::strings::StrCat(
1024 "Expecting a class definition for `base_class` passed to ",
1025 "TFE_InitEagerTensor. Got ", Py_TYPE(base_class)->tp_name)
1026 .c_str());
1027 return nullptr;
1028 }
1029 // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in
1030 // EagerTensor to allow for the space usage of the base class.
1031 PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class);
1032 if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) {
1033 PyErr_SetString(
1034 PyExc_TypeError,
1035 tensorflow::strings::StrCat(
1036 "Unable to create subclass EagerTensor from base class ",
1037 Py_TYPE(base_class)->tp_name,
1038 ". Need its size to be <= ", kMaxEagerTensorParentSize)
1039 .c_str());
1040 return nullptr;
1041 }
1042 if (base_class_type->tp_itemsize != 0) {
1043 PyErr_SetString(
1044 PyExc_TypeError,
1045 tensorflow::strings::StrCat(
1046 "Unable to create subclass EagerTensor from base class ",
1047 Py_TYPE(base_class)->tp_name,
1048 " which supports variable length instances.")
1049 .c_str());
1050 return nullptr;
1051 }
1052 Py_INCREF(base_class);
1053#if PY_MAJOR_VERSION >= 3
1054 PyObject* bases = PyTuple_New(1);
1055 PyTuple_SET_ITEM(bases, 0, base_class);
1056
1057 tensorflow::Safe_PyObjectPtr base_class_module(
1058 PyObject_GetAttrString(base_class, "__module__"));
1059 const char* module = nullptr;
1060 if (PyErr_Occurred()) {
1061 PyErr_Clear();
1062 module = "__builtin__";
1063 } else {
1064 module = PyBytes_AsString(base_class_module.get());
1065 if (module == nullptr) {
1066 PyErr_Clear();
1067 module = PyUnicode_AsUTF8(base_class_module.get());
1068 if (module == nullptr) {
1069 PyErr_Clear();
1070 module = "__builtin__";
1071 }
1072 }
1073 }
1074
1075 // NOTE: The c_str from this string needs to outlast the function, hence is
1076 // static.
1077 static tensorflow::string fully_qualified_name =
1078 tensorflow::strings::StrCat(module, ".EagerTensor");
1079
1080 static PyType_Spec EagerTensor_Type_spec = {
1081 fully_qualified_name.c_str(), sizeof(EagerTensor), 0,
1082 Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots};
1083
1084 EagerTensorType = reinterpret_cast<PyTypeObject*>(
1085 PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases));
1086 if (PyErr_Occurred()) {
1087 return nullptr;
1088 }
1089 if (EagerTensorType == nullptr) {
1090 PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType");
1091 return nullptr;
1092 }
1093 EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict);
1094 EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer;
1095#else
1096 _EagerTensorType.tp_base = base_class_type;
1097
1098 if (PyType_Ready(&_EagerTensorType) < 0) {
1099 if (PyErr_Occurred()) return nullptr;
1100 PyErr_SetString(PyExc_RuntimeError,
1101 "Error while creating EagerTensor type.");
1102 return nullptr;
1103 }
1104 EagerTensorType = &_EagerTensorType;
1105 Py_INCREF(EagerTensorType);
1106#endif
1107 return reinterpret_cast<PyObject*>(EagerTensorType);
1108}
1109
1110PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) {
1111 Py_XDECREF(eager_tensor_profiler);
1112
1113 if (profiler == Py_None) {
1114 eager_tensor_profiler = nullptr;
1115 } else {
1116 eager_tensor_profiler = profiler;
1117 Py_INCREF(eager_tensor_profiler);
1118 }
1119 Py_RETURN_NONE;
1120}
1121
1122PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) {
1123 if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) {
1124 PyErr_SetString(PyExc_TypeError,
1125 tensorflow::strings::StrCat(
1126 "tensors argument must be a list or a tuple. Got \"",
1127 Py_TYPE(tensors)->tp_name, "\"")
1128 .c_str());
1129 return nullptr;
1130 }
1131 if (slice_dim < 0) {
1132 PyErr_SetString(
1133 PyExc_ValueError,
1134 tensorflow::strings::StrCat("Slice dimension must be non-negative. "
1135 "Got ",
1136 slice_dim)
1137 .c_str());
1138 return nullptr;
1139 }
1140
1141 PyObject* py_context = GetPyEagerContext();
1142 if (py_context == nullptr) {
1143 PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat(
1144 "Cannot create EagerTensor when "
1145 "EagerContext is not valid")
1146 .c_str());
1147 return nullptr;
1148 }
1149
1150 TFE_Context* ctx = GetContextHandle(py_context);
1151
1152 Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors);
1153 PyObject** tensors_array = PySequence_Fast_ITEMS(tensors);
1154 int64_t num_tensors_int = static_cast<int64_t>(num_tensors);
1155
1156 auto status = tensorflow::make_safe(TF_NewStatus());
1157
1158 // Create an empty tensor.
1159 auto* tensor = tensorflow::unwrap(ctx)->CreateTensor(
1160 tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int});
1161
1162 if (num_tensors_int > 0) {
1163 int32_t* data = reinterpret_cast<int32_t*>(tensor->Data());
1164
1165 // Fill the tensor with dims.
1166 for (Py_ssize_t i = 0; i < num_tensors; ++i) {
1167 PyObject* tensor_obj = tensors_array[i];
1168 if (!EagerTensor_CheckExact(tensor_obj)) {
1169 PyErr_SetString(
1170 PyExc_TypeError,
1171 tensorflow::strings::StrCat("Expected a list of EagerTensors but "
1172 "element ",
1173 i, " has type \"",
1174 Py_TYPE(tensor_obj)->tp_name, "\"")
1175 .c_str());
1176 return nullptr;
1177 }
1178
1179 EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj);
1180 TFE_TensorHandle* handle = t->handle;
1181 int num_dims = TFE_TensorHandleNumDims(handle, status.get());
1182 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(),
1183 PyExc_ValueError)) {
1184 return nullptr;
1185 }
1186 if (slice_dim >= num_dims) {
1187 PyErr_SetString(
1188 PyExc_IndexError,
1189 tensorflow::strings::StrCat("Slice dimension (", slice_dim,
1190 ") must be smaller than rank of all "
1191 "tensors, but tensor at index ",
1192 i, " has rank ", num_dims)
1193 .c_str());
1194 return nullptr;
1195 }
1196 int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get());
1197 if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(),
1198 PyExc_ValueError)) {
1199 return nullptr;
1200 }
1201 data[i] = dim;
1202 }
1203 }
1204
1205 TFE_TensorHandle* handle =
1206 tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor));
1207
1208 if (!status->status.ok()) {
1209 PyErr_SetString(
1210 PyExc_RuntimeError,
1211 tensorflow::strings::StrCat("Failed to construct new tensor handle: ",
1212 TF_Message(status.get()))
1213 .c_str());
1214 return nullptr;
1215 }
1216
1217 return EagerTensorFromHandle(handle);
1218}
1219
1220PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) {
1221 if (!EagerTensor_CheckExact(tensor)) {
1222 PyErr_SetString(
1223 PyExc_TypeError,
1224 tensorflow::strings::StrCat("Expected an EagerTensors but got type \"",
1225 Py_TYPE(tensor)->tp_name, "\"")
1226 .c_str());
1227 return nullptr;
1228 }
1229 TFE_TensorHandle* handle = EagerTensor_Handle(tensor);
1230
1231 auto status = tensorflow::make_safe(TF_NewStatus());
1232 TFE_TensorDebugInfo* debug_info =
1233 TFE_TensorHandleTensorDebugInfo(handle, status.get());
1234 if (!status->status.ok()) {
1235 PyErr_SetString(
1236 PyExc_RuntimeError,
1237 tensorflow::strings::StrCat("Error retrieving tensor's device shape: ",
1238 TF_Message(status.get()))
1239 .c_str());
1240 return nullptr;
1241 }
1242
1243 int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info);
1244 PyObject* shape = PyTuple_New(rank);
1245 for (int i = 0; i < rank; ++i) {
1246 int64_t dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i);
1247 PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size));
1248 }
1249 TFE_DeleteTensorDebugInfo(debug_info);
1250
1251 return shape;
1252}
1253