1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/python/client/tf_session_helper.h" |
17 | |
18 | #include <cstring> |
19 | |
20 | #include "tensorflow/c/c_api.h" |
21 | #include "tensorflow/c/c_api_internal.h" |
22 | #include "tensorflow/c/tf_buffer_internal.h" |
23 | #include "tensorflow/c/tf_status_helper.h" |
24 | #include "tensorflow/core/framework/allocator.h" |
25 | #include "tensorflow/core/framework/attr_value.pb.h" |
26 | #include "tensorflow/core/framework/attr_value_util.h" |
27 | #include "tensorflow/core/framework/log_memory.h" |
28 | #include "tensorflow/core/framework/op_kernel.h" |
29 | #include "tensorflow/core/graph/tensor_id.h" |
30 | #include "tensorflow/core/lib/core/coding.h" |
31 | #include "tensorflow/core/lib/strings/stringprintf.h" |
32 | #include "tensorflow/core/platform/types.h" |
33 | #include "tensorflow/core/util/equal_graph_def.h" |
34 | #include "tensorflow/python/client/session_ref.h" |
35 | #include "tensorflow/python/lib/core/ndarray_tensor.h" |
36 | #include "tensorflow/python/lib/core/ndarray_tensor_bridge.h" |
37 | #include "tensorflow/python/lib/core/safe_ptr.h" |
38 | |
39 | namespace tensorflow { |
40 | |
41 | namespace { |
42 | |
43 | static const char* kFeedDictErrorMsg = |
44 | "feed_dict must be a dictionary mapping strings to NumPy arrays." ; |
45 | } // end namespace |
46 | |
47 | TF_Session* TF_NewSessionRef(TF_Graph* graph, const TF_SessionOptions* opts, |
48 | TF_Status* status) { |
49 | TF_Session* tf_session = TF_NewSession(graph, opts, status); |
50 | if (tf_session == nullptr) { |
51 | return nullptr; |
52 | } |
53 | |
54 | Session* session = reinterpret_cast<Session*>(tf_session->session); |
55 | SessionRef* session_ref = new SessionRef(session); |
56 | tf_session->session = session_ref; |
57 | return tf_session; |
58 | } |
59 | |
60 | void TF_Run_wrapper_helper(TF_DeprecatedSession* session, const char* handle, |
61 | const TF_Buffer* run_options, PyObject* feed_dict, |
62 | const NameVector& output_names, |
63 | const NameVector& target_nodes, |
64 | TF_Status* out_status, PyObjectVector* out_values, |
65 | TF_Buffer* run_outputs) { |
66 | // 1. Convert the feed inputs to the appropriate form for TF_Run. |
67 | if (!PyDict_Check(feed_dict)) { |
68 | Set_TF_Status_from_Status(out_status, |
69 | errors::InvalidArgument(kFeedDictErrorMsg)); |
70 | return; |
71 | } |
72 | |
73 | NameVector input_names; |
74 | std::vector<Safe_TF_TensorPtr> inputs_safe; // Used to delete tensors. |
75 | TF_TensorVector inputs_unsafe; // Used to contain the arg to TF_Run. |
76 | |
77 | PyObject* key; |
78 | PyObject* value; |
79 | Py_ssize_t pos = 0; |
80 | int index = 0; |
81 | Status s; |
82 | |
83 | while (PyDict_Next(feed_dict, &pos, &key, &value)) { |
84 | char* key_string = PyBytes_AsString(key); |
85 | if (!key_string) { |
86 | Set_TF_Status_from_Status(out_status, |
87 | errors::InvalidArgument(kFeedDictErrorMsg)); |
88 | return; |
89 | } |
90 | input_names.push_back(key_string); |
91 | |
92 | inputs_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr))); |
93 | s = NdarrayToTensor(nullptr /*ctx*/, value, &inputs_safe.back()); |
94 | if (!s.ok()) { |
95 | Set_TF_Status_from_Status(out_status, s); |
96 | return; |
97 | } |
98 | inputs_unsafe.push_back(inputs_safe.back().get()); |
99 | ++index; |
100 | } |
101 | |
102 | // 2. Allocate a container for the output data. |
103 | TF_TensorVector outputs(output_names.size()); |
104 | |
105 | // In case any tensors were leftover from previous runs we might as well clear |
106 | // them here. |
107 | ClearDecrefCache(); |
108 | |
109 | // 3. Actually call TF_Run(). |
110 | Py_BEGIN_ALLOW_THREADS; |
111 | if (handle == nullptr) { |
112 | TF_Run(session, run_options, input_names.data(), inputs_unsafe.data(), |
113 | input_names.size(), const_cast<const char**>(output_names.data()), |
114 | outputs.data(), output_names.size(), |
115 | const_cast<const char**>(target_nodes.data()), target_nodes.size(), |
116 | run_outputs, out_status); |
117 | } else { |
118 | TF_PRun(session, handle, input_names.data(), inputs_unsafe.data(), |
119 | input_names.size(), const_cast<const char**>(output_names.data()), |
120 | outputs.data(), output_names.size(), |
121 | const_cast<const char**>(target_nodes.data()), target_nodes.size(), |
122 | out_status); |
123 | } |
124 | |
125 | Py_END_ALLOW_THREADS; |
126 | |
127 | // Decref any numpy arrays we are not using anymore. |
128 | ClearDecrefCache(); |
129 | |
130 | if (TF_GetCode(out_status) != TF_OK) { |
131 | return; |
132 | } |
133 | |
134 | // 4. We now own the fetched tensors, so set up a safe container to |
135 | // delete them when we exit this scope. |
136 | std::vector<Safe_TF_TensorPtr> tf_outputs_safe; |
137 | for (const auto& output : outputs) { |
138 | tf_outputs_safe.emplace_back(make_safe(output)); |
139 | } |
140 | |
141 | // 5. Convert the fetched tensors into numpy ndarrays. Store them in a safe |
142 | // container so that we do not leak |
143 | std::vector<Safe_PyObjectPtr> py_outputs_safe; |
144 | for (size_t i = 0; i < output_names.size(); ++i) { |
145 | PyObject* py_array; |
146 | s = TF_TensorToPyArray(std::move(tf_outputs_safe[i]), &py_array); |
147 | if (!s.ok()) { |
148 | Set_TF_Status_from_Status(out_status, s); |
149 | return; |
150 | } |
151 | py_outputs_safe.emplace_back( |
152 | make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array)))); |
153 | } |
154 | |
155 | // 6. If we reach this point, we have successfully built a list of objects |
156 | // so we can release them from the safe container. |
157 | for (auto& output : py_outputs_safe) { |
158 | out_values->push_back(output.release()); |
159 | } |
160 | } |
161 | |
162 | // Wrapper for TF_Run that converts the arguments to appropriate types. |
163 | // If *out_status is OK, the caller becomes the owner of the PyObjects |
164 | // in *out_values. |
165 | void TF_Run_wrapper(TF_DeprecatedSession* session, const TF_Buffer* run_options, |
166 | PyObject* feed_dict, const NameVector& output_names, |
167 | const NameVector& target_nodes, TF_Status* out_status, |
168 | PyObjectVector* out_values, TF_Buffer* run_outputs) { |
169 | TF_Run_wrapper_helper(session, nullptr, run_options, feed_dict, output_names, |
170 | target_nodes, out_status, out_values, run_outputs); |
171 | ClearDecrefCache(); |
172 | } |
173 | |
174 | namespace { |
175 | void MakeCallableHelper(tensorflow::Session* session, |
176 | const TF_Buffer* callable_options, int64_t* out_handle, |
177 | TF_Status* out_status) { |
178 | tensorflow::CallableOptions callable_options_proto; |
179 | if (callable_options != nullptr && |
180 | !callable_options_proto.ParseFromArray(callable_options->data, |
181 | callable_options->length)) { |
182 | Set_TF_Status_from_Status( |
183 | out_status, |
184 | errors::InvalidArgument("Unparseable CallableOptions proto" )); |
185 | return; |
186 | } |
187 | tensorflow::Session::CallableHandle handle; |
188 | Status s = session->MakeCallable(callable_options_proto, &handle); |
189 | if (!s.ok()) { |
190 | Set_TF_Status_from_Status(out_status, s); |
191 | return; |
192 | } |
193 | *out_handle = handle; |
194 | } |
195 | } // namespace |
196 | |
197 | void TF_DeprecatedSessionMakeCallable(TF_DeprecatedSession* session, |
198 | const TF_Buffer* callable_options, |
199 | int64_t* out_handle, TF_Status* status) { |
200 | MakeCallableHelper(session->session, callable_options, out_handle, status); |
201 | } |
202 | void TF_SessionMakeCallable(TF_Session* session, |
203 | const TF_Buffer* callable_options, |
204 | int64_t* out_handle, TF_Status* status) { |
205 | MakeCallableHelper(session->session, callable_options, out_handle, status); |
206 | } |
207 | |
208 | namespace { |
209 | void RunCallableHelper(tensorflow::Session* session, int64_t handle, |
210 | PyObject* feed_values, TF_Status* out_status, |
211 | PyObjectVector* out_values, TF_Buffer* run_metadata) { |
212 | // Convert feed values to a vector of tensorflow::Tensor objects. |
213 | std::vector<Tensor> input_tensors; |
214 | Status s; |
215 | { |
216 | feed_values = |
217 | PySequence_Fast(feed_values, "feed_values must be a sequence" ); |
218 | if (feed_values == nullptr) return; |
219 | Safe_PyObjectPtr feed_values_holder(make_safe(feed_values)); |
220 | Py_ssize_t len = PySequence_Fast_GET_SIZE(feed_values); |
221 | input_tensors.reserve(len); |
222 | for (Py_ssize_t i = 0; i < len; ++i) { |
223 | PyObject* elem = PySequence_Fast_GET_ITEM(feed_values, i); |
224 | if (!elem) { |
225 | Set_TF_Status_from_Status( |
226 | out_status, errors::Internal("Could not get feed value " , i)); |
227 | return; |
228 | } |
229 | Tensor t; |
230 | s = NdarrayToTensor(elem, &t); |
231 | if (!s.ok()) { |
232 | Set_TF_Status_from_Status(out_status, s); |
233 | return; |
234 | } |
235 | input_tensors.push_back(std::move(t)); |
236 | } |
237 | } |
238 | |
239 | RunMetadata run_metadata_proto; |
240 | |
241 | // Run the callable. |
242 | std::vector<Tensor> output_tensors; |
243 | Py_BEGIN_ALLOW_THREADS; |
244 | s = session->RunCallable(handle, input_tensors, &output_tensors, |
245 | &run_metadata_proto); |
246 | Py_END_ALLOW_THREADS; |
247 | |
248 | if (!s.ok()) { |
249 | Set_TF_Status_from_Status(out_status, s); |
250 | return; |
251 | } |
252 | |
253 | // If requested, serialize the RunMetadata to pass it back to the caller. |
254 | if (run_metadata != nullptr) { |
255 | s = MessageToBuffer(run_metadata_proto, run_metadata); |
256 | if (!s.ok()) { |
257 | Set_TF_Status_from_Status(out_status, s); |
258 | return; |
259 | } |
260 | } |
261 | |
262 | // Convert results to NumPy arrays. Since this can fail, stage the |
263 | // results via a safe container that takes care of decreasing the |
264 | // reference count on failure. |
265 | std::vector<Safe_PyObjectPtr> py_outputs_safe; |
266 | py_outputs_safe.reserve(output_tensors.size()); |
267 | for (const Tensor& output : output_tensors) { |
268 | PyObject* py_array; |
269 | s = TensorToNdarray(output, &py_array); |
270 | if (!s.ok()) { |
271 | Set_TF_Status_from_Status(out_status, s); |
272 | return; |
273 | } |
274 | py_outputs_safe.push_back( |
275 | make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array)))); |
276 | } |
277 | |
278 | // If we reach this point, we have successfully built a list of objects |
279 | // so we can release them from the safe container. |
280 | out_values->reserve(py_outputs_safe.size()); |
281 | for (auto& output : py_outputs_safe) { |
282 | out_values->push_back(output.release()); |
283 | } |
284 | } |
285 | } // namespace |
286 | |
287 | void TF_DeprecatedSessionRunCallable(TF_DeprecatedSession* session, |
288 | int64_t handle, PyObject* feed_values, |
289 | PyObjectVector* out_values, |
290 | TF_Buffer* run_metadata, |
291 | TF_Status* status) { |
292 | RunCallableHelper(session->session, handle, feed_values, status, out_values, |
293 | run_metadata); |
294 | ClearDecrefCache(); |
295 | } |
296 | void TF_SessionRunCallable(TF_Session* session, int64_t handle, |
297 | PyObject* feed_values, PyObjectVector* out_values, |
298 | TF_Buffer* run_metadata, TF_Status* status) { |
299 | RunCallableHelper(session->session, handle, feed_values, status, out_values, |
300 | run_metadata); |
301 | ClearDecrefCache(); |
302 | } |
303 | |
304 | void TF_DeprecatedSessionReleaseCallable(TF_DeprecatedSession* session, |
305 | int64_t handle, TF_Status* status) { |
306 | Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle)); |
307 | } |
308 | void TF_SessionReleaseCallable(TF_Session* session, int64_t handle, |
309 | TF_Status* status) { |
310 | Set_TF_Status_from_Status(status, session->session->ReleaseCallable(handle)); |
311 | } |
312 | |
313 | // Wrapper for TF_PRunSetup that converts the arguments to appropriate types. |
314 | // If *out_status is OK, the caller becomes the owner of *out_handle. |
315 | void TF_PRunSetup_wrapper(TF_DeprecatedSession* session, |
316 | const NameVector& input_names, |
317 | const NameVector& output_names, |
318 | const NameVector& target_nodes, TF_Status* out_status, |
319 | const char** out_handle) { |
320 | Py_BEGIN_ALLOW_THREADS; |
321 | TF_PRunSetup( |
322 | session, const_cast<const char**>(input_names.data()), input_names.size(), |
323 | const_cast<const char**>(output_names.data()), output_names.size(), |
324 | const_cast<const char**>(target_nodes.data()), target_nodes.size(), |
325 | out_handle, out_status); |
326 | Py_END_ALLOW_THREADS; |
327 | } |
328 | |
329 | // Wrapper for TF_PRun that converts the arguments to appropriate types. |
330 | // If *out_status is OK, the caller becomes the owner of the PyObjects |
331 | // in *out_values. |
332 | void TF_PRun_wrapper(TF_DeprecatedSession* session, const char* handle, |
333 | PyObject* feed_dict, const NameVector& output_names, |
334 | TF_Status* out_status, PyObjectVector* out_values) { |
335 | TF_Run_wrapper_helper(session, handle, nullptr, feed_dict, output_names, |
336 | NameVector(), out_status, out_values, nullptr); |
337 | ClearDecrefCache(); |
338 | } |
339 | |
340 | // Wrapper for TF_Reset that converts the string vectors to character arrays. |
341 | void TF_Reset_wrapper(const TF_SessionOptions* opt, |
342 | const NameVector& containers, TF_Status* status) { |
343 | TF_Reset(opt, const_cast<const char**>(containers.data()), containers.size(), |
344 | status); |
345 | } |
346 | |
347 | void TF_SessionRun_wrapper_helper(TF_Session* session, const char* handle, |
348 | const TF_Buffer* run_options, |
349 | const std::vector<TF_Output>& inputs, |
350 | const std::vector<PyObject*>& input_ndarrays, |
351 | const std::vector<TF_Output>& outputs, |
352 | const std::vector<TF_Operation*>& targets, |
353 | TF_Buffer* run_metadata, |
354 | TF_Status* out_status, |
355 | std::vector<PyObject*>* py_outputs) { |
356 | DCHECK_EQ(inputs.size(), input_ndarrays.size()); |
357 | DCHECK(py_outputs != nullptr); |
358 | DCHECK(py_outputs->empty()); |
359 | Status s; |
360 | |
361 | // Convert input ndarray PyObjects to TF_Tensors. We maintain a continuous |
362 | // array of TF_Tensor*s as well as scoped containers to make sure they're |
363 | // cleaned up properly. |
364 | // |
365 | // Memory management: |
366 | // NdarrayToTensor() creates a new ndarray PyObject from the input |
367 | // ndarray. We manage the new ndarray's lifetime in order to keep the |
368 | // underlying data buffer alive (the new ndarray also guarantees a contiguous |
369 | // data buffer). The new ndarray's data buffer is used to create the |
370 | // corresponding TF_Tensor. The TF_Tensor's deallocator will queue the new |
371 | // ndarray to be decref'd by the next ClearDecrefCache() call (we can't call |
372 | // Py_DECREF in the deallocator directly because the GIL must be held). |
373 | // |
374 | // Note that TF_Tensor may directly delegate its data and deallocator to a |
375 | // TensorBuffer, which may outlive the TF_Tensor (e.g. if the tensor gets |
376 | // queued or assigned to a variable). |
377 | TF_TensorVector input_vals; |
378 | std::vector<Safe_TF_TensorPtr> input_vals_safe; |
379 | for (PyObject* ndarray : input_ndarrays) { |
380 | input_vals_safe.emplace_back(make_safe(static_cast<TF_Tensor*>(nullptr))); |
381 | s = NdarrayToTensor(nullptr, ndarray, &input_vals_safe.back()); |
382 | if (!s.ok()) { |
383 | Set_TF_Status_from_Status(out_status, s); |
384 | return; |
385 | } |
386 | input_vals.push_back(input_vals_safe.back().get()); |
387 | } |
388 | |
389 | // Allocate space for output TF_Tensor*s |
390 | TF_TensorVector output_vals(outputs.size()); |
391 | |
392 | // Clear up any unused memory leftover from previous runs |
393 | ClearDecrefCache(); |
394 | |
395 | // Call TF_SessionRun() (and release GIL during execution) |
396 | Py_BEGIN_ALLOW_THREADS; |
397 | if (handle == nullptr) { |
398 | TF_SessionRun(session, run_options, inputs.data(), input_vals.data(), |
399 | inputs.size(), outputs.data(), output_vals.data(), |
400 | outputs.size(), targets.data(), targets.size(), run_metadata, |
401 | out_status); |
402 | } else { |
403 | TF_SessionPRun(session, handle, inputs.data(), input_vals.data(), |
404 | inputs.size(), outputs.data(), output_vals.data(), |
405 | outputs.size(), targets.data(), targets.size(), out_status); |
406 | } |
407 | Py_END_ALLOW_THREADS; |
408 | |
409 | // Create scoped containers for output tensors |
410 | std::vector<Safe_TF_TensorPtr> output_vals_safe; |
411 | for (TF_Tensor* output : output_vals) { |
412 | output_vals_safe.emplace_back(make_safe(output)); |
413 | } |
414 | |
415 | // Convert outputs to ndarrays (in scoped containers) |
416 | std::vector<Safe_PyObjectPtr> py_outputs_safe; |
417 | for (size_t i = 0; i < outputs.size(); ++i) { |
418 | PyObject* py_array; |
419 | s = TF_TensorToPyArray(std::move(output_vals_safe[i]), &py_array); |
420 | if (!s.ok()) { |
421 | Set_TF_Status_from_Status(out_status, s); |
422 | return; |
423 | } |
424 | py_outputs_safe.emplace_back( |
425 | make_safe(PyArray_Return(reinterpret_cast<PyArrayObject*>(py_array)))); |
426 | } |
427 | |
428 | // If we reach this point, we have successfully built a list of objects so we |
429 | // can release them from the safe container into the return vector. |
430 | for (size_t i = 0; i < outputs.size(); ++i) { |
431 | py_outputs->push_back(py_outputs_safe[i].release()); |
432 | } |
433 | } |
434 | |
435 | void TF_SessionRun_wrapper(TF_Session* session, const TF_Buffer* run_options, |
436 | const std::vector<TF_Output>& inputs, |
437 | const std::vector<PyObject*>& input_ndarrays, |
438 | const std::vector<TF_Output>& outputs, |
439 | const std::vector<TF_Operation*>& targets, |
440 | TF_Buffer* run_metadata, TF_Status* out_status, |
441 | std::vector<PyObject*>* py_outputs) { |
442 | TF_SessionRun_wrapper_helper(session, nullptr, run_options, inputs, |
443 | input_ndarrays, outputs, targets, run_metadata, |
444 | out_status, py_outputs); |
445 | // Release any unused ndarray references (see memory management comment in |
446 | // TF_SessionRun_wrapper_helper) |
447 | ClearDecrefCache(); |
448 | } |
449 | |
450 | string EqualGraphDefWrapper(const string& actual, const string& expected) { |
451 | GraphDef actual_def; |
452 | if (!actual_def.ParseFromString(actual)) { |
453 | return "actual is not a valid serialized GraphDef" ; |
454 | } |
455 | GraphDef expected_def; |
456 | if (!expected_def.ParseFromString(expected)) { |
457 | return "expected is not a valid serialized GraphDef" ; |
458 | } |
459 | string diff; |
460 | return EqualGraphDef(actual_def, expected_def, &diff) ? "" : diff; |
461 | } |
462 | |
463 | string EqualAttrValueWrapper(const string& actual, const string& expected) { |
464 | AttrValue actual_attr_value; |
465 | if (!actual_attr_value.ParseFromString(actual)) { |
466 | return "actual is not a valid serialized AttrValue" ; |
467 | } |
468 | |
469 | AttrValue expected_attr_value; |
470 | if (!expected_attr_value.ParseFromString(expected)) { |
471 | return "expected is not a valid serialized AttrValue" ; |
472 | } |
473 | |
474 | string diff; |
475 | if (!AreAttrValuesEqual(actual_attr_value, expected_attr_value)) { |
476 | diff = strings::Printf( |
477 | "Actual AttrValue %s does not match Expected AttrValue %s." , |
478 | SummarizeAttrValue(actual_attr_value).c_str(), |
479 | SummarizeAttrValue(expected_attr_value).c_str()); |
480 | } |
481 | return diff; |
482 | } |
483 | |
484 | // Return value set to 6 inlined elements so it fits in a 64-byte cache line. |
485 | tensorflow::gtl::InlinedVector<int64_t, 6> TF_GraphGetTensorShapeHelper( |
486 | TF_Graph* graph, TF_Output output, TF_Status* out_status, |
487 | bool* unknown_shape) { |
488 | // Allocate a single variable for holding the result for RVO. |
489 | tensorflow::gtl::InlinedVector<int64_t, 6> result; |
490 | *unknown_shape = false; |
491 | int num_dims = TF_GraphGetTensorNumDims(graph, output, out_status); |
492 | if (TF_GetCode(out_status) != TF_OK) { |
493 | return result; |
494 | } |
495 | // If shape is unknown, set boolean and return. |
496 | if (num_dims == -1) { |
497 | *unknown_shape = true; |
498 | return result; |
499 | } |
500 | |
501 | // If shape is a scalar, avoid another C call and just return {}. |
502 | if (num_dims == 0) { |
503 | return result; |
504 | } |
505 | |
506 | result.resize(num_dims); |
507 | TF_GraphGetTensorShape(graph, output, result.data(), num_dims, out_status); |
508 | return result; |
509 | } |
510 | |
511 | void TF_SessionPRunSetup_wrapper(TF_Session* session, |
512 | const std::vector<TF_Output>& inputs, |
513 | const std::vector<TF_Output>& outputs, |
514 | const std::vector<TF_Operation*>& targets, |
515 | const char** out_handle, |
516 | TF_Status* out_status) { |
517 | // Call TF_SessionPRunSetup() (and release GIL during execution) |
518 | Py_BEGIN_ALLOW_THREADS; |
519 | TF_SessionPRunSetup(session, inputs.data(), inputs.size(), outputs.data(), |
520 | outputs.size(), targets.data(), targets.size(), |
521 | out_handle, out_status); |
522 | Py_END_ALLOW_THREADS; |
523 | } |
524 | |
525 | void TF_SessionPRun_wrapper(TF_Session* session, const char* handle, |
526 | const std::vector<TF_Output>& inputs, |
527 | const std::vector<PyObject*>& input_ndarrays, |
528 | const std::vector<TF_Output>& outputs, |
529 | TF_Status* out_status, |
530 | std::vector<PyObject*>* py_outputs) { |
531 | const std::vector<TF_Operation*> targets; |
532 | TF_SessionRun_wrapper_helper(session, handle, |
533 | nullptr, // run_options |
534 | inputs, input_ndarrays, outputs, targets, |
535 | nullptr, // run_metadata |
536 | out_status, py_outputs); |
537 | // Release any unused ndarray references (see memory management comment in |
538 | // TF_SessionRun_wrapper_helper) |
539 | ClearDecrefCache(); |
540 | } |
541 | |
542 | std::vector<TF_Output> GetOperationInputs(TF_Operation* oper) { |
543 | int num_inputs = TF_OperationNumInputs(oper); |
544 | std::vector<TF_Output> inputs(num_inputs); |
545 | TF_OperationAllInputs(oper, inputs.data(), inputs.size()); |
546 | return inputs; |
547 | } |
548 | |
549 | std::vector<TF_Operation*> TF_OperationGetControlInputs_wrapper( |
550 | TF_Operation* oper) { |
551 | std::vector<TF_Operation*> control_inputs(TF_OperationNumControlInputs(oper)); |
552 | TF_OperationGetControlInputs(oper, control_inputs.data(), |
553 | control_inputs.size()); |
554 | return control_inputs; |
555 | } |
556 | |
557 | std::vector<TF_Operation*> TF_OperationGetControlOutputs_wrapper( |
558 | TF_Operation* oper) { |
559 | std::vector<TF_Operation*> control_outputs( |
560 | TF_OperationNumControlOutputs(oper)); |
561 | TF_OperationGetControlOutputs(oper, control_outputs.data(), |
562 | control_outputs.size()); |
563 | return control_outputs; |
564 | } |
565 | |
566 | std::vector<const char*> TF_OperationOutputConsumers_wrapper( |
567 | TF_Output oper_out) { |
568 | int num_consumers = TF_OperationOutputNumConsumers(oper_out); |
569 | std::vector<TF_Input> consumers(num_consumers); |
570 | TF_OperationOutputConsumers(oper_out, consumers.data(), num_consumers); |
571 | |
572 | std::vector<const char*> consumer_names(num_consumers); |
573 | for (int i = 0; i < num_consumers; ++i) { |
574 | consumer_names[i] = TF_OperationName(consumers[i].oper); |
575 | } |
576 | return consumer_names; |
577 | } |
578 | |
579 | TF_Function* TF_GraphToFunction_wrapper( |
580 | const TF_Graph* fn_body, const char* fn_name, bool append_hash_to_fn_name, |
581 | const std::vector<TF_Operation*>* opers, |
582 | const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs, |
583 | const NameVector& output_names, |
584 | const std::vector<TF_Operation*>* control_outputs, |
585 | const NameVector& control_output_names, const TF_FunctionOptions* opts, |
586 | const char* description, TF_Status* out_status) { |
587 | if (!output_names.empty() && output_names.size() != outputs.size()) { |
588 | Set_TF_Status_from_Status( |
589 | out_status, |
590 | errors::InvalidArgument( |
591 | "output names must be either empty or equal in size to outputs. " , |
592 | "output names size = " , output_names.size(), |
593 | " outputs size = " , outputs.size())); |
594 | return nullptr; |
595 | } |
596 | |
597 | int nopers = -1; |
598 | const TF_Operation* const* opers_array = nullptr; |
599 | if (opers != nullptr) { |
600 | nopers = opers->size(); |
601 | opers_array = opers->data(); |
602 | } |
603 | |
604 | const char** output_names_ptr = |
605 | output_names.empty() ? nullptr |
606 | : const_cast<const char**>(output_names.data()); |
607 | |
608 | const char** control_output_names_ptr = |
609 | control_output_names.empty() |
610 | ? nullptr |
611 | : const_cast<const char**>(control_output_names.data()); |
612 | |
613 | return TF_GraphToFunctionWithControlOutputs( |
614 | fn_body, fn_name, append_hash_to_fn_name, nopers, opers_array, |
615 | inputs.size(), inputs.data(), outputs.size(), outputs.data(), |
616 | output_names_ptr, |
617 | control_outputs == nullptr ? 0 : control_outputs->size(), |
618 | control_outputs == nullptr ? nullptr : control_outputs->data(), |
619 | control_output_names_ptr, opts, description, out_status); |
620 | } |
621 | |
622 | void TF_GraphSetOutputHandleShapesAndTypes_wrapper( |
623 | TF_Graph* graph, TF_Output output, |
624 | const std::vector<std::vector<int64_t>>& shapes, |
625 | const std::vector<int>& ranks, const std::vector<TF_DataType>& types, |
626 | TF_Status* status) { |
627 | std::vector<const int64_t*> shapes_pointers(shapes.size()); |
628 | for (int i = 0; i < shapes.size(); ++i) { |
629 | shapes_pointers[i] = ranks[i] <= 0 ? nullptr : &shapes[i][0]; |
630 | } |
631 | TF_GraphSetOutputHandleShapesAndTypes(graph, output, shapes.size(), |
632 | shapes_pointers.data(), ranks.data(), |
633 | types.data(), status); |
634 | } |
635 | |
636 | void CreatePlaceholder(TF_Graph* graph, TF_Status* s, string&& name, |
637 | TF_DataType dtype, TF_Output* output) { |
638 | TF_OperationDescription* desc = |
639 | TF_NewOperation(graph, "Placeholder" , name.data()); |
640 | TF_SetAttrType(desc, "dtype" , dtype); |
641 | TF_Operation* op = TF_FinishOperation(desc, s); |
642 | output->oper = op; |
643 | output->index = 0; |
644 | } |
645 | |
646 | std::vector<TF_Output> TF_CreatePlaceholders(TF_Graph* graph, PyObject* dtypes, |
647 | const char* prefix, |
648 | TF_Status* status) { |
649 | std::vector<TF_Output> outputs; |
650 | dtypes = PySequence_Fast(dtypes, "dtypes must be a sequence" ); |
651 | if (dtypes == nullptr) { |
652 | Set_TF_Status_from_Status(status, errors::Internal("dtypes is nullptr" )); |
653 | return outputs; |
654 | } |
655 | Safe_PyObjectPtr dtypes_holder(make_safe(dtypes)); |
656 | Py_ssize_t len = PySequence_Fast_GET_SIZE(dtypes); |
657 | outputs.reserve(len); |
658 | for (size_t i = 0; i < len; i++) { |
659 | PyObject* dtype = PySequence_Fast_GET_ITEM(dtypes, i); |
660 | if (!dtype) { |
661 | Set_TF_Status_from_Status(status, |
662 | errors::Internal("Could not get dtype " , i)); |
663 | return outputs; |
664 | } |
665 | #if PY_MAJOR_VERSION >= 3 |
666 | TF_DataType tf_datatype = static_cast<TF_DataType>(PyLong_AsLong(dtype)); |
667 | #else |
668 | TF_DataType tf_datatype = static_cast<TF_DataType>(PyInt_AsLong(dtype)); |
669 | #endif |
670 | outputs.push_back(TF_Output()); |
671 | CreatePlaceholder(graph, status, strings::StrCat(prefix, i), tf_datatype, |
672 | &outputs.back()); |
673 | if (!status->status.ok()) break; |
674 | } |
675 | return outputs; |
676 | } |
677 | |
678 | void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output, |
679 | const std::vector<int64_t>& dims, |
680 | bool unknown_shape, TF_Status* status) { |
681 | if (unknown_shape) { |
682 | TF_GraphSetTensorShape(graph, output, nullptr, -1, status); |
683 | return; |
684 | } |
685 | TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status); |
686 | } |
687 | |
688 | std::vector<string> TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper( |
689 | TF_ImportGraphDefResults* results) { |
690 | int num_missing_unused_input_mappings; |
691 | const char** src_names; |
692 | int* src_indexes; |
693 | TF_ImportGraphDefResultsMissingUnusedInputMappings( |
694 | results, &num_missing_unused_input_mappings, &src_names, &src_indexes); |
695 | std::vector<string> input_strs(num_missing_unused_input_mappings); |
696 | for (int i = 0; i < num_missing_unused_input_mappings; ++i) { |
697 | input_strs[i] = TensorId(src_names[i], src_indexes[i]).ToString(); |
698 | } |
699 | return input_strs; |
700 | } |
701 | |
702 | PyObject* TF_TryEvaluateConstant_wrapper(TF_Graph* graph, TF_Output output, |
703 | TF_Status* status) { |
704 | TF_Tensor* result_tensor; |
705 | bool evaluated = |
706 | TF_TryEvaluateConstant(graph, output, &result_tensor, status); |
707 | if (!evaluated || TF_GetCode(status) != TF_OK) Py_RETURN_NONE; |
708 | |
709 | Safe_TF_TensorPtr safe_result_tensor(result_tensor); |
710 | PyObject* out; |
711 | Status s = TF_TensorToPyArray(std::move(safe_result_tensor), &out); |
712 | Set_TF_Status_from_Status(status, s); |
713 | if (!s.ok()) Py_RETURN_NONE; |
714 | return PyArray_Return(reinterpret_cast<PyArrayObject*>(out)); |
715 | } |
716 | |
717 | } // namespace tensorflow |
718 | |