1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/eager/pywrap_tensor.h" |
17 | |
18 | #include <stdlib.h> |
19 | #include <string.h> |
20 | |
21 | #include <cmath> |
22 | |
23 | #include "structmember.h" // NOLINT // For PyMemberDef |
24 | #include "pybind11/pybind11.h" |
25 | #include "tensorflow/c/c_api.h" |
26 | #include "tensorflow/c/eager/c_api.h" |
27 | #include "tensorflow/c/eager/c_api_internal.h" |
28 | #include "tensorflow/c/eager/tfe_context_internal.h" |
29 | #include "tensorflow/c/eager/tfe_tensorhandle_internal.h" |
30 | #include "tensorflow/c/tf_status.h" |
31 | #include "tensorflow/core/framework/types.h" |
32 | #include "tensorflow/core/framework/types.pb.h" |
33 | #include "tensorflow/core/lib/strings/strcat.h" |
34 | #include "tensorflow/python/eager/pywrap_tensor_conversion.h" |
35 | #include "tensorflow/python/eager/pywrap_tfe.h" |
36 | #include "tensorflow/python/lib/core/ndarray_tensor.h" |
37 | #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" |
38 | #include "tensorflow/python/lib/core/numpy.h" |
39 | #include "tensorflow/python/lib/core/py_exception_registry.h" |
40 | #include "tensorflow/python/lib/core/py_seq_tensor.h" |
41 | #include "tensorflow/python/lib/core/pybind11_status.h" |
42 | #include "tensorflow/python/lib/core/safe_ptr.h" |
43 | |
44 | // forward declare |
45 | struct EagerTensor; |
46 | namespace tensorflow { |
47 | |
48 | // Convert a TFE_TensorHandle to a Python numpy.ndarray object. |
49 | // The two may share underlying storage so changes to one may reflect in the |
50 | // other. |
51 | PyObject* TFE_TensorHandleToNumpy(TFE_TensorHandle* handle, TF_Status* status) { |
52 | if (TFE_TensorHandleDataType(handle) == TF_RESOURCE) { |
53 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
54 | "Cannot convert a Tensor of dtype resource to a NumPy array." ); |
55 | return nullptr; |
56 | } |
57 | |
58 | if (TFE_TensorHandleDataType(handle) == TF_VARIANT) { |
59 | TF_SetStatus(status, TF_INVALID_ARGUMENT, |
60 | "Cannot convert a Tensor of dtype variant to a NumPy array." ); |
61 | return nullptr; |
62 | } |
63 | tensorflow::Safe_TF_TensorPtr tensor = nullptr; |
64 | Py_BEGIN_ALLOW_THREADS; |
65 | tensor = tensorflow::make_safe(TFE_TensorHandleResolve(handle, status)); |
66 | Py_END_ALLOW_THREADS; |
67 | if (!status->status.ok()) { |
68 | return nullptr; |
69 | } |
70 | |
71 | PyObject* ret = nullptr; |
72 | auto cppstatus = |
73 | tensorflow::TF_TensorToMaybeAliasedPyArray(std::move(tensor), &ret); |
74 | tensorflow::Set_TF_Status_from_Status(status, cppstatus); |
75 | if (!status->status.ok()) { |
76 | Py_XDECREF(ret); |
77 | return nullptr; |
78 | } |
79 | CHECK_NE(ret, nullptr); |
80 | return ret; |
81 | } |
82 | } // namespace tensorflow |
83 | namespace { |
84 | |
85 | using tensorflow::TFE_TensorHandleToNumpy; |
86 | |
87 | // An instance of _EagerTensorProfiler that will receive callbacks about |
88 | // events on eager tensors. This is set by TFE_Py_InitEagerTensor, if at all. |
89 | PyObject* eager_tensor_profiler = nullptr; |
90 | |
91 | // Read-only dict. Please don't use this in any setting where the dict might |
92 | // actually get mutated. This is only used to pass empty kwargs when creating a |
93 | // new EagerTensor. |
94 | PyObject* EmptyDict() { |
95 | static PyObject* empty_dict = PyDict_New(); |
96 | return empty_dict; |
97 | } |
98 | |
99 | PyObject* EmptyTuple() { |
100 | static PyObject* empty_tuple = PyTuple_New(0); |
101 | return empty_tuple; |
102 | } |
103 | |
104 | TFE_Context* GetContextHandle(PyObject* py_context) { |
105 | tensorflow::Safe_PyObjectPtr py_context_handle( |
106 | PyObject_GetAttrString(py_context, "_handle" )); |
107 | if (py_context_handle == nullptr) { |
108 | // Current Python code makes sure this never happens. If it does, or |
109 | // becomes hard to maintain, we can call the ensure_initialized() method |
110 | // here. |
111 | PyErr_SetString( |
112 | PyExc_TypeError, |
113 | "Expected `context` argument in EagerTensor constructor to have a " |
114 | "`_handle` attribute but it did not. Was eager Context initialized?" ); |
115 | return nullptr; |
116 | } |
117 | |
118 | auto* ctx = reinterpret_cast<TFE_Context*>( |
119 | PyCapsule_GetPointer(py_context_handle.get(), nullptr)); |
120 | if (ctx == nullptr) { |
121 | PyErr_SetString(PyExc_TypeError, |
122 | tensorflow::strings::StrCat( |
123 | "Expected context._handle to contain a PyCapsule " |
124 | "encoded pointer to TFE_Context. Got " , |
125 | Py_TYPE(py_context_handle.get())->tp_name) |
126 | .c_str()); |
127 | } |
128 | return ctx; |
129 | } |
130 | |
131 | |
132 | // Helper function to convert `v` to a tensorflow::DataType and store it in |
133 | // `*out`. Returns true on success, false otherwise. |
134 | // Note that we assume that v is a python int (not long) representing a |
135 | // TF_DataType/tensorflow::DataType value. |
136 | bool PyIntToDataType(PyObject* v, tensorflow::DataType* out) { |
137 | #if PY_MAJOR_VERSION < 3 |
138 | if (PyInt_Check(v)) { |
139 | *out = static_cast<tensorflow::DataType>(PyInt_AS_LONG(v)); |
140 | return true; |
141 | } |
142 | #else |
143 | if (PyLong_Check(v)) { |
144 | *out = static_cast<tensorflow::DataType>(PyLong_AsLong(v)); |
145 | return true; |
146 | } |
147 | #endif |
148 | return false; |
149 | } |
150 | |
151 | // Helper function to create a python integer from TF_DataType. |
152 | PyObject* PyIntFromDataType(TF_DataType l) { |
153 | #if PY_MAJOR_VERSION < 3 |
154 | return PyInt_FromLong(l); |
155 | #else |
156 | return PyLong_FromLong(l); |
157 | #endif |
158 | } |
159 | |
160 | // PyObject->tensorflow::DataType conversion function to be used with |
161 | // PyArg_Parse* APIs. |
162 | int ConvertDataType(PyObject* obj, tensorflow::DataType* dst) { |
163 | if (obj == Py_None) { |
164 | *dst = tensorflow::DataType::DT_INVALID; |
165 | } else if (!PyIntToDataType(obj, dst)) { |
166 | PyErr_SetString( |
167 | PyExc_TypeError, |
168 | tensorflow::strings::StrCat( |
169 | "Expecting a DataType value for dtype. Got " , Py_TYPE(obj)->tp_name) |
170 | .c_str()); |
171 | return 0; |
172 | } |
173 | |
174 | return 1; |
175 | } |
176 | |
177 | // Conversion function extracting a const char** device name from a PyObject. |
178 | // The function should be used with PyArg_Parse* APIs. |
179 | int ConvertDeviceName(PyObject* obj, const char** dst) { |
180 | if (obj == Py_None) { |
181 | *dst = nullptr; |
182 | } else { |
183 | auto device_name = TFE_GetPythonString(obj); |
184 | if (device_name == nullptr) { |
185 | PyErr_Clear(); |
186 | PyErr_SetString(PyExc_TypeError, "Error parsing device argument." ); |
187 | return 0; |
188 | } |
189 | *dst = device_name; |
190 | } |
191 | |
192 | return 1; |
193 | } |
194 | |
195 | void RaiseExceptionTypeFromTFStatus(TF_Status* tf_status) { |
196 | auto status = tensorflow::StatusFromTF_Status(tf_status); |
197 | SetRegisteredErrFromStatus(status); |
198 | } |
199 | |
200 | } // namespace |
201 | |
202 | namespace tensorflow { |
203 | // This function checks whether the desired type is "compatible" with the |
204 | // inferred type. At a high level, compatibility means that all integral types |
205 | // are compatible with each other, and all floating types are compatible with |
206 | // each other. |
207 | // |
208 | // Type compatibility doesn't consider overflows (i.e. int64 is *always* |
209 | // compatible with int32). This is intended to match graph behavior. |
210 | bool IsCompatible(DataType desired, DataType returned) { |
211 | if (desired == returned) return true; |
212 | |
213 | if (DataTypeIsInteger(desired) && DataTypeIsInteger(returned)) { |
214 | return true; |
215 | } else if (DataTypeIsFloating(desired) && |
216 | (DataTypeIsFloating(returned) || DataTypeIsInteger(returned))) { |
217 | return true; |
218 | } else if (DataTypeIsComplex(desired) && |
219 | (DataTypeIsComplex(returned) || DataTypeIsInteger(returned) || |
220 | DataTypeIsFloating(returned))) { |
221 | return true; |
222 | } else if (DataTypeIsQuantized(desired) && DataTypeIsInteger(returned)) { |
223 | return true; |
224 | } |
225 | return false; |
226 | } |
227 | |
228 | // TODO(nareshmodi): Move EagerCast and ReadVariableOp (which use the C API to |
229 | // execute TFE Ops) to a separate common library. |
230 | // Casts data referred to by `handle` from type `src_type_enum` to type |
231 | // `dst_type_enum`. |
232 | TFE_TensorHandle* EagerCast(TFE_Context* ctx, TFE_TensorHandle* handle, |
233 | TF_DataType src_type_enum, |
234 | TF_DataType dst_type_enum, TF_Status* out_status) { |
235 | if (ctx == nullptr) return nullptr; |
236 | const char* op_name = "Cast" ; |
237 | const char* device_name = "/device:CPU:0" ; |
238 | TFE_Op* op = TFE_NewOp(ctx, op_name, out_status); |
239 | #define RETURN_ERROR \ |
240 | { \ |
241 | TFE_DeleteOp(op); \ |
242 | return nullptr; \ |
243 | } |
244 | if (!out_status->status.ok()) RETURN_ERROR |
245 | TFE_OpSetDevice(op, device_name, out_status); |
246 | if (!out_status->status.ok()) RETURN_ERROR |
247 | TFE_OpAddInput(op, handle, out_status); |
248 | if (!out_status->status.ok()) RETURN_ERROR |
249 | TFE_OpSetAttrType(op, "SrcT" , src_type_enum); |
250 | TFE_OpSetAttrType(op, "DstT" , dst_type_enum); |
251 | TFE_OpSetAttrBool(op, "Truncate" , false); |
252 | TFE_TensorHandle* output = nullptr; |
253 | int num_outputs = 1; |
254 | TFE_Execute(op, &output, &num_outputs, out_status); |
255 | if (!out_status->status.ok() || num_outputs != 1 || output == nullptr) { |
256 | if (output != nullptr) { |
257 | TFE_DeleteTensorHandle(output); |
258 | } |
259 | RETURN_ERROR |
260 | } |
261 | TFE_DeleteOp(op); |
262 | return output; |
263 | #undef RETURN_ERROR |
264 | } |
265 | |
266 | Safe_TFE_TensorHandlePtr EagerConst(TFE_Context* ctx, TFE_TensorHandle* handle, |
267 | const char* device_name, |
268 | TF_Status* out_status) { |
269 | const char* op_name = "_EagerConst" ; |
270 | std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op( |
271 | TFE_NewOp(ctx, op_name, out_status), TFE_DeleteOp); |
272 | if (!out_status->status.ok()) return nullptr; |
273 | TFE_OpSetDevice(op.get(), device_name, out_status); |
274 | if (!out_status->status.ok()) return nullptr; |
275 | TFE_OpAddInput(op.get(), handle, out_status); |
276 | if (!out_status->status.ok()) return nullptr; |
277 | TFE_OpSetAttrType(op.get(), "T" , TFE_TensorHandleDataType(handle)); |
278 | TFE_TensorHandle* output = nullptr; |
279 | int num_outputs = 1; |
280 | TFE_Execute(op.get(), &output, &num_outputs, out_status); |
281 | Safe_TFE_TensorHandlePtr result(output); |
282 | if (!out_status->status.ok() || num_outputs != 1) { |
283 | return nullptr; |
284 | } |
285 | return result; |
286 | } |
287 | |
288 | TFE_TensorHandle* ConvertToEagerTensorUncached(TFE_Context* ctx, |
289 | PyObject* value, |
290 | tensorflow::DataType dtype, |
291 | const char* device_name) { |
292 | tensorflow::Safe_PyObjectPtr value_decrefer; |
293 | if (PyArray_IsScalar(value, Generic)) { |
294 | // Convert numpy scalars to numpy arrays. |
295 | value = PyArray_FromScalar(value, nullptr); |
296 | // The returned value needs to be DECREF'd, but the original value was |
297 | // created in python code, and doesn't need to be DECREF'd. |
298 | value_decrefer.reset(value); |
299 | } |
300 | |
301 | Safe_TFE_TensorHandlePtr handle = |
302 | make_safe(PySeqToTFE_TensorHandle(ctx, value, dtype)); |
303 | |
304 | if (handle == nullptr) return nullptr; |
305 | |
306 | Safe_TF_StatusPtr status = make_safe(TF_NewStatus()); |
307 | TF_DataType handle_dtype = TFE_TensorHandleDataType(handle.get()); |
308 | if (dtype != tensorflow::DT_INVALID && |
309 | dtype != static_cast<DataType>(handle_dtype)) { |
310 | if (tensorflow::IsCompatible(dtype, static_cast<DataType>(handle_dtype))) { |
311 | handle = tensorflow::make_safe( |
312 | tensorflow::EagerCast(ctx, handle.get(), handle_dtype, |
313 | static_cast<TF_DataType>(dtype), status.get())); |
314 | if (!status->status.ok()) { |
315 | PyErr_SetString(PyExc_TypeError, |
316 | absl::StrCat("Error while casting from dtype " , |
317 | tensorflow::DataTypeString( |
318 | static_cast<DataType>(handle_dtype)), |
319 | " to " , tensorflow::DataTypeString(dtype), |
320 | ". " , TF_Message(status.get())) |
321 | .c_str()); |
322 | return nullptr; |
323 | } |
324 | } else { |
325 | tensorflow::Safe_PyObjectPtr value_str(PyObject_Repr(value)); |
326 | PyErr_SetString( |
327 | PyExc_TypeError, |
328 | absl::StrCat("Cannot convert " , TFE_GetPythonString(value_str.get()), |
329 | " to EagerTensor of dtype " , |
330 | tensorflow::DataTypeString(dtype)) |
331 | .c_str()); |
332 | return nullptr; |
333 | } |
334 | } |
335 | |
336 | // We always initially generate CPU:0 tensors. Copy to the current device. |
337 | if (device_name != nullptr) { |
338 | if (strstr(device_name, "/device:CPU:0" ) != nullptr) { |
339 | // We always generate CPU:0 tensors, but we may need to change the device |
340 | // slightly, as for example from /job:localhost/... to /job:worker/... |
341 | // |
342 | // Note that this is a shallow copy and will share the underlying buffer, |
343 | // because we are copying to the same device. |
344 | handle = make_safe(TFE_TensorHandleCopyToDevice( |
345 | handle.get(), ctx, device_name, status.get())); |
346 | const TF_Code code = TF_GetCode(status.get()); |
347 | if (code != TF_OK) { |
348 | RaiseExceptionTypeFromTFStatus(status.get()); |
349 | return nullptr; |
350 | } |
351 | } else { |
352 | /*Copy the constant to the current device. Identity is sometimes |
353 | overloaded to allow copies like this, but using a different op allows |
354 | devices to support constant creation without allowing copies via |
355 | identity ops. |
356 | |
357 | Note that running this _EagerConst op limits mirroring of cached Python |
358 | literals somewhat. Mirroring of constants themselves works: |
359 | |
360 | with tf.device("GPU:0"): |
361 | tf.constant(1.) # Cached on CPU:0, mirrored to GPU:0 |
362 | with tf.device("GPU:1"): |
363 | tf.constant(1.) # Cache hit for the CPU version, new mirror to GPU:1. |
364 | with tf.device("GPU:1"): |
365 | tf.constant(1.) # Cache hit for the CPU version, cached mirror |
366 | |
367 | But mirrors for the output of `tf.constant` are not shared just because |
368 | there was a cache hit for the input literal, because of _EagerConst: |
369 | |
370 | x = tf.constant(2.) # Cached on CPU:0 |
371 | with tf.device("GPU:1"): |
372 | tf.identity(x) # `x` now mirrored to GPU:1 |
373 | y = tf.constant(2.) # Cache hit for CPU version |
374 | with tf.device("GPU:1"): |
375 | tf.identity(y) # `y` now mirrored on GPU:1 (new copy!)*/ |
376 | handle = |
377 | tensorflow::EagerConst(ctx, handle.get(), device_name, status.get()); |
378 | const TF_Code code = TF_GetCode(status.get()); |
379 | if (code != TF_OK) { |
380 | RaiseExceptionTypeFromTFStatus(status.get()); |
381 | return nullptr; |
382 | } |
383 | } |
384 | } |
385 | |
386 | return handle.release(); |
387 | } |
388 | |
389 | TFE_TensorHandle* ConvertToEagerTensor(TFE_Context* ctx, PyObject* value, |
390 | DataType dtype, |
391 | const char* device_name) { |
392 | // Reduce the overhead of allocation/transfer-to-device for scalars by |
393 | // caching the corresponding handles. Note that currently only Python |
394 | // scalars are cached. |
395 | // TODO(slebedev): also cache singleton NumPy arrays and scalars? |
396 | if (PyArray_IsPythonNumber(value)) { |
397 | auto* cache = TFE_TensorHandleCache::Get(); |
398 | TFE_TensorHandle* handle = cache->Lookup(value, dtype, ctx, device_name); |
399 | if (handle != nullptr) return handle; |
400 | handle = ConvertToEagerTensorUncached(ctx, value, dtype, device_name); |
401 | if (handle == nullptr) return nullptr; |
402 | if (!PyFloat_Check(value) || std::isfinite(PyFloat_AS_DOUBLE(value))) { |
403 | cache->Insert(value, dtype, ctx, device_name, handle); |
404 | } |
405 | return handle; |
406 | } else { |
407 | return ConvertToEagerTensorUncached(ctx, value, dtype, device_name); |
408 | } |
409 | } |
410 | |
411 | } // namespace tensorflow |
412 | |
413 | extern "C" { |
414 | |
415 | static const int kMaxEagerTensorParentSize = 64; |
416 | |
417 | // TODO(agarwal): store context handle in EagerTensor. |
418 | typedef struct EagerTensor { |
419 | PyObject_HEAD; |
420 | // Note that we leave kMaxEagerTensorParentSize bytes here for use by the |
421 | // parent class. The parent class is set at runtime, so we don't know the |
422 | // exact size at compile time. |
423 | char unused[kMaxEagerTensorParentSize]; |
424 | TFE_TensorHandle* handle; |
425 | int64_t id; |
426 | // Indicates whether it's a packed tensor or not. |
427 | bool is_packed; |
428 | // This mirrors tensorflow.core.framework.ops.Tensor._handle_data Which will |
429 | // be None for tensors of type other than DT_RESOURCE. For DT_RESOURCE |
430 | // tensors, this will contain a serialized HandleData proto with shape |
431 | // inference metadata about shapes and dtypes of resources accessible from |
432 | // this handle. |
433 | // Note that we assume that handle_data cannot participate in reference |
434 | // cycles, and hence don't provide GC support for it. |
435 | PyObject* handle_data; |
436 | |
437 | // This stores `_tensor_shape`, a cached `TensorShape` object, and is set the |
438 | // first time that `_EagerTensorBase`'s `shape` property is called. |
439 | PyObject* tensor_shape; |
440 | |
441 | // We store a status object here as an optimization to avoid allocating a new |
442 | // Status objects on different functions that operate on EagerTensor and need |
443 | // to use a TF_Status object. However note that accesses to `status` are not |
444 | // thread-safe. |
445 | TF_Status status; |
446 | |
447 | // The eager Context (from eager/context.py) used by this Tensor. |
448 | // This is currently used only to make sure context outlives TensorHandles. |
449 | PyObject* context; |
450 | |
451 | PyObject* weakreflist; /* List of weak references */ |
452 | |
453 | // Per-instance attribute dictionary, to support monkey patching |
454 | // (e.g. EagerTensor.assign when slicing variables). This dictionary is |
455 | // created by CPython the first time an attribute is assigned, pointed to by |
456 | // tp_dictoffset. Note that garbage collection is not enabled for |
457 | // EagerTensors, so assigning objects to EagerTensor attributes which require |
458 | // garbage collection is likely to cause issues. |
459 | PyObject* dict; |
460 | } EagerTensor; |
461 | |
462 | namespace { |
463 | |
464 | // Returns true on success - successfully invoked or no profiler registered. |
465 | // Returns false if some error occurred. |
466 | bool MaybeInvokeCreatedOnEagerTensorProfiler(EagerTensor* created_tensor) { |
467 | if (eager_tensor_profiler != nullptr) { |
468 | #if PY_MAJOR_VERSION < 3 |
469 | PyObject* created_method_name = PyString_InternFromString("created" ); |
470 | #else |
471 | PyObject* created_method_name = PyUnicode_InternFromString("created" ); |
472 | #endif |
473 | if (created_method_name == nullptr) { |
474 | return false; |
475 | } |
476 | PyObject* result = PyObject_CallMethodObjArgs( |
477 | eager_tensor_profiler, created_method_name, created_tensor, NULL); |
478 | if (result == nullptr) { |
479 | LOG(ERROR) << "Invoking created() on EagerTensor profiler failed" ; |
480 | // While we can potentially continue because the error is related to |
481 | // profiling, we choose to return an error because: |
482 | // - If profiling is used, the user likely wants to stop execution on |
483 | // profiling errors. |
484 | // - Error in profiling code might have left some state in an invalid |
485 | // form that can lead to an error later on. Better to fail fast. |
486 | Py_DECREF(created_method_name); |
487 | return false; |
488 | } |
489 | Py_DECREF(created_method_name); |
490 | Py_DECREF(result); |
491 | } |
492 | return true; |
493 | } |
494 | |
495 | } // namespace |
496 | |
497 | // tp_init for EagerTensor. |
498 | int EagerTensor_init(EagerTensor* self, PyObject* args, PyObject* kwds) { |
499 | self->id = get_uid(); |
500 | self->handle = nullptr; |
501 | self->is_packed = false; |
502 | Py_INCREF(Py_None); |
503 | self->handle_data = Py_None; |
504 | Py_INCREF(Py_None); |
505 | self->tensor_shape = Py_None; |
506 | self->status.status = ::tensorflow::OkStatus(); |
507 | self->dict = nullptr; |
508 | self->weakreflist = nullptr; |
509 | self->context = nullptr; |
510 | PyObject* value; |
511 | const char* device_name = nullptr; |
512 | tensorflow::DataType dtype = tensorflow::DataType::DT_INVALID; |
513 | const char* kwlist[] = {"value" , "device" , "dtype" , nullptr}; |
514 | if (!PyArg_ParseTupleAndKeywords( |
515 | args, kwds, "OO&|O&" , const_cast<char**>(kwlist), &value, |
516 | ConvertDeviceName, &device_name, ConvertDataType, &dtype)) { |
517 | return -1; |
518 | } |
519 | |
520 | PyObject* py_context = GetPyEagerContext(); |
521 | if (py_context == nullptr) return -1; |
522 | self->context = py_context; |
523 | |
524 | auto* handle = tensorflow::ConvertToEagerTensor(GetContextHandle(py_context), |
525 | value, dtype, device_name); |
526 | if (handle == nullptr) return -1; |
527 | self->handle = handle; |
528 | |
529 | if (!MaybeInvokeCreatedOnEagerTensorProfiler(self)) { |
530 | return -1; |
531 | } |
532 | |
533 | return 0; |
534 | } |
535 | |
536 | // tp_dealloc for EagerTensor. |
537 | void EagerTensor_dealloc(EagerTensor* self) { |
538 | // Unhook the object from python's GC so that the weakref deleter doesn't |
539 | // try to re-delete this. |
540 | PyObject_GC_UnTrack((PyObject*)self); |
541 | |
542 | // Clear weak references to self. |
543 | // Needs to happen before any actual destruction. |
544 | PyObject_ClearWeakRefs((PyObject*)self); |
545 | |
546 | Py_DECREF(self->handle_data); |
547 | Py_DECREF(self->tensor_shape); |
548 | // If an attribute dictionary has been created, release it. Note that this |
549 | // is only ever created by CPython's attribute setting methods; we don't |
550 | // create it ourselves. |
551 | Py_CLEAR(self->dict); |
552 | if (self->handle != nullptr) { |
553 | // Destructor may call arbitrary functions that end up calling into |
554 | // Python from another thread. |
555 | Py_BEGIN_ALLOW_THREADS; |
556 | TFE_DeleteTensorHandle(self->handle); |
557 | Py_END_ALLOW_THREADS; |
558 | self->handle = nullptr; |
559 | } |
560 | |
561 | // Decref context after deleting the tensor handle. |
562 | Py_XDECREF(self->context); |
563 | |
564 | // We have the global interpreter lock, so use this chance to perform delayed |
565 | // refcount decrements. |
566 | tensorflow::ClearDecrefCache(); |
567 | auto id = self->id; |
568 | Py_TYPE(self)->tp_free(self); |
569 | TFE_Py_TapeSetDeleteTrace(id); |
570 | } |
571 | |
572 | // Getter for `_id`. |
573 | static PyObject* EagerTensor_getid(EagerTensor* self, void* closure) { |
574 | return PyLong_FromLongLong(self->id); |
575 | } |
576 | |
577 | // Getter for `_datatype_enum`. |
578 | static PyObject* EagerTensor_datatype_enum(EagerTensor* self) { |
579 | return PyIntFromDataType(TFE_TensorHandleDataType(self->handle)); |
580 | } |
581 | |
582 | // Getter for `_shape_tuple`. |
583 | static PyObject* EagerTensor_shape_tuple(EagerTensor* self) { |
584 | auto handle = self->handle; |
585 | int n = TFE_TensorHandleNumDims(handle, &self->status); |
586 | TF_Code code = TF_GetCode(&self->status); |
587 | if (code != TF_OK) { |
588 | RaiseExceptionTypeFromTFStatus(&self->status); |
589 | // Cleanup self->status before returning. |
590 | self->status.status = ::tensorflow::OkStatus(); |
591 | return nullptr; |
592 | } |
593 | PyObject* shape = PyTuple_New(n); |
594 | if (PyErr_Occurred()) return nullptr; |
595 | for (int i = 0; i < n; ++i) { |
596 | int64_t dim_c_value = TFE_TensorHandleDim(handle, i, &self->status); |
597 | PyObject* dim; |
598 | // The C++ convention is -1 for unknown/variable axis lengths. Translate |
599 | // that to the Python "None" convention. Unknown axis lengths are unusual |
600 | // for eager tensors. |
601 | if (dim_c_value < 0) { |
602 | Py_IncRef(Py_None); |
603 | dim = Py_None; |
604 | } else { |
605 | dim = PyLong_FromLongLong(dim_c_value); |
606 | } |
607 | code = TF_GetCode(&self->status); |
608 | if (code != TF_OK || dim == nullptr || |
609 | PyTuple_SetItem(shape, i, dim) != 0) { |
610 | if (code != TF_OK) { |
611 | RaiseExceptionTypeFromTFStatus(&self->status); |
612 | } else { |
613 | PyErr_SetString(PyExc_RuntimeError, "Error while creating shape" ); |
614 | } |
615 | // Cleanup self->status before returning. |
616 | self->status.status = ::tensorflow::OkStatus(); |
617 | Py_DECREF(shape); |
618 | if (dim != nullptr) Py_DECREF(dim); |
619 | return nullptr; |
620 | } |
621 | } |
622 | return shape; |
623 | } |
624 | |
625 | // Getter for `_rank`. |
626 | static PyObject* EagerTensor_rank(EagerTensor* self) { |
627 | int num_dims = TFE_TensorHandleNumDims(self->handle, &self->status); |
628 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) { |
629 | // Cleanup self->status before returning. |
630 | self->status.status = ::tensorflow::OkStatus(); |
631 | return nullptr; |
632 | } |
633 | #if PY_MAJOR_VERSION < 3 |
634 | return PyInt_FromLong(num_dims); |
635 | #else |
636 | return PyLong_FromLong(num_dims); |
637 | #endif |
638 | } |
639 | |
640 | // Getter for `_num_elements`. |
641 | static PyObject* EagerTensor_num_elements(EagerTensor* self) { |
642 | auto handle = self->handle; |
643 | int n = TFE_TensorHandleNumElements(handle, &self->status); |
644 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) { |
645 | // Cleanup self->status before returning. |
646 | self->status.status = ::tensorflow::OkStatus(); |
647 | return nullptr; |
648 | } |
649 | return PyLong_FromLongLong(n); |
650 | } |
651 | |
652 | static PyObject* EagerTensor_handle_data(EagerTensor* self, void* unused) { |
653 | Py_INCREF(self->handle_data); |
654 | return self->handle_data; |
655 | } |
656 | |
657 | static int EagerTensor_sethandle_data(EagerTensor* self, PyObject* value, |
658 | void* unused) { |
659 | Py_DECREF(self->handle_data); |
660 | Py_INCREF(value); |
661 | self->handle_data = value; |
662 | return 0; |
663 | } |
664 | |
665 | static PyObject* EagerTensor_tensor_shape(EagerTensor* self, void* unused) { |
666 | Py_INCREF(self->tensor_shape); |
667 | return self->tensor_shape; |
668 | } |
669 | |
670 | static int EagerTensor_settensor_shape(EagerTensor* self, PyObject* value, |
671 | void* unused) { |
672 | Py_DECREF(self->tensor_shape); |
673 | Py_INCREF(value); |
674 | self->tensor_shape = value; |
675 | return 0; |
676 | } |
677 | |
678 | // Function `_copy_to_device`. |
679 | static PyObject* EagerTensor_copy_to_device(EagerTensor* self, PyObject* args, |
680 | PyObject* kwds) { |
681 | if (!_PyArg_NoKeywords("copy_to_device" , kwds)) return nullptr; |
682 | |
683 | const char* device_name = nullptr; |
684 | if (!PyArg_ParseTuple(args, "O&:copy_to_device" , ConvertDeviceName, |
685 | &device_name)) { |
686 | return nullptr; |
687 | } |
688 | |
689 | // Note that this is a shallow copy and will share the underlying buffer |
690 | // if copying to the same device. |
691 | TFE_TensorHandle* handle = TFE_TensorHandleCopyToDevice( |
692 | self->handle, GetContextHandle(self->context), device_name, |
693 | &self->status); |
694 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, |
695 | PyExc_RuntimeError)) { |
696 | // Cleanup self->status before returning. |
697 | self->status.status = ::tensorflow::OkStatus(); |
698 | return nullptr; |
699 | } |
700 | |
701 | return EagerTensorFromHandle(handle); |
702 | } |
703 | |
704 | // Function `_numpy_internal`. |
705 | // Convert an EagerTensor to a Python numpy.ndarray object. |
706 | // The two may share underlying storage so changes to one may reflect in the |
707 | // other. |
708 | // Note that if `self` is not on CPU, we raise an Exception. |
709 | static PyObject* EagerTensor_numpy_internal(EagerTensor* self) { |
710 | auto* py_array = TFE_TensorHandleToNumpy(self->handle, &self->status); |
711 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, nullptr)) { |
712 | Py_XDECREF(py_array); |
713 | // Cleanup self->status before returning. |
714 | self->status.status = ::tensorflow::OkStatus(); |
715 | return nullptr; |
716 | } else { |
717 | return PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array)); |
718 | } |
719 | } |
720 | |
721 | // Function `_prefer_custom_summarizer`. |
722 | // |
723 | // A hint that callers should prefer `SummarizeValue` to resolving this handle |
724 | // and formatting the tensor. |
725 | static PyObject* EagerTensor_prefer_custom_summarizer(EagerTensor* self) { |
726 | if (tensorflow::unwrap(self->handle)->PreferCustomSummarizer()) { |
727 | Py_RETURN_TRUE; |
728 | } else { |
729 | Py_RETURN_FALSE; |
730 | } |
731 | } |
732 | |
733 | // Function `_summarize_value`. |
734 | // |
735 | // Returns a string PyObject which summarizes the value of this tensor. It does |
736 | // not include a shape or dtype. |
737 | static PyObject* EagerTensor_summarize_value(EagerTensor* self) { |
738 | std::string summary; |
739 | tensorflow::Status status = |
740 | tensorflow::unwrap(self->handle)->SummarizeValue(summary); |
741 | if (MaybeRaiseExceptionFromStatus(status, nullptr)) { |
742 | return nullptr; |
743 | } |
744 | return PyUnicode_FromString(summary.c_str()); |
745 | } |
746 | |
747 | // Getter `device`. |
748 | static PyObject* EagerTensor_device(EagerTensor* self) { |
749 | const char* device = TFE_TensorHandleDeviceName(self->handle, &self->status); |
750 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, |
751 | PyExc_ValueError)) { |
752 | // Cleanup self->status before returning. |
753 | self->status.status = ::tensorflow::OkStatus(); |
754 | return nullptr; |
755 | } |
756 | #if PY_MAJOR_VERSION >= 3 |
757 | return PyUnicode_FromString(device); |
758 | #else |
759 | return PyBytes_FromString(device); |
760 | #endif |
761 | } |
762 | |
763 | // Getter `backing_device`. |
764 | static PyObject* EagerTensor_backing_device(EagerTensor* self) { |
765 | const char* device = |
766 | TFE_TensorHandleBackingDeviceName(self->handle, &self->status); |
767 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, |
768 | PyExc_ValueError)) { |
769 | // Cleanup self->status before returning. |
770 | self->status.status = ::tensorflow::OkStatus(); |
771 | return nullptr; |
772 | } |
773 | #if PY_MAJOR_VERSION >= 3 |
774 | return PyUnicode_FromString(device); |
775 | #else |
776 | return PyBytes_FromString(device); |
777 | #endif |
778 | } |
779 | |
780 | // Getter `is_packed`. |
781 | static PyObject* EagerTensor_is_packed(EagerTensor* self) { |
782 | return PyBool_FromLong(self->is_packed); |
783 | } |
784 | |
785 | static PyGetSetDef EagerTensor_getsetters[] = { |
786 | {const_cast<char*>("_id" ), (getter)EagerTensor_getid, nullptr, |
787 | const_cast<char*>("Tensor ID." ), nullptr}, |
788 | {const_cast<char*>("device" ), (getter)EagerTensor_device, nullptr, |
789 | const_cast<char*>("Device of op that produced the tensor." ), nullptr}, |
790 | {const_cast<char*>("backing_device" ), (getter)EagerTensor_backing_device, |
791 | nullptr, const_cast<char*>("Device on which tensor's memory is resident." ), |
792 | nullptr}, |
793 | {const_cast<char*>("is_packed" ), (getter)EagerTensor_is_packed, nullptr, |
794 | const_cast<char*>("Whether the EagerTensor is a packed tensor or not." ), |
795 | nullptr}, |
796 | {const_cast<char*>("_handle_data" ), (getter)EagerTensor_handle_data, |
797 | (setter)EagerTensor_sethandle_data, |
798 | const_cast<char*>("Shape/DType data if the EagerTensor is a DT_RESOURCE" ), |
799 | nullptr}, |
800 | {const_cast<char*>("_tensor_shape" ), (getter)EagerTensor_tensor_shape, |
801 | (setter)EagerTensor_settensor_shape, |
802 | const_cast<char*>("Shape of the tensor." ), nullptr}, |
803 | {nullptr} /* Sentinel */ |
804 | }; |
805 | |
806 | #if PY_MAJOR_VERSION < 3 |
807 | // Only used for Python2 since Python3 seems to set the __dict__ correctly. |
808 | static PyMemberDef EagerTensor_members[] = { |
809 | {const_cast<char*>("__dict__" ), T_OBJECT, offsetof(EagerTensor, dict), |
810 | READONLY}, |
811 | {nullptr}, |
812 | }; |
813 | #endif |
814 | |
815 | static PyMethodDef EagerTensor_methods[] = { |
816 | {"_numpy_internal" , (PyCFunction)EagerTensor_numpy_internal, METH_NOARGS, |
817 | PyDoc_STR("Internal method to get a NumPy array for the tensor." )}, |
818 | {"_datatype_enum" , (PyCFunction)EagerTensor_datatype_enum, METH_NOARGS, |
819 | PyDoc_STR("The DType of the tensor as an enum." )}, |
820 | {"_shape_tuple" , (PyCFunction)EagerTensor_shape_tuple, METH_NOARGS, |
821 | PyDoc_STR("The shape of the tensor as a python tuple." )}, |
822 | {"_rank" , (PyCFunction)EagerTensor_rank, METH_NOARGS, |
823 | PyDoc_STR("The rank of the tensor." )}, |
824 | {"_copy_to_device" , (PyCFunction)EagerTensor_copy_to_device, |
825 | METH_VARARGS | METH_KEYWORDS, |
826 | PyDoc_STR("Copies the tensor to the desired device." )}, |
827 | {"_num_elements" , (PyCFunction)EagerTensor_num_elements, METH_NOARGS, |
828 | PyDoc_STR("Number of elements in the tensor." )}, |
829 | {"_prefer_custom_summarizer" , |
830 | (PyCFunction)EagerTensor_prefer_custom_summarizer, METH_NOARGS, |
831 | PyDoc_STR("Indicates whether _numpy_internal loses information." )}, |
832 | {"_summarize_value" , (PyCFunction)EagerTensor_summarize_value, METH_NOARGS, |
833 | PyDoc_STR("A string which summarizes the value of this tensor." )}, |
834 | {nullptr, nullptr}, |
835 | }; |
836 | |
837 | static int EagerTensor_getbuffer(EagerTensor* self, Py_buffer* view, |
838 | int flags) { |
839 | if ((flags & PyBUF_WRITABLE) == PyBUF_WRITABLE) { |
840 | PyErr_SetString(PyExc_BufferError, "EagerTensor is not writable." ); |
841 | return -1; |
842 | } |
843 | |
844 | // TensorHandleToNumpy is zero-copy for everything but DT_RESOURCE and |
845 | // DT_STRING so the following is only slightly slower than a NumPy-free |
846 | // implementation. |
847 | auto py_array = tensorflow::make_safe( |
848 | TFE_TensorHandleToNumpy(self->handle, &self->status)); |
849 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(&self->status, |
850 | PyExc_BufferError)) { |
851 | // Cleanup self->status before returning. |
852 | self->status.status = ::tensorflow::OkStatus(); |
853 | return -1; |
854 | } |
855 | if (PyObject_GetBuffer(py_array.get(), view, flags) < 0) { |
856 | return -1; |
857 | } |
858 | view->readonly = 1; |
859 | return 0; |
860 | } |
861 | |
862 | static PyBufferProcs EagerTensor_as_buffer = { |
863 | #if PY_MAJOR_VERSION < 3 |
864 | nullptr, nullptr, nullptr, nullptr, |
865 | #endif |
866 | (getbufferproc)EagerTensor_getbuffer, |
867 | // Never called because getbufferproc delegates to NumPy. |
868 | (releasebufferproc) nullptr}; |
869 | |
870 | // Note that here we are trying to dynamically create a new class as a subclass |
871 | // of a "HEAPTYPE" class that is itself created in python code and passed in at |
872 | // runtime. This is fairly atypical and undocumented. |
873 | // |
874 | // We use the following strategy for this. Unfortunately, we have to use |
875 | // different approaches for python2.x vs python3.x |
876 | // For python2.x, we create the class as a static type and set its tp_base to |
877 | // the passed in type. Unfortunately setting tp_flags to include |
878 | // Py_TPFLAGS_HEAPTYPE does not work by itself since it needs some more |
879 | // initialization of the underlying PyHeapTypeObject and not doing that leads to |
880 | // some random crashes especially during garbage collection. |
881 | // python3.x explicitly disables a static subclass of a HEAPTYPE base class. |
882 | // However it provides a new function, PyType_FromSpecWithBases, to create |
883 | // types dynamically. |
884 | |
885 | // Type object for EagerTensor. This is set by TFE_Py_InitEagerTensor. |
886 | PyTypeObject* EagerTensorType = nullptr; |
887 | |
888 | #if PY_MAJOR_VERSION >= 3 |
889 | static PyType_Slot EagerTensor_Type_slots[] = { |
890 | {Py_tp_dealloc, reinterpret_cast<void*>(EagerTensor_dealloc)}, |
891 | {Py_tp_methods, reinterpret_cast<void*>(EagerTensor_methods)}, |
892 | {Py_tp_getset, reinterpret_cast<void*>(EagerTensor_getsetters)}, |
893 | {Py_tp_init, reinterpret_cast<void*>(EagerTensor_init)}, |
894 | {0, nullptr}, |
895 | }; |
896 | #else |
897 | |
898 | #define EAGER_TENSOR_TPFLAGS (Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HAVE_NEWBUFFER) |
899 | |
900 | // TODO(agarwal): support active_trace. |
901 | static PyTypeObject _EagerTensorType = { |
902 | // clang-format off |
903 | PyVarObject_HEAD_INIT(nullptr, 0) |
904 | // clang-format on |
905 | "EagerTensor" , /* tp_name */ |
906 | sizeof(EagerTensor), /* tp_basicsize */ |
907 | 0, /* tp_itemsize */ |
908 | (destructor)EagerTensor_dealloc, /* tp_dealloc */ |
909 | #if PY_VERSION_HEX < 0x03080000 |
910 | nullptr, /* tp_print */ |
911 | #else |
912 | 0, /* tp_vectorcall_offset */ |
913 | #endif |
914 | nullptr, /* tp_getattr */ |
915 | nullptr, /* tp_setattr */ |
916 | nullptr, /* tp_compare */ |
917 | nullptr, /* tp_repr */ |
918 | nullptr, /* tp_as_number */ |
919 | nullptr, /* tp_as_sequence */ |
920 | nullptr, /* tp_as_mapping */ |
921 | nullptr, /* tp_hash */ |
922 | nullptr, /* tp_call */ |
923 | nullptr, /* tp_str */ |
924 | nullptr, /* tp_getattro */ |
925 | nullptr, /* tp_setattro */ |
926 | &EagerTensor_as_buffer, /* tp_as_buffer */ |
927 | EAGER_TENSOR_TPFLAGS, /* tp_flags */ |
928 | nullptr, /* tp_doc */ |
929 | nullptr, /* tp_traverse */ |
930 | nullptr, /* tp_clear */ |
931 | nullptr, /* tp_richcompare */ |
932 | offsetof(EagerTensor, weakreflist), /* tp_weaklistoffset */ |
933 | nullptr, /* tp_iter */ |
934 | nullptr, /* tp_iternext */ |
935 | EagerTensor_methods, /* tp_methods */ |
936 | EagerTensor_members, /* tp_members */ |
937 | EagerTensor_getsetters, /* tp_getset */ |
938 | nullptr, /* tp_base */ |
939 | nullptr, /* tp_dict */ |
940 | nullptr, /* tp_descr_get */ |
941 | nullptr, /* tp_descr_set */ |
942 | offsetof(EagerTensor, dict), /* tp_dictoffset */ |
943 | (initproc)EagerTensor_init, /* tp_init */ |
944 | nullptr, /* tp_alloc */ |
945 | nullptr, /* tp_new */ |
946 | }; |
947 | |
948 | #endif |
949 | |
950 | } // extern "C" |
951 | |
952 | bool EagerTensor_CheckExact(const PyObject* o) { |
953 | return Py_TYPE(o) == EagerTensorType; |
954 | } |
955 | |
956 | TFE_TensorHandle* EagerTensor_Handle(const PyObject* o) { |
957 | return reinterpret_cast<const EagerTensor*>(o)->handle; |
958 | } |
959 | |
960 | PyObject* EagerTensorFromHandle(TFE_TensorHandle* handle, |
961 | const bool is_packed) { |
962 | if (handle == nullptr) { |
963 | return nullptr; |
964 | } |
965 | EagerTensor* t = reinterpret_cast<EagerTensor*>( |
966 | EagerTensorType->tp_new(EagerTensorType, EmptyTuple(), EmptyDict())); |
967 | if (t != nullptr) { |
968 | t->id = get_uid(); |
969 | t->is_packed = is_packed; |
970 | Py_INCREF(Py_None); |
971 | t->handle_data = Py_None; |
972 | Py_INCREF(Py_None); |
973 | t->tensor_shape = Py_None; |
974 | t->handle = handle; |
975 | t->status.status = ::tensorflow::OkStatus(); |
976 | t->weakreflist = nullptr; |
977 | PyObject* py_context = GetPyEagerContext(); |
978 | if (py_context == nullptr) { |
979 | LOG(ERROR) << "Cannot create an eager tensor before eager context has " |
980 | "been set or after it has been deleted" ; |
981 | return nullptr; |
982 | } |
983 | t->context = py_context; |
984 | |
985 | if (!MaybeInvokeCreatedOnEagerTensorProfiler(t)) { |
986 | return nullptr; |
987 | } |
988 | } |
989 | return reinterpret_cast<PyObject*>(t); |
990 | } |
991 | |
992 | int64_t PyEagerTensor_ID(const PyObject* tensor) { |
993 | DCHECK(EagerTensor_CheckExact(tensor)); |
994 | return reinterpret_cast<const EagerTensor*>(tensor)->id; |
995 | } |
996 | |
997 | tensorflow::DataType PyEagerTensor_Dtype(const PyObject* tensor) { |
998 | DCHECK(EagerTensor_CheckExact(tensor)); |
999 | return static_cast<tensorflow::DataType>(TFE_TensorHandleDataType( |
1000 | reinterpret_cast<const EagerTensor*>(tensor)->handle)); |
1001 | } |
1002 | |
1003 | int64_t PyEagerTensor_NumElements(PyObject* tensor) { |
1004 | DCHECK(EagerTensor_CheckExact(tensor)); |
1005 | EagerTensor* as_c_eager_tensor = reinterpret_cast<EagerTensor*>(tensor); |
1006 | int64_t result = TFE_TensorHandleNumElements(as_c_eager_tensor->handle, |
1007 | &as_c_eager_tensor->status); |
1008 | |
1009 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(&as_c_eager_tensor->status, |
1010 | PyExc_ValueError)) { |
1011 | // Cleanup status before returning. |
1012 | as_c_eager_tensor->status.status = ::tensorflow::OkStatus(); |
1013 | return -1; |
1014 | } |
1015 | |
1016 | return result; |
1017 | } |
1018 | |
1019 | PyObject* TFE_Py_InitEagerTensor(PyObject* base_class) { |
1020 | if (!PyType_Check(base_class)) { |
1021 | PyErr_SetString( |
1022 | PyExc_TypeError, |
1023 | tensorflow::strings::StrCat( |
1024 | "Expecting a class definition for `base_class` passed to " , |
1025 | "TFE_InitEagerTensor. Got " , Py_TYPE(base_class)->tp_name) |
1026 | .c_str()); |
1027 | return nullptr; |
1028 | } |
1029 | // Note that we allocated kMaxEagerTensorParentSize bytes of unused space in |
1030 | // EagerTensor to allow for the space usage of the base class. |
1031 | PyTypeObject* base_class_type = reinterpret_cast<PyTypeObject*>(base_class); |
1032 | if (base_class_type->tp_basicsize > kMaxEagerTensorParentSize) { |
1033 | PyErr_SetString( |
1034 | PyExc_TypeError, |
1035 | tensorflow::strings::StrCat( |
1036 | "Unable to create subclass EagerTensor from base class " , |
1037 | Py_TYPE(base_class)->tp_name, |
1038 | ". Need its size to be <= " , kMaxEagerTensorParentSize) |
1039 | .c_str()); |
1040 | return nullptr; |
1041 | } |
1042 | if (base_class_type->tp_itemsize != 0) { |
1043 | PyErr_SetString( |
1044 | PyExc_TypeError, |
1045 | tensorflow::strings::StrCat( |
1046 | "Unable to create subclass EagerTensor from base class " , |
1047 | Py_TYPE(base_class)->tp_name, |
1048 | " which supports variable length instances." ) |
1049 | .c_str()); |
1050 | return nullptr; |
1051 | } |
1052 | Py_INCREF(base_class); |
1053 | #if PY_MAJOR_VERSION >= 3 |
1054 | PyObject* bases = PyTuple_New(1); |
1055 | PyTuple_SET_ITEM(bases, 0, base_class); |
1056 | |
1057 | tensorflow::Safe_PyObjectPtr base_class_module( |
1058 | PyObject_GetAttrString(base_class, "__module__" )); |
1059 | const char* module = nullptr; |
1060 | if (PyErr_Occurred()) { |
1061 | PyErr_Clear(); |
1062 | module = "__builtin__" ; |
1063 | } else { |
1064 | module = PyBytes_AsString(base_class_module.get()); |
1065 | if (module == nullptr) { |
1066 | PyErr_Clear(); |
1067 | module = PyUnicode_AsUTF8(base_class_module.get()); |
1068 | if (module == nullptr) { |
1069 | PyErr_Clear(); |
1070 | module = "__builtin__" ; |
1071 | } |
1072 | } |
1073 | } |
1074 | |
1075 | // NOTE: The c_str from this string needs to outlast the function, hence is |
1076 | // static. |
1077 | static tensorflow::string fully_qualified_name = |
1078 | tensorflow::strings::StrCat(module, ".EagerTensor" ); |
1079 | |
1080 | static PyType_Spec EagerTensor_Type_spec = { |
1081 | fully_qualified_name.c_str(), sizeof(EagerTensor), 0, |
1082 | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_HEAPTYPE, EagerTensor_Type_slots}; |
1083 | |
1084 | EagerTensorType = reinterpret_cast<PyTypeObject*>( |
1085 | PyType_FromSpecWithBases(&EagerTensor_Type_spec, bases)); |
1086 | if (PyErr_Occurred()) { |
1087 | return nullptr; |
1088 | } |
1089 | if (EagerTensorType == nullptr) { |
1090 | PyErr_SetString(PyExc_RuntimeError, "Error while creating EagerTensorType" ); |
1091 | return nullptr; |
1092 | } |
1093 | EagerTensorType->tp_dictoffset = offsetof(EagerTensor, dict); |
1094 | EagerTensorType->tp_as_buffer = &EagerTensor_as_buffer; |
1095 | #else |
1096 | _EagerTensorType.tp_base = base_class_type; |
1097 | |
1098 | if (PyType_Ready(&_EagerTensorType) < 0) { |
1099 | if (PyErr_Occurred()) return nullptr; |
1100 | PyErr_SetString(PyExc_RuntimeError, |
1101 | "Error while creating EagerTensor type." ); |
1102 | return nullptr; |
1103 | } |
1104 | EagerTensorType = &_EagerTensorType; |
1105 | Py_INCREF(EagerTensorType); |
1106 | #endif |
1107 | return reinterpret_cast<PyObject*>(EagerTensorType); |
1108 | } |
1109 | |
1110 | PyObject* TFE_Py_SetEagerTensorProfiler(PyObject* profiler) { |
1111 | Py_XDECREF(eager_tensor_profiler); |
1112 | |
1113 | if (profiler == Py_None) { |
1114 | eager_tensor_profiler = nullptr; |
1115 | } else { |
1116 | eager_tensor_profiler = profiler; |
1117 | Py_INCREF(eager_tensor_profiler); |
1118 | } |
1119 | Py_RETURN_NONE; |
1120 | } |
1121 | |
1122 | PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim) { |
1123 | if (!PyList_Check(tensors) && !PyTuple_Check(tensors)) { |
1124 | PyErr_SetString(PyExc_TypeError, |
1125 | tensorflow::strings::StrCat( |
1126 | "tensors argument must be a list or a tuple. Got \"" , |
1127 | Py_TYPE(tensors)->tp_name, "\"" ) |
1128 | .c_str()); |
1129 | return nullptr; |
1130 | } |
1131 | if (slice_dim < 0) { |
1132 | PyErr_SetString( |
1133 | PyExc_ValueError, |
1134 | tensorflow::strings::StrCat("Slice dimension must be non-negative. " |
1135 | "Got " , |
1136 | slice_dim) |
1137 | .c_str()); |
1138 | return nullptr; |
1139 | } |
1140 | |
1141 | PyObject* py_context = GetPyEagerContext(); |
1142 | if (py_context == nullptr) { |
1143 | PyErr_SetString(PyExc_RuntimeError, tensorflow::strings::StrCat( |
1144 | "Cannot create EagerTensor when " |
1145 | "EagerContext is not valid" ) |
1146 | .c_str()); |
1147 | return nullptr; |
1148 | } |
1149 | |
1150 | TFE_Context* ctx = GetContextHandle(py_context); |
1151 | |
1152 | Py_ssize_t num_tensors = PySequence_Fast_GET_SIZE(tensors); |
1153 | PyObject** tensors_array = PySequence_Fast_ITEMS(tensors); |
1154 | int64_t num_tensors_int = static_cast<int64_t>(num_tensors); |
1155 | |
1156 | auto status = tensorflow::make_safe(TF_NewStatus()); |
1157 | |
1158 | // Create an empty tensor. |
1159 | auto* tensor = tensorflow::unwrap(ctx)->CreateTensor( |
1160 | tensorflow::DT_INT32, /*dim_sizes=*/{num_tensors_int}); |
1161 | |
1162 | if (num_tensors_int > 0) { |
1163 | int32_t* data = reinterpret_cast<int32_t*>(tensor->Data()); |
1164 | |
1165 | // Fill the tensor with dims. |
1166 | for (Py_ssize_t i = 0; i < num_tensors; ++i) { |
1167 | PyObject* tensor_obj = tensors_array[i]; |
1168 | if (!EagerTensor_CheckExact(tensor_obj)) { |
1169 | PyErr_SetString( |
1170 | PyExc_TypeError, |
1171 | tensorflow::strings::StrCat("Expected a list of EagerTensors but " |
1172 | "element " , |
1173 | i, " has type \"" , |
1174 | Py_TYPE(tensor_obj)->tp_name, "\"" ) |
1175 | .c_str()); |
1176 | return nullptr; |
1177 | } |
1178 | |
1179 | EagerTensor* t = reinterpret_cast<EagerTensor*>(tensor_obj); |
1180 | TFE_TensorHandle* handle = t->handle; |
1181 | int num_dims = TFE_TensorHandleNumDims(handle, status.get()); |
1182 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), |
1183 | PyExc_ValueError)) { |
1184 | return nullptr; |
1185 | } |
1186 | if (slice_dim >= num_dims) { |
1187 | PyErr_SetString( |
1188 | PyExc_IndexError, |
1189 | tensorflow::strings::StrCat("Slice dimension (" , slice_dim, |
1190 | ") must be smaller than rank of all " |
1191 | "tensors, but tensor at index " , |
1192 | i, " has rank " , num_dims) |
1193 | .c_str()); |
1194 | return nullptr; |
1195 | } |
1196 | int64_t dim = TFE_TensorHandleDim(handle, slice_dim, status.get()); |
1197 | if (tensorflow::MaybeRaiseExceptionFromTFStatus(status.get(), |
1198 | PyExc_ValueError)) { |
1199 | return nullptr; |
1200 | } |
1201 | data[i] = dim; |
1202 | } |
1203 | } |
1204 | |
1205 | TFE_TensorHandle* handle = |
1206 | tensorflow::wrap(tensorflow::unwrap(ctx)->CreateLocalHandle(tensor)); |
1207 | |
1208 | if (!status->status.ok()) { |
1209 | PyErr_SetString( |
1210 | PyExc_RuntimeError, |
1211 | tensorflow::strings::StrCat("Failed to construct new tensor handle: " , |
1212 | TF_Message(status.get())) |
1213 | .c_str()); |
1214 | return nullptr; |
1215 | } |
1216 | |
1217 | return EagerTensorFromHandle(handle); |
1218 | } |
1219 | |
1220 | PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor) { |
1221 | if (!EagerTensor_CheckExact(tensor)) { |
1222 | PyErr_SetString( |
1223 | PyExc_TypeError, |
1224 | tensorflow::strings::StrCat("Expected an EagerTensors but got type \"" , |
1225 | Py_TYPE(tensor)->tp_name, "\"" ) |
1226 | .c_str()); |
1227 | return nullptr; |
1228 | } |
1229 | TFE_TensorHandle* handle = EagerTensor_Handle(tensor); |
1230 | |
1231 | auto status = tensorflow::make_safe(TF_NewStatus()); |
1232 | TFE_TensorDebugInfo* debug_info = |
1233 | TFE_TensorHandleTensorDebugInfo(handle, status.get()); |
1234 | if (!status->status.ok()) { |
1235 | PyErr_SetString( |
1236 | PyExc_RuntimeError, |
1237 | tensorflow::strings::StrCat("Error retrieving tensor's device shape: " , |
1238 | TF_Message(status.get())) |
1239 | .c_str()); |
1240 | return nullptr; |
1241 | } |
1242 | |
1243 | int rank = TFE_TensorDebugInfoOnDeviceNumDims(debug_info); |
1244 | PyObject* shape = PyTuple_New(rank); |
1245 | for (int i = 0; i < rank; ++i) { |
1246 | int64_t dim_size = TFE_TensorDebugInfoOnDeviceDim(debug_info, i); |
1247 | PyTuple_SET_ITEM(shape, i, PyLong_FromLongLong(dim_size)); |
1248 | } |
1249 | TFE_DeleteTensorDebugInfo(debug_info); |
1250 | |
1251 | return shape; |
1252 | } |
1253 | |